diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests.java new file mode 100644 index 0000000000000000000000000000000000000000..32c07858f28c07b7d248591bb0c7fdbb06d63ddb --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import java.util.Properties; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.support.GenericPropertiesContextLoader; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + *

+ * JUnit 4 based test class, which verifies the expected functionality of + * {@link SpringRunner} in conjunction with support for application contexts + * loaded from Java {@link Properties} files. Specifically, the + * {@link ContextConfiguration#loader() loader} attribute of {@code ContextConfiguration} + * and the + * {@link org.springframework.test.context.support.GenericPropertiesContextLoader#getResourceSuffix() + * resourceSuffix} property of {@code GenericPropertiesContextLoader} are tested. + *

+ *

+ * Since no {@link ContextConfiguration#locations() locations} are explicitly defined, the + * {@code resourceSuffix} is set to "-context.properties", and since default + * resource locations will be detected by default, this test class's dependencies will be + * injected via {@link Autowired annotation-based autowiring} from beans defined in the + * {@link ApplicationContext} loaded from the default classpath resource: " + * {@code /org/springframework/test/junit4/PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests-context.properties} + * ". + *

+ * + * @author Sam Brannen + * @since 2.5 + * @see GenericPropertiesContextLoader + * @see SpringJUnit4ClassRunnerAppCtxTests + */ +@RunWith(SpringRunner.class) +@ContextConfiguration(loader = GenericPropertiesContextLoader.class) +public class PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests { + + @Autowired + private Pet cat; + + @Autowired + private String testString; + + + @Test + public void verifyAnnotationAutowiredFields() { + assertNotNull("The cat field should have been autowired.", this.cat); + assertEquals("Garfield", this.cat.getName()); + + assertNotNull("The testString field should have been autowired.", this.testString); + assertEquals("Test String", this.testString); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/RelativePathSpringJUnit4ClassRunnerAppCtxTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/RelativePathSpringJUnit4ClassRunnerAppCtxTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5a64bdab56078ea1998bf4bec002a3e7014bb360 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/RelativePathSpringJUnit4ClassRunnerAppCtxTests.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2007 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import org.springframework.test.context.ContextConfiguration; + +/** + * Extension of {@link SpringJUnit4ClassRunnerAppCtxTests}, which verifies that + * we can specify an explicit, relative path location for our + * application context. + * + * @author Sam Brannen + * @since 2.5 + * @see SpringJUnit4ClassRunnerAppCtxTests + * @see AbsolutePathSpringJUnit4ClassRunnerAppCtxTests + */ +@ContextConfiguration(locations = { "SpringJUnit4ClassRunnerAppCtxTests-context.xml" }) +public class RelativePathSpringJUnit4ClassRunnerAppCtxTests extends SpringJUnit4ClassRunnerAppCtxTests { + /* all tests are in the parent class. */ +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/RepeatedSpringRunnerTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/RepeatedSpringRunnerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a53fb1854b7d59aac3a9b4e89c00de49135a17a0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/RepeatedSpringRunnerTests.java @@ -0,0 +1,197 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import java.io.IOException; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runner.Runner; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.test.annotation.Repeat; +import org.springframework.test.annotation.Timed; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.util.ClassUtils; + +import static org.junit.Assert.*; +import static org.springframework.test.context.junit4.JUnitTestingUtils.*; + +/** + * Verifies proper handling of the following in conjunction with the + * {@link SpringRunner}: + * + * + * @author Sam Brannen + * @since 3.0 + */ +@RunWith(Parameterized.class) +public class RepeatedSpringRunnerTests { + + protected static final AtomicInteger invocationCount = new AtomicInteger(); + + private final Class testClass; + + private final int expectedFailureCount; + private final int expectedStartedCount; + private final int expectedFinishedCount; + private final int expectedInvocationCount; + + + @Parameters(name = "{0}") + public static Object[][] repetitionData() { + return new Object[][] {// + { NonAnnotatedRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 1 },// + { DefaultRepeatValueRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 1 },// + { NegativeRepeatValueRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 1 },// + { RepeatedFiveTimesRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 5 },// + { RepeatedFiveTimesViaMetaAnnotationRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 5 },// + { TimedRepeatedTestCase.class.getSimpleName(), 3, 4, 4, (5 + 1 + 4 + 10) } // + }; + } + + public RepeatedSpringRunnerTests(String testClassName, int expectedFailureCount, + int expectedTestStartedCount, int expectedTestFinishedCount, int expectedInvocationCount) throws Exception { + this.testClass = ClassUtils.forName(getClass().getName() + "." + testClassName, getClass().getClassLoader()); + this.expectedFailureCount = expectedFailureCount; + this.expectedStartedCount = expectedTestStartedCount; + this.expectedFinishedCount = expectedTestFinishedCount; + this.expectedInvocationCount = expectedInvocationCount; + } + + protected Class getRunnerClass() { + return SpringRunner.class; + } + + @Test + public void assertRepetitions() throws Exception { + invocationCount.set(0); + + runTestsAndAssertCounters(getRunnerClass(), this.testClass, expectedStartedCount, expectedFailureCount, + expectedFinishedCount, 0, 0); + + assertEquals("invocations for [" + testClass + "]:", expectedInvocationCount, invocationCount.get()); + } + + + @TestExecutionListeners({}) + public abstract static class AbstractRepeatedTestCase { + + protected void incrementInvocationCount() throws IOException { + invocationCount.incrementAndGet(); + } + } + + public static final class NonAnnotatedRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Timed(millis = 10000) + public void nonAnnotated() throws Exception { + incrementInvocationCount(); + } + } + + public static final class DefaultRepeatValueRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Repeat + @Timed(millis = 10000) + public void defaultRepeatValue() throws Exception { + incrementInvocationCount(); + } + } + + public static final class NegativeRepeatValueRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Repeat(-5) + @Timed(millis = 10000) + public void negativeRepeatValue() throws Exception { + incrementInvocationCount(); + } + } + + public static final class RepeatedFiveTimesRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Repeat(5) + public void repeatedFiveTimes() throws Exception { + incrementInvocationCount(); + } + } + + @Repeat(5) + @Retention(RetentionPolicy.RUNTIME) + private static @interface RepeatedFiveTimes { + } + + public static final class RepeatedFiveTimesViaMetaAnnotationRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @RepeatedFiveTimes + public void repeatedFiveTimes() throws Exception { + incrementInvocationCount(); + } + } + + /** + * Unit tests for claims raised in SPR-6011. + */ + @Ignore("TestCase classes are run manually by the enclosing test class") + public static final class TimedRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Timed(millis = 1000) + @Repeat(5) + public void repeatedFiveTimesButDoesNotExceedTimeout() throws Exception { + incrementInvocationCount(); + } + + @Test + @Timed(millis = 10) + @Repeat(1) + public void singleRepetitionExceedsTimeout() throws Exception { + incrementInvocationCount(); + Thread.sleep(15); + } + + @Test + @Timed(millis = 20) + @Repeat(4) + public void firstRepetitionOfManyExceedsTimeout() throws Exception { + incrementInvocationCount(); + Thread.sleep(25); + } + + @Test + @Timed(millis = 100) + @Repeat(10) + public void collectiveRepetitionsExceedTimeout() throws Exception { + incrementInvocationCount(); + Thread.sleep(11); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackFalseRollbackAnnotationTransactionalTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackFalseRollbackAnnotationTransactionalTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8d989f00191c718b4492a3050f623a26df43db8d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackFalseRollbackAnnotationTransactionalTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import javax.sql.DataSource; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.annotation.Rollback; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * Extension of {@link DefaultRollbackFalseRollbackAnnotationTransactionalTests} + * which tests method-level rollback override behavior via the + * {@link Rollback @Rollback} annotation. + * + * @author Sam Brannen + * @since 4.2 + * @see Rollback + */ +public class RollbackOverrideDefaultRollbackFalseRollbackAnnotationTransactionalTests extends + DefaultRollbackFalseRollbackAnnotationTransactionalTests { + + private static int originalNumRows; + + private static JdbcTemplate jdbcTemplate; + + + @Autowired + public void setDataSource(DataSource dataSource) { + jdbcTemplate = new JdbcTemplate(dataSource); + } + + + @Before + @Override + public void verifyInitialTestData() { + originalNumRows = clearPersonTable(jdbcTemplate); + assertEquals("Adding bob", 1, addPerson(jdbcTemplate, BOB)); + assertEquals("Verifying the initial number of rows in the person table.", 1, + countRowsInPersonTable(jdbcTemplate)); + } + + @Test + @Rollback + @Override + public void modifyTestDataWithinTransaction() { + assertInTransaction(true); + assertEquals("Deleting bob", 1, deletePerson(jdbcTemplate, BOB)); + assertEquals("Adding jane", 1, addPerson(jdbcTemplate, JANE)); + assertEquals("Adding sue", 1, addPerson(jdbcTemplate, SUE)); + assertEquals("Verifying the number of rows in the person table within a transaction.", 2, + countRowsInPersonTable(jdbcTemplate)); + } + + @AfterClass + public static void verifyFinalTestData() { + assertEquals("Verifying the final number of rows in the person table after all tests.", originalNumRows, + countRowsInPersonTable(jdbcTemplate)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackFalseTransactionalTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackFalseTransactionalTests.java new file mode 100644 index 0000000000000000000000000000000000000000..819e924d0d7a3566a72d94bd59466f5df1483e6b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackFalseTransactionalTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import javax.sql.DataSource; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.annotation.Rollback; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * Extension of {@link DefaultRollbackFalseRollbackAnnotationTransactionalTests} + * which tests method-level rollback override behavior via the + * {@link Rollback @Rollback} annotation. + * + * @author Sam Brannen + * @since 2.5 + * @see Rollback + */ +public class RollbackOverrideDefaultRollbackFalseTransactionalTests + extends DefaultRollbackFalseRollbackAnnotationTransactionalTests { + + private static int originalNumRows; + + private static JdbcTemplate jdbcTemplate; + + + @Autowired + @Override + public void setDataSource(DataSource dataSource) { + jdbcTemplate = new JdbcTemplate(dataSource); + } + + @Before + @Override + public void verifyInitialTestData() { + originalNumRows = clearPersonTable(jdbcTemplate); + assertEquals("Adding bob", 1, addPerson(jdbcTemplate, BOB)); + assertEquals("Verifying the initial number of rows in the person table.", 1, + countRowsInPersonTable(jdbcTemplate)); + } + + @Test + @Rollback + @Override + public void modifyTestDataWithinTransaction() { + assertInTransaction(true); + assertEquals("Deleting bob", 1, deletePerson(jdbcTemplate, BOB)); + assertEquals("Adding jane", 1, addPerson(jdbcTemplate, JANE)); + assertEquals("Adding sue", 1, addPerson(jdbcTemplate, SUE)); + assertEquals("Verifying the number of rows in the person table within a transaction.", 2, + countRowsInPersonTable(jdbcTemplate)); + } + + @AfterClass + public static void verifyFinalTestData() { + assertEquals("Verifying the final number of rows in the person table after all tests.", originalNumRows, + countRowsInPersonTable(jdbcTemplate)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackTrueRollbackAnnotationTransactionalTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackTrueRollbackAnnotationTransactionalTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7fb9a144642ff93d51e257c372a97d7177fba5c6 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackTrueRollbackAnnotationTransactionalTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import javax.sql.DataSource; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.annotation.Rollback; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * Extension of {@link DefaultRollbackTrueRollbackAnnotationTransactionalTests} + * which tests method-level rollback override behavior via the + * {@link Rollback @Rollback} annotation. + * + * @author Sam Brannen + * @since 4.2 + * @see Rollback + */ +public class RollbackOverrideDefaultRollbackTrueRollbackAnnotationTransactionalTests extends + DefaultRollbackTrueRollbackAnnotationTransactionalTests { + + private static JdbcTemplate jdbcTemplate; + + + @Autowired + @Override + public void setDataSource(DataSource dataSource) { + jdbcTemplate = new JdbcTemplate(dataSource); + } + + + @Before + @Override + public void verifyInitialTestData() { + clearPersonTable(jdbcTemplate); + assertEquals("Adding bob", 1, addPerson(jdbcTemplate, BOB)); + assertEquals("Verifying the initial number of rows in the person table.", 1, + countRowsInPersonTable(jdbcTemplate)); + } + + @Test + @Rollback(false) + @Override + public void modifyTestDataWithinTransaction() { + assertInTransaction(true); + assertEquals("Adding jane", 1, addPerson(jdbcTemplate, JANE)); + assertEquals("Adding sue", 1, addPerson(jdbcTemplate, SUE)); + assertEquals("Verifying the number of rows in the person table within a transaction.", 3, + countRowsInPersonTable(jdbcTemplate)); + } + + @AfterClass + public static void verifyFinalTestData() { + assertEquals("Verifying the final number of rows in the person table after all tests.", 3, + countRowsInPersonTable(jdbcTemplate)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackTrueTransactionalTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackTrueTransactionalTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e0067c649e6a4f7e861692eeb2fce811f43fff67 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/RollbackOverrideDefaultRollbackTrueTransactionalTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import javax.sql.DataSource; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.annotation.Rollback; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * Extension of {@link DefaultRollbackTrueRollbackAnnotationTransactionalTests} + * which tests method-level rollback override behavior via the + * {@link Rollback @Rollback} annotation. + * + * @author Sam Brannen + * @since 2.5 + * @see Rollback + */ +public class RollbackOverrideDefaultRollbackTrueTransactionalTests + extends DefaultRollbackTrueRollbackAnnotationTransactionalTests { + + private static JdbcTemplate jdbcTemplate; + + + @Autowired + public void setDataSource(DataSource dataSource) { + jdbcTemplate = new JdbcTemplate(dataSource); + } + + @Before + @Override + public void verifyInitialTestData() { + clearPersonTable(jdbcTemplate); + assertEquals("Adding bob", 1, addPerson(jdbcTemplate, BOB)); + assertEquals("Verifying the initial number of rows in the person table.", 1, + countRowsInPersonTable(jdbcTemplate)); + } + + @Test + @Rollback(false) + @Override + public void modifyTestDataWithinTransaction() { + assertInTransaction(true); + assertEquals("Adding jane", 1, addPerson(jdbcTemplate, JANE)); + assertEquals("Adding sue", 1, addPerson(jdbcTemplate, SUE)); + assertEquals("Verifying the number of rows in the person table within a transaction.", 3, + countRowsInPersonTable(jdbcTemplate)); + } + + @AfterClass + public static void verifyFinalTestData() { + assertEquals("Verifying the final number of rows in the person table after all tests.", 3, + countRowsInPersonTable(jdbcTemplate)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit47ClassRunnerRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit47ClassRunnerRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d933f41b75bf148bdede919cdded04147529ecd2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit47ClassRunnerRuleTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; + +import org.springframework.test.context.TestExecutionListeners; + +import static org.junit.Assert.*; + +/** + * Verifies support for JUnit 4.7 {@link Rule Rules} in conjunction with the + * {@link SpringRunner}. The body of this test class is taken from the + * JUnit 4.7 release notes. + * + * @author JUnit 4.7 Team + * @author Sam Brannen + * @since 3.0 + */ +@RunWith(SpringRunner.class) +@TestExecutionListeners( {}) +public class SpringJUnit47ClassRunnerRuleTests { + + @Rule + public TestName name = new TestName(); + + + @Test + public void testA() { + assertEquals("testA", name.getMethodName()); + } + + @Test + public void testB() { + assertEquals("testB", name.getMethodName()); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerAppCtxTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerAppCtxTests.java new file mode 100644 index 0000000000000000000000000000000000000000..37e2f0b5df7a80b5ab48497626bc1017ab85d3a4 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerAppCtxTests.java @@ -0,0 +1,229 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import javax.annotation.Resource; +import javax.inject.Inject; +import javax.inject.Named; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.support.DependencyInjectionTestExecutionListener; +import org.springframework.test.context.support.GenericXmlContextLoader; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * SpringJUnit4ClassRunnerAppCtxTests serves as a proof of concept + * JUnit 4 based test class, which verifies the expected functionality of + * {@link SpringRunner} in conjunction with the following: + * + * + * + *

Since no application context resource + * {@link ContextConfiguration#locations() locations} are explicitly declared + * and since the {@link ContextConfiguration#loader() ContextLoader} is left set + * to the default value of {@link GenericXmlContextLoader}, this test class's + * dependencies will be injected via {@link Autowired @Autowired}, + * {@link Inject @Inject}, and {@link Resource @Resource} from beans defined in + * the {@link ApplicationContext} loaded from the default classpath resource: + * {@value #DEFAULT_CONTEXT_RESOURCE_PATH}. + * + * @author Sam Brannen + * @since 2.5 + * @see AbsolutePathSpringJUnit4ClassRunnerAppCtxTests + * @see RelativePathSpringJUnit4ClassRunnerAppCtxTests + * @see InheritedConfigSpringJUnit4ClassRunnerAppCtxTests + */ +@RunWith(SpringRunner.class) +@ContextConfiguration +@TestExecutionListeners(DependencyInjectionTestExecutionListener.class) +public class SpringJUnit4ClassRunnerAppCtxTests implements ApplicationContextAware, BeanNameAware, InitializingBean { + + /** + * Default resource path for the application context configuration for + * {@link SpringJUnit4ClassRunnerAppCtxTests}: {@value #DEFAULT_CONTEXT_RESOURCE_PATH} + */ + public static final String DEFAULT_CONTEXT_RESOURCE_PATH = + "/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerAppCtxTests-context.xml"; + + + private Employee employee; + + @Autowired + private Pet autowiredPet; + + @Inject + private Pet injectedPet; + + @Autowired(required = false) + protected Long nonrequiredLong; + + @Resource + protected String foo; + + protected String bar; + + @Value("enigma") + private String literalFieldValue; + + @Value("#{2 == (1+1)}") + private Boolean spelFieldValue; + + private String literalParameterValue; + + private Boolean spelParameterValue; + + @Autowired + @Qualifier("quux") + protected String quux; + + @Inject + @Named("quux") + protected String namedQuux; + + private String beanName; + + private ApplicationContext applicationContext; + + private boolean beanInitialized = false; + + + @Autowired + protected void setEmployee(Employee employee) { + this.employee = employee; + } + + @Resource + protected void setBar(String bar) { + this.bar = bar; + } + + @Autowired + public void setLiteralParameterValue(@Value("enigma") String literalParameterValue) { + this.literalParameterValue = literalParameterValue; + } + + @Autowired + public void setSpelParameterValue(@Value("#{2 == (1+1)}") Boolean spelParameterValue) { + this.spelParameterValue = spelParameterValue; + } + + @Override + public void setBeanName(String beanName) { + this.beanName = beanName; + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + @Override + public void afterPropertiesSet() { + this.beanInitialized = true; + } + + + @Test + public void verifyBeanNameSet() { + assertTrue("The bean name of this test instance should have been set due to BeanNameAware semantics.", + this.beanName.startsWith(getClass().getName())); + } + + @Test + public void verifyApplicationContextSet() { + assertNotNull("The application context should have been set due to ApplicationContextAware semantics.", + this.applicationContext); + } + + @Test + public void verifyBeanInitialized() { + assertTrue("This test bean should have been initialized due to InitializingBean semantics.", + this.beanInitialized); + } + + @Test + public void verifyAnnotationAutowiredAndInjectedFields() { + assertNull("The nonrequiredLong field should NOT have been autowired.", this.nonrequiredLong); + assertEquals("The quux field should have been autowired via @Autowired and @Qualifier.", "Quux", this.quux); + assertEquals("The namedFoo field should have been injected via @Inject and @Named.", "Quux", this.namedQuux); + assertSame("@Autowired/@Qualifier and @Inject/@Named quux should be the same object.", this.quux, this.namedQuux); + + assertNotNull("The pet field should have been autowired.", this.autowiredPet); + assertNotNull("The pet field should have been injected.", this.injectedPet); + assertEquals("Fido", this.autowiredPet.getName()); + assertEquals("Fido", this.injectedPet.getName()); + assertSame("@Autowired and @Inject pet should be the same object.", this.autowiredPet, this.injectedPet); + } + + @Test + public void verifyAnnotationAutowiredMethods() { + assertNotNull("The employee setter method should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } + + @Test + public void verifyAutowiredAtValueFields() { + assertNotNull("Literal @Value field should have been autowired", this.literalFieldValue); + assertNotNull("SpEL @Value field should have been autowired.", this.spelFieldValue); + assertEquals("enigma", this.literalFieldValue); + assertEquals(Boolean.TRUE, this.spelFieldValue); + } + + @Test + public void verifyAutowiredAtValueMethods() { + assertNotNull("Literal @Value method parameter should have been autowired.", this.literalParameterValue); + assertNotNull("SpEL @Value method parameter should have been autowired.", this.spelParameterValue); + assertEquals("enigma", this.literalParameterValue); + assertEquals(Boolean.TRUE, this.spelParameterValue); + } + + @Test + public void verifyResourceAnnotationInjectedFields() { + assertEquals("The foo field should have been injected via @Resource.", "Foo", this.foo); + } + + @Test + public void verifyResourceAnnotationInjectedMethods() { + assertEquals("The bar method should have been wired via @Resource.", "Bar", this.bar); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..563594ff4250c755ebad23e6d9631f95d1b33f1a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.Test; +import org.junit.runners.model.FrameworkMethod; + +import org.springframework.test.annotation.Timed; +import org.springframework.test.context.TestContextManager; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link SpringJUnit4ClassRunner}. + * + * @author Sam Brannen + * @author Rick Evans + * @since 2.5 + */ +public class SpringJUnit4ClassRunnerTests { + + @Test(expected = Exception.class) + public void checkThatExceptionsAreNotSilentlySwallowed() throws Exception { + SpringJUnit4ClassRunner runner = new SpringJUnit4ClassRunner(getClass()) { + + @Override + protected TestContextManager createTestContextManager(Class clazz) { + return new TestContextManager(clazz) { + + @Override + public void prepareTestInstance(Object testInstance) { + throw new RuntimeException( + "This RuntimeException should be caught and wrapped in an Exception."); + } + }; + } + }; + runner.createTest(); + } + + @Test + public void getSpringTimeoutViaMetaAnnotation() throws Exception { + SpringJUnit4ClassRunner runner = new SpringJUnit4ClassRunner(getClass()); + long timeout = runner.getSpringTimeout(new FrameworkMethod(getClass().getDeclaredMethod( + "springTimeoutWithMetaAnnotation"))); + assertEquals(10, timeout); + } + + @Test + public void getSpringTimeoutViaMetaAnnotationWithOverride() throws Exception { + SpringJUnit4ClassRunner runner = new SpringJUnit4ClassRunner(getClass()); + long timeout = runner.getSpringTimeout(new FrameworkMethod(getClass().getDeclaredMethod( + "springTimeoutWithMetaAnnotationAndOverride"))); + assertEquals(42, timeout); + } + + // ------------------------------------------------------------------------- + + @MetaTimed + void springTimeoutWithMetaAnnotation() { + /* no-op */ + } + + @MetaTimedWithOverride(millis = 42) + void springTimeoutWithMetaAnnotationAndOverride() { + /* no-op */ + } + + + @Timed(millis = 10) + @Retention(RetentionPolicy.RUNTIME) + private static @interface MetaTimed { + } + + @Timed(millis = 1000) + @Retention(RetentionPolicy.RUNTIME) + private static @interface MetaTimedWithOverride { + + long millis() default 1000; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4TestSuite.java b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4TestSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..994f349624eca77ccf88e14b66e4fe362b04028f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/SpringJUnit4TestSuite.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; +import org.springframework.test.context.cache.ClassLevelDirtiesContextTests; +import org.springframework.test.context.cache.SpringRunnerContextCacheTests; +import org.springframework.test.context.jdbc.IsolatedTransactionModeSqlScriptsTests; +import org.springframework.test.context.junit4.annotation.AnnotationConfigSpringJUnit4ClassRunnerAppCtxTests; +import org.springframework.test.context.junit4.annotation.BeanOverridingDefaultConfigClassesInheritedTests; +import org.springframework.test.context.junit4.annotation.BeanOverridingExplicitConfigClassesInheritedTests; +import org.springframework.test.context.junit4.annotation.DefaultConfigClassesBaseTests; +import org.springframework.test.context.junit4.annotation.DefaultConfigClassesInheritedTests; +import org.springframework.test.context.junit4.annotation.DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests; +import org.springframework.test.context.junit4.annotation.DefaultLoaderBeanOverridingExplicitConfigClassesInheritedTests; +import org.springframework.test.context.junit4.annotation.DefaultLoaderDefaultConfigClassesBaseTests; +import org.springframework.test.context.junit4.annotation.DefaultLoaderDefaultConfigClassesInheritedTests; +import org.springframework.test.context.junit4.annotation.DefaultLoaderExplicitConfigClassesBaseTests; +import org.springframework.test.context.junit4.annotation.DefaultLoaderExplicitConfigClassesInheritedTests; +import org.springframework.test.context.junit4.annotation.ExplicitConfigClassesBaseTests; +import org.springframework.test.context.junit4.annotation.ExplicitConfigClassesInheritedTests; +import org.springframework.test.context.junit4.orm.HibernateSessionFlushingTests; +import org.springframework.test.context.junit4.profile.annotation.DefaultProfileAnnotationConfigTests; +import org.springframework.test.context.junit4.profile.annotation.DevProfileAnnotationConfigTests; +import org.springframework.test.context.junit4.profile.annotation.DevProfileResolverAnnotationConfigTests; +import org.springframework.test.context.junit4.profile.xml.DefaultProfileXmlConfigTests; +import org.springframework.test.context.junit4.profile.xml.DevProfileResolverXmlConfigTests; +import org.springframework.test.context.junit4.profile.xml.DevProfileXmlConfigTests; +import org.springframework.test.context.transaction.programmatic.ProgrammaticTxMgmtTests; + +/** + * JUnit test suite for tests involving {@link SpringRunner} and the + * Spring TestContext Framework; only intended to be run manually as a + * convenience. + * + *

This test suite serves a dual purpose of verifying that tests run with + * {@link SpringRunner} can be used in conjunction with JUnit's + * {@link Suite} runner. + * + *

Note that tests included in this suite will be executed at least twice if + * run from an automated build process, test runner, etc. that is not configured + * to exclude tests based on a {@code "*TestSuite.class"} pattern match. + * + * @author Sam Brannen + * @since 2.5 + */ +@RunWith(Suite.class) +// Note: the following 'multi-line' layout is for enhanced code readability. +@SuiteClasses({// +StandardJUnit4FeaturesTests.class,// + StandardJUnit4FeaturesSpringRunnerTests.class,// + SpringJUnit47ClassRunnerRuleTests.class,// + AnnotationConfigSpringJUnit4ClassRunnerAppCtxTests.class,// + DefaultConfigClassesBaseTests.class,// + DefaultConfigClassesInheritedTests.class,// + BeanOverridingDefaultConfigClassesInheritedTests.class,// + ExplicitConfigClassesBaseTests.class,// + ExplicitConfigClassesInheritedTests.class,// + BeanOverridingExplicitConfigClassesInheritedTests.class,// + DefaultLoaderDefaultConfigClassesBaseTests.class,// + DefaultLoaderDefaultConfigClassesInheritedTests.class,// + DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests.class,// + DefaultLoaderExplicitConfigClassesBaseTests.class,// + DefaultLoaderExplicitConfigClassesInheritedTests.class,// + DefaultLoaderBeanOverridingExplicitConfigClassesInheritedTests.class,// + DefaultProfileAnnotationConfigTests.class,// + DevProfileAnnotationConfigTests.class,// + DevProfileResolverAnnotationConfigTests.class,// + DefaultProfileXmlConfigTests.class,// + DevProfileXmlConfigTests.class,// + DevProfileResolverXmlConfigTests.class,// + ExpectedExceptionSpringRunnerTests.class,// + TimedSpringRunnerTests.class,// + RepeatedSpringRunnerTests.class,// + EnabledAndIgnoredSpringRunnerTests.class,// + HardCodedProfileValueSourceSpringRunnerTests.class,// + SpringJUnit4ClassRunnerAppCtxTests.class,// + ClassPathResourceSpringJUnit4ClassRunnerAppCtxTests.class,// + AbsolutePathSpringJUnit4ClassRunnerAppCtxTests.class,// + RelativePathSpringJUnit4ClassRunnerAppCtxTests.class,// + MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests.class,// + InheritedConfigSpringJUnit4ClassRunnerAppCtxTests.class,// + PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests.class,// + CustomDefaultContextLoaderClassSpringRunnerTests.class,// + SpringRunnerContextCacheTests.class,// + ClassLevelDirtiesContextTests.class,// + ParameterizedDependencyInjectionTests.class,// + ConcreteTransactionalJUnit4SpringContextTests.class,// + ClassLevelTransactionalSpringRunnerTests.class,// + MethodLevelTransactionalSpringRunnerTests.class,// + DefaultRollbackTrueRollbackAnnotationTransactionalTests.class,// + DefaultRollbackFalseRollbackAnnotationTransactionalTests.class,// + RollbackOverrideDefaultRollbackTrueTransactionalTests.class,// + RollbackOverrideDefaultRollbackFalseTransactionalTests.class,// + BeforeAndAfterTransactionAnnotationTests.class,// + TimedTransactionalSpringRunnerTests.class,// + ProgrammaticTxMgmtTests.class,// + IsolatedTransactionModeSqlScriptsTests.class,// + HibernateSessionFlushingTests.class // +}) +public class SpringJUnit4TestSuite { + /* this test case consists entirely of tests loaded as a suite. */ +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/StandardJUnit4FeaturesSpringRunnerTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/StandardJUnit4FeaturesSpringRunnerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3390b772a9e08bb28b1280f2d70343d6911cc682 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/StandardJUnit4FeaturesSpringRunnerTests.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import org.junit.runner.RunWith; + +import org.springframework.test.context.TestExecutionListeners; + +/** + *

+ * Simple unit test to verify that {@link SpringRunner} does not + * hinder correct functionality of standard JUnit 4.4+ testing features. + *

+ *

+ * Note that {@link TestExecutionListeners @TestExecutionListeners} is + * explicitly configured with an empty list, thus disabling all default + * listeners. + *

+ * + * @author Sam Brannen + * @since 2.5 + * @see StandardJUnit4FeaturesTests + */ +@RunWith(SpringRunner.class) +@TestExecutionListeners({}) +public class StandardJUnit4FeaturesSpringRunnerTests extends StandardJUnit4FeaturesTests { + + /* All tests are in the parent class... */ + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/StandardJUnit4FeaturesTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/StandardJUnit4FeaturesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..03344d6069541b79b879cfd52d7865ce48431847 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/StandardJUnit4FeaturesTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import java.util.ArrayList; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.junit.Assume.*; + +/** + * Simple unit test to verify the expected functionality of standard JUnit 4.4+ + * testing features. + *

+ * Currently testing: {@link Test @Test} (including expected exceptions and + * timeouts), {@link BeforeClass @BeforeClass}, {@link Before @Before}, and + * assumptions. + *

+ *

+ * Due to the fact that JUnit does not guarantee a particular ordering of test + * method execution, the following are currently not tested: + * {@link org.junit.AfterClass @AfterClass} and {@link org.junit.After @After}. + *

+ * + * @author Sam Brannen + * @since 2.5 + * @see StandardJUnit4FeaturesSpringRunnerTests + */ +public class StandardJUnit4FeaturesTests { + + private static int staticBeforeCounter = 0; + + + @BeforeClass + public static void incrementStaticBeforeCounter() { + StandardJUnit4FeaturesTests.staticBeforeCounter++; + } + + + private int beforeCounter = 0; + + + @Test + @Ignore + public void alwaysFailsButShouldBeIgnored() { + fail("The body of an ignored test should never be executed!"); + } + + @Test + public void alwaysSucceeds() { + assertTrue(true); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void expectingAnIndexOutOfBoundsException() { + new ArrayList<>().get(1); + } + + @Test + public void failedAssumptionShouldPrecludeImminentFailure() { + assumeTrue(false); + fail("A failed assumption should preclude imminent failure!"); + } + + @Before + public void incrementBeforeCounter() { + this.beforeCounter++; + } + + @Test(timeout = 10000) + public void noOpShouldNotTimeOut() { + /* no-op */ + } + + @Test + public void verifyBeforeAnnotation() { + assertEquals(1, this.beforeCounter); + } + + @Test + public void verifyBeforeClassAnnotation() { + // Instead of testing for equality to 1, we just assert that the value + // was incremented at least once, since this test class may serve as a + // parent class to other tests in a suite, etc. + assertTrue(StandardJUnit4FeaturesTests.staticBeforeCounter > 0); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/TimedSpringRunnerTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/TimedSpringRunnerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..82931405d097d18ffcd5e60dfb5f7078e1cdc7bd --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/TimedSpringRunnerTests.java @@ -0,0 +1,124 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runner.Runner; +import org.junit.runners.JUnit4; + +import org.springframework.test.annotation.Timed; +import org.springframework.test.context.TestExecutionListeners; + +import static org.springframework.test.context.junit4.JUnitTestingUtils.*; + +/** + * Verifies proper handling of the following in conjunction with the + * {@link SpringRunner}: + * + * + * @author Sam Brannen + * @since 3.0 + */ +@RunWith(JUnit4.class) +public class TimedSpringRunnerTests { + + protected Class getTestCase() { + return TimedSpringRunnerTestCase.class; + } + + protected Class getRunnerClass() { + return SpringRunner.class; + } + + @Test + public void timedTests() throws Exception { + runTestsAndAssertCounters(getRunnerClass(), getTestCase(), 7, 5, 7, 0, 0); + } + + + @Ignore("TestCase classes are run manually by the enclosing test class") + @TestExecutionListeners({}) + public static class TimedSpringRunnerTestCase { + + // Should Pass. + @Test(timeout = 2000) + public void jUnitTimeoutWithNoOp() { + /* no-op */ + } + + // Should Pass. + @Test + @Timed(millis = 2000) + public void springTimeoutWithNoOp() { + /* no-op */ + } + + // Should Fail due to timeout. + @Test(timeout = 100) + public void jUnitTimeoutWithSleep() throws Exception { + Thread.sleep(200); + } + + // Should Fail due to timeout. + @Test + @Timed(millis = 100) + public void springTimeoutWithSleep() throws Exception { + Thread.sleep(200); + } + + // Should Fail due to timeout. + @Test + @MetaTimed + public void springTimeoutWithSleepAndMetaAnnotation() throws Exception { + Thread.sleep(200); + } + + // Should Fail due to timeout. + @Test + @MetaTimedWithOverride(millis = 100) + public void springTimeoutWithSleepAndMetaAnnotationAndOverride() throws Exception { + Thread.sleep(200); + } + + // Should Fail due to duplicate configuration. + @Test(timeout = 200) + @Timed(millis = 200) + public void springAndJUnitTimeouts() { + /* no-op */ + } + } + + @Timed(millis = 100) + @Retention(RetentionPolicy.RUNTIME) + private static @interface MetaTimed { + } + + @Timed(millis = 1000) + @Retention(RetentionPolicy.RUNTIME) + private static @interface MetaTimedWithOverride { + long millis() default 1000; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/TimedTransactionalSpringRunnerTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/TimedTransactionalSpringRunnerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..36f75ff346730c9015ec318d305e407d871bbe84 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/TimedTransactionalSpringRunnerTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.test.annotation.Repeat; +import org.springframework.test.annotation.Timed; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; + +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * JUnit 4 based integration test which verifies support of Spring's + * {@link Transactional @Transactional} annotation in conjunction + * with {@link Timed @Timed} and JUnit 4's {@link Test#timeout() + * timeout} attribute. + * + * @author Sam Brannen + * @since 2.5 + */ +@RunWith(SpringRunner.class) +@ContextConfiguration("transactionalTests-context.xml") +@Transactional +public class TimedTransactionalSpringRunnerTests { + + @Test + @Timed(millis = 10000) + @Repeat(5) + public void transactionalWithSpringTimeout() { + assertInTransaction(true); + } + + @Test(timeout = 10000) + @Repeat(5) + public void transactionalWithJUnitTimeout() { + assertInTransaction(true); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + @Timed(millis = 10000) + @Repeat(5) + public void notTransactionalWithSpringTimeout() { + assertInTransaction(false); + } + + @Test(timeout = 10000) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + @Repeat(5) + public void notTransactionalWithJUnitTimeout() { + assertInTransaction(false); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/TrackingRunListener.java b/spring-test/src/test/java/org/springframework/test/context/junit4/TrackingRunListener.java new file mode 100644 index 0000000000000000000000000000000000000000..89607794f54d1d004aeb7252adef0abfb4dbc8c2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/TrackingRunListener.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.runner.Description; +import org.junit.runner.notification.Failure; +import org.junit.runner.notification.RunListener; + +/** + * Simple {@link RunListener} which tracks how many times certain JUnit callback + * methods were called: only intended for the integration test suite. + * + * @author Sam Brannen + * @since 3.0 + */ +public class TrackingRunListener extends RunListener { + + private final AtomicInteger testFailureCount = new AtomicInteger(); + + private final AtomicInteger testStartedCount = new AtomicInteger(); + + private final AtomicInteger testFinishedCount = new AtomicInteger(); + + private final AtomicInteger testAssumptionFailureCount = new AtomicInteger(); + + private final AtomicInteger testIgnoredCount = new AtomicInteger(); + + + public int getTestFailureCount() { + return this.testFailureCount.get(); + } + + public int getTestStartedCount() { + return this.testStartedCount.get(); + } + + public int getTestFinishedCount() { + return this.testFinishedCount.get(); + } + + public int getTestAssumptionFailureCount() { + return this.testAssumptionFailureCount.get(); + } + + public int getTestIgnoredCount() { + return this.testIgnoredCount.get(); + } + + @Override + public void testFailure(Failure failure) throws Exception { + this.testFailureCount.incrementAndGet(); + } + + @Override + public void testStarted(Description description) throws Exception { + this.testStartedCount.incrementAndGet(); + } + + @Override + public void testFinished(Description description) throws Exception { + this.testFinishedCount.incrementAndGet(); + } + + @Override + public void testAssumptionFailure(Failure failure) { + this.testAssumptionFailureCount.incrementAndGet(); + } + + @Override + public void testIgnored(Description description) throws Exception { + this.testIgnoredCount.incrementAndGet(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/AciTestSuite.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/AciTestSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..30d70eefd4f12ae79123a58810800668701c17ad --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/AciTestSuite.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.test.context.junit4.aci.annotation.InitializerWithoutConfigFilesOrClassesTests; +import org.springframework.test.context.junit4.aci.annotation.MergedInitializersAnnotationConfigTests; +import org.springframework.test.context.junit4.aci.annotation.MultipleInitializersAnnotationConfigTests; +import org.springframework.test.context.junit4.aci.annotation.OrderedInitializersAnnotationConfigTests; +import org.springframework.test.context.junit4.aci.annotation.OverriddenInitializersAnnotationConfigTests; +import org.springframework.test.context.junit4.aci.annotation.SingleInitializerAnnotationConfigTests; +import org.springframework.test.context.junit4.aci.xml.MultipleInitializersXmlConfigTests; + +/** + * Convenience test suite for integration tests that verify support for + * {@link ApplicationContextInitializer ApplicationContextInitializers} (ACIs) + * in the TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(Suite.class) +// Note: the following 'multi-line' layout is for enhanced code readability. +@SuiteClasses({// + MultipleInitializersXmlConfigTests.class,// + SingleInitializerAnnotationConfigTests.class,// + MultipleInitializersAnnotationConfigTests.class,// + MergedInitializersAnnotationConfigTests.class,// + OverriddenInitializersAnnotationConfigTests.class,// + OrderedInitializersAnnotationConfigTests.class,// + InitializerWithoutConfigFilesOrClassesTests.class // +}) +public class AciTestSuite { +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/DevProfileInitializer.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/DevProfileInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..16a4ce4ec5ae9d6a370dcf12b3b335b27f635802 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/DevProfileInitializer.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.support.GenericApplicationContext; + +/** + * @author Sam Brannen + * @since 3.2 + */ +public class DevProfileInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(GenericApplicationContext applicationContext) { + applicationContext.getEnvironment().setActiveProfiles("dev"); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/FooBarAliasInitializer.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/FooBarAliasInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..d37fa16d28bd9dfd674e7a9818276d7eefe59b14 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/FooBarAliasInitializer.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.support.GenericApplicationContext; + +/** + * @author Sam Brannen + * @since 3.2 + */ +public class FooBarAliasInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(GenericApplicationContext applicationContext) { + applicationContext.registerAlias("foo", "bar"); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/BarConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/BarConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..0cca3ed54705dbe40edf72d70ecac0bcf168cf9c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/BarConfig.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * @author Sam Brannen + * @since 4.3 + */ +@Configuration +class BarConfig { + + @Bean + String bar() { + return "bar"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/DevProfileConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/DevProfileConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..aef2ae2c098dbccf0f1b448e4a98a6794b594d66 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/DevProfileConfig.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; + +/** + * @author Sam Brannen + * @since 3.2 + */ +@Configuration +@Profile("dev") +class DevProfileConfig { + + @Bean + public String baz() { + return "dev profile config"; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/FooConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/FooConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..11c1d51130c12b4548834d3047405e8405859be6 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/FooConfig.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * @author Sam Brannen + * @since 4.3 + */ +@Configuration +class FooConfig { + + @Bean + String foo() { + return "foo"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/GlobalConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/GlobalConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..49bf8278dab866403c8699af5e0a0c5debcf3d23 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/GlobalConfig.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * @author Sam Brannen + * @since 3.2 + */ +@Configuration +class GlobalConfig { + + @Bean + public String foo() { + return "foo"; + } + + @Bean + public String baz() { + return "global config"; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/InitializerConfiguredViaMetaAnnotationTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/InitializerConfiguredViaMetaAnnotationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..32c9a5cef3e9eab7f54ede24325bce448d15add0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/InitializerConfiguredViaMetaAnnotationTests.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.List; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.AnnotatedBeanDefinitionReader; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.annotation.AliasFor; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.junit4.aci.annotation.InitializerConfiguredViaMetaAnnotationTests.ComposedContextConfiguration; +import org.springframework.test.context.support.AnnotationConfigContextLoader; + +import static org.junit.Assert.assertEquals; + +/** + * Integration test that demonstrates how to register one or more {@code @Configuration} + * classes via an {@link ApplicationContextInitializer} in a composed annotation so + * that certain {@code @Configuration} classes are always registered whenever the composed + * annotation is used, even if the composed annotation is used to declare additional + * {@code @Configuration} classes. + * + *

This class has been implemented in response to the following Stack Overflow question: + * + * Can {@code @ContextConfiguration} in a custom annotation be merged? + * + * @author Sam Brannen + * @since 4.3 + */ +@RunWith(SpringRunner.class) +@ComposedContextConfiguration(BarConfig.class) +public class InitializerConfiguredViaMetaAnnotationTests { + + @Autowired + String foo; + + @Autowired + String bar; + + @Autowired + List strings; + + + @Test + public void beansFromInitializerAndComposedAnnotation() { + assertEquals(2, strings.size()); + assertEquals("foo", foo); + assertEquals("bar", bar); + } + + + static class FooConfigInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(GenericApplicationContext applicationContext) { + new AnnotatedBeanDefinitionReader(applicationContext).register(FooConfig.class); + } + } + + @ContextConfiguration(loader = AnnotationConfigContextLoader.class, initializers = FooConfigInitializer.class) + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + @interface ComposedContextConfiguration { + + @AliasFor(annotation = ContextConfiguration.class, attribute = "classes") + Class[] value() default {}; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/InitializerWithoutConfigFilesOrClassesTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/InitializerWithoutConfigFilesOrClassesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..97eaf08403b3b40d138079abf859105075f1e293 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/InitializerWithoutConfigFilesOrClassesTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.AnnotatedBeanDefinitionReader; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.junit4.aci.annotation.InitializerWithoutConfigFilesOrClassesTests.EntireAppInitializer; + +import static org.junit.Assert.*; + +/** + * Integration test that verifies support for {@link ApplicationContextInitializer + * ApplicationContextInitializers} in the TestContext framework when the test class + * declares neither XML configuration files nor annotated configuration classes. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(initializers = EntireAppInitializer.class) +public class InitializerWithoutConfigFilesOrClassesTests { + + @Autowired + private String foo; + + + @Test + public void foo() { + assertEquals("foo", foo); + } + + + static class EntireAppInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(GenericApplicationContext applicationContext) { + new AnnotatedBeanDefinitionReader(applicationContext).register(GlobalConfig.class); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/MergedInitializersAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/MergedInitializersAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7f03061fdb5cee057a5a913baae2f9de79289d3b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/MergedInitializersAnnotationConfigTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.junit.Test; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.aci.DevProfileInitializer; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for {@link ApplicationContextInitializer + * ApplicationContextInitializers} in conjunction with annotation-driven + * configuration in the TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@ContextConfiguration(initializers = DevProfileInitializer.class) +public class MergedInitializersAnnotationConfigTests extends SingleInitializerAnnotationConfigTests { + + @Override + @Test + public void activeBeans() { + assertEquals("foo", foo); + assertEquals("foo", bar); + assertEquals("dev profile config", baz); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/MultipleInitializersAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/MultipleInitializersAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..41234930ace96f5eef51deaed186eb2079f690da --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/MultipleInitializersAnnotationConfigTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.junit4.aci.DevProfileInitializer; +import org.springframework.test.context.junit4.aci.FooBarAliasInitializer; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for {@link ApplicationContextInitializer + * ApplicationContextInitializers} in conjunction with annotation-driven + * configuration in the TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = { GlobalConfig.class, DevProfileConfig.class }, initializers = { + FooBarAliasInitializer.class, DevProfileInitializer.class }) +public class MultipleInitializersAnnotationConfigTests { + + @Autowired + private String foo, bar, baz; + + + @Test + public void activeBeans() { + assertEquals("foo", foo); + assertEquals("foo", bar); + assertEquals("dev profile config", baz); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/OrderedInitializersAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/OrderedInitializersAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..257db2c6a625bb9f0ce60a5e795b106a788c815b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/OrderedInitializersAnnotationConfigTests.java @@ -0,0 +1,149 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.junit4.aci.annotation.OrderedInitializersAnnotationConfigTests.ConfigOne; +import org.springframework.test.context.junit4.aci.annotation.OrderedInitializersAnnotationConfigTests.ConfigTwo; +import org.springframework.test.context.junit4.aci.annotation.OrderedInitializersAnnotationConfigTests.GlobalConfig; +import org.springframework.test.context.junit4.aci.annotation.OrderedInitializersAnnotationConfigTests.OrderedOneInitializer; +import org.springframework.test.context.junit4.aci.annotation.OrderedInitializersAnnotationConfigTests.OrderedTwoInitializer; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify that any {@link ApplicationContextInitializer + * ApplicationContextInitializers} implementing + * {@link org.springframework.core.Ordered Ordered} or marked with + * {@link org.springframework.core.annotation.Order @Order} will be sorted + * appropriately in conjunction with annotation-driven configuration in the + * TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +// Note: the ordering of the config classes is intentionally: global, two, one. +// Note: the ordering of the initializers is intentionally: two, one. +@ContextConfiguration(classes = { GlobalConfig.class, ConfigTwo.class, ConfigOne.class }, initializers = { + OrderedTwoInitializer.class, OrderedOneInitializer.class }) +public class OrderedInitializersAnnotationConfigTests { + + private static final String PROFILE_GLOBAL = "global"; + private static final String PROFILE_ONE = "one"; + private static final String PROFILE_TWO = "two"; + + @Autowired + private String foo, bar, baz; + + + @Test + public void activeBeans() { + assertEquals(PROFILE_GLOBAL, foo); + assertEquals(PROFILE_GLOBAL, bar); + assertEquals(PROFILE_TWO, baz); + } + + + // ------------------------------------------------------------------------- + + @Configuration + static class GlobalConfig { + + @Bean + public String foo() { + return PROFILE_GLOBAL; + } + + @Bean + public String bar() { + return PROFILE_GLOBAL; + } + + @Bean + public String baz() { + return PROFILE_GLOBAL; + } + } + + @Configuration + @Profile(PROFILE_ONE) + static class ConfigOne { + + @Bean + public String foo() { + return PROFILE_ONE; + } + + @Bean + public String bar() { + return PROFILE_ONE; + } + + @Bean + public String baz() { + return PROFILE_ONE; + } + } + + @Configuration + @Profile(PROFILE_TWO) + static class ConfigTwo { + + @Bean + public String baz() { + return PROFILE_TWO; + } + } + + // ------------------------------------------------------------------------- + + static class OrderedOneInitializer implements ApplicationContextInitializer, Ordered { + + @Override + public void initialize(GenericApplicationContext applicationContext) { + applicationContext.getEnvironment().setActiveProfiles(PROFILE_ONE); + } + + @Override + public int getOrder() { + return 1; + } + } + + @Order(2) + static class OrderedTwoInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(GenericApplicationContext applicationContext) { + applicationContext.getEnvironment().setActiveProfiles(PROFILE_TWO); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/OverriddenInitializersAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/OverriddenInitializersAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..97be4af896cb9abcc6c496542e7b35d64546d9d4 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/OverriddenInitializersAnnotationConfigTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.junit.Test; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.aci.DevProfileInitializer; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for {@link ApplicationContextInitializer + * ApplicationContextInitializers} in conjunction with annotation-driven + * configuration in the TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@ContextConfiguration(initializers = DevProfileInitializer.class, inheritInitializers = false) +public class OverriddenInitializersAnnotationConfigTests extends SingleInitializerAnnotationConfigTests { + + @Test + @Override + public void activeBeans() { + assertEquals("foo", foo); + assertNull(bar); + assertEquals("dev profile config", baz); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/PropertySourcesInitializerTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/PropertySourcesInitializerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6627ee6219bb7c566656326e0b554d1a9f0d7558 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/PropertySourcesInitializerTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.env.PropertySource; +import org.springframework.mock.env.MockPropertySource; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify that a {@link PropertySource} can be set via a + * custom {@link ApplicationContextInitializer} in the Spring TestContext Framework. + * + * @author Sam Brannen + * @since 4.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(initializers = PropertySourcesInitializerTests.PropertySourceInitializer.class) +public class PropertySourcesInitializerTests { + + @Configuration + static class Config { + + @Value("${enigma}") + // The following can also be used to directly access the + // environment instead of relying on placeholder replacement. + // @Value("#{ environment['enigma'] }") + private String enigma; + + + @Bean + public String enigma() { + return enigma; + } + + } + + + @Autowired + private String enigma; + + + @Test + public void customPropertySourceConfiguredViaContextInitializer() { + assertEquals("foo", enigma); + } + + + public static class PropertySourceInitializer implements + ApplicationContextInitializer { + + @Override + public void initialize(ConfigurableApplicationContext applicationContext) { + applicationContext.getEnvironment().getPropertySources().addFirst( + new MockPropertySource().withProperty("enigma", "foo")); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/SingleInitializerAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/SingleInitializerAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4dd57adee7d6cd28bf709b0423cb36355e77bfe7 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/annotation/SingleInitializerAnnotationConfigTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.junit4.aci.FooBarAliasInitializer; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for {@link ApplicationContextInitializer + * ApplicationContextInitializers} in conjunction with annotation-driven + * configuration in the TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = { GlobalConfig.class, DevProfileConfig.class }, initializers = FooBarAliasInitializer.class) +public class SingleInitializerAnnotationConfigTests { + + @Autowired + protected String foo; + + @Autowired(required = false) + @Qualifier("bar") + protected String bar; + + @Autowired + protected String baz; + + + @Test + public void activeBeans() { + assertEquals("foo", foo); + assertEquals("foo", bar); + assertEquals("global config", baz); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/aci/xml/MultipleInitializersXmlConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/xml/MultipleInitializersXmlConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3737976875fca9524517c372fa06ef2cf34ab3c1 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/aci/xml/MultipleInitializersXmlConfigTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.aci.xml; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.junit4.aci.DevProfileInitializer; +import org.springframework.test.context.junit4.aci.FooBarAliasInitializer; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for {@link ApplicationContextInitializer + * ApplicationContextInitializers} in conjunction with XML configuration files + * in the TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(initializers = { FooBarAliasInitializer.class, DevProfileInitializer.class }) +public class MultipleInitializersXmlConfigTests { + + @Autowired + private String foo, bar, baz; + + + @Test + public void activeBeans() { + assertEquals("foo", foo); + assertEquals("foo", bar); + assertEquals("dev profile config", baz); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/AnnotationConfigSpringJUnit4ClassRunnerAppCtxTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/AnnotationConfigSpringJUnit4ClassRunnerAppCtxTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cc368b2862b24ddec8514d45205f451f4b080fc0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/AnnotationConfigSpringJUnit4ClassRunnerAppCtxTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunnerAppCtxTests; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework. + * + *

Furthermore, by extending {@link SpringJUnit4ClassRunnerAppCtxTests}, + * this class also verifies support for several basic features of the + * Spring TestContext Framework. See JavaDoc in + * {@code SpringJUnit4ClassRunnerAppCtxTests} for details. + * + *

Configuration will be loaded from {@link PojoAndStringConfig}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration(classes = PojoAndStringConfig.class, inheritLocations = false) +public class AnnotationConfigSpringJUnit4ClassRunnerAppCtxTests extends SpringJUnit4ClassRunnerAppCtxTests { + /* all tests are in the parent class. */ +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/AnnotationConfigTestSuite.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/AnnotationConfigTestSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..f5f16ea9937186ee5ca5f3d24221a1c601fbc6d5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/AnnotationConfigTestSuite.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** + * JUnit test suite for annotation-driven configuration class + * support in the Spring TestContext Framework. + * + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(Suite.class) +// Note: the following 'multi-line' layout is for enhanced code readability. +@SuiteClasses({// +AnnotationConfigSpringJUnit4ClassRunnerAppCtxTests.class,// + DefaultConfigClassesBaseTests.class,// + DefaultConfigClassesInheritedTests.class,// + BeanOverridingDefaultConfigClassesInheritedTests.class,// + ExplicitConfigClassesBaseTests.class,// + ExplicitConfigClassesInheritedTests.class,// + BeanOverridingExplicitConfigClassesInheritedTests.class,// + DefaultLoaderDefaultConfigClassesBaseTests.class,// + DefaultLoaderDefaultConfigClassesInheritedTests.class,// + DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests.class,// + DefaultLoaderExplicitConfigClassesBaseTests.class,// + DefaultLoaderExplicitConfigClassesInheritedTests.class,// + DefaultLoaderBeanOverridingExplicitConfigClassesInheritedTests.class // +}) +public class AnnotationConfigTestSuite { +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/BeanOverridingDefaultConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/BeanOverridingDefaultConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7e517f68b8ff389ff89bd6995e6b7c5913560e3c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/BeanOverridingDefaultConfigClassesInheritedTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework. + * + *

Configuration will be loaded from {@link DefaultConfigClassesBaseTests.ContextConfiguration} + * and {@link BeanOverridingDefaultConfigClassesInheritedTests.ContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration +public class BeanOverridingDefaultConfigClassesInheritedTests extends DefaultConfigClassesBaseTests { + + @Configuration + static class ContextConfiguration { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("Yoda"); + employee.setAge(900); + employee.setCompany("The Force"); + return employee; + } + } + + + @Test + @Override + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("The employee bean should have been overridden.", "Yoda", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/BeanOverridingExplicitConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/BeanOverridingExplicitConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b60421dd53ffb3f863165cad613a1d57f6430c55 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/BeanOverridingExplicitConfigClassesInheritedTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; + +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework. + * + *

Configuration will be loaded from {@link DefaultConfigClassesBaseTests.ContextConfiguration} + * and {@link BeanOverridingDefaultConfigClassesInheritedTests.ContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration(classes = BeanOverridingDefaultConfigClassesInheritedTests.ContextConfiguration.class) +public class BeanOverridingExplicitConfigClassesInheritedTests extends ExplicitConfigClassesBaseTests { + + @Test + @Override + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("The employee bean should have been overridden.", "Yoda", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultConfigClassesBaseTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultConfigClassesBaseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c9d45dbb5768b7aa85a22eaa4aab7c79e85462b0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultConfigClassesBaseTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.AnnotationConfigContextLoader; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework. + * + *

Configuration will be loaded from {@link DefaultConfigClassesBaseTests.ContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + * @see DefaultLoaderDefaultConfigClassesBaseTests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(loader = AnnotationConfigContextLoader.class) +public class DefaultConfigClassesBaseTests { + + @Configuration + static class ContextConfiguration { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("John Smith"); + employee.setAge(42); + employee.setCompany("Acme Widgets, Inc."); + return employee; + } + } + + + @Autowired + protected Employee employee; + + + @Test + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee field should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ca35da919ea6de7e3f2be73480453a0e07d596c0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultConfigClassesInheritedTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework. + * + *

Configuration will be loaded from {@link DefaultConfigClassesBaseTests.ContextConfiguration} + * and {@link DefaultConfigClassesInheritedTests.ContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration +public class DefaultConfigClassesInheritedTests extends DefaultConfigClassesBaseTests { + + @Configuration + static class ContextConfiguration { + + @Bean + public Pet pet() { + return new Pet("Fido"); + } + } + + + @Autowired + private Pet pet; + + + @Test + public void verifyPetSetFromExtendedContextConfig() { + assertNotNull("The pet should have been autowired.", this.pet); + assertEquals("Fido", this.pet.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..23c6dd4e8c2b9f3b14508017fcc7148ca7e63ce9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.support.DelegatingSmartContextLoader; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework in conjunction with the + * {@link DelegatingSmartContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration +public class DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests extends + DefaultLoaderDefaultConfigClassesBaseTests { + + @Configuration + static class Config { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("Yoda"); + employee.setAge(900); + employee.setCompany("The Force"); + return employee; + } + } + + + @Test + @Override + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("The employee bean should have been overridden.", "Yoda", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderBeanOverridingExplicitConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderBeanOverridingExplicitConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..de79c131ebd44d64e2a82db263a56864fe22a7ba --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderBeanOverridingExplicitConfigClassesInheritedTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; + +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.support.DelegatingSmartContextLoader; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework in conjunction with the + * {@link DelegatingSmartContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration(classes = DefaultLoaderBeanOverridingDefaultConfigClassesInheritedTests.Config.class) +public class DefaultLoaderBeanOverridingExplicitConfigClassesInheritedTests extends + DefaultLoaderExplicitConfigClassesBaseTests { + + @Test + @Override + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("The employee bean should have been overridden.", "Yoda", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderDefaultConfigClassesBaseTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderDefaultConfigClassesBaseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5c738d7efb42f7044821d3df9ccbec4c6ab64ec5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderDefaultConfigClassesBaseTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.DelegatingSmartContextLoader; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework in conjunction with the + * {@link DelegatingSmartContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + * @see DefaultConfigClassesBaseTests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class DefaultLoaderDefaultConfigClassesBaseTests { + + @Configuration + static class Config { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("John Smith"); + employee.setAge(42); + employee.setCompany("Acme Widgets, Inc."); + return employee; + } + } + + + @Autowired + protected Employee employee; + + + @Test + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee field should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderDefaultConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderDefaultConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..05eabe5d77a0bc9f7ef5b5004c3eae34512fc72d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderDefaultConfigClassesInheritedTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.support.DelegatingSmartContextLoader; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework in conjunction with the + * {@link DelegatingSmartContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration +public class DefaultLoaderDefaultConfigClassesInheritedTests extends DefaultLoaderDefaultConfigClassesBaseTests { + + @Configuration + static class Config { + + @Bean + public Pet pet() { + return new Pet("Fido"); + } + } + + + @Autowired + private Pet pet; + + + @Test + public void verifyPetSetFromExtendedContextConfig() { + assertNotNull("The pet should have been autowired.", this.pet); + assertEquals("Fido", this.pet.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderExplicitConfigClassesBaseTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderExplicitConfigClassesBaseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..feec6a94e3c3aa011a45a759865989a351648e6f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderExplicitConfigClassesBaseTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.DelegatingSmartContextLoader; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework in conjunction with the + * {@link DelegatingSmartContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = DefaultLoaderDefaultConfigClassesBaseTests.Config.class) +public class DefaultLoaderExplicitConfigClassesBaseTests { + + @Autowired + protected Employee employee; + + + @Test + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderExplicitConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderExplicitConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5cdec6353d2d06924c9ada368aaa0dd8cd3591fe --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/DefaultLoaderExplicitConfigClassesInheritedTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.DelegatingSmartContextLoader; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework in conjunction with the + * {@link DelegatingSmartContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = DefaultLoaderDefaultConfigClassesInheritedTests.Config.class) +public class DefaultLoaderExplicitConfigClassesInheritedTests extends DefaultLoaderExplicitConfigClassesBaseTests { + + @Autowired + private Pet pet; + + + @Test + public void verifyPetSetFromExtendedContextConfig() { + assertNotNull("The pet should have been autowired.", this.pet); + assertEquals("Fido", this.pet.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/ExplicitConfigClassesBaseTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/ExplicitConfigClassesBaseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cc6d56ba689145b2f0a983fd48ac2b61c8b2bdd3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/ExplicitConfigClassesBaseTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.AnnotationConfigContextLoader; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework. + * + *

Configuration will be loaded from {@link DefaultConfigClassesBaseTests.ContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(loader = AnnotationConfigContextLoader.class, classes = DefaultConfigClassesBaseTests.ContextConfiguration.class) +public class ExplicitConfigClassesBaseTests { + + @Autowired + protected Employee employee; + + + @Test + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/ExplicitConfigClassesInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/ExplicitConfigClassesInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1cebcca22da3a6123790562114126674a9be119e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/ExplicitConfigClassesInheritedTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.AnnotationConfigContextLoader; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for configuration classes in + * the Spring TestContext Framework. + * + *

Configuration will be loaded from {@link DefaultConfigClassesInheritedTests.ContextConfiguration} + * and {@link DefaultConfigClassesBaseTests.ContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(loader = AnnotationConfigContextLoader.class, classes = DefaultConfigClassesInheritedTests.ContextConfiguration.class) +public class ExplicitConfigClassesInheritedTests extends ExplicitConfigClassesBaseTests { + + @Autowired + private Pet pet; + + + @Test + public void verifyPetSetFromExtendedContextConfig() { + assertNotNull("The pet should have been autowired.", this.pet); + assertEquals("Fido", this.pet.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/PojoAndStringConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/PojoAndStringConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..fab6e51dbf47ff3776a393db85f9373d83c19cd6 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/PojoAndStringConfig.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +/** + * ApplicationContext configuration class for various integration tests. + * + *

The beans defined in this configuration class map directly to the + * beans defined in {@code SpringJUnit4ClassRunnerAppCtxTests-context.xml}. + * Consequently, the application contexts loaded from these two sources + * should be identical with regard to bean definitions. + * + * @author Sam Brannen + * @since 3.1 + */ +@Configuration +public class PojoAndStringConfig { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("John Smith"); + employee.setAge(42); + employee.setCompany("Acme Widgets, Inc."); + return employee; + } + + @Bean + public Pet pet() { + return new Pet("Fido"); + } + + @Bean + public String foo() { + return "Foo"; + } + + @Bean + public String bar() { + return "Bar"; + } + + @Bean + public String quux() { + return "Quux"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..e084468e39bbf9f8aa6a6af216c508feb08f933d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ActiveProfilesResolver; +import org.springframework.test.context.ContextConfiguration; + +/** + * Custom configuration annotation with meta-annotation attribute overrides for + * {@link ContextConfiguration#classes} and {@link ActiveProfiles#resolver} and + * with default configuration local to the composed annotation. + * + * @author Sam Brannen + * @since 4.0.3 + */ +@ContextConfiguration +@ActiveProfiles +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig { + + @Configuration + @Profile("dev") + static class DevConfig { + + @Bean + public String foo() { + return "Dev Foo"; + } + } + + @Configuration + @Profile("prod") + static class ProductionConfig { + + @Bean + public String foo() { + return "Production Foo"; + } + } + + @Configuration + @Profile("resolver") + static class ResolverConfig { + + @Bean + public String foo() { + return "Resolver Foo"; + } + } + + static class CustomResolver implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return testClass.getSimpleName().equals("ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigTests") ? new String[] { "resolver" } + : new String[] {}; + } + } + + + Class[] classes() default { DevConfig.class, ProductionConfig.class, ResolverConfig.class }; + + Class resolver() default CustomResolver.class; + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d14534829242968b337900d6bf1cfd6c86742408 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests for meta-annotation attribute override support, relying on + * default attribute values defined in {@link ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig}. + * + * @author Sam Brannen + * @since 4.0.3 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig +public class ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigTests { + + @Autowired + private String foo; + + + @Test + public void foo() { + assertEquals("Resolver Foo", foo); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigWithOverridesTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigWithOverridesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..20515d9b998b8dcfb5b172593b92fc7857da4edd --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigWithOverridesTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import org.springframework.test.context.ActiveProfilesResolver; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests for meta-annotation attribute override support, overriding + * default attribute values defined in {@link ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig}. + * + * @author Sam Brannen + * @since 4.0.3 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfig(classes = LocalDevConfig.class, resolver = DevResolver.class) +public class ConfigClassesAndProfileResolverWithCustomDefaultsMetaConfigWithOverridesTests { + + @Autowired + private String foo; + + + @Test + public void foo() { + assertEquals("Local Dev Foo", foo); + } +} + +@Configuration +@Profile("dev") +class LocalDevConfig { + + @Bean + public String foo() { + return "Local Dev Foo"; + } +} + +class DevResolver implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + // Checking that the "test class" name ends with "*Tests" ensures that an actual + // test class is passed to this method as opposed to a "*Config" class which would + // imply that we likely have been passed the 'declaringClass' instead of the + // 'rootDeclaringClass'. + return testClass.getName().endsWith("Tests") ? new String[] { "dev" } : new String[] {}; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesMetaConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesMetaConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..cbba94088998156504aea65ba60e78ebb8cd4e46 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesMetaConfig.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; + +/** + * Custom configuration annotation with meta-annotation attribute overrides for + * {@link ContextConfiguration#classes} and {@link ActiveProfiles#profiles} and + * no default configuration local to the composed annotation. + * + * @author Sam Brannen + * @since 4.0.3 + */ +@ContextConfiguration +@ActiveProfiles +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ConfigClassesAndProfilesMetaConfig { + + Class[] classes() default {}; + + String[] profiles() default {}; + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesMetaConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesMetaConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0b5c2df10c5cd6ccb4099022a1f185ccb80ac429 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesMetaConfigTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests for meta-annotation attribute override support, demonstrating + * that the test class is used as the declaring class when detecting default + * configuration classes for the declaration of {@code @ContextConfiguration}. + * + * @author Sam Brannen + * @since 4.0.3 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ConfigClassesAndProfilesMetaConfig(profiles = "dev") +public class ConfigClassesAndProfilesMetaConfigTests { + + @Configuration + @Profile("dev") + static class DevConfig { + + @Bean + public String foo() { + return "Local Dev Foo"; + } + } + + @Configuration + @Profile("prod") + static class ProductionConfig { + + @Bean + public String foo() { + return "Local Production Foo"; + } + } + + + @Autowired + private String foo; + + + @Test + public void foo() { + assertEquals("Local Dev Foo", foo); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..b953200f5de21514d0d78c20d0d50978bd540b8b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfig.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; + +/** + * Custom configuration annotation with meta-annotation attribute overrides for + * {@link ContextConfiguration#classes} and {@link ActiveProfiles#profiles} and + * with default configuration local to the composed annotation. + * + * @author Sam Brannen + * @since 4.0 + */ +@ContextConfiguration +@ActiveProfiles +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ConfigClassesAndProfilesWithCustomDefaultsMetaConfig { + + @Configuration + @Profile("dev") + static class DevConfig { + + @Bean + public String foo() { + return "Dev Foo"; + } + } + + @Configuration + @Profile("prod") + static class ProductionConfig { + + @Bean + public String foo() { + return "Production Foo"; + } + } + + + Class[] classes() default { DevConfig.class, ProductionConfig.class }; + + String[] profiles() default "dev"; + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..096716529832df807c25726a3e465812fa030d0a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfigTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests for meta-annotation attribute override support, relying on + * default attribute values defined in {@link ConfigClassesAndProfilesWithCustomDefaultsMetaConfig}. + * + * @author Sam Brannen + * @since 4.0 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ConfigClassesAndProfilesWithCustomDefaultsMetaConfig +public class ConfigClassesAndProfilesWithCustomDefaultsMetaConfigTests { + + @Autowired + private String foo; + + + @Test + public void foo() { + assertEquals("Dev Foo", foo); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfigWithOverridesTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfigWithOverridesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3502f1c92f0a0a7f52b535b9b9f03b50245b733b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/ConfigClassesAndProfilesWithCustomDefaultsMetaConfigWithOverridesTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.junit4.annotation.PojoAndStringConfig; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * Integration tests for meta-annotation attribute override support, overriding + * default attribute values defined in {@link ConfigClassesAndProfilesWithCustomDefaultsMetaConfig}. + * + * @author Sam Brannen + * @since 4.0 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ConfigClassesAndProfilesWithCustomDefaultsMetaConfig(classes = { PojoAndStringConfig.class, + ConfigClassesAndProfilesWithCustomDefaultsMetaConfig.ProductionConfig.class }, profiles = "prod") +public class ConfigClassesAndProfilesWithCustomDefaultsMetaConfigWithOverridesTests { + + @Autowired + private String foo; + + @Autowired + private Pet pet; + + @Autowired + protected Employee employee; + + + @Test + public void verifyEmployee() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } + + @Test + public void verifyPet() { + assertNotNull("The pet should have been autowired.", this.pet); + assertEquals("Fido", this.pet.getName()); + } + + @Test + public void verifyFoo() { + assertEquals("Production Foo", this.foo); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/MetaMetaConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/MetaMetaConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..be59c1119f024433d0a7e3149aa1f726f7f812af --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/MetaMetaConfig.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.test.context.ActiveProfiles; + +/** + * Custom configuration annotation that is itself meta-annotated with {@link + * ConfigClassesAndProfilesWithCustomDefaultsMetaConfig} and {@link ActiveProfiles}. + * + * @author Sam Brannen + * @since 4.0.3 + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +@ConfigClassesAndProfilesWithCustomDefaultsMetaConfig +// Override default "dev" profile from +// @ConfigClassesAndProfilesWithCustomDefaultsMetaConfig: +@ActiveProfiles("prod") +public @interface MetaMetaConfig { + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/MetaMetaConfigDefaultsTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/MetaMetaConfigDefaultsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3d2e60ee9c20e4e44e2656db69aac2934557da84 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/annotation/meta/MetaMetaConfigDefaultsTests.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.annotation.meta; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests for meta-meta-annotation support, relying on default attribute + * values defined in {@link ConfigClassesAndProfilesWithCustomDefaultsMetaConfig} and + * overrides in {@link MetaMetaConfig}. + * + * @author Sam Brannen + * @since 4.0.3 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@MetaMetaConfig +public class MetaMetaConfigDefaultsTests { + + @Autowired + private String foo; + + + @Test + public void foo() { + assertEquals("Production Foo", foo); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/concurrency/SpringJUnit4ConcurrencyTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/concurrency/SpringJUnit4ConcurrencyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8f8bc9bc9a5bb0c2abe38fd418e890e3e0ffc73d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/concurrency/SpringJUnit4ConcurrencyTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.concurrency; + +import java.lang.annotation.Annotation; +import java.util.Arrays; + +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.experimental.ParallelComputer; + +import org.springframework.test.context.hierarchies.web.DispatcherWacRootWacEarTests; +import org.springframework.test.context.junit4.InheritedConfigSpringJUnit4ClassRunnerAppCtxTests; +import org.springframework.test.context.junit4.MethodLevelTransactionalSpringRunnerTests; +import org.springframework.test.context.junit4.SpringJUnit47ClassRunnerRuleTests; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunnerAppCtxTests; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.junit4.TimedTransactionalSpringRunnerTests; +import org.springframework.test.context.junit4.rules.BaseAppCtxRuleTests; +import org.springframework.test.context.junit4.rules.BasicAnnotationConfigWacSpringRuleTests; +import org.springframework.test.context.junit4.rules.SpringClassRule; +import org.springframework.test.context.junit4.rules.SpringMethodRule; +import org.springframework.test.context.web.RequestAndSessionScopedBeansWacTests; +import org.springframework.test.context.web.socket.WebSocketServletServerContainerFactoryBeanTests; +import org.springframework.test.web.client.samples.SampleTests; +import org.springframework.test.web.servlet.samples.context.JavaConfigTests; +import org.springframework.test.web.servlet.samples.context.WebAppResourceTests; +import org.springframework.tests.Assume; +import org.springframework.tests.TestGroup; +import org.springframework.util.ReflectionUtils; + +import static org.springframework.core.annotation.AnnotatedElementUtils.*; +import static org.springframework.test.context.junit4.JUnitTestingUtils.*; + +/** + * Concurrency tests for the {@link SpringRunner}, {@link SpringClassRule}, and + * {@link SpringMethodRule} that use JUnit 4's experimental {@link ParallelComputer} + * to execute tests in parallel. + * + *

The tests executed by this test class come from a hand-picked collection of test + * classes within the test suite that is intended to cover most categories of tests + * that are currently supported by the TestContext Framework on JUnit 4. + * + *

The chosen test classes intentionally do not include any classes that + * fall under the following categories. + * + *

+ * + *

NOTE: these tests only run if the {@link TestGroup#LONG_RUNNING + * LONG_RUNNING} test group is enabled. + * + * @author Sam Brannen + * @since 5.0 + * @see org.springframework.test.context.TestContextConcurrencyTests + */ +public class SpringJUnit4ConcurrencyTests { + + private final Class[] testClasses = new Class[] { + // Basics + SpringJUnit4ClassRunnerAppCtxTests.class, + InheritedConfigSpringJUnit4ClassRunnerAppCtxTests.class, + SpringJUnit47ClassRunnerRuleTests.class, + BaseAppCtxRuleTests.class, + // Transactional + MethodLevelTransactionalSpringRunnerTests.class, + TimedTransactionalSpringRunnerTests.class, + // Web and Scopes + DispatcherWacRootWacEarTests.class, + BasicAnnotationConfigWacSpringRuleTests.class, + RequestAndSessionScopedBeansWacTests.class, + WebSocketServletServerContainerFactoryBeanTests.class, + // Spring MVC Test + JavaConfigTests.class, + WebAppResourceTests.class, + SampleTests.class + }; + + + @BeforeClass + public static void abortIfLongRunningTestGroupIsNotEnabled() { + Assume.group(TestGroup.LONG_RUNNING); + } + + @Test + public void runAllTestsConcurrently() throws Exception { + final int FAILED = 0; + final int ABORTED = 0; + final int IGNORED = countAnnotatedMethods(Ignore.class); + final int TESTS = countAnnotatedMethods(Test.class) - IGNORED; + + runTestsAndAssertCounters(new ParallelComputer(true, true), TESTS, FAILED, TESTS, IGNORED, ABORTED, + this.testClasses); + } + + private int countAnnotatedMethods(Class annotationType) { + return (int) Arrays.stream(this.testClasses) + .map(ReflectionUtils::getUniqueDeclaredMethods) + .flatMap(Arrays::stream) + .filter(method -> hasAnnotation(method, annotationType)) + .count(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/hybrid/HybridContextLoader.java b/spring-test/src/test/java/org/springframework/test/context/junit4/hybrid/HybridContextLoader.java new file mode 100644 index 0000000000000000000000000000000000000000..02f74de9b8a8f1f9662bf464b0fdee26dea65579 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/hybrid/HybridContextLoader.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.hybrid; + +import org.springframework.beans.factory.support.BeanDefinitionReader; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.context.annotation.AnnotatedBeanDefinitionReader; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.test.context.ContextConfigurationAttributes; +import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.test.context.SmartContextLoader; +import org.springframework.test.context.support.AbstractGenericContextLoader; +import org.springframework.util.Assert; + +import static org.springframework.test.context.support.AnnotationConfigContextLoaderUtils.*; + +/** + * Hybrid {@link SmartContextLoader} that supports path-based and class-based + * resources simultaneously. + *

This test loader is inspired by Spring Boot. + *

Detects defaults for XML configuration and annotated classes. + *

Beans from XML configuration always override those from annotated classes. + * + * @author Sam Brannen + * @since 4.0.4 + */ +public class HybridContextLoader extends AbstractGenericContextLoader { + + @Override + protected void validateMergedContextConfiguration(MergedContextConfiguration mergedConfig) { + Assert.isTrue(mergedConfig.hasClasses() || mergedConfig.hasLocations(), getClass().getSimpleName() + + " requires either classes or locations"); + } + + @Override + public void processContextConfiguration(ContextConfigurationAttributes configAttributes) { + // Detect default XML configuration files: + super.processContextConfiguration(configAttributes); + + // Detect default configuration classes: + if (!configAttributes.hasClasses() && isGenerateDefaultLocations()) { + configAttributes.setClasses(detectDefaultConfigurationClasses(configAttributes.getDeclaringClass())); + } + } + + @Override + protected void loadBeanDefinitions(GenericApplicationContext context, MergedContextConfiguration mergedConfig) { + // Order doesn't matter: always wins over @Bean. + new XmlBeanDefinitionReader(context).loadBeanDefinitions(mergedConfig.getLocations()); + new AnnotatedBeanDefinitionReader(context).register(mergedConfig.getClasses()); + } + + @Override + protected BeanDefinitionReader createBeanDefinitionReader(GenericApplicationContext context) { + throw new UnsupportedOperationException(getClass().getSimpleName() + " doesn't support this"); + } + + @Override + protected String getResourceSuffix() { + return "-context.xml"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/hybrid/HybridContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/hybrid/HybridContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fedfff06b995129aaa7080e5ab6566af4de83153 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/hybrid/HybridContextLoaderTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.hybrid; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.SmartContextLoader; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests for hybrid {@link SmartContextLoader} implementations that + * support path-based and class-based resources simultaneously, as is done in + * Spring Boot. + * + * @author Sam Brannen + * @since 4.0.4 + * @see HybridContextLoader + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(loader = HybridContextLoader.class) +public class HybridContextLoaderTests { + + @Configuration + static class Config { + + @Bean + public String fooFromJava() { + return "Java"; + } + + @Bean + public String enigma() { + return "enigma from Java"; + } + } + + + @Autowired + private String fooFromXml; + + @Autowired + private String fooFromJava; + + @Autowired + private String enigma; + + + @Test + public void verifyContentsOfHybridApplicationContext() { + assertEquals("XML", fooFromXml); + assertEquals("Java", fooFromJava); + + // Note: the XML bean definition for "enigma" always wins since + // ConfigurationClassBeanDefinitionReader.isOverriddenByExistingDefinition() + // lets XML bean definitions override those "discovered" later via an + // @Bean method. + assertEquals("enigma from XML", enigma); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/nested/NestedTestsWithSpringRulesTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/nested/NestedTestsWithSpringRulesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..66d36cafc24e0c1d5a9e73e070b5da5c9f235e9e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/nested/NestedTestsWithSpringRulesTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.nested; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.nested.NestedTestsWithSpringRulesTests.TopLevelConfig; +import org.springframework.test.context.junit4.rules.SpringClassRule; +import org.springframework.test.context.junit4.rules.SpringMethodRule; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import de.bechte.junit.runners.context.HierarchicalContextRunner; + +/** + * JUnit 4 based integration tests for nested test classes that are + * executed via a custom JUnit 4 {@link HierarchicalContextRunner} and Spring's + * {@link SpringClassRule} and {@link SpringMethodRule} support. + * + * @author Sam Brannen + * @since 5.0 + * @see org.springframework.test.context.junit.jupiter.nested.NestedTestsWithSpringAndJUnitJupiterTestCase + */ +@RunWith(HierarchicalContextRunner.class) +@ContextConfiguration(classes = TopLevelConfig.class) +public class NestedTestsWithSpringRulesTests extends SpringRuleConfigurer { + + @Autowired + String foo; + + + @Test + public void topLevelTest() { + assertEquals("foo", foo); + } + + + @ContextConfiguration(classes = NestedConfig.class) + public class NestedTestCase extends SpringRuleConfigurer { + + @Autowired + String bar; + + + @Test + public void nestedTest() throws Exception { + // Note: the following would fail since TestExecutionListeners in + // the Spring TestContext Framework are not applied to the enclosing + // instance of an inner test class. + // + // assertEquals("foo", foo); + + assertNull("@Autowired field in enclosing instance should be null.", foo); + assertEquals("bar", bar); + } + } + + // ------------------------------------------------------------------------- + + @Configuration + public static class TopLevelConfig { + + @Bean + String foo() { + return "foo"; + } + } + + @Configuration + public static class NestedConfig { + + @Bean + String bar() { + return "bar"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/nested/SpringRuleConfigurer.java b/spring-test/src/test/java/org/springframework/test/context/junit4/nested/SpringRuleConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..983a276f43dc60d588693f712a25c5b4fc51a051 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/nested/SpringRuleConfigurer.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.nested; + +import org.junit.ClassRule; +import org.junit.Rule; + +import org.springframework.test.context.junit4.rules.SpringClassRule; +import org.springframework.test.context.junit4.rules.SpringMethodRule; + +/** + * Abstract base test class that preconfigures the {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 5.0 + */ +public abstract class SpringRuleConfigurer { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/orm/HibernateSessionFlushingTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/HibernateSessionFlushingTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f809ab198441f2ef9403c282371695bf0d26f024 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/HibernateSessionFlushingTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.orm; + +import javax.persistence.PersistenceException; + +import org.hibernate.Session; +import org.hibernate.SessionFactory; +import org.hibernate.exception.ConstraintViolationException; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests; +import org.springframework.test.context.junit4.orm.domain.DriversLicense; +import org.springframework.test.context.junit4.orm.domain.Person; +import org.springframework.test.context.junit4.orm.service.PersonService; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * Transactional integration tests regarding manual session flushing with + * Hibernate. + * + * @author Sam Brannen + * @author Juergen Hoeller + * @author Vlad Mihalcea + * @since 3.0 + */ +@ContextConfiguration +public class HibernateSessionFlushingTests extends AbstractTransactionalJUnit4SpringContextTests { + + private static final String SAM = "Sam"; + private static final String JUERGEN = "Juergen"; + + @Autowired + private PersonService personService; + + @Autowired + private SessionFactory sessionFactory; + + + @Before + public void setup() { + assertInTransaction(true); + assertNotNull("PersonService should have been autowired.", personService); + assertNotNull("SessionFactory should have been autowired.", sessionFactory); + } + + + @Test + public void findSam() { + Person sam = personService.findByName(SAM); + assertNotNull("Should be able to find Sam", sam); + DriversLicense driversLicense = sam.getDriversLicense(); + assertNotNull("Sam's driver's license should not be null", driversLicense); + assertEquals("Verifying Sam's driver's license number", Long.valueOf(1234), driversLicense.getNumber()); + } + + @Test // SPR-16956 + @Transactional(readOnly = true) + public void findSamWithReadOnlySession() { + Person sam = personService.findByName(SAM); + sam.setName("Vlad"); + // By setting setDefaultReadOnly(true), the user can no longer modify any entity... + Session session = sessionFactory.getCurrentSession(); + session.flush(); + session.refresh(sam); + assertEquals("Sam", sam.getName()); + } + + @Test + public void saveJuergenWithDriversLicense() { + DriversLicense driversLicense = new DriversLicense(2L, 2222L); + Person juergen = new Person(JUERGEN, driversLicense); + int numRows = countRowsInTable("person"); + personService.save(juergen); + assertEquals("Verifying number of rows in the 'person' table.", numRows + 1, countRowsInTable("person")); + assertNotNull("Should be able to save and retrieve Juergen", personService.findByName(JUERGEN)); + assertNotNull("Juergen's ID should have been set", juergen.getId()); + } + + @Test(expected = ConstraintViolationException.class) + public void saveJuergenWithNullDriversLicense() { + personService.save(new Person(JUERGEN)); + } + + @Test + // no expected exception! + public void updateSamWithNullDriversLicenseWithoutSessionFlush() { + updateSamWithNullDriversLicense(); + // False positive, since an exception will be thrown once the session is + // finally flushed (i.e., in production code) + } + + @Test(expected = ConstraintViolationException.class) + public void updateSamWithNullDriversLicenseWithSessionFlush() throws Throwable { + updateSamWithNullDriversLicense(); + // Manual flush is required to avoid false positive in test + try { + sessionFactory.getCurrentSession().flush(); + } + catch (PersistenceException ex) { + // Wrapped in Hibernate 5.2, with the constraint violation as cause + throw ex.getCause(); + } + } + + private void updateSamWithNullDriversLicense() { + Person sam = personService.findByName(SAM); + assertNotNull("Should be able to find Sam", sam); + sam.setDriversLicense(null); + personService.save(sam); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/orm/domain/DriversLicense.java b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/domain/DriversLicense.java new file mode 100644 index 0000000000000000000000000000000000000000..8f3d9df0591544945f35ac16df0594f722374196 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/domain/DriversLicense.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.orm.domain; + +/** + * DriversLicense POJO. + * + * @author Sam Brannen + * @since 3.0 + */ +public class DriversLicense { + + private Long id; + + private Long number; + + + public DriversLicense() { + } + + public DriversLicense(Long number) { + this(null, number); + } + + public DriversLicense(Long id, Long number) { + this.id = id; + this.number = number; + } + + public Long getId() { + return this.id; + } + + protected void setId(Long id) { + this.id = id; + } + + public Long getNumber() { + return this.number; + } + + public void setNumber(Long number) { + this.number = number; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/orm/domain/Person.java b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/domain/Person.java new file mode 100644 index 0000000000000000000000000000000000000000..e0348fd1bc888d85f2403dfd2a313aa623613121 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/domain/Person.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.orm.domain; + +/** + * Person POJO. + * + * @author Sam Brannen + * @since 3.0 + */ +public class Person { + + private Long id; + private String name; + private DriversLicense driversLicense; + + + public Person() { + } + + public Person(Long id) { + this(id, null, null); + } + + public Person(String name) { + this(name, null); + } + + public Person(String name, DriversLicense driversLicense) { + this(null, name, driversLicense); + } + + public Person(Long id, String name, DriversLicense driversLicense) { + this.id = id; + this.name = name; + this.driversLicense = driversLicense; + } + + public Long getId() { + return this.id; + } + + protected void setId(Long id) { + this.id = id; + } + + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + + public DriversLicense getDriversLicense() { + return this.driversLicense; + } + + public void setDriversLicense(DriversLicense driversLicense) { + this.driversLicense = driversLicense; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/orm/repository/PersonRepository.java b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/repository/PersonRepository.java new file mode 100644 index 0000000000000000000000000000000000000000..b75f511f937782316ec3b1d304988b92c241f73a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/repository/PersonRepository.java @@ -0,0 +1,33 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.orm.repository; + +import org.springframework.test.context.junit4.orm.domain.Person; + +/** + * Person Repository API. + * + * @author Sam Brannen + * @since 3.0 + */ +public interface PersonRepository { + + Person findByName(String name); + + Person save(Person person); + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/orm/repository/hibernate/HibernatePersonRepository.java b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/repository/hibernate/HibernatePersonRepository.java new file mode 100644 index 0000000000000000000000000000000000000000..317ec5a5fdc22b2da4e9845dc679fdd82ceeceb8 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/repository/hibernate/HibernatePersonRepository.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.orm.repository.hibernate; + +import org.hibernate.SessionFactory; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Repository; +import org.springframework.test.context.junit4.orm.domain.Person; +import org.springframework.test.context.junit4.orm.repository.PersonRepository; + +/** + * Hibernate implementation of the {@link PersonRepository} API. + * + * @author Sam Brannen + * @since 3.0 + */ +@Repository +public class HibernatePersonRepository implements PersonRepository { + + private final SessionFactory sessionFactory; + + + @Autowired + public HibernatePersonRepository(SessionFactory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Person save(Person person) { + this.sessionFactory.getCurrentSession().save(person); + return person; + } + + @Override + public Person findByName(String name) { + return (Person) this.sessionFactory.getCurrentSession().createQuery( + "from Person person where person.name = :name").setParameter("name", name).getSingleResult(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/orm/service/PersonService.java b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/service/PersonService.java new file mode 100644 index 0000000000000000000000000000000000000000..d42b03602e1bdd811f042b3cb49e15bd15d20177 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/service/PersonService.java @@ -0,0 +1,33 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.orm.service; + +import org.springframework.test.context.junit4.orm.domain.Person; + +/** + * Person Service API. + * + * @author Sam Brannen + * @since 3.0 + */ +public interface PersonService { + + Person findByName(String name); + + Person save(Person person); + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/orm/service/impl/StandardPersonService.java b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/service/impl/StandardPersonService.java new file mode 100644 index 0000000000000000000000000000000000000000..2e98308a19d2b9fcf56d15433ba79c00fe9dc375 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/orm/service/impl/StandardPersonService.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.orm.service.impl; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.test.context.junit4.orm.domain.Person; +import org.springframework.test.context.junit4.orm.repository.PersonRepository; +import org.springframework.test.context.junit4.orm.service.PersonService; +import org.springframework.transaction.annotation.Transactional; + +/** + * Standard implementation of the {@link PersonService} API. + * + * @author Sam Brannen + * @since 3.0 + */ +@Service +@Transactional(readOnly = true) +public class StandardPersonService implements PersonService { + + private final PersonRepository personRepository; + + + @Autowired + public StandardPersonService(PersonRepository personRepository) { + this.personRepository = personRepository; + } + + @Override + public Person findByName(String name) { + return this.personRepository.findByName(name); + } + + @Override + @Transactional(readOnly = false) + public Person save(Person person) { + return this.personRepository.save(person); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DefaultProfileAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DefaultProfileAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7b3737b7a26bd7e439b3921c99b401dfc62c5622 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DefaultProfileAnnotationConfigTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.annotation; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.AnnotationConfigContextLoader; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = { DefaultProfileConfig.class, DevProfileConfig.class }, loader = AnnotationConfigContextLoader.class) +public class DefaultProfileAnnotationConfigTests { + + @Autowired + protected Pet pet; + + @Autowired(required = false) + protected Employee employee; + + + @Test + public void pet() { + assertNotNull(pet); + assertEquals("Fido", pet.getName()); + } + + @Test + public void employee() { + assertNull("employee bean should not be created for the default profile", employee); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DefaultProfileConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DefaultProfileConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..9f23e674e3a97d192a273b6cf4038310b4dfd5d9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DefaultProfileConfig.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.tests.sample.beans.Pet; + +/** + * @author Sam Brannen + * @since 3.1 + */ +@Configuration +public class DefaultProfileConfig { + + @Bean + public Pet pet() { + return new Pet("Fido"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..470edbf2c26c30d10b0132715bd6d675144ad501 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileAnnotationConfigTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.annotation; + +import org.junit.Test; + +import org.springframework.test.context.ActiveProfiles; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 3.1 + */ +@ActiveProfiles("dev") +public class DevProfileAnnotationConfigTests extends DefaultProfileAnnotationConfigTests { + + @Test + @Override + public void employee() { + assertNotNull("employee bean should be loaded for the 'dev' profile", employee); + assertEquals("John Smith", employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..d1a3d05aef444be99a4d325f5f43898bf7d76124 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileConfig.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; +import org.springframework.tests.sample.beans.Employee; + +/** + * @author Sam Brannen + * @since 3.1 + */ +@Profile("dev") +@Configuration +public class DevProfileConfig { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("John Smith"); + employee.setAge(42); + employee.setCompany("Acme Widgets, Inc."); + return employee; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileResolverAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileResolverAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9fe028dc7abb19c4c3f316ff9a64d90030d547e8 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/DevProfileResolverAnnotationConfigTests.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.annotation; + +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ActiveProfilesResolver; + +/** + * @author Michail Nikolaev + * @since 4.0 + */ +@ActiveProfiles(resolver = DevProfileResolverAnnotationConfigTests.class, inheritProfiles = false) +public class DevProfileResolverAnnotationConfigTests extends DevProfileAnnotationConfigTests implements + ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return new String[] { "dev" }; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/ProfileAnnotationConfigTestSuite.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/ProfileAnnotationConfigTestSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..26d29dfa264f39c7fe7d744512d5a1b137b1ba6b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/annotation/ProfileAnnotationConfigTestSuite.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.annotation; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** + * JUnit test suite for bean definition profile support in the + * Spring TestContext Framework with annotation-based configuration. + * + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(Suite.class) +// Note: the following 'multi-line' layout is for enhanced code readability. +@SuiteClasses({// +DefaultProfileAnnotationConfigTests.class,// + DevProfileAnnotationConfigTests.class,// + DevProfileResolverAnnotationConfigTests.class // +}) +public class ProfileAnnotationConfigTestSuite { +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DefaultProfileAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DefaultProfileAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9a499db38adb5e2291acbf5cc7e5ae2df18aa691 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DefaultProfileAnnotationConfigTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.importresource; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @since 3.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = DefaultProfileConfig.class) +public class DefaultProfileAnnotationConfigTests { + + @Autowired + protected Pet pet; + + @Autowired(required = false) + protected Employee employee; + + + @Test + public void pet() { + assertNotNull(pet); + assertEquals("Fido", pet.getName()); + } + + @Test + public void employee() { + assertNull("employee bean should not be created for the default profile", employee); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DefaultProfileConfig.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DefaultProfileConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..afcfc8722b0fb24ccba1fb8a9fcda600d04408b3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DefaultProfileConfig.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.importresource; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportResource; +import org.springframework.tests.sample.beans.Pet; + +/** + * @author Juergen Hoeller + * @since 3.1 + */ +@Configuration +@ImportResource("org/springframework/test/context/junit4/profile/importresource/import.xml") +public class DefaultProfileConfig { + + @Bean + public Pet pet() { + return new Pet("Fido"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DevProfileAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DevProfileAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f9c531d97034d46af282497d5d7a53dfa26cb56a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DevProfileAnnotationConfigTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.importresource; + +import org.junit.Test; + +import org.springframework.test.context.ActiveProfiles; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @since 3.1 + */ +@ActiveProfiles("dev") +public class DevProfileAnnotationConfigTests extends DefaultProfileAnnotationConfigTests { + + @Test + @Override + public void employee() { + assertNotNull("employee bean should be loaded for the 'dev' profile", employee); + assertEquals("John Smith", employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DevProfileResolverAnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DevProfileResolverAnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a913abaf6ef13d0884850bacfca004b9e8d42b6f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/importresource/DevProfileResolverAnnotationConfigTests.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.importresource; + +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ActiveProfilesResolver; + +/** + * @author Michail Nikolaev + * @since 4.0 + */ +@ActiveProfiles(resolver = DevProfileResolverAnnotationConfigTests.class, inheritProfiles = false) +public class DevProfileResolverAnnotationConfigTests extends DevProfileAnnotationConfigTests implements + ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return new String[] { "dev" }; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/resolver/ClassNameActiveProfilesResolver.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/resolver/ClassNameActiveProfilesResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..d352b5a5a4b0905b9a4074d4d175b9901d1937da --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/resolver/ClassNameActiveProfilesResolver.java @@ -0,0 +1,31 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.resolver; + +import org.springframework.test.context.ActiveProfilesResolver; + +/** + * @author Michail Nikolaev + * @since 4.0 + */ +public class ClassNameActiveProfilesResolver implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return new String[] { testClass.getSimpleName().toLowerCase() }; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/resolver/ClassNameActiveProfilesResolverTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/resolver/ClassNameActiveProfilesResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8a2ea5b39192ebc201f6c6690db0424ddc009202 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/resolver/ClassNameActiveProfilesResolverTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.resolver; + +import java.util.Arrays; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * @author Michail Nikolaev + * @since 4.0 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@ActiveProfiles(resolver = ClassNameActiveProfilesResolver.class) +public class ClassNameActiveProfilesResolverTests { + + @Configuration + static class Config { + + } + + + @Autowired + private ApplicationContext applicationContext; + + + @Test + public void test() { + assertTrue(Arrays.asList(applicationContext.getEnvironment().getActiveProfiles()).contains( + getClass().getSimpleName().toLowerCase())); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DefaultProfileXmlConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DefaultProfileXmlConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..49b2e0516755066abe8791b7a8087755edc98f1e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DefaultProfileXmlConfigTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.xml; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class DefaultProfileXmlConfigTests { + + @Autowired + protected Pet pet; + + @Autowired(required = false) + protected Employee employee; + + + @Test + public void pet() { + assertNotNull(pet); + assertEquals("Fido", pet.getName()); + } + + @Test + public void employee() { + assertNull("employee bean should not be created for the default profile", employee); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DevProfileResolverXmlConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DevProfileResolverXmlConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ef30225d9c4088a1c6dfcc12719d5a51dcaa72ec --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DevProfileResolverXmlConfigTests.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.xml; + +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ActiveProfilesResolver; + +/** + * @author Michail Nikolaev + * @since 4.0 + */ +@ActiveProfiles(resolver = DevProfileResolverXmlConfigTests.class, inheritProfiles = false) +public class DevProfileResolverXmlConfigTests extends DevProfileXmlConfigTests implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return new String[] { "dev" }; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DevProfileXmlConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DevProfileXmlConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9e5d0306307bfc2498b4a9f4c35e28a4f54d320a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/DevProfileXmlConfigTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.xml; + +import org.junit.Test; + +import org.springframework.test.context.ActiveProfiles; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 3.1 + */ +@ActiveProfiles("dev") +public class DevProfileXmlConfigTests extends DefaultProfileXmlConfigTests { + + @Test + @Override + public void employee() { + assertNotNull("employee bean should be loaded for the 'dev' profile", employee); + assertEquals("John Smith", employee.getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/ProfileXmlConfigTestSuite.java b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/ProfileXmlConfigTestSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..741327d7e198a80765793e32b0858cad047df064 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/profile/xml/ProfileXmlConfigTestSuite.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.profile.xml; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** + * JUnit test suite for bean definition profile support in the + * Spring TestContext Framework with XML-based configuration. + * + * @author Sam Brannen + * @since 3.1 + */ +@RunWith(Suite.class) +// Note: the following 'multi-line' layout is for enhanced code readability. +@SuiteClasses({// +DefaultProfileXmlConfigTests.class,// + DevProfileXmlConfigTests.class,// + DevProfileResolverXmlConfigTests.class // +}) +public class ProfileXmlConfigTestSuite { +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/AutowiredRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/AutowiredRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7ef4411f661834913672a6b80e7191821cc8eb49 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/AutowiredRuleTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import static org.junit.Assert.*; + +/** + * Integration tests for an issue raised in https://jira.spring.io/browse/SPR-15927. + * + * @author Sam Brannen + * @since 5.0 + */ +public class AutowiredRuleTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + @Autowired + @Rule + public AutowiredTestRule autowiredTestRule; + + @Test + public void test() { + assertNotNull("TestRule should have been @Autowired", autowiredTestRule); + + // Rationale for the following assertion: + // + // The field value for the custom rule is null when JUnit sees it. JUnit then + // ignores the null value, and at a later point in time Spring injects the rule + // from the ApplicationContext and overrides the null field value. But that's too + // late: JUnit never sees the rule supplied by Spring via dependency injection. + assertFalse("@Autowired TestRule should NOT have been applied", autowiredTestRule.applied); + } + + @Configuration + static class Config { + + @Bean + AutowiredTestRule autowiredTestRule() { + return new AutowiredTestRule(); + } + } + + static class AutowiredTestRule implements TestRule { + + private boolean applied = false; + + @Override + public Statement apply(Statement base, Description description) { + this.applied = true; + return base; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BaseAppCtxRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BaseAppCtxRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0d6a870f7249bf23b5198f01289c85b016adb277 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BaseAppCtxRuleTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * Base class for integration tests involving Spring {@code ApplicationContexts} + * in conjunction with {@link SpringClassRule} and {@link SpringMethodRule}. + * + *

The goal of this class and its subclasses is to ensure that Rule-based + * configuration can be inherited without requiring {@link SpringClassRule} + * or {@link SpringMethodRule} to be redeclared on subclasses. + * + * @author Sam Brannen + * @since 4.2 + * @see Subclass1AppCtxRuleTests + * @see Subclass2AppCtxRuleTests + */ +@ContextConfiguration +public class BaseAppCtxRuleTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + @Autowired + private String foo; + + + @Test + public void foo() { + assertEquals("foo", foo); + } + + + @Configuration + static class Config { + + @Bean + public String foo() { + return "foo"; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BasicAnnotationConfigWacSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BasicAnnotationConfigWacSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8c2f26f8fda4eaa66a960b08c9966cd659d2ec02 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BasicAnnotationConfigWacSpringRuleTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.springframework.test.context.web.BasicAnnotationConfigWacTests; + +/** + * This class is an extension of {@link BasicAnnotationConfigWacTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(JUnit4.class) +public class BasicAnnotationConfigWacSpringRuleTests extends BasicAnnotationConfigWacTests { + + // All tests are in superclass. + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BeforeAndAfterTransactionAnnotationSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BeforeAndAfterTransactionAnnotationSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..bff96a6fe8dbe7958516f58a1ae3009fba217f3b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/BeforeAndAfterTransactionAnnotationSpringRuleTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.springframework.test.context.junit4.BeforeAndAfterTransactionAnnotationTests; + +/** + * This class is an extension of {@link BeforeAndAfterTransactionAnnotationTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(JUnit4.class) +public class BeforeAndAfterTransactionAnnotationSpringRuleTests extends BeforeAndAfterTransactionAnnotationTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + // All tests are in superclass. + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ClassLevelDisabledSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ClassLevelDisabledSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..453a9c7ef2bc5537ba33b2950457cadac48ff2e3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ClassLevelDisabledSpringRuleTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.springframework.test.context.junit4.ClassLevelDisabledSpringRunnerTests; + +/** + * This class is an extension of {@link ClassLevelDisabledSpringRunnerTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(JUnit4.class) +public class ClassLevelDisabledSpringRuleTests extends ClassLevelDisabledSpringRunnerTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + // All tests are in superclass. + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/EnabledAndIgnoredSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/EnabledAndIgnoredSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..27f29eca93d2102fef70d484170789f294321f67 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/EnabledAndIgnoredSpringRuleTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.springframework.test.context.junit4.EnabledAndIgnoredSpringRunnerTests; + +/** + * This class is an extension of {@link EnabledAndIgnoredSpringRunnerTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(JUnit4.class) +public class EnabledAndIgnoredSpringRuleTests extends EnabledAndIgnoredSpringRunnerTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + // All tests are in superclass. + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/FailingBeforeAndAfterMethodsSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/FailingBeforeAndAfterMethodsSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ff0968c3e3d6fef32046bc6243818b97f9ceb5d7 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/FailingBeforeAndAfterMethodsSpringRuleTests.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runner.Runner; +import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.junit4.FailingBeforeAndAfterMethodsSpringRunnerTests; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * This class is an extension of {@link FailingBeforeAndAfterMethodsSpringRunnerTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +public class FailingBeforeAndAfterMethodsSpringRuleTests extends FailingBeforeAndAfterMethodsSpringRunnerTests { + + @Parameters(name = "{0}") + public static Object[] testData() { + return new Object[] {// + AlwaysFailingBeforeTestClassSpringRuleTestCase.class.getSimpleName(),// + AlwaysFailingAfterTestClassSpringRuleTestCase.class.getSimpleName(),// + AlwaysFailingPrepareTestInstanceSpringRuleTestCase.class.getSimpleName(),// + AlwaysFailingBeforeTestMethodSpringRuleTestCase.class.getSimpleName(),// + AlwaysFailingAfterTestMethodSpringRuleTestCase.class.getSimpleName(),// + FailingBeforeTransactionSpringRuleTestCase.class.getSimpleName(),// + FailingAfterTransactionSpringRuleTestCase.class.getSimpleName() // + }; + } + + public FailingBeforeAndAfterMethodsSpringRuleTests(String testClassName) throws Exception { + super(testClassName); + } + + @Override + protected Class getRunnerClass() { + return JUnit4.class; + } + + // All tests are in superclass. + + @RunWith(JUnit4.class) + public static abstract class BaseSpringRuleTestCase { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + + @Test + public void testNothing() { + } + } + + @Ignore("TestCase classes are run manually by the enclosing test class") + @TestExecutionListeners(AlwaysFailingBeforeTestClassTestExecutionListener.class) + public static class AlwaysFailingBeforeTestClassSpringRuleTestCase extends BaseSpringRuleTestCase { + } + + @Ignore("TestCase classes are run manually by the enclosing test class") + @TestExecutionListeners(AlwaysFailingAfterTestClassTestExecutionListener.class) + public static class AlwaysFailingAfterTestClassSpringRuleTestCase extends BaseSpringRuleTestCase { + } + + @Ignore("TestCase classes are run manually by the enclosing test class") + @TestExecutionListeners(AlwaysFailingPrepareTestInstanceTestExecutionListener.class) + public static class AlwaysFailingPrepareTestInstanceSpringRuleTestCase extends BaseSpringRuleTestCase { + } + + @Ignore("TestCase classes are run manually by the enclosing test class") + @TestExecutionListeners(AlwaysFailingBeforeTestMethodTestExecutionListener.class) + public static class AlwaysFailingBeforeTestMethodSpringRuleTestCase extends BaseSpringRuleTestCase { + } + + @Ignore("TestCase classes are run manually by the enclosing test class") + @TestExecutionListeners(AlwaysFailingAfterTestMethodTestExecutionListener.class) + public static class AlwaysFailingAfterTestMethodSpringRuleTestCase extends BaseSpringRuleTestCase { + } + + @Ignore("TestCase classes are run manually by the enclosing test class") + @RunWith(JUnit4.class) + @ContextConfiguration("../FailingBeforeAndAfterMethodsTests-context.xml") + @Transactional + public static class FailingBeforeTransactionSpringRuleTestCase { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + + @Test + public void testNothing() { + } + + @BeforeTransaction + public void beforeTransaction() { + fail("always failing beforeTransaction()"); + } + } + + @Ignore("TestCase classes are run manually by the enclosing test class") + @RunWith(JUnit4.class) + @ContextConfiguration("../FailingBeforeAndAfterMethodsTests-context.xml") + @Transactional + public static class FailingAfterTransactionSpringRuleTestCase { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + + @Test + public void testNothing() { + } + + @AfterTransaction + public void afterTransaction() { + fail("always failing afterTransaction()"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ParameterizedSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ParameterizedSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a0d2dec4e42ecfface4d2c56ec8e945630049a36 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ParameterizedSpringRuleTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * Integration test which demonstrates how to use JUnit's {@link Parameterized} + * runner in conjunction with {@link SpringClassRule} and {@link SpringMethodRule} + * to provide dependency injection to a parameterized test instance. + * + * @author Sam Brannen + * @since 4.2 + * @see org.springframework.test.context.junit4.ParameterizedDependencyInjectionTests + */ +@RunWith(Parameterized.class) +@ContextConfiguration("/org/springframework/test/context/junit4/ParameterizedDependencyInjectionTests-context.xml") +public class ParameterizedSpringRuleTests { + + private static final AtomicInteger invocationCount = new AtomicInteger(); + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + @Autowired + private ApplicationContext applicationContext; + + @Autowired + private Pet pet; + + @Parameter(0) + public String employeeBeanName; + + @Parameter(1) + public String employeeName; + + + @Parameters(name = "bean [{0}], employee [{1}]") + public static String[][] employeeData() { + return new String[][] { { "employee1", "John Smith" }, { "employee2", "Jane Smith" } }; + } + + @BeforeClass + public static void BeforeClass() { + invocationCount.set(0); + } + + @Test + public final void verifyPetAndEmployee() { + invocationCount.incrementAndGet(); + + // Verifying dependency injection: + assertNotNull("The pet field should have been autowired.", this.pet); + + // Verifying 'parameterized' support: + Employee employee = this.applicationContext.getBean(this.employeeBeanName, Employee.class); + assertEquals("Name of the employee configured as bean [" + this.employeeBeanName + "].", this.employeeName, + employee.getName()); + } + + @AfterClass + public static void verifyNumParameterizedRuns() { + assertEquals("Number of times the parameterized test method was executed.", employeeData().length, + invocationCount.get()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ProgrammaticTxMgmtSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ProgrammaticTxMgmtSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8bd5043cfba2375f8112bc0beb32bd48c2e950fa --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/ProgrammaticTxMgmtSpringRuleTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import javax.sql.DataSource; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.programmatic.ProgrammaticTxMgmtTests; +import org.springframework.transaction.PlatformTransactionManager; + +/** + * This class is an extension of {@link ProgrammaticTxMgmtTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(JUnit4.class) +@ContextConfiguration +public class ProgrammaticTxMgmtSpringRuleTests extends ProgrammaticTxMgmtTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + // All tests are in superclass. + + // ------------------------------------------------------------------------- + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager transactionManager() { + return new DataSourceTransactionManager(dataSource()); + } + + @Bean + public DataSource dataSource() { + return new EmbeddedDatabaseBuilder()// + .generateUniqueName(true)// + .addScript("classpath:/org/springframework/test/context/jdbc/schema.sql") // + .build(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/RepeatedSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/RepeatedSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d06cdd71015f8cdf1206ec03519bc35f6b444348 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/RepeatedSpringRuleTests.java @@ -0,0 +1,177 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import java.io.IOException; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.Runner; +import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.test.annotation.Repeat; +import org.springframework.test.annotation.Timed; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.junit4.RepeatedSpringRunnerTests; + +/** + * This class is an extension of {@link RepeatedSpringRunnerTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +public class RepeatedSpringRuleTests extends RepeatedSpringRunnerTests { + + @Parameters(name = "{0}") + public static Object[][] repetitionData() { + return new Object[][] {// + { NonAnnotatedRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 1 },// + { DefaultRepeatValueRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 1 },// + { NegativeRepeatValueRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 1 },// + { RepeatedFiveTimesRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 5 },// + { RepeatedFiveTimesViaMetaAnnotationRepeatedTestCase.class.getSimpleName(), 0, 1, 1, 5 },// + { TimedRepeatedTestCase.class.getSimpleName(), 3, 4, 4, (5 + 1 + 4 + 10) } // + }; + } + + public RepeatedSpringRuleTests(String testClassName, int expectedFailureCount, int expectedTestStartedCount, + int expectedTestFinishedCount, int expectedInvocationCount) throws Exception { + + super(testClassName, expectedFailureCount, expectedTestStartedCount, expectedTestFinishedCount, + expectedInvocationCount); + } + + @Override + protected Class getRunnerClass() { + return JUnit4.class; + } + + // All tests are in superclass. + + @TestExecutionListeners({}) + public abstract static class AbstractRepeatedTestCase { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + + protected void incrementInvocationCount() throws IOException { + invocationCount.incrementAndGet(); + } + } + + public static final class NonAnnotatedRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Timed(millis = 10000) + public void nonAnnotated() throws Exception { + incrementInvocationCount(); + } + } + + public static final class DefaultRepeatValueRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Repeat + @Timed(millis = 10000) + public void defaultRepeatValue() throws Exception { + incrementInvocationCount(); + } + } + + public static final class NegativeRepeatValueRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Repeat(-5) + @Timed(millis = 10000) + public void negativeRepeatValue() throws Exception { + incrementInvocationCount(); + } + } + + public static final class RepeatedFiveTimesRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Repeat(5) + public void repeatedFiveTimes() throws Exception { + incrementInvocationCount(); + } + } + + @Repeat(5) + @Retention(RetentionPolicy.RUNTIME) + private static @interface RepeatedFiveTimes { + } + + public static final class RepeatedFiveTimesViaMetaAnnotationRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @RepeatedFiveTimes + public void repeatedFiveTimes() throws Exception { + incrementInvocationCount(); + } + } + + /** + * Unit tests for claims raised in SPR-6011. + */ + @Ignore("TestCase classes are run manually by the enclosing test class") + public static final class TimedRepeatedTestCase extends AbstractRepeatedTestCase { + + @Test + @Timed(millis = 1000) + @Repeat(5) + public void repeatedFiveTimesButDoesNotExceedTimeout() throws Exception { + incrementInvocationCount(); + } + + @Test + @Timed(millis = 10) + @Repeat(1) + public void singleRepetitionExceedsTimeout() throws Exception { + incrementInvocationCount(); + Thread.sleep(15); + } + + @Test + @Timed(millis = 20) + @Repeat(4) + public void firstRepetitionOfManyExceedsTimeout() throws Exception { + incrementInvocationCount(); + Thread.sleep(25); + } + + @Test + @Timed(millis = 100) + @Repeat(10) + public void collectiveRepetitionsExceedTimeout() throws Exception { + incrementInvocationCount(); + Thread.sleep(11); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/Subclass1AppCtxRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/Subclass1AppCtxRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3cb915e8b70bdbfbfb3eed55a8a90a0f32fd015e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/Subclass1AppCtxRuleTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * Subclass #1 of {@link BaseAppCtxRuleTests}. + * + * @author Sam Brannen + * @since 4.2 + */ +@ContextConfiguration +public class Subclass1AppCtxRuleTests extends BaseAppCtxRuleTests { + + @Autowired + private String bar; + + + @Test + public void bar() { + assertEquals("bar", bar); + } + + + @Configuration + static class Config { + + @Bean + public String bar() { + return "bar"; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/Subclass2AppCtxRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/Subclass2AppCtxRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cf43bd2c240bbad6c3d85a744af115e5b106443f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/Subclass2AppCtxRuleTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * Subclass #2 of {@link BaseAppCtxRuleTests}. + * + * @author Sam Brannen + * @since 4.2 + */ +@ContextConfiguration +public class Subclass2AppCtxRuleTests extends BaseAppCtxRuleTests { + + @Autowired + private String baz; + + + @Test + public void baz() { + assertEquals("baz", baz); + } + + + @Configuration + static class Config { + + @Bean + public String baz() { + return "baz"; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TimedSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TimedSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..802c69b41ead11f1b210cfe44eedae56bead4cfe --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TimedSpringRuleTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.runner.Runner; +import org.junit.runners.JUnit4; + +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.junit4.TimedSpringRunnerTests; + +import static org.junit.Assert.*; + +/** + * This class is an extension of {@link TimedSpringRunnerTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +public class TimedSpringRuleTests extends TimedSpringRunnerTests { + + // All tests are in superclass. + + @Override + protected Class getTestCase() { + return TimedSpringRuleTestCase.class; + } + + @Override + protected Class getRunnerClass() { + return JUnit4.class; + } + + + @Ignore("TestCase classes are run manually by the enclosing test class") + @TestExecutionListeners({}) + public static final class TimedSpringRuleTestCase extends TimedSpringRunnerTestCase { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + + /** + * Overridden to always throw an exception, since Spring's Rule-based + * JUnit integration does not fail a test for duplicate configuration + * of timeouts. + */ + @Override + public void springAndJUnitTimeouts() { + fail("intentional failure to make tests in superclass pass"); + } + + // All other tests are in superclass. + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TimedTransactionalSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TimedTransactionalSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..105017497b14ffd3fa187af398628dd42cfd2547 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TimedTransactionalSpringRuleTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import java.util.concurrent.TimeUnit; + +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.springframework.test.annotation.Repeat; +import org.springframework.test.context.junit4.TimedTransactionalSpringRunnerTests; + +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * This class is an extension of {@link TimedTransactionalSpringRunnerTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(JUnit4.class) +public class TimedTransactionalSpringRuleTests extends TimedTransactionalSpringRunnerTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + @Rule + public Timeout timeout = Timeout.builder().withTimeout(10, TimeUnit.SECONDS).build(); + + + /** + * Overridden since Spring's Rule-based JUnit support cannot properly + * integrate with timed execution that is controlled by a third-party runner. + */ + @Test(timeout = 10000) + @Repeat(5) + @Override + public void transactionalWithJUnitTimeout() { + assertInTransaction(false); + } + + /** + * {@code timeout} explicitly not declared due to presence of Timeout rule. + */ + @Test + public void transactionalWithJUnitRuleBasedTimeout() { + assertInTransaction(true); + } + + // All other tests are in superclass. + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TransactionalSqlScriptsSpringRuleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TransactionalSqlScriptsSpringRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..523a3b36eb0279e3fcbcf530faf2735ebdc1667a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/rules/TransactionalSqlScriptsSpringRuleTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.rules; + +import java.util.concurrent.TimeUnit; + +import org.junit.ClassRule; +import org.junit.FixMethodOrder; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.MethodSorters; + +import org.springframework.test.context.jdbc.Sql; +import org.springframework.test.context.jdbc.TransactionalSqlScriptsTests; + +/** + * This class is an extension of {@link TransactionalSqlScriptsTests} + * that has been modified to use {@link SpringClassRule} and + * {@link SpringMethodRule}. + * + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(JUnit4.class) +// Note: @FixMethodOrder is NOT @Inherited. +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +// Overriding @Sql declaration to reference scripts using relative path. +@Sql({ "../../jdbc/schema.sql", "../../jdbc/data.sql" }) +public class TransactionalSqlScriptsSpringRuleTests extends TransactionalSqlScriptsTests { + + @ClassRule + public static final SpringClassRule springClassRule = new SpringClassRule(); + + @Rule + public final SpringMethodRule springMethodRule = new SpringMethodRule(); + + @Rule + public Timeout timeout = Timeout.builder().withTimeout(10, TimeUnit.SECONDS).build(); + + + /** + * Redeclared to ensure that {@code @FixMethodOrder} is properly applied. + */ + @Test + @Override + // test##_ prefix is required for @FixMethodOrder. + public void test01_classLevelScripts() { + assertNumUsers(1); + } + + /** + * Overriding {@code @Sql} declaration to reference scripts using relative path. + */ + @Test + @Sql({ "../../jdbc/drop-schema.sql", "../../jdbc/schema.sql", "../../jdbc/data.sql", "../../jdbc/data-add-dogbert.sql" }) + @Override + // test##_ prefix is required for @FixMethodOrder. + public void test02_methodLevelScripts() { + assertNumUsers(2); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/BeanOverridingDefaultLocationsInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/BeanOverridingDefaultLocationsInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..97274401f0fbf9dfc2b91b144b3afea82b55568b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/BeanOverridingDefaultLocationsInheritedTests.java @@ -0,0 +1,44 @@ +/* + * Copyright 2007 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr3896; + +import org.junit.Test; + +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * JUnit 4 based integration test for verifying support for the + * {@link ContextConfiguration#inheritLocations() inheritLocations} flag of + * {@link ContextConfiguration @ContextConfiguration} indirectly proposed in SPR-3896. + * + * @author Sam Brannen + * @since 2.5 + */ +@ContextConfiguration +public class BeanOverridingDefaultLocationsInheritedTests extends DefaultLocationsBaseTests { + + @Test + @Override + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("The employee bean should have been overridden.", "Yoda", this.employee.getName()); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/BeanOverridingExplicitLocationsInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/BeanOverridingExplicitLocationsInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..efd60eadbc36734b88c1c0b0cb13caa315b40fe5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/BeanOverridingExplicitLocationsInheritedTests.java @@ -0,0 +1,44 @@ +/* + * Copyright 2007 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr3896; + +import org.junit.Test; + +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * JUnit 4 based integration test for verifying support for the + * {@link ContextConfiguration#inheritLocations() inheritLocations} flag of + * {@link ContextConfiguration @ContextConfiguration} indirectly proposed in SPR-3896. + * + * @author Sam Brannen + * @since 2.5 + */ +@ContextConfiguration("BeanOverridingDefaultLocationsInheritedTests-context.xml") +public class BeanOverridingExplicitLocationsInheritedTests extends ExplicitLocationsBaseTests { + + @Test + @Override + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("The employee bean should have been overridden.", "Yoda", this.employee.getName()); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/DefaultLocationsBaseTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/DefaultLocationsBaseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d4eb3a0291f871de93490d1fc43f1fa1f8fb4d63 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/DefaultLocationsBaseTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr3896; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * JUnit 4 based integration test for verifying support for the + * {@link ContextConfiguration#inheritLocations() inheritLocations} flag of + * {@link ContextConfiguration @ContextConfiguration} indirectly proposed in SPR-3896. + * + * @author Sam Brannen + * @since 2.5 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class DefaultLocationsBaseTests { + + @Autowired + protected Employee employee; + + + @Test + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/DefaultLocationsInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/DefaultLocationsInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a282649a6f4b67bbfc69a1a2de7f959a6ee35c94 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/DefaultLocationsInheritedTests.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr3896; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * JUnit 4 based integration test for verifying support for the + * {@link ContextConfiguration#inheritLocations() inheritLocations} flag of + * {@link ContextConfiguration @ContextConfiguration} indirectly proposed in SPR-3896. + * + * @author Sam Brannen + * @since 2.5 + */ +@ContextConfiguration +public class DefaultLocationsInheritedTests extends DefaultLocationsBaseTests { + + @Autowired + private Pet pet; + + + @Test + public void verifyPetSetFromExtendedContextConfig() { + assertNotNull("The pet should have been autowired.", this.pet); + assertEquals("Fido", this.pet.getName()); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/ExplicitLocationsBaseTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/ExplicitLocationsBaseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f3b75c99c966dee22ff534ef7047e20ed343afdf --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/ExplicitLocationsBaseTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr3896; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.tests.sample.beans.Employee; + +import static org.junit.Assert.*; + +/** + * JUnit 4 based integration test for verifying support for the + * {@link ContextConfiguration#inheritLocations() inheritLocations} flag of + * {@link ContextConfiguration @ContextConfiguration} indirectly proposed in SPR-3896. + * + * @author Sam Brannen + * @since 2.5 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration("DefaultLocationsBaseTests-context.xml") +public class ExplicitLocationsBaseTests { + + @Autowired + protected Employee employee; + + + @Test + public void verifyEmployeeSetFromBaseContextConfig() { + assertNotNull("The employee should have been autowired.", this.employee); + assertEquals("John Smith", this.employee.getName()); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/ExplicitLocationsInheritedTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/ExplicitLocationsInheritedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2f3da5a432dab81ba3ec24c1cdd006c93312f820 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/ExplicitLocationsInheritedTests.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr3896; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.tests.sample.beans.Pet; + +import static org.junit.Assert.*; + +/** + * JUnit 4 based integration test for verifying support for the + * {@link ContextConfiguration#inheritLocations() inheritLocations} flag of + * {@link ContextConfiguration @ContextConfiguration} indirectly proposed in SPR-3896. + * + * @author Sam Brannen + * @since 2.5 + */ +@ContextConfiguration("DefaultLocationsInheritedTests-context.xml") +public class ExplicitLocationsInheritedTests extends ExplicitLocationsBaseTests { + + @Autowired + private Pet pet; + + + @Test + public void verifyPetSetFromExtendedContextConfig() { + assertNotNull("The pet should have been autowired.", this.pet); + assertEquals("Fido", this.pet.getName()); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/Spr3896SuiteTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/Spr3896SuiteTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2e9b6e9cbb2ce48b14a44c0715ff3acf56a132ff --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr3896/Spr3896SuiteTests.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2007 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr3896; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** + * JUnit 4 based test suite for functionality proposed in SPR-3896. + * + * @author Sam Brannen + * @since 2.5 + */ +@RunWith(Suite.class) +// Note: the following 'multi-line' layout is for enhanced code readability. +@SuiteClasses({ + +DefaultLocationsBaseTests.class, + +DefaultLocationsInheritedTests.class, + +ExplicitLocationsBaseTests.class, + +ExplicitLocationsInheritedTests.class, + +BeanOverridingDefaultLocationsInheritedTests.class, + +BeanOverridingExplicitLocationsInheritedTests.class + +}) +public class Spr3896SuiteTests { +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr4868/Jsr250LifecycleTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr4868/Jsr250LifecycleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e968a97f314ec630ab64dc68786bc7e2c65290b9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr4868/Jsr250LifecycleTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr4868; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.support.DependencyInjectionTestExecutionListener; + +import static org.junit.Assert.*; + +/** + * Integration tests that investigate the applicability of JSR-250 lifecycle + * annotations in test classes. + * + *

This class does not really contain actual tests per se. Rather it + * can be used to empirically verify the expected log output (see below). In + * order to see the log output, one would naturally need to ensure that the + * logger category for this class is enabled at {@code INFO} level. + * + *

Expected Log Output

+ *
+ * INFO : org.springframework.test.context.junit4.spr4868.LifecycleBean - initializing
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - beforeAllTests()
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - setUp()
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - test1()
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - tearDown()
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - beforeAllTests()
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - setUp()
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - test2()
+ * INFO : org.springframework.test.context.junit4.spr4868.ExampleTest - tearDown()
+ * INFO : org.springframework.test.context.junit4.spr4868.LifecycleBean - destroying
+ * 
+ * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@TestExecutionListeners({ DependencyInjectionTestExecutionListener.class }) +@ContextConfiguration +public class Jsr250LifecycleTests { + + private final Log logger = LogFactory.getLog(Jsr250LifecycleTests.class); + + + @Configuration + static class Config { + + @Bean + public LifecycleBean lifecycleBean() { + return new LifecycleBean(); + } + } + + + @Autowired + private LifecycleBean lifecycleBean; + + + @PostConstruct + public void beforeAllTests() { + logger.info("beforeAllTests()"); + } + + @PreDestroy + public void afterTestSuite() { + logger.info("afterTestSuite()"); + } + + @Before + public void setUp() throws Exception { + logger.info("setUp()"); + } + + @After + public void tearDown() throws Exception { + logger.info("tearDown()"); + } + + @Test + public void test1() { + logger.info("test1()"); + assertNotNull(lifecycleBean); + } + + @Test + public void test2() { + logger.info("test2()"); + assertNotNull(lifecycleBean); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr4868/LifecycleBean.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr4868/LifecycleBean.java new file mode 100644 index 0000000000000000000000000000000000000000..be527d05a90cedeef9bf514d964252209a6424da --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr4868/LifecycleBean.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr4868; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * @author Sam Brannen + * @since 3.2 + */ +class LifecycleBean { + + private final Log logger = LogFactory.getLog(LifecycleBean.class); + + + @PostConstruct + public void init() { + logger.info("initializing"); + } + + @PreDestroy + public void destroy() { + logger.info("destroying"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr6128/AutowiredQualifierTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr6128/AutowiredQualifierTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5c32bdfe37d08d3dbb798c76acd4b3ce99ca10f0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr6128/AutowiredQualifierTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr6128; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.hamcrest.core.IsEqual.*; +import static org.junit.Assert.*; + +/** + * Integration tests to verify claims made in SPR-6128. + * + * @author Sam Brannen + * @author Chris Beams + * @since 3.0 + */ +@ContextConfiguration +@RunWith(SpringJUnit4ClassRunner.class) +public class AutowiredQualifierTests { + + @Autowired + private String foo; + + @Autowired + @Qualifier("customFoo") + private String customFoo; + + + @Test + public void test() { + assertThat(foo, equalTo("normal")); + assertThat(customFoo, equalTo("custom")); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/Spr8849Tests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/Spr8849Tests.java new file mode 100644 index 0000000000000000000000000000000000000000..b6429d9ff4ebbe8b7884fee7a65fd8b08d13bff3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/Spr8849Tests.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr8849; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** + * Test suite to investigate claims raised in + * SPR-8849. + * + *

Work Around

+ *

By using a SpEL expression to generate a random {@code database-name} + * for the embedded database (see {@code datasource-config.xml}), we ensure + * that each {@code ApplicationContext} that imports the common configuration + * will create an embedded database with a unique name. + * + *

To reproduce the problem mentioned in SPR-8849, delete the declaration + * of the {@code database-name} attribute of the embedded database in + * {@code datasource-config.xml} and run this suite. + * + *

Solution

+ *

As of Spring 4.2, a proper solution is possible thanks to SPR-8849. + * {@link TestClass3} and {@link TestClass4} both import + * {@code datasource-config-with-auto-generated-db-name.xml} which makes + * use of the new {@code generate-name} attribute of {@code }. + * + * @author Sam Brannen + * @since 3.2 + */ +@SuppressWarnings("javadoc") +@RunWith(Suite.class) +@SuiteClasses({ TestClass1.class, TestClass2.class, TestClass3.class, TestClass4.class }) +public class Spr8849Tests { + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass1.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass1.java new file mode 100644 index 0000000000000000000000000000000000000000..7a49d74a234545016f1a1098deb19eeb5f864252 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass1.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr8849; + +import javax.annotation.Resource; +import javax.sql.DataSource; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportResource; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * This name of this class intentionally does not end with "Test" or "Tests" + * since it should only be run as part of the test suite: {@link Spr8849Tests}. + * + * @author Mickael Leduque + * @author Sam Brannen + * @since 3.2 + * @see Spr8849Tests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class TestClass1 { + + @Configuration + @ImportResource("classpath:/org/springframework/test/context/junit4/spr8849/datasource-config.xml") + static class Config { + } + + + @Resource + DataSource dataSource; + + + @Test + public void dummyTest() { + assertNotNull(dataSource); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass2.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass2.java new file mode 100644 index 0000000000000000000000000000000000000000..b184d6549d1e5fb82a55287f716c633b4a3bf1e6 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass2.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr8849; + +import javax.annotation.Resource; +import javax.sql.DataSource; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportResource; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * This name of this class intentionally does not end with "Test" or "Tests" + * since it should only be run as part of the test suite: {@link Spr8849Tests}. + * + * @author Mickael Leduque + * @author Sam Brannen + * @since 3.2 + * @see Spr8849Tests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class TestClass2 { + + @Configuration + @ImportResource("classpath:/org/springframework/test/context/junit4/spr8849/datasource-config.xml") + static class Config { + } + + + @Resource + DataSource dataSource; + + + @Test + public void dummyTest() { + assertNotNull(dataSource); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass3.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass3.java new file mode 100644 index 0000000000000000000000000000000000000000..c9b374635cacf941c940d2263048cb3d82820b90 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass3.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr8849; + +import javax.annotation.Resource; +import javax.sql.DataSource; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportResource; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * This name of this class intentionally does not end with "Test" or "Tests" + * since it should only be run as part of the test suite: {@link Spr8849Tests}. + * + * @author Sam Brannen + * @since 4.2 + * @see Spr8849Tests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class TestClass3 { + + @Configuration + @ImportResource("classpath:/org/springframework/test/context/junit4/spr8849/datasource-config-with-auto-generated-db-name.xml") + static class Config { + } + + + @Resource + DataSource dataSource; + + + @Test + public void dummyTest() { + assertNotNull(dataSource); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass4.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass4.java new file mode 100644 index 0000000000000000000000000000000000000000..7dc87487178db8738082e2b454e2013cc20a8bbf --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr8849/TestClass4.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr8849; + +import javax.annotation.Resource; +import javax.sql.DataSource; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportResource; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * This name of this class intentionally does not end with "Test" or "Tests" + * since it should only be run as part of the test suite: {@link Spr8849Tests}. + * + * @author Sam Brannen + * @since 4.2 + * @see Spr8849Tests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class TestClass4 { + + @Configuration + @ImportResource("classpath:/org/springframework/test/context/junit4/spr8849/datasource-config-with-auto-generated-db-name.xml") + static class Config { + } + + + @Resource + DataSource dataSource; + + + @Test + public void dummyTest() { + assertNotNull(dataSource); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AbstractTransactionalAnnotatedConfigClassTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AbstractTransactionalAnnotatedConfigClassTests.java new file mode 100644 index 0000000000000000000000000000000000000000..48e673bf937001e2ca31fb8ad5f4df7f6e054bbb --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AbstractTransactionalAnnotatedConfigClassTests.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9051; + +import javax.sql.DataSource; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.annotation.DirtiesContext.ClassMode; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * This set of tests (i.e., all concrete subclasses) investigates the claims made in + * SPR-9051 + * with regard to transactional tests. + * + * @author Sam Brannen + * @since 3.2 + * @see org.springframework.test.context.testng.AnnotationConfigTransactionalTestNGSpringContextTests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@DirtiesContext(classMode = ClassMode.AFTER_EACH_TEST_METHOD) +public abstract class AbstractTransactionalAnnotatedConfigClassTests { + + protected static final String JANE = "jane"; + protected static final String SUE = "sue"; + protected static final String YODA = "yoda"; + + protected DataSource dataSourceFromTxManager; + protected DataSource dataSourceViaInjection; + + protected JdbcTemplate jdbcTemplate; + + @Autowired + private Employee employee; + + + @Autowired + public void setTransactionManager(DataSourceTransactionManager transactionManager) { + this.dataSourceFromTxManager = transactionManager.getDataSource(); + } + + @Autowired + public void setDataSource(DataSource dataSource) { + this.dataSourceViaInjection = dataSource; + this.jdbcTemplate = new JdbcTemplate(dataSource); + } + + private int countRowsInTable(String tableName) { + return jdbcTemplate.queryForObject("SELECT COUNT(0) FROM " + tableName, Integer.class); + } + + private int createPerson(String name) { + return jdbcTemplate.update("INSERT INTO person VALUES(?)", name); + } + + protected int deletePerson(String name) { + return jdbcTemplate.update("DELETE FROM person WHERE name=?", name); + } + + protected void assertNumRowsInPersonTable(int expectedNumRows, String testState) { + assertEquals("the number of rows in the person table (" + testState + ").", expectedNumRows, + countRowsInTable("person")); + } + + protected void assertAddPerson(final String name) { + assertEquals("Adding '" + name + "'", 1, createPerson(name)); + } + + @Test + public void autowiringFromConfigClass() { + assertNotNull("The employee should have been autowired.", employee); + assertEquals("John Smith", employee.getName()); + } + + @BeforeTransaction + public void beforeTransaction() { + assertNumRowsInPersonTable(0, "before a transactional test method"); + assertAddPerson(YODA); + } + + @Before + public void setUp() throws Exception { + assertNumRowsInPersonTable((inTransaction() ? 1 : 0), "before a test method"); + } + + @Test + @Transactional + public void modifyTestDataWithinTransaction() { + assertInTransaction(true); + assertAddPerson(JANE); + assertAddPerson(SUE); + assertNumRowsInPersonTable(3, "in modifyTestDataWithinTransaction()"); + } + + @After + public void tearDown() throws Exception { + assertNumRowsInPersonTable((inTransaction() ? 3 : 0), "after a test method"); + } + + @AfterTransaction + public void afterTransaction() { + assertEquals("Deleting yoda", 1, deletePerson(YODA)); + assertNumRowsInPersonTable(0, "after a transactional test method"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AnnotatedConfigClassesWithoutAtConfigurationTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AnnotatedConfigClassesWithoutAtConfigurationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..72d8379ef11f3f8aaaf99935b5a494385be37218 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AnnotatedConfigClassesWithoutAtConfigurationTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9051; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * This set of tests refutes the claims made in + * SPR-9051. + * + *

The Claims: + * + *

+ * When a {@code @ContextConfiguration} test class references a config class + * missing an {@code @Configuration} annotation, {@code @Bean} dependencies are + * wired successfully but the bean lifecycle is not applied (no init methods are + * invoked, for example). Adding the missing {@code @Configuration} annotation + * solves the problem, however the problem and solution isn't obvious since + * wiring/injection appeared to work. + *
+ * + * @author Sam Brannen + * @author Phillip Webb + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = AnnotatedConfigClassesWithoutAtConfigurationTests.AnnotatedFactoryBeans.class) +public class AnnotatedConfigClassesWithoutAtConfigurationTests { + + /** + * This is intentionally not annotated with {@code @Configuration}. + * Consequently, this class contains what we call annotated factory bean + * methods instead of standard bean definition methods. + */ + static class AnnotatedFactoryBeans { + + static final AtomicInteger enigmaCallCount = new AtomicInteger(); + + + @Bean + public String enigma() { + return "enigma #" + enigmaCallCount.incrementAndGet(); + } + + @Bean + public LifecycleBean lifecycleBean() { + // The following call to enigma() literally invokes the local + // enigma() method, not a CGLIB proxied version, since these methods + // are essentially factory bean methods. + LifecycleBean bean = new LifecycleBean(enigma()); + assertFalse(bean.isInitialized()); + return bean; + } + } + + + @Autowired + private String enigma; + + @Autowired + private LifecycleBean lifecycleBean; + + + @Test + public void testSPR_9051() throws Exception { + assertNotNull(enigma); + assertNotNull(lifecycleBean); + assertTrue(lifecycleBean.isInitialized()); + Set names = new HashSet<>(); + names.add(enigma.toString()); + names.add(lifecycleBean.getName()); + assertEquals(names, new HashSet<>(Arrays.asList("enigma #1", "enigma #2"))); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AtBeanLiteModeScopeTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AtBeanLiteModeScopeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..47755fb29fc38450c90e6fbf323050ce6e418a7b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/AtBeanLiteModeScopeTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9051; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Scope; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify proper scoping of beans created in + * {@code @Bean} Lite Mode. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = AtBeanLiteModeScopeTests.LiteBeans.class) +public class AtBeanLiteModeScopeTests { + + /** + * This is intentionally not annotated with {@code @Configuration}. + */ + static class LiteBeans { + + @Bean + public LifecycleBean singleton() { + LifecycleBean bean = new LifecycleBean("singleton"); + assertFalse(bean.isInitialized()); + return bean; + } + + @Bean + @Scope("prototype") + public LifecycleBean prototype() { + LifecycleBean bean = new LifecycleBean("prototype"); + assertFalse(bean.isInitialized()); + return bean; + } + } + + + @Autowired + private ApplicationContext applicationContext; + + @Autowired + @Qualifier("singleton") + private LifecycleBean injectedSingletonBean; + + @Autowired + @Qualifier("prototype") + private LifecycleBean injectedPrototypeBean; + + + @Test + public void singletonLiteBean() { + assertNotNull(injectedSingletonBean); + assertTrue(injectedSingletonBean.isInitialized()); + + LifecycleBean retrievedSingletonBean = applicationContext.getBean("singleton", LifecycleBean.class); + assertNotNull(retrievedSingletonBean); + assertTrue(retrievedSingletonBean.isInitialized()); + + assertSame(injectedSingletonBean, retrievedSingletonBean); + } + + @Test + public void prototypeLiteBean() { + assertNotNull(injectedPrototypeBean); + assertTrue(injectedPrototypeBean.isInitialized()); + + LifecycleBean retrievedPrototypeBean = applicationContext.getBean("prototype", LifecycleBean.class); + assertNotNull(retrievedPrototypeBean); + assertTrue(retrievedPrototypeBean.isInitialized()); + + assertNotSame(injectedPrototypeBean, retrievedPrototypeBean); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/LifecycleBean.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/LifecycleBean.java new file mode 100644 index 0000000000000000000000000000000000000000..64876355d0535069fcab8f4bb16078a5f2eb9211 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/LifecycleBean.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9051; + +import javax.annotation.PostConstruct; + +/** + * Simple POJO that contains lifecycle callbacks. + * + * @author Sam Brannen + * @since 3.2 + */ +public class LifecycleBean { + + private final String name; + + private boolean initialized = false; + + + public LifecycleBean(String name) { + this.name = name; + } + + public String getName() { + return this.name; + } + + @PostConstruct + public void init() { + initialized = true; + } + + public boolean isInitialized() { + return this.initialized; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/TransactionalAnnotatedConfigClassWithAtConfigurationTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/TransactionalAnnotatedConfigClassWithAtConfigurationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..da3d03829d0bb9871280067fc49de36c76de897c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/TransactionalAnnotatedConfigClassWithAtConfigurationTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9051; + +import javax.sql.DataSource; + +import org.junit.Before; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.transaction.PlatformTransactionManager; + +import static org.junit.Assert.*; + +/** + * Concrete implementation of {@link AbstractTransactionalAnnotatedConfigClassTests} + * that uses a true {@link Configuration @Configuration class}. + * + * @author Sam Brannen + * @since 3.2 + * @see TransactionalAnnotatedConfigClassesWithoutAtConfigurationTests + */ +@ContextConfiguration +public class TransactionalAnnotatedConfigClassWithAtConfigurationTests extends + AbstractTransactionalAnnotatedConfigClassTests { + + /** + * This is intentionally annotated with {@code @Configuration}. + * + *

Consequently, this class contains standard singleton bean methods + * instead of annotated factory bean methods. + */ + @Configuration + static class Config { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("John Smith"); + employee.setAge(42); + employee.setCompany("Acme Widgets, Inc."); + return employee; + } + + @Bean + public PlatformTransactionManager transactionManager() { + return new DataSourceTransactionManager(dataSource()); + } + + @Bean + public DataSource dataSource() { + return new EmbeddedDatabaseBuilder()// + .addScript("classpath:/org/springframework/test/jdbc/schema.sql")// + // Ensure that this in-memory database is only used by this class: + .setName(getClass().getName())// + .build(); + } + + } + + + @Before + public void compareDataSources() throws Exception { + // NOTE: the two DataSource instances ARE the same! + assertSame(dataSourceFromTxManager, dataSourceViaInjection); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/TransactionalAnnotatedConfigClassesWithoutAtConfigurationTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/TransactionalAnnotatedConfigClassesWithoutAtConfigurationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..97dd76e26d7e6df3c888a9ecd177c2068a4365fc --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9051/TransactionalAnnotatedConfigClassesWithoutAtConfigurationTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9051; + +import javax.sql.DataSource; + +import org.junit.Before; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.transaction.PlatformTransactionManager; + +import static org.junit.Assert.*; + +/** + * Concrete implementation of {@link AbstractTransactionalAnnotatedConfigClassTests} + * that does not use a true {@link Configuration @Configuration class} but + * rather a lite mode configuration class (see the Javadoc for {@link Bean @Bean} + * for details). + * + * @author Sam Brannen + * @since 3.2 + * @see Bean + * @see TransactionalAnnotatedConfigClassWithAtConfigurationTests + */ +@ContextConfiguration(classes = TransactionalAnnotatedConfigClassesWithoutAtConfigurationTests.AnnotatedFactoryBeans.class) +public class TransactionalAnnotatedConfigClassesWithoutAtConfigurationTests extends + AbstractTransactionalAnnotatedConfigClassTests { + + /** + * This is intentionally not annotated with {@code @Configuration}. + * + *

Consequently, this class contains annotated factory bean methods + * instead of standard singleton bean methods. + */ + // @Configuration + static class AnnotatedFactoryBeans { + + @Bean + public Employee employee() { + Employee employee = new Employee(); + employee.setName("John Smith"); + employee.setAge(42); + employee.setCompany("Acme Widgets, Inc."); + return employee; + } + + @Bean + public PlatformTransactionManager transactionManager() { + return new DataSourceTransactionManager(dataSource()); + } + + /** + * Since this method does not reside in a true {@code @Configuration class}, + * it acts as a factory method when invoked directly (e.g., from + * {@link #transactionManager()}) and as a singleton bean when retrieved + * through the application context (e.g., when injected into the test + * instance). The result is that this method will be called twice: + * + *

    + *
  1. once indirectly by the {@link TransactionalTestExecutionListener} + * when it retrieves the {@link PlatformTransactionManager} from the + * application context
  2. + *
  3. and again when the {@link DataSource} is injected into the test + * instance in {@link AbstractTransactionalAnnotatedConfigClassTests#setDataSource(DataSource)}.
  4. + *
+ * + * Consequently, the {@link JdbcTemplate} used by this test instance and + * the {@link PlatformTransactionManager} used by the Spring TestContext + * Framework will operate on two different {@code DataSource} instances, + * which is almost certainly not the desired or intended behavior. + */ + @Bean + public DataSource dataSource() { + return new EmbeddedDatabaseBuilder()// + .addScript("classpath:/org/springframework/test/jdbc/schema.sql")// + // Ensure that this in-memory database is only used by this class: + .setName(getClass().getName())// + .build(); + } + + } + + + @Before + public void compareDataSources() throws Exception { + // NOTE: the two DataSource instances are NOT the same! + assertNotSame(dataSourceFromTxManager, dataSourceViaInjection); + } + + /** + * Overrides {@code afterTransaction()} in order to assert a different result. + * + *

See in-line comments for details. + * + * @see AbstractTransactionalAnnotatedConfigClassTests#afterTransaction() + * @see AbstractTransactionalAnnotatedConfigClassTests#modifyTestDataWithinTransaction() + */ + @AfterTransaction + @Override + public void afterTransaction() { + assertEquals("Deleting yoda", 1, deletePerson(YODA)); + + // NOTE: We would actually expect that there are now ZERO entries in the + // person table, since the transaction is rolled back by the framework; + // however, since our JdbcTemplate and the transaction manager used by + // the Spring TestContext Framework use two different DataSource + // instances, our insert statements were executed in transactions that + // are not controlled by the test framework. Consequently, there was no + // rollback for the two insert statements in + // modifyTestDataWithinTransaction(). + // + assertNumRowsInPersonTable(2, "after a transactional test method"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9604/LookUpTxMgrViaTransactionManagementConfigurerTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9604/LookUpTxMgrViaTransactionManagementConfigurerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a7c76ce73242ef2b31cfb726261108f439600c73 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9604/LookUpTxMgrViaTransactionManagementConfigurerTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9604; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.TransactionManagementConfigurer; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify the behavior requested in + * SPR-9604. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@Transactional +public class LookUpTxMgrViaTransactionManagementConfigurerTests { + + private static final CallCountingTransactionManager txManager1 = new CallCountingTransactionManager(); + private static final CallCountingTransactionManager txManager2 = new CallCountingTransactionManager(); + + + @Configuration + static class Config implements TransactionManagementConfigurer { + + @Override + public PlatformTransactionManager annotationDrivenTransactionManager() { + return txManager1(); + } + + @Bean + public PlatformTransactionManager txManager1() { + return txManager1; + } + + @Bean + public PlatformTransactionManager txManager2() { + return txManager2; + } + } + + + @BeforeTransaction + public void beforeTransaction() { + txManager1.clear(); + txManager2.clear(); + } + + @Test + public void transactionalTest() { + assertEquals(1, txManager1.begun); + assertEquals(1, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(0, txManager1.rollbacks); + + assertEquals(0, txManager2.begun); + assertEquals(0, txManager2.inflight); + assertEquals(0, txManager2.commits); + assertEquals(0, txManager2.rollbacks); + } + + @AfterTransaction + public void afterTransaction() { + assertEquals(1, txManager1.begun); + assertEquals(0, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(1, txManager1.rollbacks); + + assertEquals(0, txManager2.begun); + assertEquals(0, txManager2.inflight); + assertEquals(0, txManager2.commits); + assertEquals(0, txManager2.rollbacks); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpNonexistentTxMgrTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpNonexistentTxMgrTests.java new file mode 100644 index 0000000000000000000000000000000000000000..300fcd578e3e3a05b529a4e81c0e412c0d639be8 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpNonexistentTxMgrTests.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9645; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify the behavior requested in + * SPR-9645. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class LookUpNonexistentTxMgrTests { + + private static final CallCountingTransactionManager txManager = new CallCountingTransactionManager(); + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager transactionManager() { + return txManager; + } + } + + @Test + public void nonTransactionalTest() { + assertEquals(0, txManager.begun); + assertEquals(0, txManager.inflight); + assertEquals(0, txManager.commits); + assertEquals(0, txManager.rollbacks); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndDefaultNameTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndDefaultNameTests.java new file mode 100644 index 0000000000000000000000000000000000000000..74314c2ea6669aa90f8c57b8fb6d404fc211a991 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndDefaultNameTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9645; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify the behavior requested in + * SPR-9645. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@Transactional +public class LookUpTxMgrByTypeAndDefaultNameTests { + + private static final CallCountingTransactionManager txManager1 = new CallCountingTransactionManager(); + private static final CallCountingTransactionManager txManager2 = new CallCountingTransactionManager(); + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager transactionManager() { + return txManager1; + } + + @Bean + public PlatformTransactionManager txManager2() { + return txManager2; + } + } + + @BeforeTransaction + public void beforeTransaction() { + txManager1.clear(); + txManager2.clear(); + } + + @Test + public void transactionalTest() { + assertEquals(1, txManager1.begun); + assertEquals(1, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(0, txManager1.rollbacks); + } + + @AfterTransaction + public void afterTransaction() { + assertEquals(1, txManager1.begun); + assertEquals(0, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(1, txManager1.rollbacks); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndNameTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndNameTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d6ae560da6e4d75c558b0f0cd2ea3e4cfe2576a1 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndNameTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9645; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify the behavior requested in + * SPR-9645. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@Transactional("txManager1") +public class LookUpTxMgrByTypeAndNameTests { + + private static final CallCountingTransactionManager txManager1 = new CallCountingTransactionManager(); + private static final CallCountingTransactionManager txManager2 = new CallCountingTransactionManager(); + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager txManager1() { + return txManager1; + } + + @Bean + public PlatformTransactionManager txManager2() { + return txManager2; + } + } + + @BeforeTransaction + public void beforeTransaction() { + txManager1.clear(); + txManager2.clear(); + } + + @Test + public void transactionalTest() { + assertEquals(1, txManager1.begun); + assertEquals(1, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(0, txManager1.rollbacks); + } + + @AfterTransaction + public void afterTransaction() { + assertEquals(1, txManager1.begun); + assertEquals(0, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(1, txManager1.rollbacks); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndQualifierAtClassLevelTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndQualifierAtClassLevelTests.java new file mode 100644 index 0000000000000000000000000000000000000000..455209b1760eb5adb9b1c5ebc6522b9d4c4cd793 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndQualifierAtClassLevelTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9645; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify the behavior requested in + * SPR-9645. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@Transactional("txManager1") +public class LookUpTxMgrByTypeAndQualifierAtClassLevelTests { + + private static final CallCountingTransactionManager txManager1 = new CallCountingTransactionManager(); + private static final CallCountingTransactionManager txManager2 = new CallCountingTransactionManager(); + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager txManager1() { + return txManager1; + } + + @Bean + public PlatformTransactionManager txManager2() { + return txManager2; + } + } + + @BeforeTransaction + public void beforeTransaction() { + txManager1.clear(); + txManager2.clear(); + } + + @Test + public void transactionalTest() { + assertEquals(1, txManager1.begun); + assertEquals(1, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(0, txManager1.rollbacks); + } + + @AfterTransaction + public void afterTransaction() { + assertEquals(1, txManager1.begun); + assertEquals(0, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(1, txManager1.rollbacks); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndQualifierAtMethodLevelTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndQualifierAtMethodLevelTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d9e20694f97f06e3b622d60a0ba11ab98a5d511b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeAndQualifierAtMethodLevelTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9645; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify the behavior requested in + * SPR-9645. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class LookUpTxMgrByTypeAndQualifierAtMethodLevelTests { + + private static final CallCountingTransactionManager txManager1 = new CallCountingTransactionManager(); + private static final CallCountingTransactionManager txManager2 = new CallCountingTransactionManager(); + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager txManager1() { + return txManager1; + } + + @Bean + public PlatformTransactionManager txManager2() { + return txManager2; + } + } + + @BeforeTransaction + public void beforeTransaction() { + txManager1.clear(); + txManager2.clear(); + } + + @Transactional("txManager1") + @Test + public void transactionalTest() { + assertEquals(1, txManager1.begun); + assertEquals(1, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(0, txManager1.rollbacks); + } + + @AfterTransaction + public void afterTransaction() { + assertEquals(1, txManager1.begun); + assertEquals(0, txManager1.inflight); + assertEquals(0, txManager1.commits); + assertEquals(1, txManager1.rollbacks); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fd748a127498fc82a18afcb5b5ec93a947ef5a18 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9645/LookUpTxMgrByTypeTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9645; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify the behavior requested in + * SPR-9645. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@Transactional +public class LookUpTxMgrByTypeTests { + + private static final CallCountingTransactionManager txManager = new CallCountingTransactionManager(); + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager txManager() { + return txManager; + } + } + + @BeforeTransaction + public void beforeTransaction() { + txManager.clear(); + } + + @Test + public void transactionalTest() { + assertEquals(1, txManager.begun); + assertEquals(1, txManager.inflight); + assertEquals(0, txManager.commits); + assertEquals(0, txManager.rollbacks); + } + + @AfterTransaction + public void afterTransaction() { + assertEquals(1, txManager.begun); + assertEquals(0, txManager.inflight); + assertEquals(0, txManager.commits); + assertEquals(1, txManager.rollbacks); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9799/Spr9799AnnotationConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9799/Spr9799AnnotationConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..91c809c3dc8ba195beac7ba0c77feb87e134166f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9799/Spr9799AnnotationConfigTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9799; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +/** + * Integration tests used to assess claims raised in + * SPR-9799. + * + * @author Sam Brannen + * @since 3.2 + * @see Spr9799XmlConfigTests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +// NOTE: if we omit the @WebAppConfiguration declaration, the ApplicationContext will fail +// to load since @EnableWebMvc requires that the context be a WebApplicationContext. +@WebAppConfiguration +public class Spr9799AnnotationConfigTests { + + @Configuration + @EnableWebMvc + static class Config { + /* intentionally no beans defined */ + } + + + @Test + public void applicationContextLoads() { + // no-op + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/spr9799/Spr9799XmlConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9799/Spr9799XmlConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..37c7f933847562e3b3cd8cc77d3ca8158d623ed5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/spr9799/Spr9799XmlConfigTests.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.spr9799; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +/** + * Integration tests used to assess claims raised in + * SPR-9799. + * + * @author Sam Brannen + * @since 3.2 + * @see Spr9799AnnotationConfigTests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class Spr9799XmlConfigTests { + + @Test + public void applicationContextLoads() { + // nothing to assert: we just want to make sure that the context loads without + // errors. + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/junit4/statements/SpringFailOnTimeoutTests.java b/spring-test/src/test/java/org/springframework/test/context/junit4/statements/SpringFailOnTimeoutTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fd9d7bb452bca7af3eb81c3246bcc1701e8a5902 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/junit4/statements/SpringFailOnTimeoutTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.junit4.statements; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runners.model.Statement; +import org.mockito.stubbing.Answer; + +import org.springframework.test.context.junit4.statements.SpringFailOnTimeout; + +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link SpringFailOnTimeout}. + * + * @author Igor Suhorukov + * @author Sam Brannen + * @since 4.3.17 + */ +public class SpringFailOnTimeoutTests { + + private Statement statement = mock(Statement.class); + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void nullNextStatement() throws Throwable { + exception.expect(IllegalArgumentException.class); + new SpringFailOnTimeout(null, 1); + } + + @Test + public void negativeTimeout() throws Throwable { + exception.expect(IllegalArgumentException.class); + new SpringFailOnTimeout(statement, -1); + } + + @Test + public void userExceptionPropagates() throws Throwable { + doThrow(new Boom()).when(statement).evaluate(); + + exception.expect(Boom.class); + new SpringFailOnTimeout(statement, 1).evaluate(); + } + + @Test + public void timeoutExceptionThrownIfNoUserException() throws Throwable { + doAnswer((Answer) invocation -> { + TimeUnit.MILLISECONDS.sleep(50); + return null; + }).when(statement).evaluate(); + + exception.expect(TimeoutException.class); + new SpringFailOnTimeout(statement, 1).evaluate(); + } + + @Test + public void noExceptionThrownIfNoUserExceptionAndTimeoutDoesNotOccur() throws Throwable { + doAnswer((Answer) invocation -> { + return null; + }).when(statement).evaluate(); + + new SpringFailOnTimeout(statement, 100).evaluate(); + } + + @SuppressWarnings("serial") + private static class Boom extends RuntimeException { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/AbstractContextConfigurationUtilsTests.java b/spring-test/src/test/java/org/springframework/test/context/support/AbstractContextConfigurationUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1ebcc41bbd9daa02edd1a09b8a1a627a8097087b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/AbstractContextConfigurationUtilsTests.java @@ -0,0 +1,220 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.Collections; +import java.util.Set; + +import org.mockito.Mockito; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.BootstrapContext; +import org.springframework.test.context.BootstrapTestUtils; +import org.springframework.test.context.CacheAwareContextLoaderDelegate; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextConfigurationAttributes; +import org.springframework.test.context.ContextLoader; +import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.test.context.TestContextBootstrapper; +import org.springframework.test.context.web.WebAppConfiguration; + +import static org.junit.Assert.*; + +/** + * Abstract base class for tests involving {@link ContextLoaderUtils}, + * {@link BootstrapTestUtils}, and {@link ActiveProfilesUtils}. + * + * @author Sam Brannen + * @since 3.1 + */ +abstract class AbstractContextConfigurationUtilsTests { + + static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; + + static final String[] EMPTY_STRING_ARRAY = new String[0]; + + static final Set>> + EMPTY_INITIALIZER_CLASSES = Collections.>> emptySet(); + + + MergedContextConfiguration buildMergedContextConfiguration(Class testClass) { + CacheAwareContextLoaderDelegate cacheAwareContextLoaderDelegate = Mockito.mock(CacheAwareContextLoaderDelegate.class); + BootstrapContext bootstrapContext = BootstrapTestUtils.buildBootstrapContext(testClass, cacheAwareContextLoaderDelegate); + TestContextBootstrapper bootstrapper = BootstrapTestUtils.resolveTestContextBootstrapper(bootstrapContext); + return bootstrapper.buildMergedContextConfiguration(); + } + + void assertAttributes(ContextConfigurationAttributes attributes, Class expectedDeclaringClass, + String[] expectedLocations, Class[] expectedClasses, + Class expectedContextLoaderClass, boolean expectedInheritLocations) { + + assertEquals("declaring class", expectedDeclaringClass, attributes.getDeclaringClass()); + assertArrayEquals("locations", expectedLocations, attributes.getLocations()); + assertArrayEquals("classes", expectedClasses, attributes.getClasses()); + assertEquals("inherit locations", expectedInheritLocations, attributes.isInheritLocations()); + assertEquals("context loader", expectedContextLoaderClass, attributes.getContextLoaderClass()); + } + + void assertMergedConfig(MergedContextConfiguration mergedConfig, Class expectedTestClass, + String[] expectedLocations, Class[] expectedClasses, + Class expectedContextLoaderClass) { + + assertMergedConfig(mergedConfig, expectedTestClass, expectedLocations, expectedClasses, + EMPTY_INITIALIZER_CLASSES, expectedContextLoaderClass); + } + + void assertMergedConfig( + MergedContextConfiguration mergedConfig, + Class expectedTestClass, + String[] expectedLocations, + Class[] expectedClasses, + Set>> expectedInitializerClasses, + Class expectedContextLoaderClass) { + + assertNotNull(mergedConfig); + assertEquals(expectedTestClass, mergedConfig.getTestClass()); + assertNotNull(mergedConfig.getLocations()); + assertArrayEquals(expectedLocations, mergedConfig.getLocations()); + assertNotNull(mergedConfig.getClasses()); + assertArrayEquals(expectedClasses, mergedConfig.getClasses()); + assertNotNull(mergedConfig.getActiveProfiles()); + if (expectedContextLoaderClass == null) { + assertNull(mergedConfig.getContextLoader()); + } + else { + assertEquals(expectedContextLoaderClass, mergedConfig.getContextLoader().getClass()); + } + assertNotNull(mergedConfig.getContextInitializerClasses()); + assertEquals(expectedInitializerClasses, mergedConfig.getContextInitializerClasses()); + } + + @SafeVarargs + static T[] array(T... objects) { + return objects; + } + + + static class Enigma { + } + + @ContextConfiguration + @ActiveProfiles + static class BareAnnotations { + } + + @Configuration + static class FooConfig { + } + + @Configuration + static class BarConfig { + } + + @ContextConfiguration("/foo.xml") + @ActiveProfiles(profiles = "foo") + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + public static @interface MetaLocationsFooConfig { + } + + @ContextConfiguration + @ActiveProfiles + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + public static @interface MetaLocationsFooConfigWithOverrides { + + String[] locations() default "/foo.xml"; + + String[] profiles() default "foo"; + } + + @ContextConfiguration("/bar.xml") + @ActiveProfiles(profiles = "bar") + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + public static @interface MetaLocationsBarConfig { + } + + @MetaLocationsFooConfig + static class MetaLocationsFoo { + } + + @MetaLocationsBarConfig + static class MetaLocationsBar extends MetaLocationsFoo { + } + + @MetaLocationsFooConfigWithOverrides + static class MetaLocationsFooWithOverrides { + } + + @MetaLocationsFooConfigWithOverrides(locations = {"foo1.xml", "foo2.xml"}, profiles = {"foo1", "foo2"}) + static class MetaLocationsFooWithOverriddenAttributes { + } + + @ContextConfiguration(locations = "/foo.xml", inheritLocations = false) + @ActiveProfiles("foo") + static class LocationsFoo { + } + + @ContextConfiguration(classes = FooConfig.class, inheritLocations = false) + @ActiveProfiles("foo") + static class ClassesFoo { + } + + @WebAppConfiguration + static class WebClassesFoo extends ClassesFoo { + } + + @ContextConfiguration(locations = "/bar.xml", inheritLocations = true, loader = AnnotationConfigContextLoader.class) + @ActiveProfiles("bar") + static class LocationsBar extends LocationsFoo { + } + + @ContextConfiguration(locations = "/bar.xml", inheritLocations = false, loader = AnnotationConfigContextLoader.class) + @ActiveProfiles("bar") + static class OverriddenLocationsBar extends LocationsFoo { + } + + @ContextConfiguration(classes = BarConfig.class, inheritLocations = true, loader = AnnotationConfigContextLoader.class) + @ActiveProfiles("bar") + static class ClassesBar extends ClassesFoo { + } + + @ContextConfiguration(classes = BarConfig.class, inheritLocations = false, loader = AnnotationConfigContextLoader.class) + @ActiveProfiles("bar") + static class OverriddenClassesBar extends ClassesFoo { + } + + @ContextConfiguration(locations = "/foo.properties", loader = GenericPropertiesContextLoader.class) + @ActiveProfiles("foo") + static class PropertiesLocationsFoo { + } + + // Combining @Configuration classes with a Properties based loader doesn't really make + // sense, but that's OK for unit testing purposes. + @ContextConfiguration(classes = FooConfig.class, loader = GenericPropertiesContextLoader.class) + @ActiveProfiles("foo") + static class PropertiesClassesFoo { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/ActiveProfilesUtilsTests.java b/spring-test/src/test/java/org/springframework/test/context/support/ActiveProfilesUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3f4cce058987cdaa5fcc3c65350c375807fb4b13 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/ActiveProfilesUtilsTests.java @@ -0,0 +1,386 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; + +import org.springframework.core.annotation.AnnotationConfigurationException; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ActiveProfilesResolver; +import org.springframework.util.StringUtils; + +import static org.junit.Assert.*; +import static org.springframework.test.context.support.ActiveProfilesUtils.*; + +/** + * Unit tests for {@link ActiveProfilesUtils} involving resolution of active bean + * definition profiles. + * + * @author Sam Brannen + * @author Michail Nikolaev + * @since 3.1 + */ +public class ActiveProfilesUtilsTests extends AbstractContextConfigurationUtilsTests { + + private void assertResolvedProfiles(Class testClass, String... expected) { + assertArrayEquals(expected, resolveActiveProfiles(testClass)); + } + + @Test + public void resolveActiveProfilesWithoutAnnotation() { + assertResolvedProfiles(Enigma.class, EMPTY_STRING_ARRAY); + } + + @Test + public void resolveActiveProfilesWithNoProfilesDeclared() { + assertResolvedProfiles(BareAnnotations.class, EMPTY_STRING_ARRAY); + } + + @Test + public void resolveActiveProfilesWithEmptyProfiles() { + assertResolvedProfiles(EmptyProfiles.class, EMPTY_STRING_ARRAY); + } + + @Test + public void resolveActiveProfilesWithDuplicatedProfiles() { + assertResolvedProfiles(DuplicatedProfiles.class, "foo", "bar", "baz"); + } + + @Test + public void resolveActiveProfilesWithLocalAndInheritedDuplicatedProfiles() { + assertResolvedProfiles(ExtendedDuplicatedProfiles.class, "foo", "bar", "baz", "cat", "dog"); + } + + @Test + public void resolveActiveProfilesWithLocalAnnotation() { + assertResolvedProfiles(LocationsFoo.class, "foo"); + } + + @Test + public void resolveActiveProfilesWithInheritedAnnotationAndLocations() { + assertResolvedProfiles(InheritedLocationsFoo.class, "foo"); + } + + @Test + public void resolveActiveProfilesWithInheritedAnnotationAndClasses() { + assertResolvedProfiles(InheritedClassesFoo.class, "foo"); + } + + @Test + public void resolveActiveProfilesWithLocalAndInheritedAnnotations() { + assertResolvedProfiles(LocationsBar.class, "foo", "bar"); + } + + @Test + public void resolveActiveProfilesWithOverriddenAnnotation() { + assertResolvedProfiles(Animals.class, "dog", "cat"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithMetaAnnotation() { + assertResolvedProfiles(MetaLocationsFoo.class, "foo"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithMetaAnnotationAndOverrides() { + assertResolvedProfiles(MetaLocationsFooWithOverrides.class, "foo"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithMetaAnnotationAndOverriddenAttributes() { + assertResolvedProfiles(MetaLocationsFooWithOverriddenAttributes.class, "foo1", "foo2"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithLocalAndInheritedMetaAnnotations() { + assertResolvedProfiles(MetaLocationsBar.class, "foo", "bar"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithOverriddenMetaAnnotation() { + assertResolvedProfiles(MetaAnimals.class, "dog", "cat"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithResolver() { + assertResolvedProfiles(FooActiveProfilesResolverTestCase.class, "foo"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithInheritedResolver() { + assertResolvedProfiles(InheritedFooActiveProfilesResolverTestCase.class, "foo"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithMergedInheritedResolver() { + assertResolvedProfiles(MergedInheritedFooActiveProfilesResolverTestCase.class, "foo", "bar"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithOverridenInheritedResolver() { + assertResolvedProfiles(OverridenInheritedFooActiveProfilesResolverTestCase.class, "bar"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithResolverAndProfiles() { + assertResolvedProfiles(ResolverAndProfilesTestCase.class, "bar"); + } + + /** + * @since 4.0 + */ + @Test + public void resolveActiveProfilesWithResolverAndValue() { + assertResolvedProfiles(ResolverAndValueTestCase.class, "bar"); + } + + /** + * @since 4.0 + */ + @Test(expected = AnnotationConfigurationException.class) + public void resolveActiveProfilesWithConflictingProfilesAndValue() { + resolveActiveProfiles(ConflictingProfilesAndValueTestCase.class); + } + + /** + * @since 4.0 + */ + @Test(expected = IllegalStateException.class) + public void resolveActiveProfilesWithResolverWithoutDefaultConstructor() { + resolveActiveProfiles(NoDefaultConstructorActiveProfilesResolverTestCase.class); + } + + /** + * @since 4.0 + */ + public void resolveActiveProfilesWithResolverThatReturnsNull() { + assertResolvedProfiles(NullActiveProfilesResolverTestCase.class); + } + + /** + * This test verifies that the actual test class, not the composed annotation, + * is passed to the resolver. + * @since 4.0.3 + */ + @Test + public void resolveActiveProfilesWithMetaAnnotationAndTestClassVerifyingResolver() { + Class testClass = TestClassVerifyingActiveProfilesResolverTestCase.class; + assertResolvedProfiles(testClass, testClass.getSimpleName()); + } + + /** + * This test verifies that {@link DefaultActiveProfilesResolver} can be declared explicitly. + * @since 4.1.5 + */ + @Test + public void resolveActiveProfilesWithDefaultActiveProfilesResolver() { + assertResolvedProfiles(DefaultActiveProfilesResolverTestCase.class, "default"); + } + + /** + * This test verifies that {@link DefaultActiveProfilesResolver} can be extended. + * @since 4.1.5 + */ + @Test + public void resolveActiveProfilesWithExtendedDefaultActiveProfilesResolver() { + assertResolvedProfiles(ExtendedDefaultActiveProfilesResolverTestCase.class, "default", "foo"); + } + + + // ------------------------------------------------------------------------- + + @ActiveProfiles({ " ", "\t" }) + private static class EmptyProfiles { + } + + @ActiveProfiles({ "foo", "bar", " foo", "bar ", "baz" }) + private static class DuplicatedProfiles { + } + + @ActiveProfiles({ "cat", "dog", " foo", "bar ", "cat" }) + private static class ExtendedDuplicatedProfiles extends DuplicatedProfiles { + } + + @ActiveProfiles(profiles = { "dog", "cat" }, inheritProfiles = false) + private static class Animals extends LocationsBar { + } + + @ActiveProfiles(profiles = { "dog", "cat" }, inheritProfiles = false) + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + private static @interface MetaAnimalsConfig { + } + + @ActiveProfiles(resolver = TestClassVerifyingActiveProfilesResolver.class) + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + private static @interface MetaResolverConfig { + } + + @MetaAnimalsConfig + private static class MetaAnimals extends MetaLocationsBar { + } + + private static class InheritedLocationsFoo extends LocationsFoo { + } + + private static class InheritedClassesFoo extends ClassesFoo { + } + + @ActiveProfiles(resolver = NullActiveProfilesResolver.class) + private static class NullActiveProfilesResolverTestCase { + } + + @ActiveProfiles(resolver = NoDefaultConstructorActiveProfilesResolver.class) + private static class NoDefaultConstructorActiveProfilesResolverTestCase { + } + + @ActiveProfiles(resolver = FooActiveProfilesResolver.class) + private static class FooActiveProfilesResolverTestCase { + } + + private static class InheritedFooActiveProfilesResolverTestCase extends FooActiveProfilesResolverTestCase { + } + + @ActiveProfiles(resolver = BarActiveProfilesResolver.class) + private static class MergedInheritedFooActiveProfilesResolverTestCase extends + InheritedFooActiveProfilesResolverTestCase { + } + + @ActiveProfiles(resolver = BarActiveProfilesResolver.class, inheritProfiles = false) + private static class OverridenInheritedFooActiveProfilesResolverTestCase extends + InheritedFooActiveProfilesResolverTestCase { + } + + @ActiveProfiles(resolver = BarActiveProfilesResolver.class, profiles = "ignored by custom resolver") + private static class ResolverAndProfilesTestCase { + } + + @ActiveProfiles(resolver = BarActiveProfilesResolver.class, value = "ignored by custom resolver") + private static class ResolverAndValueTestCase { + } + + @MetaResolverConfig + private static class TestClassVerifyingActiveProfilesResolverTestCase { + } + + @ActiveProfiles(profiles = "default", resolver = DefaultActiveProfilesResolver.class) + private static class DefaultActiveProfilesResolverTestCase { + } + + @ActiveProfiles(profiles = "default", resolver = ExtendedDefaultActiveProfilesResolver.class) + private static class ExtendedDefaultActiveProfilesResolverTestCase { + } + + @ActiveProfiles(profiles = "conflict 1", value = "conflict 2") + private static class ConflictingProfilesAndValueTestCase { + } + + private static class FooActiveProfilesResolver implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return new String[] { "foo" }; + } + } + + private static class BarActiveProfilesResolver implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return new String[] { "bar" }; + } + } + + private static class NullActiveProfilesResolver implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return null; + } + } + + private static class NoDefaultConstructorActiveProfilesResolver implements ActiveProfilesResolver { + + @SuppressWarnings("unused") + NoDefaultConstructorActiveProfilesResolver(Object argument) { + } + + @Override + public String[] resolve(Class testClass) { + return null; + } + } + + private static class TestClassVerifyingActiveProfilesResolver implements ActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + return testClass.isAnnotation() ? new String[] { "@" + testClass.getSimpleName() } + : new String[] { testClass.getSimpleName() }; + } + } + + private static class ExtendedDefaultActiveProfilesResolver extends DefaultActiveProfilesResolver { + + @Override + public String[] resolve(Class testClass) { + List profiles = new ArrayList<>(Arrays.asList(super.resolve(testClass))); + profiles.add("foo"); + return StringUtils.toStringArray(profiles); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/AnnotatedFooConfigInnerClassTestCase.java b/spring-test/src/test/java/org/springframework/test/context/support/AnnotatedFooConfigInnerClassTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..21dec677f0376174753dc15461757895c8462d6d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/AnnotatedFooConfigInnerClassTestCase.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * Not an actual test case. + * + * @author Sam Brannen + * @since 3.1 + * @see AnnotationConfigContextLoaderTests + */ +public class AnnotatedFooConfigInnerClassTestCase { + + @Configuration + static class FooConfig { + + @Bean + public String foo() { + return "foo"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/AnnotationConfigContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/support/AnnotationConfigContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..158c3700ebe88fbdd037f305034ffecbb8c10e62 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/AnnotationConfigContextLoaderTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.test.context.MergedContextConfiguration; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link AnnotationConfigContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + */ +public class AnnotationConfigContextLoaderTests { + + private final AnnotationConfigContextLoader contextLoader = new AnnotationConfigContextLoader(); + + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + private static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + + /** + * @since 4.0.4 + */ + @Test + public void configMustNotContainLocations() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(containsString("does not support resource locations")); + + MergedContextConfiguration mergedConfig = new MergedContextConfiguration(getClass(), + new String[] { "config.xml" }, EMPTY_CLASS_ARRAY, EMPTY_STRING_ARRAY, contextLoader); + contextLoader.loadContext(mergedConfig); + } + + @Test + public void detectDefaultConfigurationClassesForAnnotatedInnerClass() { + Class[] configClasses = contextLoader.detectDefaultConfigurationClasses(ContextConfigurationInnerClassTestCase.class); + assertNotNull(configClasses); + assertEquals("annotated static ContextConfiguration should be considered.", 1, configClasses.length); + + configClasses = contextLoader.detectDefaultConfigurationClasses(AnnotatedFooConfigInnerClassTestCase.class); + assertNotNull(configClasses); + assertEquals("annotated static FooConfig should be considered.", 1, configClasses.length); + } + + @Test + public void detectDefaultConfigurationClassesForMultipleAnnotatedInnerClasses() { + Class[] configClasses = contextLoader.detectDefaultConfigurationClasses(MultipleStaticConfigurationClassesTestCase.class); + assertNotNull(configClasses); + assertEquals("multiple annotated static classes should be considered.", 2, configClasses.length); + } + + @Test + public void detectDefaultConfigurationClassesForNonAnnotatedInnerClass() { + Class[] configClasses = contextLoader.detectDefaultConfigurationClasses(PlainVanillaFooConfigInnerClassTestCase.class); + assertNotNull(configClasses); + assertEquals("non-annotated static FooConfig should NOT be considered.", 0, configClasses.length); + } + + @Test + public void detectDefaultConfigurationClassesForFinalAnnotatedInnerClass() { + Class[] configClasses = contextLoader.detectDefaultConfigurationClasses(FinalConfigInnerClassTestCase.class); + assertNotNull(configClasses); + assertEquals("final annotated static Config should NOT be considered.", 0, configClasses.length); + } + + @Test + public void detectDefaultConfigurationClassesForPrivateAnnotatedInnerClass() { + Class[] configClasses = contextLoader.detectDefaultConfigurationClasses(PrivateConfigInnerClassTestCase.class); + assertNotNull(configClasses); + assertEquals("private annotated inner classes should NOT be considered.", 0, configClasses.length); + } + + @Test + public void detectDefaultConfigurationClassesForNonStaticAnnotatedInnerClass() { + Class[] configClasses = contextLoader.detectDefaultConfigurationClasses(NonStaticConfigInnerClassesTestCase.class); + assertNotNull(configClasses); + assertEquals("non-static annotated inner classes should NOT be considered.", 0, configClasses.length); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/AnnotationConfigContextLoaderUtilsTests.java b/spring-test/src/test/java/org/springframework/test/context/support/AnnotationConfigContextLoaderUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..286072aa5379af2418f63a04de0d68173f73cb44 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/AnnotationConfigContextLoaderUtilsTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.junit.Test; +import org.springframework.context.annotation.Configuration; + +import static org.junit.Assert.*; +import static org.springframework.test.context.support.AnnotationConfigContextLoaderUtils.*; + +/** + * Unit tests for {@link AnnotationConfigContextLoaderUtils}. + * + * @author Sam Brannen + * @since 4.1.5 + */ +public class AnnotationConfigContextLoaderUtilsTests { + + @Test(expected = IllegalArgumentException.class) + public void detectDefaultConfigurationClassesWithNullDeclaringClass() { + detectDefaultConfigurationClasses(null); + } + + @Test + public void detectDefaultConfigurationClassesWithoutConfigurationClass() { + Class[] configClasses = detectDefaultConfigurationClasses(NoConfigTestCase.class); + assertNotNull(configClasses); + assertEquals(0, configClasses.length); + } + + @Test + public void detectDefaultConfigurationClassesWithExplicitConfigurationAnnotation() { + Class[] configClasses = detectDefaultConfigurationClasses(ExplicitConfigTestCase.class); + assertNotNull(configClasses); + assertArrayEquals(new Class[] { ExplicitConfigTestCase.Config.class }, configClasses); + } + + @Test + public void detectDefaultConfigurationClassesWithConfigurationMetaAnnotation() { + Class[] configClasses = detectDefaultConfigurationClasses(MetaAnnotatedConfigTestCase.class); + assertNotNull(configClasses); + assertArrayEquals(new Class[] { MetaAnnotatedConfigTestCase.Config.class }, configClasses); + } + + + private static class NoConfigTestCase { + + } + + private static class ExplicitConfigTestCase { + + @Configuration + static class Config { + } + } + + @Configuration + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + private static @interface MetaConfig { + } + + private static class MetaAnnotatedConfigTestCase { + + @MetaConfig + static class Config { + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/BootstrapTestUtilsContextInitializerTests.java b/spring-test/src/test/java/org/springframework/test/context/support/BootstrapTestUtilsContextInitializerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..278dc81fa9d3639fc8784fe41e7099c86eeb8b39 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/BootstrapTestUtilsContextInitializerTests.java @@ -0,0 +1,131 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import org.junit.Test; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.test.context.BootstrapTestUtils; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.web.context.support.GenericWebApplicationContext; + +/** + * Unit tests for {@link BootstrapTestUtils} involving {@link ApplicationContextInitializer}s. + * + * @author Sam Brannen + * @since 3.1 + */ +@SuppressWarnings("unchecked") +public class BootstrapTestUtilsContextInitializerTests extends AbstractContextConfigurationUtilsTests { + + @Test + public void buildMergedConfigWithSingleLocalInitializer() { + Class testClass = SingleInitializer.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, EMPTY_CLASS_ARRAY, + initializers(FooInitializer.class), DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithLocalInitializerAndConfigClass() { + Class testClass = InitializersFoo.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, classes(FooConfig.class), + initializers(FooInitializer.class), DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithLocalAndInheritedInitializer() { + Class testClass = InitializersBar.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, classes(FooConfig.class, BarConfig.class), + initializers(FooInitializer.class, BarInitializer.class), DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithOverriddenInitializers() { + Class testClass = OverriddenInitializersBar.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, classes(FooConfig.class, BarConfig.class), + initializers(BarInitializer.class), DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithOverriddenInitializersAndClasses() { + Class testClass = OverriddenInitializersAndClassesBar.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, classes(BarConfig.class), + initializers(BarInitializer.class), DelegatingSmartContextLoader.class); + } + + private Set>> initializers( + Class>... classes) { + + return new HashSet<>(Arrays.asList(classes)); + } + + private Class[] classes(Class... classes) { + return classes; + } + + + private static class FooInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(GenericApplicationContext applicationContext) { + } + } + + private static class BarInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(GenericWebApplicationContext applicationContext) { + } + } + + @ContextConfiguration(initializers = FooInitializer.class) + private static class SingleInitializer { + } + + @ContextConfiguration(classes = FooConfig.class, initializers = FooInitializer.class) + private static class InitializersFoo { + } + + @ContextConfiguration(classes = BarConfig.class, initializers = BarInitializer.class) + private static class InitializersBar extends InitializersFoo { + } + + @ContextConfiguration(classes = BarConfig.class, initializers = BarInitializer.class, inheritInitializers = false) + private static class OverriddenInitializersBar extends InitializersFoo { + } + + @ContextConfiguration(classes = BarConfig.class, inheritLocations = false, initializers = BarInitializer.class, inheritInitializers = false) + private static class OverriddenInitializersAndClassesBar extends InitializersFoo { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/BootstrapTestUtilsMergedConfigTests.java b/spring-test/src/test/java/org/springframework/test/context/support/BootstrapTestUtilsMergedConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..248426df62c2d0630a5b1049293edb63ddde5b36 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/BootstrapTestUtilsMergedConfigTests.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.test.context.BootstrapTestUtils; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextLoader; +import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.test.context.web.WebDelegatingSmartContextLoader; +import org.springframework.test.context.web.WebMergedContextConfiguration; + +import static org.hamcrest.CoreMatchers.startsWith; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * Unit tests for {@link BootstrapTestUtils} involving {@link MergedContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + */ +public class BootstrapTestUtilsMergedConfigTests extends AbstractContextConfigurationUtilsTests { + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void buildImplicitMergedConfigWithoutAnnotation() { + Class testClass = Enigma.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, EMPTY_CLASS_ARRAY, DelegatingSmartContextLoader.class); + } + + /** + * @since 4.3 + */ + @Test + public void buildMergedConfigWithContextConfigurationWithoutLocationsClassesOrInitializers() { + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("DelegatingSmartContextLoader was unable to detect defaults, " + + "and no ApplicationContextInitializers or ContextCustomizers were declared for context configuration attributes")); + + buildMergedContextConfiguration(MissingContextAttributesTestCase.class); + } + + @Test + public void buildMergedConfigWithBareAnnotations() { + Class testClass = BareAnnotations.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig( + mergedConfig, + testClass, + array("classpath:org/springframework/test/context/support/AbstractContextConfigurationUtilsTests$BareAnnotations-context.xml"), + EMPTY_CLASS_ARRAY, DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithLocalAnnotationAndLocations() { + Class testClass = LocationsFoo.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, array("classpath:/foo.xml"), EMPTY_CLASS_ARRAY, + DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithMetaAnnotationAndLocations() { + Class testClass = MetaLocationsFoo.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, array("classpath:/foo.xml"), EMPTY_CLASS_ARRAY, + DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithMetaAnnotationAndClasses() { + buildMergedConfigWithMetaAnnotationAndClasses(Dog.class); + buildMergedConfigWithMetaAnnotationAndClasses(WorkingDog.class); + buildMergedConfigWithMetaAnnotationAndClasses(GermanShepherd.class); + } + + private void buildMergedConfigWithMetaAnnotationAndClasses(Class testClass) { + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, array(FooConfig.class, + BarConfig.class), DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithLocalAnnotationAndClasses() { + Class testClass = ClassesFoo.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, array(FooConfig.class), + DelegatingSmartContextLoader.class); + } + + /** + * Introduced to investigate claims made in a discussion on + * Stack Overflow. + */ + @Test + public void buildMergedConfigWithAtWebAppConfigurationWithAnnotationAndClassesOnSuperclass() { + Class webTestClass = WebClassesFoo.class; + Class standardTestClass = ClassesFoo.class; + WebMergedContextConfiguration webMergedConfig = (WebMergedContextConfiguration) buildMergedContextConfiguration(webTestClass); + MergedContextConfiguration standardMergedConfig = buildMergedContextConfiguration(standardTestClass); + + assertEquals(webMergedConfig, webMergedConfig); + assertEquals(standardMergedConfig, standardMergedConfig); + assertNotEquals(standardMergedConfig, webMergedConfig); + assertNotEquals(webMergedConfig, standardMergedConfig); + + assertMergedConfig(webMergedConfig, webTestClass, EMPTY_STRING_ARRAY, array(FooConfig.class), + WebDelegatingSmartContextLoader.class); + assertMergedConfig(standardMergedConfig, standardTestClass, EMPTY_STRING_ARRAY, + array(FooConfig.class), DelegatingSmartContextLoader.class); + } + + @Test + public void buildMergedConfigWithLocalAnnotationAndOverriddenContextLoaderAndLocations() { + Class testClass = PropertiesLocationsFoo.class; + Class expectedContextLoaderClass = GenericPropertiesContextLoader.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, array("classpath:/foo.properties"), EMPTY_CLASS_ARRAY, + expectedContextLoaderClass); + } + + @Test + public void buildMergedConfigWithLocalAnnotationAndOverriddenContextLoaderAndClasses() { + Class testClass = PropertiesClassesFoo.class; + Class expectedContextLoaderClass = GenericPropertiesContextLoader.class; + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, array(FooConfig.class), + expectedContextLoaderClass); + } + + @Test + public void buildMergedConfigWithLocalAndInheritedAnnotationsAndLocations() { + Class testClass = LocationsBar.class; + String[] expectedLocations = array("/foo.xml", "/bar.xml"); + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, expectedLocations, EMPTY_CLASS_ARRAY, + AnnotationConfigContextLoader.class); + } + + @Test + public void buildMergedConfigWithLocalAndInheritedAnnotationsAndClasses() { + Class testClass = ClassesBar.class; + Class[] expectedClasses = array(FooConfig.class, BarConfig.class); + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, expectedClasses, + AnnotationConfigContextLoader.class); + } + + @Test + public void buildMergedConfigWithAnnotationsAndOverriddenLocations() { + Class testClass = OverriddenLocationsBar.class; + String[] expectedLocations = array("/bar.xml"); + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, expectedLocations, EMPTY_CLASS_ARRAY, + AnnotationConfigContextLoader.class); + } + + @Test + public void buildMergedConfigWithAnnotationsAndOverriddenClasses() { + Class testClass = OverriddenClassesBar.class; + Class[] expectedClasses = array(BarConfig.class); + MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); + + assertMergedConfig(mergedConfig, testClass, EMPTY_STRING_ARRAY, expectedClasses, + AnnotationConfigContextLoader.class); + } + + + @ContextConfiguration + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + public static @interface SpringAppConfig { + + Class[] classes() default {}; + } + + @SpringAppConfig(classes = { FooConfig.class, BarConfig.class }) + public static abstract class Dog { + } + + public static abstract class WorkingDog extends Dog { + } + + public static class GermanShepherd extends WorkingDog { + } + + @ContextConfiguration + static class MissingContextAttributesTestCase { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/ContextConfigurationInnerClassTestCase.java b/spring-test/src/test/java/org/springframework/test/context/support/ContextConfigurationInnerClassTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..bc106845a2830e3ae2d0a992c5fbca08370b91ea --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/ContextConfigurationInnerClassTestCase.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.springframework.context.annotation.Configuration; + +/** + * Not an actual test case. + * + * @author Sam Brannen + * @since 3.1 + * @see AnnotationConfigContextLoaderTests + */ +public class ContextConfigurationInnerClassTestCase { + + @Configuration + static class ContextConfiguration { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/ContextLoaderUtilsConfigurationAttributesTests.java b/spring-test/src/test/java/org/springframework/test/context/support/ContextLoaderUtilsConfigurationAttributesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ac99c22c7a8e0157c2fbb1d77fe76ad955affeae --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/ContextLoaderUtilsConfigurationAttributesTests.java @@ -0,0 +1,188 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.util.List; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.core.annotation.AnnotationConfigurationException; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextConfigurationAttributes; +import org.springframework.test.context.ContextLoader; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.springframework.test.context.support.ContextLoaderUtils.*; + +/** + * Unit tests for {@link ContextLoaderUtils} involving {@link ContextConfigurationAttributes}. + * + * @author Sam Brannen + * @since 3.1 + */ +public class ContextLoaderUtilsConfigurationAttributesTests extends AbstractContextConfigurationUtilsTests { + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + private void assertLocationsFooAttributes(ContextConfigurationAttributes attributes) { + assertAttributes(attributes, LocationsFoo.class, new String[] { "/foo.xml" }, EMPTY_CLASS_ARRAY, + ContextLoader.class, false); + } + + private void assertClassesFooAttributes(ContextConfigurationAttributes attributes) { + assertAttributes(attributes, ClassesFoo.class, EMPTY_STRING_ARRAY, new Class[] {FooConfig.class}, + ContextLoader.class, false); + } + + private void assertLocationsBarAttributes(ContextConfigurationAttributes attributes) { + assertAttributes(attributes, LocationsBar.class, new String[] {"/bar.xml"}, EMPTY_CLASS_ARRAY, + AnnotationConfigContextLoader.class, true); + } + + private void assertClassesBarAttributes(ContextConfigurationAttributes attributes) { + assertAttributes(attributes, ClassesBar.class, EMPTY_STRING_ARRAY, new Class[] {BarConfig.class}, + AnnotationConfigContextLoader.class, true); + } + + @Test + public void resolveConfigAttributesWithConflictingLocations() { + exception.expect(AnnotationConfigurationException.class); + exception.expectMessage(containsString(ConflictingLocations.class.getName())); + exception.expectMessage(either( + containsString("attribute 'value' and its alias 'locations'")).or( + containsString("attribute 'locations' and its alias 'value'"))); + exception.expectMessage(either( + containsString("values of [{x}] and [{y}]")).or( + containsString("values of [{y}] and [{x}]"))); + exception.expectMessage(containsString("but only one is permitted")); + resolveContextConfigurationAttributes(ConflictingLocations.class); + } + + @Test + public void resolveConfigAttributesWithBareAnnotations() { + Class testClass = BareAnnotations.class; + List attributesList = resolveContextConfigurationAttributes(testClass); + assertNotNull(attributesList); + assertEquals(1, attributesList.size()); + assertAttributes(attributesList.get(0), + testClass, EMPTY_STRING_ARRAY, EMPTY_CLASS_ARRAY, ContextLoader.class, true); + } + + @Test + public void resolveConfigAttributesWithLocalAnnotationAndLocations() { + List attributesList = resolveContextConfigurationAttributes(LocationsFoo.class); + assertNotNull(attributesList); + assertEquals(1, attributesList.size()); + assertLocationsFooAttributes(attributesList.get(0)); + } + + @Test + public void resolveConfigAttributesWithMetaAnnotationAndLocations() { + Class testClass = MetaLocationsFoo.class; + List attributesList = resolveContextConfigurationAttributes(testClass); + assertNotNull(attributesList); + assertEquals(1, attributesList.size()); + assertAttributes(attributesList.get(0), + testClass, new String[] {"/foo.xml"}, EMPTY_CLASS_ARRAY, ContextLoader.class, true); + } + + @Test + public void resolveConfigAttributesWithMetaAnnotationAndLocationsAndOverrides() { + Class testClass = MetaLocationsFooWithOverrides.class; + List attributesList = resolveContextConfigurationAttributes(testClass); + assertNotNull(attributesList); + assertEquals(1, attributesList.size()); + assertAttributes(attributesList.get(0), + testClass, new String[] {"/foo.xml"}, EMPTY_CLASS_ARRAY, ContextLoader.class, true); + } + + @Test + public void resolveConfigAttributesWithMetaAnnotationAndLocationsAndOverriddenAttributes() { + Class testClass = MetaLocationsFooWithOverriddenAttributes.class; + List attributesList = resolveContextConfigurationAttributes(testClass); + assertNotNull(attributesList); + assertEquals(1, attributesList.size()); + assertAttributes(attributesList.get(0), + testClass, new String[] {"foo1.xml", "foo2.xml"}, EMPTY_CLASS_ARRAY, ContextLoader.class, true); + } + + @Test + public void resolveConfigAttributesWithMetaAnnotationAndLocationsInClassHierarchy() { + Class testClass = MetaLocationsBar.class; + List attributesList = resolveContextConfigurationAttributes(testClass); + assertNotNull(attributesList); + assertEquals(2, attributesList.size()); + assertAttributes(attributesList.get(0), + testClass, new String[] {"/bar.xml"}, EMPTY_CLASS_ARRAY, ContextLoader.class, true); + assertAttributes(attributesList.get(1), + MetaLocationsFoo.class, new String[] {"/foo.xml"}, EMPTY_CLASS_ARRAY, ContextLoader.class, true); + } + + @Test + public void resolveConfigAttributesWithLocalAnnotationAndClasses() { + List attributesList = resolveContextConfigurationAttributes(ClassesFoo.class); + assertNotNull(attributesList); + assertEquals(1, attributesList.size()); + assertClassesFooAttributes(attributesList.get(0)); + } + + @Test + public void resolveConfigAttributesWithLocalAndInheritedAnnotationsAndLocations() { + List attributesList = resolveContextConfigurationAttributes(LocationsBar.class); + assertNotNull(attributesList); + assertEquals(2, attributesList.size()); + assertLocationsBarAttributes(attributesList.get(0)); + assertLocationsFooAttributes(attributesList.get(1)); + } + + @Test + public void resolveConfigAttributesWithLocalAndInheritedAnnotationsAndClasses() { + List attributesList = resolveContextConfigurationAttributes(ClassesBar.class); + assertNotNull(attributesList); + assertEquals(2, attributesList.size()); + assertClassesBarAttributes(attributesList.get(0)); + assertClassesFooAttributes(attributesList.get(1)); + } + + /** + * Verifies change requested in SPR-11634. + * @since 4.0.4 + */ + @Test + public void resolveConfigAttributesWithLocationsAndClasses() { + List attributesList = resolveContextConfigurationAttributes(LocationsAndClasses.class); + assertNotNull(attributesList); + assertEquals(1, attributesList.size()); + } + + + // ------------------------------------------------------------------------- + + @ContextConfiguration(value = "x", locations = "y") + private static class ConflictingLocations { + } + + @ContextConfiguration(locations = "x", classes = Object.class) + private static class LocationsAndClasses { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/ContextLoaderUtilsContextHierarchyTests.java b/spring-test/src/test/java/org/springframework/test/context/support/ContextLoaderUtilsContextHierarchyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0ffa9b580455bc31982f6efdfbf1fe8327eceae4 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/ContextLoaderUtilsContextHierarchyTests.java @@ -0,0 +1,611 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextConfigurationAttributes; +import org.springframework.test.context.ContextHierarchy; +import org.springframework.test.context.ContextLoader; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.test.context.support.ContextLoaderUtils.*; + +/** + * Unit tests for {@link ContextLoaderUtils} involving context hierarchies. + * + * @author Sam Brannen + * @since 3.2.2 + */ +public class ContextLoaderUtilsContextHierarchyTests extends AbstractContextConfigurationUtilsTests { + + private void debugConfigAttributes(List configAttributesList) { + // for (ContextConfigurationAttributes configAttributes : configAttributesList) { + // System.err.println(configAttributes); + // } + } + + @Test(expected = IllegalStateException.class) + public void resolveContextHierarchyAttributesForSingleTestClassWithContextConfigurationAndContextHierarchy() { + resolveContextHierarchyAttributes(SingleTestClassWithContextConfigurationAndContextHierarchy.class); + } + + @Test(expected = IllegalStateException.class) + public void resolveContextHierarchyAttributesForSingleTestClassWithContextConfigurationAndContextHierarchyOnSingleMetaAnnotation() { + resolveContextHierarchyAttributes(SingleTestClassWithContextConfigurationAndContextHierarchyOnSingleMetaAnnotation.class); + } + + @Test + public void resolveContextHierarchyAttributesForSingleTestClassWithImplicitSingleLevelContextHierarchy() { + List> hierarchyAttributes = resolveContextHierarchyAttributes(BareAnnotations.class); + assertEquals(1, hierarchyAttributes.size()); + List configAttributesList = hierarchyAttributes.get(0); + assertEquals(1, configAttributesList.size()); + debugConfigAttributes(configAttributesList); + } + + @Test + public void resolveContextHierarchyAttributesForSingleTestClassWithSingleLevelContextHierarchy() { + List> hierarchyAttributes = resolveContextHierarchyAttributes(SingleTestClassWithSingleLevelContextHierarchy.class); + assertEquals(1, hierarchyAttributes.size()); + List configAttributesList = hierarchyAttributes.get(0); + assertEquals(1, configAttributesList.size()); + debugConfigAttributes(configAttributesList); + } + + @Test + public void resolveContextHierarchyAttributesForSingleTestClassWithSingleLevelContextHierarchyFromMetaAnnotation() { + Class testClass = SingleTestClassWithSingleLevelContextHierarchyFromMetaAnnotation.class; + List> hierarchyAttributes = resolveContextHierarchyAttributes(testClass); + assertEquals(1, hierarchyAttributes.size()); + + List configAttributesList = hierarchyAttributes.get(0); + assertNotNull(configAttributesList); + assertEquals(1, configAttributesList.size()); + debugConfigAttributes(configAttributesList); + assertAttributes(configAttributesList.get(0), testClass, new String[] { "A.xml" }, EMPTY_CLASS_ARRAY, + ContextLoader.class, true); + } + + @Test + public void resolveContextHierarchyAttributesForSingleTestClassWithTripleLevelContextHierarchy() { + Class testClass = SingleTestClassWithTripleLevelContextHierarchy.class; + List> hierarchyAttributes = resolveContextHierarchyAttributes(testClass); + assertEquals(1, hierarchyAttributes.size()); + + List configAttributesList = hierarchyAttributes.get(0); + assertNotNull(configAttributesList); + assertEquals(3, configAttributesList.size()); + debugConfigAttributes(configAttributesList); + assertAttributes(configAttributesList.get(0), testClass, new String[] { "A.xml" }, EMPTY_CLASS_ARRAY, + ContextLoader.class, true); + assertAttributes(configAttributesList.get(1), testClass, new String[] { "B.xml" }, EMPTY_CLASS_ARRAY, + ContextLoader.class, true); + assertAttributes(configAttributesList.get(2), testClass, new String[] { "C.xml" }, EMPTY_CLASS_ARRAY, + ContextLoader.class, true); + } + + @Test + public void resolveContextHierarchyAttributesForTestClassHierarchyWithSingleLevelContextHierarchies() { + List> hierarchyAttributes = resolveContextHierarchyAttributes(TestClass3WithSingleLevelContextHierarchy.class); + assertEquals(3, hierarchyAttributes.size()); + + List configAttributesListClassLevel1 = hierarchyAttributes.get(0); + debugConfigAttributes(configAttributesListClassLevel1); + assertEquals(1, configAttributesListClassLevel1.size()); + assertThat(configAttributesListClassLevel1.get(0).getLocations()[0], equalTo("one.xml")); + + List configAttributesListClassLevel2 = hierarchyAttributes.get(1); + debugConfigAttributes(configAttributesListClassLevel2); + assertEquals(1, configAttributesListClassLevel2.size()); + assertArrayEquals(new String[] { "two-A.xml", "two-B.xml" }, + configAttributesListClassLevel2.get(0).getLocations()); + + List configAttributesListClassLevel3 = hierarchyAttributes.get(2); + debugConfigAttributes(configAttributesListClassLevel3); + assertEquals(1, configAttributesListClassLevel3.size()); + assertThat(configAttributesListClassLevel3.get(0).getLocations()[0], equalTo("three.xml")); + } + + @Test + public void resolveContextHierarchyAttributesForTestClassHierarchyWithSingleLevelContextHierarchiesAndMetaAnnotations() { + List> hierarchyAttributes = resolveContextHierarchyAttributes(TestClass3WithSingleLevelContextHierarchyFromMetaAnnotation.class); + assertEquals(3, hierarchyAttributes.size()); + + List configAttributesListClassLevel1 = hierarchyAttributes.get(0); + debugConfigAttributes(configAttributesListClassLevel1); + assertEquals(1, configAttributesListClassLevel1.size()); + assertThat(configAttributesListClassLevel1.get(0).getLocations()[0], equalTo("A.xml")); + assertAttributes(configAttributesListClassLevel1.get(0), + TestClass1WithSingleLevelContextHierarchyFromMetaAnnotation.class, new String[] { "A.xml" }, + EMPTY_CLASS_ARRAY, ContextLoader.class, true); + + List configAttributesListClassLevel2 = hierarchyAttributes.get(1); + debugConfigAttributes(configAttributesListClassLevel2); + assertEquals(1, configAttributesListClassLevel2.size()); + assertArrayEquals(new String[] { "B-one.xml", "B-two.xml" }, + configAttributesListClassLevel2.get(0).getLocations()); + assertAttributes(configAttributesListClassLevel2.get(0), + TestClass2WithSingleLevelContextHierarchyFromMetaAnnotation.class, + new String[] { "B-one.xml", + "B-two.xml" }, EMPTY_CLASS_ARRAY, ContextLoader.class, true); + + List configAttributesListClassLevel3 = hierarchyAttributes.get(2); + debugConfigAttributes(configAttributesListClassLevel3); + assertEquals(1, configAttributesListClassLevel3.size()); + assertThat(configAttributesListClassLevel3.get(0).getLocations()[0], equalTo("C.xml")); + assertAttributes(configAttributesListClassLevel3.get(0), + TestClass3WithSingleLevelContextHierarchyFromMetaAnnotation.class, new String[] { "C.xml" }, + EMPTY_CLASS_ARRAY, ContextLoader.class, true); + } + + private void assertOneTwo(List> hierarchyAttributes) { + assertEquals(2, hierarchyAttributes.size()); + + List configAttributesListClassLevel1 = hierarchyAttributes.get(0); + List configAttributesListClassLevel2 = hierarchyAttributes.get(1); + debugConfigAttributes(configAttributesListClassLevel1); + debugConfigAttributes(configAttributesListClassLevel2); + + assertEquals(1, configAttributesListClassLevel1.size()); + assertThat(configAttributesListClassLevel1.get(0).getLocations()[0], equalTo("one.xml")); + + assertEquals(1, configAttributesListClassLevel2.size()); + assertThat(configAttributesListClassLevel2.get(0).getLocations()[0], equalTo("two.xml")); + } + + @Test + public void resolveContextHierarchyAttributesForTestClassHierarchyWithBareContextConfigurationInSuperclass() { + assertOneTwo(resolveContextHierarchyAttributes(TestClass2WithBareContextConfigurationInSuperclass.class)); + } + + @Test + public void resolveContextHierarchyAttributesForTestClassHierarchyWithBareContextConfigurationInSubclass() { + assertOneTwo(resolveContextHierarchyAttributes(TestClass2WithBareContextConfigurationInSubclass.class)); + } + + @Test + public void resolveContextHierarchyAttributesForTestClassHierarchyWithBareMetaContextConfigWithOverridesInSuperclass() { + assertOneTwo(resolveContextHierarchyAttributes(TestClass2WithBareMetaContextConfigWithOverridesInSuperclass.class)); + } + + @Test + public void resolveContextHierarchyAttributesForTestClassHierarchyWithBareMetaContextConfigWithOverridesInSubclass() { + assertOneTwo(resolveContextHierarchyAttributes(TestClass2WithBareMetaContextConfigWithOverridesInSubclass.class)); + } + + @Test + public void resolveContextHierarchyAttributesForTestClassHierarchyWithMultiLevelContextHierarchies() { + List> hierarchyAttributes = resolveContextHierarchyAttributes(TestClass3WithMultiLevelContextHierarchy.class); + assertEquals(3, hierarchyAttributes.size()); + + List configAttributesListClassLevel1 = hierarchyAttributes.get(0); + debugConfigAttributes(configAttributesListClassLevel1); + assertEquals(2, configAttributesListClassLevel1.size()); + assertThat(configAttributesListClassLevel1.get(0).getLocations()[0], equalTo("1-A.xml")); + assertThat(configAttributesListClassLevel1.get(1).getLocations()[0], equalTo("1-B.xml")); + + List configAttributesListClassLevel2 = hierarchyAttributes.get(1); + debugConfigAttributes(configAttributesListClassLevel2); + assertEquals(2, configAttributesListClassLevel2.size()); + assertThat(configAttributesListClassLevel2.get(0).getLocations()[0], equalTo("2-A.xml")); + assertThat(configAttributesListClassLevel2.get(1).getLocations()[0], equalTo("2-B.xml")); + + List configAttributesListClassLevel3 = hierarchyAttributes.get(2); + debugConfigAttributes(configAttributesListClassLevel3); + assertEquals(3, configAttributesListClassLevel3.size()); + assertThat(configAttributesListClassLevel3.get(0).getLocations()[0], equalTo("3-A.xml")); + assertThat(configAttributesListClassLevel3.get(1).getLocations()[0], equalTo("3-B.xml")); + assertThat(configAttributesListClassLevel3.get(2).getLocations()[0], equalTo("3-C.xml")); + } + + @Test + public void buildContextHierarchyMapForTestClassHierarchyWithMultiLevelContextHierarchies() { + Map> map = buildContextHierarchyMap(TestClass3WithMultiLevelContextHierarchy.class); + + assertThat(map.size(), is(3)); + assertThat(map.keySet(), hasItems("alpha", "beta", "gamma")); + + List alphaConfig = map.get("alpha"); + assertThat(alphaConfig.size(), is(3)); + assertThat(alphaConfig.get(0).getLocations()[0], is("1-A.xml")); + assertThat(alphaConfig.get(1).getLocations()[0], is("2-A.xml")); + assertThat(alphaConfig.get(2).getLocations()[0], is("3-A.xml")); + + List betaConfig = map.get("beta"); + assertThat(betaConfig.size(), is(3)); + assertThat(betaConfig.get(0).getLocations()[0], is("1-B.xml")); + assertThat(betaConfig.get(1).getLocations()[0], is("2-B.xml")); + assertThat(betaConfig.get(2).getLocations()[0], is("3-B.xml")); + + List gammaConfig = map.get("gamma"); + assertThat(gammaConfig.size(), is(1)); + assertThat(gammaConfig.get(0).getLocations()[0], is("3-C.xml")); + } + + @Test + public void buildContextHierarchyMapForTestClassHierarchyWithMultiLevelContextHierarchiesAndUnnamedConfig() { + Map> map = buildContextHierarchyMap(TestClass3WithMultiLevelContextHierarchyAndUnnamedConfig.class); + + String level1 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 1; + String level2 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 2; + String level3 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 3; + String level4 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 4; + String level5 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 5; + String level6 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 6; + String level7 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 7; + + assertThat(map.size(), is(7)); + assertThat(map.keySet(), hasItems(level1, level2, level3, level4, level5, level6, level7)); + + List level1Config = map.get(level1); + assertThat(level1Config.size(), is(1)); + assertThat(level1Config.get(0).getLocations()[0], is("1-A.xml")); + + List level2Config = map.get(level2); + assertThat(level2Config.size(), is(1)); + assertThat(level2Config.get(0).getLocations()[0], is("1-B.xml")); + + List level3Config = map.get(level3); + assertThat(level3Config.size(), is(1)); + assertThat(level3Config.get(0).getLocations()[0], is("2-A.xml")); + + // ... + + List level7Config = map.get(level7); + assertThat(level7Config.size(), is(1)); + assertThat(level7Config.get(0).getLocations()[0], is("3-C.xml")); + } + + @Test + public void buildContextHierarchyMapForTestClassHierarchyWithMultiLevelContextHierarchiesAndPartiallyNamedConfig() { + Map> map = buildContextHierarchyMap(TestClass2WithMultiLevelContextHierarchyAndPartiallyNamedConfig.class); + + String level1 = "parent"; + String level2 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 2; + String level3 = GENERATED_CONTEXT_HIERARCHY_LEVEL_PREFIX + 3; + + assertThat(map.size(), is(3)); + assertThat(map.keySet(), hasItems(level1, level2, level3)); + Iterator levels = map.keySet().iterator(); + assertThat(levels.next(), is(level1)); + assertThat(levels.next(), is(level2)); + assertThat(levels.next(), is(level3)); + + List level1Config = map.get(level1); + assertThat(level1Config.size(), is(2)); + assertThat(level1Config.get(0).getLocations()[0], is("1-A.xml")); + assertThat(level1Config.get(1).getLocations()[0], is("2-A.xml")); + + List level2Config = map.get(level2); + assertThat(level2Config.size(), is(1)); + assertThat(level2Config.get(0).getLocations()[0], is("1-B.xml")); + + List level3Config = map.get(level3); + assertThat(level3Config.size(), is(1)); + assertThat(level3Config.get(0).getLocations()[0], is("2-C.xml")); + } + + private void assertContextConfigEntriesAreNotUnique(Class testClass) { + try { + buildContextHierarchyMap(testClass); + fail("Should throw an IllegalStateException"); + } + catch (IllegalStateException e) { + String msg = String.format( + "The @ContextConfiguration elements configured via @ContextHierarchy in test class [%s] and its superclasses must define unique contexts per hierarchy level.", + testClass.getName()); + assertEquals(msg, e.getMessage()); + } + } + + @Test + public void buildContextHierarchyMapForSingleTestClassWithMultiLevelContextHierarchyWithEmptyContextConfig() { + assertContextConfigEntriesAreNotUnique(SingleTestClassWithMultiLevelContextHierarchyWithEmptyContextConfig.class); + } + + @Test + public void buildContextHierarchyMapForSingleTestClassWithMultiLevelContextHierarchyWithDuplicatedContextConfig() { + assertContextConfigEntriesAreNotUnique(SingleTestClassWithMultiLevelContextHierarchyWithDuplicatedContextConfig.class); + } + + /** + * Used to reproduce bug reported in https://jira.spring.io/browse/SPR-10997 + */ + @Test + public void buildContextHierarchyMapForTestClassHierarchyWithMultiLevelContextHierarchiesAndOverriddenInitializers() { + Map> map = buildContextHierarchyMap(TestClass2WithMultiLevelContextHierarchyWithOverriddenInitializers.class); + + assertThat(map.size(), is(2)); + assertThat(map.keySet(), hasItems("alpha", "beta")); + + List alphaConfig = map.get("alpha"); + assertThat(alphaConfig.size(), is(2)); + assertThat(alphaConfig.get(0).getLocations().length, is(1)); + assertThat(alphaConfig.get(0).getLocations()[0], is("1-A.xml")); + assertThat(alphaConfig.get(0).getInitializers().length, is(0)); + assertThat(alphaConfig.get(1).getLocations().length, is(0)); + assertThat(alphaConfig.get(1).getInitializers().length, is(1)); + assertEquals(DummyApplicationContextInitializer.class, alphaConfig.get(1).getInitializers()[0]); + + List betaConfig = map.get("beta"); + assertThat(betaConfig.size(), is(2)); + assertThat(betaConfig.get(0).getLocations().length, is(1)); + assertThat(betaConfig.get(0).getLocations()[0], is("1-B.xml")); + assertThat(betaConfig.get(0).getInitializers().length, is(0)); + assertThat(betaConfig.get(1).getLocations().length, is(0)); + assertThat(betaConfig.get(1).getInitializers().length, is(1)); + assertEquals(DummyApplicationContextInitializer.class, betaConfig.get(1).getInitializers()[0]); + } + + + // ------------------------------------------------------------------------- + + @ContextConfiguration("foo.xml") + @ContextHierarchy(@ContextConfiguration("bar.xml")) + private static class SingleTestClassWithContextConfigurationAndContextHierarchy { + } + + @ContextConfiguration("foo.xml") + @ContextHierarchy(@ContextConfiguration("bar.xml")) + @Retention(RetentionPolicy.RUNTIME) + private static @interface ContextConfigurationAndContextHierarchyOnSingleMeta { + } + + @ContextConfigurationAndContextHierarchyOnSingleMeta + private static class SingleTestClassWithContextConfigurationAndContextHierarchyOnSingleMetaAnnotation { + } + + @ContextHierarchy(@ContextConfiguration("A.xml")) + private static class SingleTestClassWithSingleLevelContextHierarchy { + } + + @ContextHierarchy({// + // + @ContextConfiguration("A.xml"),// + @ContextConfiguration("B.xml"),// + @ContextConfiguration("C.xml") // + }) + private static class SingleTestClassWithTripleLevelContextHierarchy { + } + + @ContextHierarchy(@ContextConfiguration("one.xml")) + private static class TestClass1WithSingleLevelContextHierarchy { + } + + @ContextHierarchy(@ContextConfiguration({ "two-A.xml", "two-B.xml" })) + private static class TestClass2WithSingleLevelContextHierarchy extends TestClass1WithSingleLevelContextHierarchy { + } + + @ContextHierarchy(@ContextConfiguration("three.xml")) + private static class TestClass3WithSingleLevelContextHierarchy extends TestClass2WithSingleLevelContextHierarchy { + } + + @ContextConfiguration("one.xml") + private static class TestClass1WithBareContextConfigurationInSuperclass { + } + + @ContextHierarchy(@ContextConfiguration("two.xml")) + private static class TestClass2WithBareContextConfigurationInSuperclass extends + TestClass1WithBareContextConfigurationInSuperclass { + } + + @ContextHierarchy(@ContextConfiguration("one.xml")) + private static class TestClass1WithBareContextConfigurationInSubclass { + } + + @ContextConfiguration("two.xml") + private static class TestClass2WithBareContextConfigurationInSubclass extends + TestClass1WithBareContextConfigurationInSubclass { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "1-A.xml", name = "alpha"),// + @ContextConfiguration(locations = "1-B.xml", name = "beta") // + }) + private static class TestClass1WithMultiLevelContextHierarchy { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "2-A.xml", name = "alpha"),// + @ContextConfiguration(locations = "2-B.xml", name = "beta") // + }) + private static class TestClass2WithMultiLevelContextHierarchy extends TestClass1WithMultiLevelContextHierarchy { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "3-A.xml", name = "alpha"),// + @ContextConfiguration(locations = "3-B.xml", name = "beta"),// + @ContextConfiguration(locations = "3-C.xml", name = "gamma") // + }) + private static class TestClass3WithMultiLevelContextHierarchy extends TestClass2WithMultiLevelContextHierarchy { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "1-A.xml"),// + @ContextConfiguration(locations = "1-B.xml") // + }) + private static class TestClass1WithMultiLevelContextHierarchyAndUnnamedConfig { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "2-A.xml"),// + @ContextConfiguration(locations = "2-B.xml") // + }) + private static class TestClass2WithMultiLevelContextHierarchyAndUnnamedConfig extends + TestClass1WithMultiLevelContextHierarchyAndUnnamedConfig { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "3-A.xml"),// + @ContextConfiguration(locations = "3-B.xml"),// + @ContextConfiguration(locations = "3-C.xml") // + }) + private static class TestClass3WithMultiLevelContextHierarchyAndUnnamedConfig extends + TestClass2WithMultiLevelContextHierarchyAndUnnamedConfig { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "1-A.xml", name = "parent"),// + @ContextConfiguration(locations = "1-B.xml") // + }) + private static class TestClass1WithMultiLevelContextHierarchyAndPartiallyNamedConfig { + } + + @ContextHierarchy({// + // + @ContextConfiguration(locations = "2-A.xml", name = "parent"),// + @ContextConfiguration(locations = "2-C.xml") // + }) + private static class TestClass2WithMultiLevelContextHierarchyAndPartiallyNamedConfig extends + TestClass1WithMultiLevelContextHierarchyAndPartiallyNamedConfig { + } + + @ContextHierarchy({ + // + @ContextConfiguration,// + @ContextConfiguration // + }) + private static class SingleTestClassWithMultiLevelContextHierarchyWithEmptyContextConfig { + } + + @ContextHierarchy({ + // + @ContextConfiguration("foo.xml"),// + @ContextConfiguration(classes = BarConfig.class),// duplicate! + @ContextConfiguration("baz.xml"),// + @ContextConfiguration(classes = BarConfig.class),// duplicate! + @ContextConfiguration(loader = AnnotationConfigContextLoader.class) // + }) + private static class SingleTestClassWithMultiLevelContextHierarchyWithDuplicatedContextConfig { + } + + /** + * Used to reproduce bug reported in https://jira.spring.io/browse/SPR-10997 + */ + @ContextHierarchy({// + // + @ContextConfiguration(name = "alpha", locations = "1-A.xml"),// + @ContextConfiguration(name = "beta", locations = "1-B.xml") // + }) + private static class TestClass1WithMultiLevelContextHierarchyWithUniqueContextConfig { + } + + /** + * Used to reproduce bug reported in https://jira.spring.io/browse/SPR-10997 + */ + @ContextHierarchy({// + // + @ContextConfiguration(name = "alpha", initializers = DummyApplicationContextInitializer.class),// + @ContextConfiguration(name = "beta", initializers = DummyApplicationContextInitializer.class) // + }) + private static class TestClass2WithMultiLevelContextHierarchyWithOverriddenInitializers extends + TestClass1WithMultiLevelContextHierarchyWithUniqueContextConfig { + } + + /** + * Used to reproduce bug reported in https://jira.spring.io/browse/SPR-10997 + */ + private static class DummyApplicationContextInitializer implements + ApplicationContextInitializer { + + @Override + public void initialize(ConfigurableApplicationContext applicationContext) { + /* no-op */ + } + } + + // ------------------------------------------------------------------------- + + @ContextHierarchy(@ContextConfiguration("A.xml")) + @Retention(RetentionPolicy.RUNTIME) + private static @interface ContextHierarchyA { + } + + @ContextHierarchy(@ContextConfiguration({ "B-one.xml", "B-two.xml" })) + @Retention(RetentionPolicy.RUNTIME) + private static @interface ContextHierarchyB { + } + + @ContextHierarchy(@ContextConfiguration("C.xml")) + @Retention(RetentionPolicy.RUNTIME) + private static @interface ContextHierarchyC { + } + + @ContextHierarchyA + private static class SingleTestClassWithSingleLevelContextHierarchyFromMetaAnnotation { + } + + @ContextHierarchyA + private static class TestClass1WithSingleLevelContextHierarchyFromMetaAnnotation { + } + + @ContextHierarchyB + private static class TestClass2WithSingleLevelContextHierarchyFromMetaAnnotation extends + TestClass1WithSingleLevelContextHierarchyFromMetaAnnotation { + } + + @ContextHierarchyC + private static class TestClass3WithSingleLevelContextHierarchyFromMetaAnnotation extends + TestClass2WithSingleLevelContextHierarchyFromMetaAnnotation { + } + + // ------------------------------------------------------------------------- + + @ContextConfiguration + @Retention(RetentionPolicy.RUNTIME) + private static @interface ContextConfigWithOverrides { + + String[] locations() default "A.xml"; + } + + @ContextConfigWithOverrides(locations = "one.xml") + private static class TestClass1WithBareMetaContextConfigWithOverridesInSuperclass { + } + + @ContextHierarchy(@ContextConfiguration(locations = "two.xml")) + private static class TestClass2WithBareMetaContextConfigWithOverridesInSuperclass extends + TestClass1WithBareMetaContextConfigWithOverridesInSuperclass { + } + + @ContextHierarchy(@ContextConfiguration(locations = "one.xml")) + private static class TestClass1WithBareMetaContextConfigWithOverridesInSubclass { + } + + @ContextConfigWithOverrides(locations = "two.xml") + private static class TestClass2WithBareMetaContextConfigWithOverridesInSubclass extends + TestClass1WithBareMetaContextConfigWithOverridesInSubclass { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/CustomizedGenericXmlContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/support/CustomizedGenericXmlContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..753e10addbe247ad4b8053b5e91efa16658167dc --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/CustomizedGenericXmlContextLoaderTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.junit.Test; + +import org.springframework.context.support.GenericApplicationContext; + +import static org.junit.Assert.*; + +/** + * Unit test which verifies that extensions of + * {@link AbstractGenericContextLoader} are able to customize the + * newly created {@code ApplicationContext}. Specifically, this test + * addresses the issues raised in SPR-4008: Supply an opportunity to customize context + * before calling refresh in ContextLoaders. + * + * @author Sam Brannen + * @since 2.5 + */ +public class CustomizedGenericXmlContextLoaderTests { + + @Test + public void customizeContext() throws Exception { + + final StringBuilder builder = new StringBuilder(); + final String expectedContents = "customizeContext() was called"; + + new GenericXmlContextLoader() { + + @Override + protected void customizeContext(GenericApplicationContext context) { + assertFalse("The context should not yet have been refreshed.", context.isActive()); + builder.append(expectedContents); + } + }.loadContext("classpath:/org/springframework/test/context/support/CustomizedGenericXmlContextLoaderTests-context.xml"); + + assertEquals("customizeContext() should have been called.", expectedContents, builder.toString()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/DelegatingSmartContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/support/DelegatingSmartContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e05ace59fb7acd523a904c903c644dba8a55e5dd --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/DelegatingSmartContextLoaderTests.java @@ -0,0 +1,208 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfigurationAttributes; +import org.springframework.test.context.ContextLoader; +import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.util.ObjectUtils; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link DelegatingSmartContextLoader}. + * + * @author Sam Brannen + * @since 3.1 + */ +public class DelegatingSmartContextLoaderTests { + + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + private static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; + + private final DelegatingSmartContextLoader loader = new DelegatingSmartContextLoader(); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + + private static void assertEmpty(Object[] array) { + assertTrue(ObjectUtils.isEmpty(array)); + } + + // --- SmartContextLoader - processContextConfiguration() ------------------ + + @Test + public void processContextConfigurationWithDefaultXmlConfigGeneration() { + ContextConfigurationAttributes configAttributes = new ContextConfigurationAttributes( + XmlTestCase.class, EMPTY_STRING_ARRAY, EMPTY_CLASS_ARRAY, true, null, true, ContextLoader.class); + loader.processContextConfiguration(configAttributes); + assertEquals(1, configAttributes.getLocations().length); + assertEmpty(configAttributes.getClasses()); + } + + @Test + public void processContextConfigurationWithDefaultConfigurationClassGeneration() { + ContextConfigurationAttributes configAttributes = new ContextConfigurationAttributes( + ConfigClassTestCase.class, EMPTY_STRING_ARRAY, EMPTY_CLASS_ARRAY, true, null, true, ContextLoader.class); + loader.processContextConfiguration(configAttributes); + assertEquals(1, configAttributes.getClasses().length); + assertEmpty(configAttributes.getLocations()); + } + + @Test + public void processContextConfigurationWithDefaultXmlConfigAndConfigurationClassGeneration() { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(containsString("both default locations AND default configuration classes were detected")); + + ContextConfigurationAttributes configAttributes = new ContextConfigurationAttributes( + ImproperDuplicateDefaultXmlAndConfigClassTestCase.class, EMPTY_STRING_ARRAY, EMPTY_CLASS_ARRAY, + true, null, true, ContextLoader.class); + loader.processContextConfiguration(configAttributes); + } + + @Test + public void processContextConfigurationWithLocation() { + String[] locations = new String[] {"classpath:/foo.xml"}; + ContextConfigurationAttributes configAttributes = new ContextConfigurationAttributes( + getClass(), locations, EMPTY_CLASS_ARRAY, true, null, true, ContextLoader.class); + loader.processContextConfiguration(configAttributes); + assertArrayEquals(locations, configAttributes.getLocations()); + assertEmpty(configAttributes.getClasses()); + } + + @Test + public void processContextConfigurationWithConfigurationClass() { + Class[] classes = new Class[] {getClass()}; + ContextConfigurationAttributes configAttributes = new ContextConfigurationAttributes( + getClass(), EMPTY_STRING_ARRAY, classes, true, null, true, ContextLoader.class); + loader.processContextConfiguration(configAttributes); + assertArrayEquals(classes, configAttributes.getClasses()); + assertEmpty(configAttributes.getLocations()); + } + + // --- SmartContextLoader - loadContext() ---------------------------------- + + @Test(expected = IllegalArgumentException.class) + public void loadContextWithNullConfig() throws Exception { + MergedContextConfiguration mergedConfig = null; + loader.loadContext(mergedConfig); + } + + @Test + public void loadContextWithoutLocationsAndConfigurationClasses() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(startsWith("Neither")); + expectedException.expectMessage(containsString("was able to load an ApplicationContext from")); + + MergedContextConfiguration mergedConfig = new MergedContextConfiguration( + getClass(), EMPTY_STRING_ARRAY, EMPTY_CLASS_ARRAY, EMPTY_STRING_ARRAY, loader); + loader.loadContext(mergedConfig); + } + + /** + * @since 4.1 + */ + @Test + public void loadContextWithLocationsAndConfigurationClasses() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(startsWith("Neither")); + expectedException.expectMessage(endsWith("declare either 'locations' or 'classes' but not both.")); + + MergedContextConfiguration mergedConfig = new MergedContextConfiguration(getClass(), + new String[] {"test.xml"}, new Class[] {getClass()}, EMPTY_STRING_ARRAY, loader); + loader.loadContext(mergedConfig); + } + + private void assertApplicationContextLoadsAndContainsFooString(MergedContextConfiguration mergedConfig) + throws Exception { + + ApplicationContext applicationContext = loader.loadContext(mergedConfig); + assertNotNull(applicationContext); + assertEquals("foo", applicationContext.getBean(String.class)); + assertTrue(applicationContext instanceof ConfigurableApplicationContext); + ((ConfigurableApplicationContext) applicationContext).close(); + } + + @Test + public void loadContextWithXmlConfig() throws Exception { + MergedContextConfiguration mergedConfig = new MergedContextConfiguration( + XmlTestCase.class, + new String[] {"classpath:/org/springframework/test/context/support/DelegatingSmartContextLoaderTests$XmlTestCase-context.xml"}, + EMPTY_CLASS_ARRAY, EMPTY_STRING_ARRAY, loader); + assertApplicationContextLoadsAndContainsFooString(mergedConfig); + } + + @Test + public void loadContextWithConfigurationClass() throws Exception { + MergedContextConfiguration mergedConfig = new MergedContextConfiguration(ConfigClassTestCase.class, + EMPTY_STRING_ARRAY, new Class[] {ConfigClassTestCase.Config.class}, EMPTY_STRING_ARRAY, loader); + assertApplicationContextLoadsAndContainsFooString(mergedConfig); + } + + // --- ContextLoader ------------------------------------------------------- + + @Test(expected = UnsupportedOperationException.class) + public void processLocations() { + loader.processLocations(getClass(), EMPTY_STRING_ARRAY); + } + + @Test(expected = UnsupportedOperationException.class) + public void loadContextFromLocations() throws Exception { + loader.loadContext(EMPTY_STRING_ARRAY); + } + + + // ------------------------------------------------------------------------- + + static class XmlTestCase { + } + + static class ConfigClassTestCase { + + @Configuration + static class Config { + + @Bean + public String foo() { + return new String("foo"); + } + } + + static class NotAConfigClass { + } + } + + static class ImproperDuplicateDefaultXmlAndConfigClassTestCase { + + @Configuration + static class Config { + // intentionally empty: we just need the class to be present to fail + // the test + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/DirtiesContextTestExecutionListenerTests.java b/spring-test/src/test/java/org/springframework/test/context/support/DirtiesContextTestExecutionListenerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c2ee97646fcb5925b0e3326ae17ba96279154845 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/DirtiesContextTestExecutionListenerTests.java @@ -0,0 +1,396 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.Test; + +import org.mockito.BDDMockito; + +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.annotation.DirtiesContext.ClassMode; +import org.springframework.test.annotation.DirtiesContext.HierarchyMode; +import org.springframework.test.context.TestContext; +import org.springframework.test.context.TestExecutionListener; + +import static org.mockito.BDDMockito.*; +import static org.springframework.test.annotation.DirtiesContext.ClassMode.*; +import static org.springframework.test.annotation.DirtiesContext.HierarchyMode.*; +import static org.springframework.test.annotation.DirtiesContext.MethodMode.*; + +/** + * Unit tests for {@link DirtiesContextBeforeModesTestExecutionListener}. + * and {@link DirtiesContextTestExecutionListener} + * + * @author Sam Brannen + * @since 4.0 + */ +public class DirtiesContextTestExecutionListenerTests { + + private final TestExecutionListener beforeListener = new DirtiesContextBeforeModesTestExecutionListener(); + private final TestExecutionListener afterListener = new DirtiesContextTestExecutionListener(); + private final TestContext testContext = mock(TestContext.class); + + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredLocallyOnMethodWithBeforeMethodMode() throws Exception { + Class clazz = getClass(); + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn( + clazz.getDeclaredMethod("dirtiesContextDeclaredLocallyWithBeforeMethodMode")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredLocallyOnMethodWithAfterMethodMode() throws Exception { + Class clazz = getClass(); + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn( + clazz.getDeclaredMethod("dirtiesContextDeclaredLocallyWithAfterMethodMode")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredOnMethodViaMetaAnnotationWithAfterMethodMode() + throws Exception { + Class clazz = getClass(); + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn( + clazz.getDeclaredMethod("dirtiesContextDeclaredViaMetaAnnotationWithAfterMethodMode")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredLocallyOnClassBeforeEachTestMethod() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyBeforeEachTestMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("clean")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredLocallyOnClassAfterEachTestMethod() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyAfterEachTestMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("clean")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredViaMetaAnnotationOnClassAfterEachTestMethod() + throws Exception { + Class clazz = DirtiesContextDeclaredViaMetaAnnotationAfterEachTestMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("clean")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredLocallyOnClassBeforeClass() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyBeforeClass.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("clean")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredLocallyOnClassAfterClass() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyAfterClass.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("clean")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextDeclaredViaMetaAnnotationOnClassAfterClass() throws Exception { + Class clazz = DirtiesContextDeclaredViaMetaAnnotationAfterClass.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("clean")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestMethodForDirtiesContextViaMetaAnnotationWithOverrides() throws Exception { + Class clazz = DirtiesContextViaMetaAnnotationWithOverrides.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("clean")); + beforeListener.beforeTestMethod(testContext); + afterListener.beforeTestMethod(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestMethod(testContext); + beforeListener.afterTestMethod(testContext); + verify(testContext, times(1)).markApplicationContextDirty(CURRENT_LEVEL); + } + + // ------------------------------------------------------------------------- + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredLocallyOnMethod() throws Exception { + Class clazz = getClass(); + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredLocallyOnClassBeforeEachTestMethod() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyBeforeEachTestMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredLocallyOnClassAfterEachTestMethod() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyAfterEachTestMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredViaMetaAnnotationOnClassAfterEachTestMethod() + throws Exception { + Class clazz = DirtiesContextDeclaredViaMetaAnnotationAfterEachTestMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredLocallyOnClassBeforeClass() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyBeforeClass.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredLocallyOnClassAfterClass() throws Exception { + Class clazz = DirtiesContextDeclaredLocallyAfterClass.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredViaMetaAnnotationOnClassAfterClass() throws Exception { + Class clazz = DirtiesContextDeclaredViaMetaAnnotationAfterClass.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredViaMetaAnnotationWithOverrides() throws Exception { + Class clazz = DirtiesContextViaMetaAnnotationWithOverrides.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + } + + @Test + public void beforeAndAfterTestClassForDirtiesContextDeclaredViaMetaAnnotationWithOverridenAttributes() + throws Exception { + Class clazz = DirtiesContextViaMetaAnnotationWithOverridenAttributes.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + beforeListener.beforeTestClass(testContext); + afterListener.beforeTestClass(testContext); + verify(testContext, times(0)).markApplicationContextDirty(any(HierarchyMode.class)); + afterListener.afterTestClass(testContext); + beforeListener.afterTestClass(testContext); + verify(testContext, times(1)).markApplicationContextDirty(EXHAUSTIVE); + } + + // ------------------------------------------------------------------------- + + @DirtiesContext(methodMode = BEFORE_METHOD) + void dirtiesContextDeclaredLocallyWithBeforeMethodMode() { + /* no-op */ + } + + @DirtiesContext + void dirtiesContextDeclaredLocallyWithAfterMethodMode() { + /* no-op */ + } + + @MetaDirtyAfterMethod + void dirtiesContextDeclaredViaMetaAnnotationWithAfterMethodMode() { + /* no-op */ + } + + + @DirtiesContext + @Retention(RetentionPolicy.RUNTIME) + static @interface MetaDirtyAfterMethod { + } + + @DirtiesContext(classMode = AFTER_EACH_TEST_METHOD) + @Retention(RetentionPolicy.RUNTIME) + static @interface MetaDirtyAfterEachTestMethod { + } + + @DirtiesContext(classMode = AFTER_CLASS) + @Retention(RetentionPolicy.RUNTIME) + static @interface MetaDirtyAfterClass { + } + + @DirtiesContext(classMode = BEFORE_EACH_TEST_METHOD) + static class DirtiesContextDeclaredLocallyBeforeEachTestMethod { + + void clean() { + /* no-op */ + } + } + + @DirtiesContext(classMode = AFTER_EACH_TEST_METHOD) + static class DirtiesContextDeclaredLocallyAfterEachTestMethod { + + void clean() { + /* no-op */ + } + } + + @DirtiesContext + @Retention(RetentionPolicy.RUNTIME) + static @interface MetaDirtyWithOverrides { + + ClassMode classMode() default AFTER_EACH_TEST_METHOD; + + HierarchyMode hierarchyMode() default HierarchyMode.CURRENT_LEVEL; + } + + @MetaDirtyAfterEachTestMethod + static class DirtiesContextDeclaredViaMetaAnnotationAfterEachTestMethod { + + void clean() { + /* no-op */ + } + } + + @DirtiesContext(classMode = BEFORE_CLASS) + static class DirtiesContextDeclaredLocallyBeforeClass { + + void clean() { + /* no-op */ + } + } + + @DirtiesContext(classMode = AFTER_CLASS) + static class DirtiesContextDeclaredLocallyAfterClass { + + void clean() { + /* no-op */ + } + } + + @MetaDirtyAfterClass + static class DirtiesContextDeclaredViaMetaAnnotationAfterClass { + + void clean() { + /* no-op */ + } + } + + @MetaDirtyWithOverrides + static class DirtiesContextViaMetaAnnotationWithOverrides { + + void clean() { + /* no-op */ + } + } + + @MetaDirtyWithOverrides(classMode = AFTER_CLASS, hierarchyMode = EXHAUSTIVE) + static class DirtiesContextViaMetaAnnotationWithOverridenAttributes { + + void clean() { + /* no-op */ + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/FinalConfigInnerClassTestCase.java b/spring-test/src/test/java/org/springframework/test/context/support/FinalConfigInnerClassTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..44cc16d5de6a2c94c8127683e123056a48dca4f3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/FinalConfigInnerClassTestCase.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.springframework.context.annotation.Configuration; + +/** + * Not an actual test case. + * + * @author Sam Brannen + * @since 3.1 + * @see AnnotationConfigContextLoaderTests + */ +public class FinalConfigInnerClassTestCase { + + // Intentionally FINAL. + @Configuration + static final class Config { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/GenericPropertiesContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/support/GenericPropertiesContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d48a43614a26b0984594f86e90c7eed2f937cb63 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/GenericPropertiesContextLoaderTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.test.context.MergedContextConfiguration; + +import static org.hamcrest.CoreMatchers.*; + +/** + * Unit tests for {@link GenericPropertiesContextLoader}. + * + * @author Sam Brannen + * @since 4.0.4 + */ +public class GenericPropertiesContextLoaderTests { + + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + + @Test + public void configMustNotContainAnnotatedClasses() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(containsString("does not support annotated classes")); + + GenericPropertiesContextLoader loader = new GenericPropertiesContextLoader(); + MergedContextConfiguration mergedConfig = new MergedContextConfiguration(getClass(), EMPTY_STRING_ARRAY, + new Class[] { getClass() }, EMPTY_STRING_ARRAY, loader); + loader.loadContext(mergedConfig); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/GenericXmlContextLoaderResourceLocationsTests.java b/spring-test/src/test/java/org/springframework/test/context/support/GenericXmlContextLoaderResourceLocationsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d36983847e44d4c2b95e1c3776679af17f13c2ae --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/GenericXmlContextLoaderResourceLocationsTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextLoader; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; + +import static org.junit.Assert.*; + +/** + * JUnit 4 based unit test which verifies proper + * {@link ContextLoader#processLocations(Class, String...) processing} of + * {@code resource locations} by a {@link GenericXmlContextLoader} + * configured via {@link ContextConfiguration @ContextConfiguration}. + * Specifically, this test addresses the issues raised in SPR-3949: + * ContextConfiguration annotation should accept not only classpath resources. + * + * @author Sam Brannen + * @since 2.5 + */ +@RunWith(Parameterized.class) +public class GenericXmlContextLoaderResourceLocationsTests { + + private static final Log logger = LogFactory.getLog(GenericXmlContextLoaderResourceLocationsTests.class); + + protected final Class testClass; + protected final String[] expectedLocations; + + + @Parameters(name = "{0}") + public static Collection contextConfigurationLocationsData() { + @ContextConfiguration + class ClasspathNonExistentDefaultLocationsTestCase { + } + + @ContextConfiguration + class ClasspathExistentDefaultLocationsTestCase { + } + + @ContextConfiguration({ "context1.xml", "context2.xml" }) + class ImplicitClasspathLocationsTestCase { + } + + @ContextConfiguration("classpath:context.xml") + class ExplicitClasspathLocationsTestCase { + } + + @ContextConfiguration("file:/testing/directory/context.xml") + class ExplicitFileLocationsTestCase { + } + + @ContextConfiguration("https://example.com/context.xml") + class ExplicitUrlLocationsTestCase { + } + + @ContextConfiguration({ "context1.xml", "classpath:context2.xml", "/context3.xml", + "file:/testing/directory/context.xml", "https://example.com/context.xml" }) + class ExplicitMixedPathTypesLocationsTestCase { + } + + return Arrays.asList(new Object[][] { + + { ClasspathNonExistentDefaultLocationsTestCase.class.getSimpleName(), new String[] {} }, + + { + ClasspathExistentDefaultLocationsTestCase.class.getSimpleName(), + new String[] { "classpath:org/springframework/test/context/support/GenericXmlContextLoaderResourceLocationsTests$1ClasspathExistentDefaultLocationsTestCase-context.xml" } }, + + { + ImplicitClasspathLocationsTestCase.class.getSimpleName(), + new String[] { "classpath:/org/springframework/test/context/support/context1.xml", + "classpath:/org/springframework/test/context/support/context2.xml" } }, + + { ExplicitClasspathLocationsTestCase.class.getSimpleName(), new String[] { "classpath:context.xml" } }, + + { ExplicitFileLocationsTestCase.class.getSimpleName(), new String[] { "file:/testing/directory/context.xml" } }, + + { ExplicitUrlLocationsTestCase.class.getSimpleName(), new String[] { "https://example.com/context.xml" } }, + + { + ExplicitMixedPathTypesLocationsTestCase.class.getSimpleName(), + new String[] { "classpath:/org/springframework/test/context/support/context1.xml", + "classpath:context2.xml", "classpath:/context3.xml", "file:/testing/directory/context.xml", + "https://example.com/context.xml" } } + + }); + } + + public GenericXmlContextLoaderResourceLocationsTests(final String testClassName, final String[] expectedLocations) throws Exception { + this.testClass = ClassUtils.forName(getClass().getName() + "$1" + testClassName, getClass().getClassLoader()); + this.expectedLocations = expectedLocations; + } + + @Test + public void assertContextConfigurationLocations() throws Exception { + + final ContextConfiguration contextConfig = this.testClass.getAnnotation(ContextConfiguration.class); + final ContextLoader contextLoader = new GenericXmlContextLoader(); + final String[] configuredLocations = (String[]) AnnotationUtils.getValue(contextConfig); + final String[] processedLocations = contextLoader.processLocations(this.testClass, configuredLocations); + + if (logger.isDebugEnabled()) { + logger.debug("----------------------------------------------------------------------"); + logger.debug("Configured locations: " + ObjectUtils.nullSafeToString(configuredLocations)); + logger.debug("Expected locations: " + ObjectUtils.nullSafeToString(this.expectedLocations)); + logger.debug("Processed locations: " + ObjectUtils.nullSafeToString(processedLocations)); + } + + assertArrayEquals("Verifying locations for test [" + this.testClass + "].", this.expectedLocations, + processedLocations); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/GenericXmlContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/support/GenericXmlContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c8ae43f73099299534c137ad24483fa244fb4232 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/GenericXmlContextLoaderTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.test.context.MergedContextConfiguration; + +import static org.hamcrest.CoreMatchers.*; + +/** + * Unit tests for {@link GenericXmlContextLoader}. + * + * @author Sam Brannen + * @since 4.0.4 + * @see GenericXmlContextLoaderResourceLocationsTests + */ +public class GenericXmlContextLoaderTests { + + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + + @Test + public void configMustNotContainAnnotatedClasses() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(containsString("does not support annotated classes")); + + GenericXmlContextLoader loader = new GenericXmlContextLoader(); + MergedContextConfiguration mergedConfig = new MergedContextConfiguration(getClass(), EMPTY_STRING_ARRAY, + new Class[] { getClass() }, EMPTY_STRING_ARRAY, loader); + loader.loadContext(mergedConfig); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/MultipleStaticConfigurationClassesTestCase.java b/spring-test/src/test/java/org/springframework/test/context/support/MultipleStaticConfigurationClassesTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..7c50c69be03b5828e37c4e8063c5de3db013e66e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/MultipleStaticConfigurationClassesTestCase.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.springframework.context.annotation.Configuration; + +/** + * Not an actual test case. + * + * @author Sam Brannen + * @since 3.1 + * @see AnnotationConfigContextLoaderTests + */ +public class MultipleStaticConfigurationClassesTestCase { + + @Configuration + static class ConfigA { + } + + @Configuration + static class ConfigB { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/NonStaticConfigInnerClassesTestCase.java b/spring-test/src/test/java/org/springframework/test/context/support/NonStaticConfigInnerClassesTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..773d8b38b9be611b08c2223d9e37066fa33a7b5b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/NonStaticConfigInnerClassesTestCase.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.springframework.context.annotation.Configuration; + +/** + * Not an actual test case. + * + * @author Sam Brannen + * @since 3.1 + * @see AnnotationConfigContextLoaderTests + */ +public class NonStaticConfigInnerClassesTestCase { + + // Intentionally not static + @Configuration + class FooConfig { + } + + // Intentionally not static + @Configuration + class BarConfig { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/PlainVanillaFooConfigInnerClassTestCase.java b/spring-test/src/test/java/org/springframework/test/context/support/PlainVanillaFooConfigInnerClassTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..33e15a806b4f3ddafbd86c4f74201950ab08cda5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/PlainVanillaFooConfigInnerClassTestCase.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +/** + * Not an actual test case. + * + * @author Sam Brannen + * @since 3.1 + * @see AnnotationConfigContextLoaderTests + */ +public class PlainVanillaFooConfigInnerClassTestCase { + + // Intentionally NOT annotated with @Configuration + static class FooConfig { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/PrivateConfigInnerClassTestCase.java b/spring-test/src/test/java/org/springframework/test/context/support/PrivateConfigInnerClassTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..c185cf49d7331b7279e1b6ca8feabceb62b656a5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/PrivateConfigInnerClassTestCase.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import org.springframework.context.annotation.Configuration; + +/** + * Not an actual test case. + * + * @author Sam Brannen + * @since 3.1 + * @see AnnotationConfigContextLoaderTests + */ +public class PrivateConfigInnerClassTestCase { + + // Intentionally private + @Configuration + private static class PrivateConfig { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/support/TestPropertySourceUtilsTests.java b/spring-test/src/test/java/org/springframework/test/context/support/TestPropertySourceUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f4febfee1c9dfee2ca374cabf1845dbc550cba04 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/support/TestPropertySourceUtilsTests.java @@ -0,0 +1,289 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.support; + +import java.util.Map; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.core.annotation.AnnotationConfigurationException; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.MutablePropertySources; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.mock.env.MockEnvironment; +import org.springframework.mock.env.MockPropertySource; +import org.springframework.test.context.TestPropertySource; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; +import static org.springframework.test.context.support.TestPropertySourceUtils.*; + +/** + * Unit tests for {@link TestPropertySourceUtils}. + * + * @author Sam Brannen + * @since 4.1 + */ +public class TestPropertySourceUtilsTests { + + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + + private static final String[] KEY_VALUE_PAIR = new String[] {"key = value"}; + + private static final String[] FOO_LOCATIONS = new String[] {"classpath:/foo.properties"}; + + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + + @Test + public void emptyAnnotation() { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(startsWith("Could not detect default properties file for test")); + expectedException.expectMessage(containsString("EmptyPropertySources.properties")); + buildMergedTestPropertySources(EmptyPropertySources.class); + } + + @Test + public void extendedEmptyAnnotation() { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(startsWith("Could not detect default properties file for test")); + expectedException.expectMessage(containsString("ExtendedEmptyPropertySources.properties")); + buildMergedTestPropertySources(ExtendedEmptyPropertySources.class); + } + + @Test + public void value() { + assertMergedTestPropertySources(ValuePropertySources.class, asArray("classpath:/value.xml"), + EMPTY_STRING_ARRAY); + } + + @Test + public void locationsAndValueAttributes() { + expectedException.expect(AnnotationConfigurationException.class); + buildMergedTestPropertySources(LocationsAndValuePropertySources.class); + } + + @Test + public void locationsAndProperties() { + assertMergedTestPropertySources(LocationsAndPropertiesPropertySources.class, + asArray("classpath:/foo1.xml", "classpath:/foo2.xml"), asArray("k1a=v1a", "k1b: v1b")); + } + + @Test + public void inheritedLocationsAndProperties() { + assertMergedTestPropertySources(InheritedPropertySources.class, + asArray("classpath:/foo1.xml", "classpath:/foo2.xml"), asArray("k1a=v1a", "k1b: v1b")); + } + + @Test + public void extendedLocationsAndProperties() { + assertMergedTestPropertySources(ExtendedPropertySources.class, + asArray("classpath:/foo1.xml", "classpath:/foo2.xml", "classpath:/bar1.xml", "classpath:/bar2.xml"), + asArray("k1a=v1a", "k1b: v1b", "k2a v2a", "k2b: v2b")); + } + + @Test + public void overriddenLocations() { + assertMergedTestPropertySources(OverriddenLocationsPropertySources.class, + asArray("classpath:/baz.properties"), asArray("k1a=v1a", "k1b: v1b", "key = value")); + } + + @Test + public void overriddenProperties() { + assertMergedTestPropertySources(OverriddenPropertiesPropertySources.class, + asArray("classpath:/foo1.xml", "classpath:/foo2.xml", "classpath:/baz.properties"), KEY_VALUE_PAIR); + } + + @Test + public void overriddenLocationsAndProperties() { + assertMergedTestPropertySources(OverriddenLocationsAndPropertiesPropertySources.class, + asArray("classpath:/baz.properties"), KEY_VALUE_PAIR); + } + + + @Test + public void addPropertiesFilesToEnvironmentWithNullContext() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("must not be null"); + addPropertiesFilesToEnvironment((ConfigurableApplicationContext) null, FOO_LOCATIONS); + } + + @Test + public void addPropertiesFilesToEnvironmentWithContextAndNullLocations() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("must not be null"); + addPropertiesFilesToEnvironment(mock(ConfigurableApplicationContext.class), (String[]) null); + } + + @Test + public void addPropertiesFilesToEnvironmentWithNullEnvironment() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("must not be null"); + addPropertiesFilesToEnvironment((ConfigurableEnvironment) null, mock(ResourceLoader.class), FOO_LOCATIONS); + } + + @Test + public void addPropertiesFilesToEnvironmentWithEnvironmentAndNullLocations() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("must not be null"); + addPropertiesFilesToEnvironment(new MockEnvironment(), mock(ResourceLoader.class), (String[]) null); + } + + @Test + public void addPropertiesFilesToEnvironmentWithSinglePropertyFromVirtualFile() { + ConfigurableEnvironment environment = new MockEnvironment(); + + MutablePropertySources propertySources = environment.getPropertySources(); + propertySources.remove(MockPropertySource.MOCK_PROPERTIES_PROPERTY_SOURCE_NAME); + assertEquals(0, propertySources.size()); + + String pair = "key = value"; + ByteArrayResource resource = new ByteArrayResource(pair.getBytes(), "from inlined property: " + pair); + ResourceLoader resourceLoader = mock(ResourceLoader.class); + when(resourceLoader.getResource(anyString())).thenReturn(resource); + + addPropertiesFilesToEnvironment(environment, resourceLoader, FOO_LOCATIONS); + assertEquals(1, propertySources.size()); + assertEquals("value", environment.getProperty("key")); + } + + @Test + public void addInlinedPropertiesToEnvironmentWithNullContext() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("context"); + addInlinedPropertiesToEnvironment((ConfigurableApplicationContext) null, KEY_VALUE_PAIR); + } + + @Test + public void addInlinedPropertiesToEnvironmentWithContextAndNullInlinedProperties() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("inlined"); + addInlinedPropertiesToEnvironment(mock(ConfigurableApplicationContext.class), (String[]) null); + } + + @Test + public void addInlinedPropertiesToEnvironmentWithNullEnvironment() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("environment"); + addInlinedPropertiesToEnvironment((ConfigurableEnvironment) null, KEY_VALUE_PAIR); + } + + @Test + public void addInlinedPropertiesToEnvironmentWithEnvironmentAndNullInlinedProperties() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("inlined"); + addInlinedPropertiesToEnvironment(new MockEnvironment(), (String[]) null); + } + + @Test + public void addInlinedPropertiesToEnvironmentWithMalformedUnicodeInValue() { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Failed to load test environment property"); + addInlinedPropertiesToEnvironment(new MockEnvironment(), asArray("key = \\uZZZZ")); + } + + @Test + public void addInlinedPropertiesToEnvironmentWithMultipleKeyValuePairsInSingleInlinedProperty() { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Failed to load exactly one test environment property"); + addInlinedPropertiesToEnvironment(new MockEnvironment(), asArray("a=b\nx=y")); + } + + @Test + @SuppressWarnings("rawtypes") + public void addInlinedPropertiesToEnvironmentWithEmptyProperty() { + ConfigurableEnvironment environment = new MockEnvironment(); + MutablePropertySources propertySources = environment.getPropertySources(); + propertySources.remove(MockPropertySource.MOCK_PROPERTIES_PROPERTY_SOURCE_NAME); + assertEquals(0, propertySources.size()); + addInlinedPropertiesToEnvironment(environment, asArray(" ")); + assertEquals(1, propertySources.size()); + assertEquals(0, ((Map) propertySources.iterator().next().getSource()).size()); + } + + @Test + public void convertInlinedPropertiesToMapWithNullInlinedProperties() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("inlined"); + convertInlinedPropertiesToMap((String[]) null); + } + + + private static void assertMergedTestPropertySources(Class testClass, String[] expectedLocations, + String[] expectedProperties) { + + MergedTestPropertySources mergedPropertySources = buildMergedTestPropertySources(testClass); + assertNotNull(mergedPropertySources); + assertArrayEquals(expectedLocations, mergedPropertySources.getLocations()); + assertArrayEquals(expectedProperties, mergedPropertySources.getProperties()); + } + + + @SafeVarargs + private static T[] asArray(T... arr) { + return arr; + } + + + @TestPropertySource + static class EmptyPropertySources { + } + + @TestPropertySource + static class ExtendedEmptyPropertySources extends EmptyPropertySources { + } + + @TestPropertySource(locations = "/foo", value = "/bar") + static class LocationsAndValuePropertySources { + } + + @TestPropertySource("/value.xml") + static class ValuePropertySources { + } + + @TestPropertySource(locations = { "/foo1.xml", "/foo2.xml" }, properties = { "k1a=v1a", "k1b: v1b" }) + static class LocationsAndPropertiesPropertySources { + } + + static class InheritedPropertySources extends LocationsAndPropertiesPropertySources { + } + + @TestPropertySource(locations = { "/bar1.xml", "/bar2.xml" }, properties = { "k2a v2a", "k2b: v2b" }) + static class ExtendedPropertySources extends LocationsAndPropertiesPropertySources { + } + + @TestPropertySource(locations = "/baz.properties", properties = "key = value", inheritLocations = false) + static class OverriddenLocationsPropertySources extends LocationsAndPropertiesPropertySources { + } + + @TestPropertySource(locations = "/baz.properties", properties = "key = value", inheritProperties = false) + static class OverriddenPropertiesPropertySources extends LocationsAndPropertiesPropertySources { + } + + @TestPropertySource(locations = "/baz.properties", properties = "key = value", inheritLocations = false, inheritProperties = false) + static class OverriddenLocationsAndPropertiesPropertySources extends LocationsAndPropertiesPropertySources { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/AnnotationConfigTestNGSpringContextTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/AnnotationConfigTestNGSpringContextTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f04280d8644eba9282a7700a463ec64e7b6bc2dc --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/AnnotationConfigTestNGSpringContextTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; + +import org.testng.annotations.Test; + +import static org.testng.Assert.*; + +/** + * Integration tests that verify support for + * {@link org.springframework.context.annotation.Configuration @Configuration} classes + * with TestNG-based tests. + * + *

Configuration will be loaded from + * {@link AnnotationConfigTestNGSpringContextTests.Config}. + * + * @author Sam Brannen + * @since 5.1 + */ +@ContextConfiguration +public class AnnotationConfigTestNGSpringContextTests extends AbstractTestNGSpringContextTests { + + @Autowired + Employee employee; + + @Autowired + Pet pet; + + @Test + void autowiringFromConfigClass() { + assertNotNull(employee, "The employee should have been autowired."); + assertEquals(employee.getName(), "John Smith"); + + assertNotNull(pet, "The pet should have been autowired."); + assertEquals(pet.getName(), "Fido"); + } + + + @Configuration + static class Config { + + @Bean + Employee employee() { + return new Employee("John Smith"); + } + + @Bean + Pet pet() { + return new Pet("Fido"); + } + + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/AnnotationConfigTransactionalTestNGSpringContextTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/AnnotationConfigTransactionalTestNGSpringContextTests.java new file mode 100644 index 0000000000000000000000000000000000000000..70777580f7d5516d54c81f796cb4eefbacf7ccc4 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/AnnotationConfigTransactionalTestNGSpringContextTests.java @@ -0,0 +1,192 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng; + +import javax.sql.DataSource; + +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; + +import static org.springframework.test.transaction.TransactionTestUtils.*; +import static org.testng.Assert.*; + +/** + * Integration tests that verify support for + * {@link org.springframework.context.annotation.Configuration @Configuration} classes + * with transactional TestNG-based tests. + * + *

Configuration will be loaded from + * {@link AnnotationConfigTransactionalTestNGSpringContextTests.ContextConfiguration}. + * + * @author Sam Brannen + * @since 3.1 + */ +@ContextConfiguration +public class AnnotationConfigTransactionalTestNGSpringContextTests + extends AbstractTransactionalTestNGSpringContextTests { + + private static final String JANE = "jane"; + private static final String SUE = "sue"; + private static final String YODA = "yoda"; + + private static final int NUM_TESTS = 2; + private static final int NUM_TX_TESTS = 1; + + private static int numSetUpCalls = 0; + private static int numSetUpCallsInTransaction = 0; + private static int numTearDownCalls = 0; + private static int numTearDownCallsInTransaction = 0; + + @Autowired + private Employee employee; + + @Autowired + private Pet pet; + + + private int createPerson(String name) { + return jdbcTemplate.update("INSERT INTO person VALUES(?)", name); + } + + private int deletePerson(String name) { + return jdbcTemplate.update("DELETE FROM person WHERE name=?", name); + } + + private void assertNumRowsInPersonTable(int expectedNumRows, String testState) { + assertEquals(countRowsInTable("person"), expectedNumRows, "the number of rows in the person table (" + + testState + ")."); + } + + private void assertAddPerson(final String name) { + assertEquals(createPerson(name), 1, "Adding '" + name + "'"); + } + + @BeforeClass + void beforeClass() { + numSetUpCalls = 0; + numSetUpCallsInTransaction = 0; + numTearDownCalls = 0; + numTearDownCallsInTransaction = 0; + } + + @AfterClass + void afterClass() { + assertEquals(numSetUpCalls, NUM_TESTS, "number of calls to setUp()."); + assertEquals(numSetUpCallsInTransaction, NUM_TX_TESTS, "number of calls to setUp() within a transaction."); + assertEquals(numTearDownCalls, NUM_TESTS, "number of calls to tearDown()."); + assertEquals(numTearDownCallsInTransaction, NUM_TX_TESTS, "number of calls to tearDown() within a transaction."); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void autowiringFromConfigClass() { + assertNotNull(employee, "The employee should have been autowired."); + assertEquals(employee.getName(), "John Smith"); + + assertNotNull(pet, "The pet should have been autowired."); + assertEquals(pet.getName(), "Fido"); + } + + @BeforeTransaction + void beforeTransaction() { + assertNumRowsInPersonTable(1, "before a transactional test method"); + assertAddPerson(YODA); + } + + @BeforeMethod + void setUp() throws Exception { + numSetUpCalls++; + if (inTransaction()) { + numSetUpCallsInTransaction++; + } + assertNumRowsInPersonTable((inTransaction() ? 2 : 1), "before a test method"); + } + + @Test + void modifyTestDataWithinTransaction() { + assertInTransaction(true); + assertAddPerson(JANE); + assertAddPerson(SUE); + assertNumRowsInPersonTable(4, "in modifyTestDataWithinTransaction()"); + } + + @AfterMethod + void tearDown() throws Exception { + numTearDownCalls++; + if (inTransaction()) { + numTearDownCallsInTransaction++; + } + assertNumRowsInPersonTable((inTransaction() ? 4 : 1), "after a test method"); + } + + @AfterTransaction + void afterTransaction() { + assertEquals(deletePerson(YODA), 1, "Deleting yoda"); + assertNumRowsInPersonTable(1, "after a transactional test method"); + } + + + @Configuration + static class ContextConfiguration { + + @Bean + Employee employee() { + Employee employee = new Employee(); + employee.setName("John Smith"); + employee.setAge(42); + employee.setCompany("Acme Widgets, Inc."); + return employee; + } + + @Bean + Pet pet() { + return new Pet("Fido"); + } + + @Bean + PlatformTransactionManager transactionManager() { + return new DataSourceTransactionManager(dataSource()); + } + + @Bean + DataSource dataSource() { + return new EmbeddedDatabaseBuilder()// + .addScript("classpath:/org/springframework/test/jdbc/schema.sql")// + .addScript("classpath:/org/springframework/test/jdbc/data.sql")// + .build(); + } + + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/ConcreteTransactionalTestNGSpringContextTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/ConcreteTransactionalTestNGSpringContextTests.java new file mode 100644 index 0000000000000000000000000000000000000000..caba59481834f4581c943eb62c346167e852d116 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/ConcreteTransactionalTestNGSpringContextTests.java @@ -0,0 +1,234 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng; + +import javax.annotation.Resource; + +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.tests.sample.beans.Employee; +import org.springframework.tests.sample.beans.Pet; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; + +import static org.springframework.test.transaction.TransactionTestUtils.*; +import static org.testng.Assert.*; + +/** + * Combined integration test for {@link AbstractTestNGSpringContextTests} and + * {@link AbstractTransactionalTestNGSpringContextTests}. + * + * @author Sam Brannen + * @since 2.5 + */ +@ContextConfiguration +public class ConcreteTransactionalTestNGSpringContextTests extends AbstractTransactionalTestNGSpringContextTests + implements BeanNameAware, InitializingBean { + + private static final String JANE = "jane"; + private static final String SUE = "sue"; + private static final String YODA = "yoda"; + + private static final int NUM_TESTS = 8; + private static final int NUM_TX_TESTS = 1; + + private static int numSetUpCalls = 0; + private static int numSetUpCallsInTransaction = 0; + private static int numTearDownCalls = 0; + private static int numTearDownCallsInTransaction = 0; + + + private Employee employee; + + @Autowired + private Pet pet; + + @Autowired(required = false) + private Long nonrequiredLong; + + @Resource + private String foo; + + private String bar; + + private String beanName; + + private boolean beanInitialized = false; + + + @Autowired + private void setEmployee(Employee employee) { + this.employee = employee; + } + + @Resource + private void setBar(String bar) { + this.bar = bar; + } + + @Override + public void setBeanName(String beanName) { + this.beanName = beanName; + } + + @Override + public void afterPropertiesSet() { + this.beanInitialized = true; + } + + + @BeforeClass + void beforeClass() { + numSetUpCalls = 0; + numSetUpCallsInTransaction = 0; + numTearDownCalls = 0; + numTearDownCallsInTransaction = 0; + } + + @AfterClass + void afterClass() { + assertEquals(numSetUpCalls, NUM_TESTS, "number of calls to setUp()."); + assertEquals(numSetUpCallsInTransaction, NUM_TX_TESTS, "number of calls to setUp() within a transaction."); + assertEquals(numTearDownCalls, NUM_TESTS, "number of calls to tearDown()."); + assertEquals(numTearDownCallsInTransaction, NUM_TX_TESTS, "number of calls to tearDown() within a transaction."); + } + + @BeforeMethod + void setUp() { + numSetUpCalls++; + if (inTransaction()) { + numSetUpCallsInTransaction++; + } + assertNumRowsInPersonTable((inTransaction() ? 2 : 1), "before a test method"); + } + + @AfterMethod + void tearDown() { + numTearDownCalls++; + if (inTransaction()) { + numTearDownCallsInTransaction++; + } + assertNumRowsInPersonTable((inTransaction() ? 4 : 1), "after a test method"); + } + + @BeforeTransaction + void beforeTransaction() { + assertNumRowsInPersonTable(1, "before a transactional test method"); + assertAddPerson(YODA); + } + + @AfterTransaction + void afterTransaction() { + assertEquals(deletePerson(YODA), 1, "Deleting yoda"); + assertNumRowsInPersonTable(1, "after a transactional test method"); + } + + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void verifyBeanNameSet() { + assertInTransaction(false); + assertTrue(this.beanName.startsWith(getClass().getName()), "The bean name of this test instance " + + "should have been set to the fully qualified class name due to BeanNameAware semantics."); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void verifyApplicationContextSet() { + assertInTransaction(false); + assertNotNull(super.applicationContext, + "The application context should have been set due to ApplicationContextAware semantics."); + Employee employeeBean = (Employee) super.applicationContext.getBean("employee"); + assertEquals(employeeBean.getName(), "John Smith", "employee's name."); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void verifyBeanInitialized() { + assertInTransaction(false); + assertTrue(beanInitialized, + "This test instance should have been initialized due to InitializingBean semantics."); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void verifyAnnotationAutowiredFields() { + assertInTransaction(false); + assertNull(nonrequiredLong, "The nonrequiredLong field should NOT have been autowired."); + assertNotNull(pet, "The pet field should have been autowired."); + assertEquals(pet.getName(), "Fido", "pet's name."); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void verifyAnnotationAutowiredMethods() { + assertInTransaction(false); + assertNotNull(employee, "The setEmployee() method should have been autowired."); + assertEquals(employee.getName(), "John Smith", "employee's name."); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void verifyResourceAnnotationInjectedFields() { + assertInTransaction(false); + assertEquals(foo, "Foo", "The foo field should have been injected via @Resource."); + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void verifyResourceAnnotationInjectedMethods() { + assertInTransaction(false); + assertEquals(bar, "Bar", "The setBar() method should have been injected via @Resource."); + } + + @Test + void modifyTestDataWithinTransaction() { + assertInTransaction(true); + assertAddPerson(JANE); + assertAddPerson(SUE); + assertNumRowsInPersonTable(4, "in modifyTestDataWithinTransaction()"); + } + + + private int createPerson(String name) { + return jdbcTemplate.update("INSERT INTO person VALUES(?)", name); + } + + private int deletePerson(String name) { + return jdbcTemplate.update("DELETE FROM person WHERE name=?", name); + } + + private void assertNumRowsInPersonTable(int expectedNumRows, String testState) { + assertEquals(countRowsInTable("person"), expectedNumRows, + "the number of rows in the person table (" + testState + ")."); + } + + private void assertAddPerson(String name) { + assertEquals(createPerson(name), 1, "Adding '" + name + "'"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/DirtiesContextTransactionalTestNGSpringContextTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/DirtiesContextTransactionalTestNGSpringContextTests.java new file mode 100644 index 0000000000000000000000000000000000000000..66cb01be63a6839478a6b1cb5e5ea99a87e6450a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/DirtiesContextTransactionalTestNGSpringContextTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng; + +import org.testng.annotations.Test; + +import org.springframework.context.ApplicationContext; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestContextManager; + +import static org.springframework.test.transaction.TransactionTestUtils.*; +import static org.testng.Assert.*; + +/** + *

+ * TestNG based integration test to assess the claim in SPR-3880 that a "context marked dirty using + * {@link DirtiesContext @DirtiesContext} in [a] TestNG based test is not + * reloaded in subsequent tests". + *

+ *

+ * After careful analysis, it turns out that the {@link ApplicationContext} was + * in fact reloaded; however, due to how the test instance was instrumented with + * the {@link TestContextManager} in {@link AbstractTestNGSpringContextTests}, + * dependency injection was not being performed on the test instance between + * individual tests. DirtiesContextTransactionalTestNGSpringContextTests + * therefore verifies the expected behavior and correct semantics. + *

+ * + * @author Sam Brannen + * @since 2.5 + */ +@ContextConfiguration +public class DirtiesContextTransactionalTestNGSpringContextTests extends AbstractTransactionalTestNGSpringContextTests { + + private ApplicationContext dirtiedApplicationContext; + + + private void performCommonAssertions() { + assertInTransaction(true); + assertNotNull(super.applicationContext, + "The application context should have been set due to ApplicationContextAware semantics."); + assertNotNull(super.jdbcTemplate, + "The JdbcTemplate should have been created in setDataSource() via DI for the DataSource."); + } + + @Test + @DirtiesContext + public void dirtyContext() { + performCommonAssertions(); + this.dirtiedApplicationContext = super.applicationContext; + } + + @Test(dependsOnMethods = { "dirtyContext" }) + public void verifyContextWasDirtied() { + performCommonAssertions(); + assertNotSame(super.applicationContext, this.dirtiedApplicationContext, + "The application context should have been 'dirtied'."); + this.dirtiedApplicationContext = super.applicationContext; + } + + @Test(dependsOnMethods = { "verifyContextWasDirtied" }) + public void verifyContextWasNotDirtied() { + assertSame(this.applicationContext, this.dirtiedApplicationContext, + "The application context should NOT have been 'dirtied'."); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/TimedTransactionalTestNGSpringContextTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/TimedTransactionalTestNGSpringContextTests.java new file mode 100644 index 0000000000000000000000000000000000000000..39c7e22b633b63b568a5d93f7bf75bd79374b3d5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/TimedTransactionalTestNGSpringContextTests.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng; + +import org.testng.annotations.Test; + +import org.springframework.test.context.ContextConfiguration; + +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * Timed integration tests for + * {@link AbstractTransactionalTestNGSpringContextTests}; used to verify claim + * raised in SPR-6124. + * + * @author Sam Brannen + * @since 3.0 + */ +@ContextConfiguration +public class TimedTransactionalTestNGSpringContextTests extends AbstractTransactionalTestNGSpringContextTests { + + @Test + public void testWithoutTimeout() { + assertInTransaction(true); + } + + // TODO Enable TestNG test with timeout once we have a solution. + @Test(timeOut = 10000, enabled = false) + public void testWithTimeout() { + assertInTransaction(true); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/TrackingTestNGTestListener.java b/spring-test/src/test/java/org/springframework/test/context/testng/TrackingTestNGTestListener.java new file mode 100644 index 0000000000000000000000000000000000000000..192c5e346cd3a5d942ee6ce96470e1454c04d571 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/TrackingTestNGTestListener.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng; + +import org.testng.ITestContext; +import org.testng.ITestListener; +import org.testng.ITestResult; + +/** + * Simple {@link ITestListener} which tracks how many times certain TestNG + * callback methods were called: only intended for the integration test suite. + * + * @author Sam Brannen + * @since 4.2 + */ +public class TrackingTestNGTestListener implements ITestListener { + + public int testStartCount = 0; + + public int testSuccessCount = 0; + + public int testFailureCount = 0; + + public int failedConfigurationsCount = 0; + + + @Override + public void onFinish(ITestContext testContext) { + this.failedConfigurationsCount += testContext.getFailedConfigurations().size(); + } + + @Override + public void onStart(ITestContext testContext) { + } + + @Override + public void onTestFailedButWithinSuccessPercentage(ITestResult testResult) { + } + + @Override + public void onTestFailure(ITestResult testResult) { + this.testFailureCount++; + } + + @Override + public void onTestSkipped(ITestResult testResult) { + } + + @Override + public void onTestStart(ITestResult testResult) { + this.testStartCount++; + } + + @Override + public void onTestSuccess(ITestResult testResult) { + this.testSuccessCount++; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/AbstractEjbTxDaoTestNGTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/AbstractEjbTxDaoTestNGTests.java new file mode 100644 index 0000000000000000000000000000000000000000..078ec529ef24da5a723bf0cfc34ed38f3ae3fb28 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/AbstractEjbTxDaoTestNGTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.transaction.ejb; + +import javax.ejb.EJB; +import javax.persistence.EntityManager; +import javax.persistence.PersistenceContext; + +import org.testng.annotations.AfterMethod; +import org.testng.annotations.Test; + +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.annotation.DirtiesContext.ClassMode; +import org.springframework.test.context.testng.AbstractTransactionalTestNGSpringContextTests; +import org.springframework.test.context.transaction.ejb.dao.TestEntityDao; + +import static org.testng.AssertJUnit.*; + +/** + * Abstract base class for all TestNG-based tests involving EJB transaction + * support in the TestContext framework. + * + * @author Sam Brannen + * @author Xavier Detant + * @since 4.0.1 + */ +@DirtiesContext(classMode = ClassMode.AFTER_CLASS) +public abstract class AbstractEjbTxDaoTestNGTests extends AbstractTransactionalTestNGSpringContextTests { + + protected static final String TEST_NAME = "test-name"; + + @EJB + protected TestEntityDao dao; + + @PersistenceContext + protected EntityManager em; + + + @Test + public void test1InitialState() { + int count = dao.getCount(TEST_NAME); + assertEquals("New TestEntity should have count=0.", 0, count); + } + + @Test(dependsOnMethods = "test1InitialState") + public void test2IncrementCount1() { + int count = dao.incrementCount(TEST_NAME); + assertEquals("Expected count=1 after first increment.", 1, count); + } + + /** + * The default implementation of this method assumes that the transaction + * for {@link #test2IncrementCount1()} was committed. Therefore, it is + * expected that the previous increment has been persisted in the database. + */ + @Test(dependsOnMethods = "test2IncrementCount1") + public void test3IncrementCount2() { + int count = dao.getCount(TEST_NAME); + assertEquals("Expected count=1 after test2IncrementCount1().", 1, count); + + count = dao.incrementCount(TEST_NAME); + assertEquals("Expected count=2 now.", 2, count); + } + + @AfterMethod(alwaysRun = true) + public void synchronizePersistenceContext() { + em.flush(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/CommitForRequiredEjbTxDaoTestNGTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/CommitForRequiredEjbTxDaoTestNGTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8aea564adb02d7585ea73d663ad68403536612f2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/CommitForRequiredEjbTxDaoTestNGTests.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.transaction.ejb; + +import org.springframework.test.annotation.Commit; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; +import org.springframework.test.context.transaction.ejb.dao.RequiredEjbTxTestEntityDao; + +import org.testng.annotations.Test; + +/** + * Concrete subclass of {@link AbstractEjbTxDaoTestNGTests} which uses the + * {@link RequiredEjbTxTestEntityDao} and sets the default rollback semantics + * for the {@link TransactionalTestExecutionListener} to {@code false} (i.e., + * commit). + * + * @author Sam Brannen + * @since 4.0.1 + */ +@Test(suiteName = "Commit for REQUIRED") +@ContextConfiguration("/org/springframework/test/context/transaction/ejb/required-tx-config.xml") +@Commit +public class CommitForRequiredEjbTxDaoTestNGTests extends AbstractEjbTxDaoTestNGTests { + + /* test methods in superclass */ + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/CommitForRequiresNewEjbTxDaoTestNGTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/CommitForRequiresNewEjbTxDaoTestNGTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5cf811184b8765d72d465447054fa5598bf47c2d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/CommitForRequiresNewEjbTxDaoTestNGTests.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.transaction.ejb; + +import org.springframework.test.annotation.Commit; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; +import org.springframework.test.context.transaction.ejb.dao.RequiresNewEjbTxTestEntityDao; + +import org.testng.annotations.Test; + +/** + * Concrete subclass of {@link AbstractEjbTxDaoTestNGTests} which uses the + * {@link RequiresNewEjbTxTestEntityDao} and sets the default rollback semantics + * for the {@link TransactionalTestExecutionListener} to {@code false} (i.e., + * commit). + * + * @author Sam Brannen + * @since 4.0.1 + */ +@Test(suiteName = "Commit for REQUIRES_NEW") +@ContextConfiguration("/org/springframework/test/context/transaction/ejb/requires-new-tx-config.xml") +@Commit +public class CommitForRequiresNewEjbTxDaoTestNGTests extends AbstractEjbTxDaoTestNGTests { + + /* test methods in superclass */ + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/RollbackForRequiredEjbTxDaoTestNGTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/RollbackForRequiredEjbTxDaoTestNGTests.java new file mode 100644 index 0000000000000000000000000000000000000000..eda666379fa8a4abe80ee4f15a583f27738fd234 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/RollbackForRequiredEjbTxDaoTestNGTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.transaction.ejb; + +import org.springframework.test.annotation.Rollback; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; + +import org.testng.annotations.Test; + +import static org.testng.AssertJUnit.*; + +/** + * Extension of {@link CommitForRequiredEjbTxDaoTestNGTests} which sets the default + * rollback semantics for the {@link TransactionalTestExecutionListener} to + * {@code true}. The transaction managed by the TestContext framework will be + * rolled back after each test method. Consequently, any work performed in + * transactional methods that participate in the test-managed transaction will + * be rolled back automatically. + * + * @author Sam Brannen + * @since 4.0.1 + */ +@Test(suiteName = "Rollback for REQUIRED") +@Rollback +public class RollbackForRequiredEjbTxDaoTestNGTests extends CommitForRequiredEjbTxDaoTestNGTests { + + /** + * Overrides parent implementation in order to change expectations to align with + * behavior associated with "required" transactions on repositories/DAOs and + * default rollback semantics for transactions managed by the TestContext + * framework. + */ + @Test(dependsOnMethods = "test2IncrementCount1") + @Override + public void test3IncrementCount2() { + int count = dao.getCount(TEST_NAME); + // Expecting count=0 after test2IncrementCount1() since REQUIRED transactions + // participate in the existing transaction (if present), which in this case is the + // transaction managed by the TestContext framework which will be rolled back + // after each test method. + assertEquals("Expected count=0 after test2IncrementCount1().", 0, count); + + count = dao.incrementCount(TEST_NAME); + assertEquals("Expected count=1 now.", 1, count); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/RollbackForRequiresNewEjbTxDaoTestNGTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/RollbackForRequiresNewEjbTxDaoTestNGTests.java new file mode 100644 index 0000000000000000000000000000000000000000..84d312bfc319de531f3520f08a3667faa58027a0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/ejb/RollbackForRequiresNewEjbTxDaoTestNGTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.transaction.ejb; + +import org.springframework.test.annotation.Rollback; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; + +import org.testng.annotations.Test; + +/** + * Extension of {@link CommitForRequiresNewEjbTxDaoTestNGTests} which sets the default + * rollback semantics for the {@link TransactionalTestExecutionListener} to + * {@code true}. The transaction managed by the TestContext framework will be + * rolled back after each test method. Consequently, any work performed in + * transactional methods that participate in the test-managed transaction will + * be rolled back automatically. On the other hand, any work performed in + * transactional methods that do not participate in the + * test-managed transaction will not be affected by the rollback of the + * test-managed transaction. For example, such work may in fact be committed + * outside the scope of the test-managed transaction. + * + * @author Sam Brannen + * @since 4.0.1 + */ +@Test(suiteName = "Rollback for REQUIRES_NEW") +@Rollback +public class RollbackForRequiresNewEjbTxDaoTestNGTests extends CommitForRequiresNewEjbTxDaoTestNGTests { + + /* test methods in superclass */ + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/transaction/programmatic/ProgrammaticTxMgmtTestNGTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/programmatic/ProgrammaticTxMgmtTestNGTests.java new file mode 100644 index 0000000000000000000000000000000000000000..39f2b0b477f985c3995987c07a15cb54899c2020 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/transaction/programmatic/ProgrammaticTxMgmtTestNGTests.java @@ -0,0 +1,279 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.transaction.programmatic; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import javax.sql.DataSource; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.test.annotation.Commit; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.testng.AbstractTransactionalTestNGSpringContextTests; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.test.context.transaction.TestTransaction; +import org.springframework.test.context.transaction.programmatic.ProgrammaticTxMgmtTests; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; + +import org.testng.IHookCallBack; +import org.testng.ITestResult; +import org.testng.annotations.Test; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * This class is a copy of the JUnit-based {@link ProgrammaticTxMgmtTests} class + * that has been modified to run with TestNG. + * + * @author Sam Brannen + * @since 4.1 + */ +@ContextConfiguration +public class ProgrammaticTxMgmtTestNGTests extends AbstractTransactionalTestNGSpringContextTests { + + private String method; + + + @Override + public void run(IHookCallBack callBack, ITestResult testResult) { + this.method = testResult.getMethod().getMethodName(); + super.run(callBack, testResult); + } + + @BeforeTransaction + public void beforeTransaction() { + deleteFromTables("user"); + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data.sql", false); + } + + @AfterTransaction + public void afterTransaction() { + switch (method) { + case "commitTxAndStartNewTx": + case "commitTxButDoNotStartNewTx": { + assertUsers("Dogbert"); + break; + } + case "rollbackTxAndStartNewTx": + case "rollbackTxButDoNotStartNewTx": + case "startTxWithExistingTransaction": { + assertUsers("Dilbert"); + break; + } + case "rollbackTxAndStartNewTxWithDefaultCommitSemantics": { + assertUsers("Dilbert", "Dogbert"); + break; + } + default: { + fail("missing 'after transaction' assertion for test method: " + method); + } + } + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void isActiveWithNonExistentTransactionContext() { + assertFalse(TestTransaction.isActive()); + } + + @Test(expectedExceptions = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void flagForRollbackWithNonExistentTransactionContext() { + TestTransaction.flagForRollback(); + } + + @Test(expectedExceptions = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void flagForCommitWithNonExistentTransactionContext() { + TestTransaction.flagForCommit(); + } + + @Test(expectedExceptions = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void isFlaggedForRollbackWithNonExistentTransactionContext() { + TestTransaction.isFlaggedForRollback(); + } + + @Test(expectedExceptions = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void startTxWithNonExistentTransactionContext() { + TestTransaction.start(); + } + + @Test(expectedExceptions = IllegalStateException.class) + public void startTxWithExistingTransaction() { + TestTransaction.start(); + } + + @Test(expectedExceptions = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void endTxWithNonExistentTransactionContext() { + TestTransaction.end(); + } + + @Test + public void commitTxAndStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Commit + TestTransaction.flagForCommit(); + assertFalse(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertInTransaction(false); + assertFalse(TestTransaction.isActive()); + assertUsers(); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dogbert"); + + TestTransaction.start(); + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + } + + @Test + public void commitTxButDoNotStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Commit + TestTransaction.flagForCommit(); + assertFalse(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers(); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dogbert"); + } + + @Test + public void rollbackTxAndStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Rollback (automatically) + assertTrue(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers("Dilbert"); + + // Start new transaction with default rollback semantics + TestTransaction.start(); + assertInTransaction(true); + assertTrue(TestTransaction.isFlaggedForRollback()); + assertTrue(TestTransaction.isActive()); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dilbert", "Dogbert"); + } + + @Test + public void rollbackTxButDoNotStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Rollback (automatically) + assertTrue(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers("Dilbert"); + } + + @Test + @Commit + public void rollbackTxAndStartNewTxWithDefaultCommitSemantics() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Rollback + TestTransaction.flagForRollback(); + assertTrue(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers("Dilbert"); + + // Start new transaction with default commit semantics + TestTransaction.start(); + assertInTransaction(true); + assertFalse(TestTransaction.isFlaggedForRollback()); + assertTrue(TestTransaction.isActive()); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dilbert", "Dogbert"); + } + + // ------------------------------------------------------------------------- + + private void assertUsers(String... users) { + List expected = Arrays.asList(users); + Collections.sort(expected); + List actual = jdbcTemplate.queryForList("select name from user", String.class); + Collections.sort(actual); + assertEquals("Users in database;", expected, actual); + } + + + // ------------------------------------------------------------------------- + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager transactionManager() { + return new DataSourceTransactionManager(dataSource()); + } + + @Bean + public DataSource dataSource() { + return new EmbeddedDatabaseBuilder()// + .setName("programmatic-tx-mgmt-test-db")// + .addScript("classpath:/org/springframework/test/context/jdbc/schema.sql") // + .build(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/web/ServletTestExecutionListenerTestNGIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/web/ServletTestExecutionListenerTestNGIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f81b7e63dbbc9611e58c53cbf3f20a2a4a9b9fb6 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/web/ServletTestExecutionListenerTestNGIntegrationTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.web; + +import org.testng.annotations.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.testng.AbstractTestNGSpringContextTests; +import org.springframework.test.context.web.ServletTestExecutionListener; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import static org.junit.Assert.*; + +/** + * TestNG-based integration tests for {@link ServletTestExecutionListener}. + * + * @author Sam Brannen + * @since 3.2.9 + * @see org.springframework.test.context.web.ServletTestExecutionListenerJUnitIntegrationTests + */ +@ContextConfiguration +@WebAppConfiguration +public class ServletTestExecutionListenerTestNGIntegrationTests extends AbstractTestNGSpringContextTests { + + @Configuration + static class Config { + /* no beans required for this test */ + } + + + @Autowired + private MockHttpServletRequest servletRequest; + + + /** + * Verifies bug fix for SPR-11626. + * + * @see #ensureMocksAreReinjectedBetweenTests_2 + */ + @Test + void ensureMocksAreReinjectedBetweenTests_1() { + assertInjectedServletRequestEqualsRequestInRequestContextHolder(); + } + + /** + * Verifies bug fix for SPR-11626. + * + * @see #ensureMocksAreReinjectedBetweenTests_1 + */ + @Test + void ensureMocksAreReinjectedBetweenTests_2() { + assertInjectedServletRequestEqualsRequestInRequestContextHolder(); + } + + private void assertInjectedServletRequestEqualsRequestInRequestContextHolder() { + assertEquals("Injected ServletRequest must be stored in the RequestContextHolder", servletRequest, + ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/testng/web/TestNGSpringContextWebTests.java b/spring-test/src/test/java/org/springframework/test/context/testng/web/TestNGSpringContextWebTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4c0ced774719275a9efc7dfd7e7e9b53d31df19f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/testng/web/TestNGSpringContextWebTests.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.testng.web; + +import java.io.File; + +import javax.servlet.ServletContext; + +import org.testng.annotations.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.testng.AbstractTestNGSpringContextTests; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.web.context.ServletContextAware; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * TestNG-based integration tests that verify support for loading a + * {@link WebApplicationContext} when extending {@link AbstractTestNGSpringContextTests}. + * + * @author Sam Brannen + * @since 3.2.7 + */ +@ContextConfiguration +@WebAppConfiguration +public class TestNGSpringContextWebTests extends AbstractTestNGSpringContextTests implements ServletContextAware { + + @Configuration + static class Config { + + @Bean + String foo() { + return "enigma"; + } + } + + + ServletContext servletContext; + + @Autowired + WebApplicationContext wac; + + @Autowired + MockServletContext mockServletContext; + + @Autowired + MockHttpServletRequest request; + + @Autowired + MockHttpServletResponse response; + + @Autowired + MockHttpSession session; + + @Autowired + ServletWebRequest webRequest; + + @Autowired + String foo; + + + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + + @Test + void basicWacFeatures() throws Exception { + assertNotNull("ServletContext should be set in the WAC.", wac.getServletContext()); + + assertNotNull("ServletContext should have been set via ServletContextAware.", servletContext); + + assertNotNull("ServletContext should have been autowired from the WAC.", mockServletContext); + assertNotNull("MockHttpServletRequest should have been autowired from the WAC.", request); + assertNotNull("MockHttpServletResponse should have been autowired from the WAC.", response); + assertNotNull("MockHttpSession should have been autowired from the WAC.", session); + assertNotNull("ServletWebRequest should have been autowired from the WAC.", webRequest); + + Object rootWac = mockServletContext.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + assertNotNull("Root WAC must be stored in the ServletContext as: " + + WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, rootWac); + assertSame("test WAC and Root WAC in ServletContext must be the same object.", wac, rootWac); + assertSame("ServletContext instances must be the same object.", mockServletContext, wac.getServletContext()); + assertSame("ServletContext in the WAC and in the mock request", mockServletContext, request.getServletContext()); + + assertEquals("Getting real path for ServletContext resource.", + new File("src/main/webapp/index.jsp").getCanonicalPath(), mockServletContext.getRealPath("index.jsp")); + + } + + @Test + void fooEnigmaAutowired() { + assertEquals("enigma", foo); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/PrimaryTransactionManagerTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/PrimaryTransactionManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3b08738ef2049c68192f8e917729617bcc9fca6e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/PrimaryTransactionManagerTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction; + +import javax.sql.DataSource; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; +import org.springframework.core.io.ClassPathResource; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.jdbc.JdbcTestUtils; +import org.springframework.test.transaction.TransactionTestUtils; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.EnableTransactionManagement; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; + +/** + * Integration tests that ensure that primary transaction managers + * are supported. + * + * @author Sam Brannen + * @since 4.3 + * @see org.springframework.test.context.jdbc.PrimaryDataSourceTests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@DirtiesContext +public final class PrimaryTransactionManagerTests { + + private JdbcTemplate jdbcTemplate; + + + @Autowired + public void setDataSource(DataSource dataSource1) { + this.jdbcTemplate = new JdbcTemplate(dataSource1); + } + + + @BeforeTransaction + public void beforeTransaction() { + assertNumUsers(0); + } + + @AfterTransaction + public void afterTransaction() { + assertNumUsers(0); + } + + @Test + @Transactional + public void transactionalTest() { + TransactionTestUtils.assertInTransaction(true); + + ClassPathResource resource = new ClassPathResource("/org/springframework/test/context/jdbc/data.sql"); + new ResourceDatabasePopulator(resource).execute(jdbcTemplate.getDataSource()); + + assertNumUsers(1); + } + + private void assertNumUsers(int expected) { + assertEquals("Number of rows in the 'user' table", expected, + JdbcTestUtils.countRowsInTable(this.jdbcTemplate, "user")); + } + + + @Configuration + @EnableTransactionManagement // SPR-17137: should not break trying to proxy the final test class + static class Config { + + @Primary + @Bean + public PlatformTransactionManager primaryTransactionManager() { + return new DataSourceTransactionManager(dataSource1()); + } + + @Bean + public PlatformTransactionManager additionalTransactionManager() { + return new DataSourceTransactionManager(dataSource2()); + } + + @Bean + public DataSource dataSource1() { + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .addScript("classpath:/org/springframework/test/context/jdbc/schema.sql") + .build(); + } + + @Bean + public DataSource dataSource2() { + return new EmbeddedDatabaseBuilder().generateUniqueName(true).build(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/TransactionalTestExecutionListenerTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/TransactionalTestExecutionListenerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..528d98cfa558d2cc2c086a3fad41e78069f6daec --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/TransactionalTestExecutionListenerTests.java @@ -0,0 +1,583 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.BDDMockito; + +import org.springframework.beans.BeanUtils; +import org.springframework.core.annotation.AliasFor; +import org.springframework.test.annotation.Commit; +import org.springframework.test.annotation.Rollback; +import org.springframework.test.context.TestContext; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.SimpleTransactionStatus; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; +import static org.springframework.transaction.annotation.Propagation.*; + +/** + * Unit tests for {@link TransactionalTestExecutionListener}. + * + * @author Sam Brannen + * @since 4.0 + */ +public class TransactionalTestExecutionListenerTests { + + private final PlatformTransactionManager tm = mock(PlatformTransactionManager.class); + + private final TransactionalTestExecutionListener listener = new TransactionalTestExecutionListener() { + @Override + protected PlatformTransactionManager getTransactionManager(TestContext testContext, String qualifier) { + return tm; + } + }; + + private final TestContext testContext = mock(TestContext.class); + + @Rule + public ExpectedException exception = ExpectedException.none(); + + + @After + public void cleanUpThreadLocalStateForSubsequentTestClassesInSuite() { + TransactionContextHolder.removeCurrentTransactionContext(); + } + + + @Test // SPR-13895 + public void transactionalTestWithoutTransactionManager() throws Exception { + TransactionalTestExecutionListener listener = new TransactionalTestExecutionListener() { + protected PlatformTransactionManager getTransactionManager(TestContext testContext, String qualifier) { + return null; + } + }; + + Class clazz = TransactionalDeclaredOnClassLocallyTestCase.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + Invocable instance = BeanUtils.instantiateClass(clazz); + given(testContext.getTestInstance()).willReturn(instance); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("transactionalTest")); + + assertFalse("callback should not have been invoked", instance.invoked()); + TransactionContextHolder.removeCurrentTransactionContext(); + + try { + listener.beforeTestMethod(testContext); + fail("Should have thrown an IllegalStateException"); + } + catch (IllegalStateException e) { + assertTrue(e.getMessage().startsWith( + "Failed to retrieve PlatformTransactionManager for @Transactional test")); + } + } + + @Test + public void beforeTestMethodWithTransactionalDeclaredOnClassLocally() throws Exception { + assertBeforeTestMethodWithTransactionalTestMethod(TransactionalDeclaredOnClassLocallyTestCase.class); + } + + @Test + public void beforeTestMethodWithTransactionalDeclaredOnClassViaMetaAnnotation() throws Exception { + assertBeforeTestMethodWithTransactionalTestMethod(TransactionalDeclaredOnClassViaMetaAnnotationTestCase.class); + } + + @Test + public void beforeTestMethodWithTransactionalDeclaredOnClassViaMetaAnnotationWithOverride() throws Exception { + // Note: not actually invoked within a transaction since the test class is + // annotated with @MetaTxWithOverride(propagation = NOT_SUPPORTED) + assertBeforeTestMethodWithTransactionalTestMethod( + TransactionalDeclaredOnClassViaMetaAnnotationWithOverrideTestCase.class, false); + } + + @Test + public void beforeTestMethodWithTransactionalDeclaredOnMethodViaMetaAnnotationWithOverride() throws Exception { + // Note: not actually invoked within a transaction since the method is + // annotated with @MetaTxWithOverride(propagation = NOT_SUPPORTED) + assertBeforeTestMethodWithTransactionalTestMethod( + TransactionalDeclaredOnMethodViaMetaAnnotationWithOverrideTestCase.class, false); + assertBeforeTestMethodWithNonTransactionalTestMethod(TransactionalDeclaredOnMethodViaMetaAnnotationWithOverrideTestCase.class); + } + + @Test + public void beforeTestMethodWithTransactionalDeclaredOnMethodLocally() throws Exception { + assertBeforeTestMethod(TransactionalDeclaredOnMethodLocallyTestCase.class); + } + + @Test + public void beforeTestMethodWithTransactionalDeclaredOnMethodViaMetaAnnotation() throws Exception { + assertBeforeTestMethod(TransactionalDeclaredOnMethodViaMetaAnnotationTestCase.class); + } + + @Test + public void beforeTestMethodWithBeforeTransactionDeclaredLocally() throws Exception { + assertBeforeTestMethod(BeforeTransactionDeclaredLocallyTestCase.class); + } + + @Test + public void beforeTestMethodWithBeforeTransactionDeclaredViaMetaAnnotation() throws Exception { + assertBeforeTestMethod(BeforeTransactionDeclaredViaMetaAnnotationTestCase.class); + } + + @Test + public void afterTestMethodWithAfterTransactionDeclaredLocally() throws Exception { + assertAfterTestMethod(AfterTransactionDeclaredLocallyTestCase.class); + } + + @Test + public void afterTestMethodWithAfterTransactionDeclaredViaMetaAnnotation() throws Exception { + assertAfterTestMethod(AfterTransactionDeclaredViaMetaAnnotationTestCase.class); + } + + @Test + public void beforeTestMethodWithBeforeTransactionDeclaredAsInterfaceDefaultMethod() throws Exception { + assertBeforeTestMethod(BeforeTransactionDeclaredAsInterfaceDefaultMethodTestCase.class); + } + + @Test + public void afterTestMethodWithAfterTransactionDeclaredAsInterfaceDefaultMethod() throws Exception { + assertAfterTestMethod(AfterTransactionDeclaredAsInterfaceDefaultMethodTestCase.class); + } + + @Test + public void isRollbackWithMissingRollback() throws Exception { + assertIsRollback(MissingRollbackTestCase.class, true); + } + + @Test + public void isRollbackWithEmptyMethodLevelRollback() throws Exception { + assertIsRollback(EmptyMethodLevelRollbackTestCase.class, true); + } + + @Test + public void isRollbackWithMethodLevelRollbackWithExplicitValue() throws Exception { + assertIsRollback(MethodLevelRollbackWithExplicitValueTestCase.class, false); + } + + @Test + public void isRollbackWithMethodLevelRollbackViaMetaAnnotation() throws Exception { + assertIsRollback(MethodLevelRollbackViaMetaAnnotationTestCase.class, false); + } + + @Test + public void isRollbackWithEmptyClassLevelRollback() throws Exception { + assertIsRollback(EmptyClassLevelRollbackTestCase.class, true); + } + + @Test + public void isRollbackWithClassLevelRollbackWithExplicitValue() throws Exception { + assertIsRollback(ClassLevelRollbackWithExplicitValueTestCase.class, false); + } + + @Test + public void isRollbackWithClassLevelRollbackViaMetaAnnotation() throws Exception { + assertIsRollback(ClassLevelRollbackViaMetaAnnotationTestCase.class, false); + } + + @Test + public void isRollbackWithClassLevelRollbackWithExplicitValueOnTestInterface() throws Exception { + assertIsRollback(ClassLevelRollbackWithExplicitValueOnTestInterfaceTestCase.class, false); + } + + @Test + public void isRollbackWithClassLevelRollbackViaMetaAnnotationOnTestInterface() throws Exception { + assertIsRollback(ClassLevelRollbackViaMetaAnnotationOnTestInterfaceTestCase.class, false); + } + + + private void assertBeforeTestMethod(Class clazz) throws Exception { + assertBeforeTestMethodWithTransactionalTestMethod(clazz); + assertBeforeTestMethodWithNonTransactionalTestMethod(clazz); + } + + private void assertBeforeTestMethodWithTransactionalTestMethod(Class clazz) throws Exception { + assertBeforeTestMethodWithTransactionalTestMethod(clazz, true); + } + + private void assertBeforeTestMethodWithTransactionalTestMethod(Class clazz, boolean invokedInTx) + throws Exception { + + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + Invocable instance = BeanUtils.instantiateClass(clazz); + given(testContext.getTestInstance()).willReturn(instance); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("transactionalTest")); + + assertFalse("callback should not have been invoked", instance.invoked()); + TransactionContextHolder.removeCurrentTransactionContext(); + listener.beforeTestMethod(testContext); + assertEquals(invokedInTx, instance.invoked()); + } + + private void assertBeforeTestMethodWithNonTransactionalTestMethod(Class clazz) throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + Invocable instance = BeanUtils.instantiateClass(clazz); + given(testContext.getTestInstance()).willReturn(instance); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("nonTransactionalTest")); + + assertFalse("callback should not have been invoked", instance.invoked()); + TransactionContextHolder.removeCurrentTransactionContext(); + listener.beforeTestMethod(testContext); + assertFalse("callback should not have been invoked", instance.invoked()); + } + + private void assertAfterTestMethod(Class clazz) throws Exception { + assertAfterTestMethodWithTransactionalTestMethod(clazz); + assertAfterTestMethodWithNonTransactionalTestMethod(clazz); + } + + private void assertAfterTestMethodWithTransactionalTestMethod(Class clazz) throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + Invocable instance = BeanUtils.instantiateClass(clazz); + given(testContext.getTestInstance()).willReturn(instance); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("transactionalTest")); + given(tm.getTransaction(BDDMockito.any(TransactionDefinition.class))).willReturn(new SimpleTransactionStatus()); + + assertFalse("callback should not have been invoked", instance.invoked()); + TransactionContextHolder.removeCurrentTransactionContext(); + listener.beforeTestMethod(testContext); + assertFalse("callback should not have been invoked", instance.invoked()); + listener.afterTestMethod(testContext); + assertTrue("callback should have been invoked", instance.invoked()); + } + + private void assertAfterTestMethodWithNonTransactionalTestMethod(Class clazz) throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + Invocable instance = BeanUtils.instantiateClass(clazz); + given(testContext.getTestInstance()).willReturn(instance); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("nonTransactionalTest")); + + assertFalse("callback should not have been invoked", instance.invoked()); + TransactionContextHolder.removeCurrentTransactionContext(); + listener.beforeTestMethod(testContext); + listener.afterTestMethod(testContext); + assertFalse("callback should not have been invoked", instance.invoked()); + } + + private void assertIsRollback(Class clazz, boolean rollback) throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("test")); + assertEquals(rollback, listener.isRollback(testContext)); + } + + + @Transactional + @Retention(RetentionPolicy.RUNTIME) + private @interface MetaTransactional { + } + + @Transactional + @Retention(RetentionPolicy.RUNTIME) + private static @interface MetaTxWithOverride { + + @AliasFor(annotation = Transactional.class, attribute = "value") + String transactionManager() default ""; + + Propagation propagation() default REQUIRED; + } + + @BeforeTransaction + @Retention(RetentionPolicy.RUNTIME) + private @interface MetaBeforeTransaction { + } + + @AfterTransaction + @Retention(RetentionPolicy.RUNTIME) + private @interface MetaAfterTransaction { + } + + private interface Invocable { + + void invoked(boolean invoked); + + boolean invoked(); + } + + private static class AbstractInvocable implements Invocable { + + boolean invoked = false; + + + @Override + public void invoked(boolean invoked) { + this.invoked = invoked; + } + + @Override + public boolean invoked() { + return this.invoked; + } + } + + @Transactional + static class TransactionalDeclaredOnClassLocallyTestCase extends AbstractInvocable { + + @BeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + public void transactionalTest() { + } + } + + static class TransactionalDeclaredOnMethodLocallyTestCase extends AbstractInvocable { + + @BeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + @Transactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + @MetaTransactional + static class TransactionalDeclaredOnClassViaMetaAnnotationTestCase extends AbstractInvocable { + + @BeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + public void transactionalTest() { + } + } + + static class TransactionalDeclaredOnMethodViaMetaAnnotationTestCase extends AbstractInvocable { + + @BeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + @MetaTransactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + @MetaTxWithOverride(propagation = NOT_SUPPORTED) + static class TransactionalDeclaredOnClassViaMetaAnnotationWithOverrideTestCase extends AbstractInvocable { + + @BeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + public void transactionalTest() { + } + } + + static class TransactionalDeclaredOnMethodViaMetaAnnotationWithOverrideTestCase extends AbstractInvocable { + + @BeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + @MetaTxWithOverride(propagation = NOT_SUPPORTED) + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + static class BeforeTransactionDeclaredLocallyTestCase extends AbstractInvocable { + + @BeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + @Transactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + static class BeforeTransactionDeclaredViaMetaAnnotationTestCase extends AbstractInvocable { + + @MetaBeforeTransaction + public void beforeTransaction() { + invoked(true); + } + + @Transactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + static class AfterTransactionDeclaredLocallyTestCase extends AbstractInvocable { + + @AfterTransaction + public void afterTransaction() { + invoked(true); + } + + @Transactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + static class AfterTransactionDeclaredViaMetaAnnotationTestCase extends AbstractInvocable { + + @MetaAfterTransaction + public void afterTransaction() { + invoked(true); + } + + @Transactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + interface BeforeTransactionInterface extends Invocable { + + @BeforeTransaction + default void beforeTransaction() { + invoked(true); + } + } + + interface AfterTransactionInterface extends Invocable { + + @AfterTransaction + default void afterTransaction() { + invoked(true); + } + } + + static class BeforeTransactionDeclaredAsInterfaceDefaultMethodTestCase extends AbstractInvocable + implements BeforeTransactionInterface { + + @Transactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + static class AfterTransactionDeclaredAsInterfaceDefaultMethodTestCase extends AbstractInvocable + implements AfterTransactionInterface { + + @Transactional + public void transactionalTest() { + } + + public void nonTransactionalTest() { + } + } + + static class MissingRollbackTestCase { + + public void test() { + } + } + + static class EmptyMethodLevelRollbackTestCase { + + @Rollback + public void test() { + } + } + + static class MethodLevelRollbackWithExplicitValueTestCase { + + @Rollback(false) + public void test() { + } + } + + static class MethodLevelRollbackViaMetaAnnotationTestCase { + + @Commit + public void test() { + } + } + + @Rollback + static class EmptyClassLevelRollbackTestCase { + + public void test() { + } + } + + @Rollback(false) + static class ClassLevelRollbackWithExplicitValueTestCase { + + public void test() { + } + } + + @Commit + static class ClassLevelRollbackViaMetaAnnotationTestCase { + + public void test() { + } + } + + @Rollback(false) + interface RollbackFalseTestInterface { + } + + static class ClassLevelRollbackWithExplicitValueOnTestInterfaceTestCase implements RollbackFalseTestInterface { + + public void test() { + } + } + + @Commit + interface RollbackFalseViaMetaAnnotationTestInterface { + } + + static class ClassLevelRollbackViaMetaAnnotationOnTestInterfaceTestCase + implements RollbackFalseViaMetaAnnotationTestInterface { + + public void test() { + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/AbstractEjbTxDaoTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/AbstractEjbTxDaoTests.java new file mode 100644 index 0000000000000000000000000000000000000000..42997c53fe7bd51530f121723cd1309788cb5d7d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/AbstractEjbTxDaoTests.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb; + +import javax.ejb.EJB; +import javax.persistence.EntityManager; +import javax.persistence.PersistenceContext; + +import org.junit.After; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.annotation.DirtiesContext.ClassMode; +import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests; +import org.springframework.test.context.transaction.ejb.dao.TestEntityDao; + +import static org.junit.Assert.*; + +/** + * Abstract base class for all tests involving EJB transaction support in the + * TestContext framework. + * + * @author Sam Brannen + * @author Xavier Detant + * @since 4.0.1 + */ +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +@DirtiesContext(classMode = ClassMode.AFTER_CLASS) +public abstract class AbstractEjbTxDaoTests extends AbstractTransactionalJUnit4SpringContextTests { + + protected static final String TEST_NAME = "test-name"; + + @EJB + protected TestEntityDao dao; + + @PersistenceContext + protected EntityManager em; + + + @Test + public void test1InitialState() { + int count = dao.getCount(TEST_NAME); + assertEquals("New TestEntity should have count=0.", 0, count); + } + + @Test + public void test2IncrementCount1() { + int count = dao.incrementCount(TEST_NAME); + assertEquals("Expected count=1 after first increment.", 1, count); + } + + /** + * The default implementation of this method assumes that the transaction + * for {@link #test2IncrementCount1()} was committed. Therefore, it is + * expected that the previous increment has been persisted in the database. + */ + @Test + public void test3IncrementCount2() { + int count = dao.getCount(TEST_NAME); + assertEquals("Expected count=1 after test2IncrementCount1().", 1, count); + + count = dao.incrementCount(TEST_NAME); + assertEquals("Expected count=2 now.", 2, count); + } + + @After + public void synchronizePersistenceContext() { + em.flush(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/CommitForRequiredEjbTxDaoTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/CommitForRequiredEjbTxDaoTests.java new file mode 100644 index 0000000000000000000000000000000000000000..237cebd38da1e0f61e01932cfeb48efddfbec913 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/CommitForRequiredEjbTxDaoTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb; + +import org.springframework.test.annotation.Commit; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; +import org.springframework.test.context.transaction.ejb.dao.RequiredEjbTxTestEntityDao; + +/** + * Concrete subclass of {@link AbstractEjbTxDaoTests} which uses the + * {@link RequiredEjbTxTestEntityDao} and sets the default rollback semantics + * for the {@link TransactionalTestExecutionListener} to {@code false} (i.e., + * commit). + * + * @author Sam Brannen + * @since 4.0.1 + */ +@ContextConfiguration("required-tx-config.xml") +@Commit +public class CommitForRequiredEjbTxDaoTests extends AbstractEjbTxDaoTests { + + /* test methods in superclass */ + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/CommitForRequiresNewEjbTxDaoTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/CommitForRequiresNewEjbTxDaoTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a726008be0a33c9dcbe5fe9562874dfab58d3fa0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/CommitForRequiresNewEjbTxDaoTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb; + +import org.springframework.test.annotation.Commit; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; +import org.springframework.test.context.transaction.ejb.dao.RequiresNewEjbTxTestEntityDao; + +/** + * Concrete subclass of {@link AbstractEjbTxDaoTests} which uses the + * {@link RequiresNewEjbTxTestEntityDao} and sets the default rollback semantics + * for the {@link TransactionalTestExecutionListener} to {@code false} (i.e., + * commit). + * + * @author Sam Brannen + * @since 4.0.1 + */ +@ContextConfiguration("requires-new-tx-config.xml") +@Commit +public class CommitForRequiresNewEjbTxDaoTests extends AbstractEjbTxDaoTests { + + /* test methods in superclass */ + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/RollbackForRequiredEjbTxDaoTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/RollbackForRequiredEjbTxDaoTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c30cc5782624b6741786dcd370d569c0ae2c75c9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/RollbackForRequiredEjbTxDaoTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb; + +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import org.springframework.test.annotation.Rollback; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; + +import static org.junit.Assert.*; + +/** + * Extension of {@link CommitForRequiredEjbTxDaoTests} which sets the default + * rollback semantics for the {@link TransactionalTestExecutionListener} to + * {@code true}. The transaction managed by the TestContext framework will be + * rolled back after each test method. Consequently, any work performed in + * transactional methods that participate in the test-managed transaction will + * be rolled back automatically. + * + * @author Sam Brannen + * @since 4.0.1 + */ +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +@Rollback +public class RollbackForRequiredEjbTxDaoTests extends CommitForRequiredEjbTxDaoTests { + + /** + * Redeclared to ensure test method execution order. Simply delegates to super. + */ + @Test + @Override + public void test1InitialState() { + super.test1InitialState(); + } + + /** + * Redeclared to ensure test method execution order. Simply delegates to super. + */ + @Test + @Override + public void test2IncrementCount1() { + super.test2IncrementCount1(); + } + + /** + * Overrides parent implementation in order to change expectations to align with + * behavior associated with "required" transactions on repositories/DAOs and + * default rollback semantics for transactions managed by the TestContext + * framework. + */ + @Test + @Override + public void test3IncrementCount2() { + int count = dao.getCount(TEST_NAME); + // Expecting count=0 after test2IncrementCount1() since REQUIRED transactions + // participate in the existing transaction (if present), which in this case is the + // transaction managed by the TestContext framework which will be rolled back + // after each test method. + assertEquals("Expected count=0 after test2IncrementCount1().", 0, count); + + count = dao.incrementCount(TEST_NAME); + assertEquals("Expected count=1 now.", 1, count); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/RollbackForRequiresNewEjbTxDaoTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/RollbackForRequiresNewEjbTxDaoTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4b44c8872b3639d0c7c3a9a5b1dbc9516d6df1ad --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/RollbackForRequiresNewEjbTxDaoTests.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb; + +import org.springframework.test.annotation.Rollback; +import org.springframework.test.context.transaction.TransactionalTestExecutionListener; + +/** + * Extension of {@link CommitForRequiresNewEjbTxDaoTests} which sets the default + * rollback semantics for the {@link TransactionalTestExecutionListener} to + * {@code true}. The transaction managed by the TestContext framework will be + * rolled back after each test method. Consequently, any work performed in + * transactional methods that participate in the test-managed transaction will + * be rolled back automatically. On the other hand, any work performed in + * transactional methods that do not participate in the + * test-managed transaction will not be affected by the rollback of the + * test-managed transaction. For example, such work may in fact be committed + * outside the scope of the test-managed transaction. + * + * @author Sam Brannen + * @since 4.0.1 + */ +@Rollback +public class RollbackForRequiresNewEjbTxDaoTests extends CommitForRequiresNewEjbTxDaoTests { + + /* test methods in superclass */ + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/AbstractEjbTxTestEntityDao.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/AbstractEjbTxTestEntityDao.java new file mode 100644 index 0000000000000000000000000000000000000000..ed518cfc257cc2ba3fb404b46ce945f6a7478274 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/AbstractEjbTxTestEntityDao.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb.dao; + +import javax.ejb.TransactionAttribute; +import javax.persistence.EntityManager; +import javax.persistence.PersistenceContext; + +import org.springframework.test.context.transaction.ejb.model.TestEntity; + +/** + * Abstract base class for EJB implementations of {@link TestEntityDao} which + * declare transaction semantics for {@link #incrementCount(String)} via + * {@link TransactionAttribute}. + * + * @author Sam Brannen + * @author Xavier Detant + * @since 4.0.1 + * @see RequiredEjbTxTestEntityDao + * @see RequiresNewEjbTxTestEntityDao + */ +public abstract class AbstractEjbTxTestEntityDao implements TestEntityDao { + + @PersistenceContext + protected EntityManager entityManager; + + + protected final TestEntity getTestEntity(String name) { + TestEntity te = entityManager.find(TestEntity.class, name); + if (te == null) { + te = new TestEntity(name, 0); + entityManager.persist(te); + } + return te; + } + + protected final int getCountInternal(String name) { + return getTestEntity(name).getCount(); + } + + protected final int incrementCountInternal(String name) { + TestEntity te = getTestEntity(name); + int count = te.getCount(); + count++; + te.setCount(count); + return count; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/RequiredEjbTxTestEntityDao.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/RequiredEjbTxTestEntityDao.java new file mode 100644 index 0000000000000000000000000000000000000000..f87fa0aa97daf164a3a87a4d498cd79dd7061a99 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/RequiredEjbTxTestEntityDao.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb.dao; + +import javax.ejb.Local; +import javax.ejb.Stateless; +import javax.ejb.TransactionAttribute; +import javax.ejb.TransactionAttributeType; + +/** + * EJB implementation of {@link TestEntityDao} which declares transaction + * semantics for {@link #incrementCount(String)} with + * {@link TransactionAttributeType#REQUIRED}. + * + * @author Sam Brannen + * @author Xavier Detant + * @since 4.0.1 + * @see RequiresNewEjbTxTestEntityDao + */ +@Stateless +@Local(TestEntityDao.class) +@TransactionAttribute(TransactionAttributeType.MANDATORY) +public class RequiredEjbTxTestEntityDao extends AbstractEjbTxTestEntityDao { + + @Override + public int getCount(String name) { + return super.getCountInternal(name); + } + + @TransactionAttribute(TransactionAttributeType.REQUIRED) + @Override + public int incrementCount(String name) { + return super.incrementCountInternal(name); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/RequiresNewEjbTxTestEntityDao.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/RequiresNewEjbTxTestEntityDao.java new file mode 100644 index 0000000000000000000000000000000000000000..e9957398501d6df775af7f7bbf334a8d485cf673 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/RequiresNewEjbTxTestEntityDao.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb.dao; + +import javax.ejb.Local; +import javax.ejb.Stateless; +import javax.ejb.TransactionAttribute; +import javax.ejb.TransactionAttributeType; + +/** + * EJB implementation of {@link TestEntityDao} which declares transaction + * semantics for {@link #incrementCount(String)} with + * {@link TransactionAttributeType#REQUIRES_NEW}. + * + * @author Sam Brannen + * @author Xavier Detant + * @since 4.0.1 + * @see RequiredEjbTxTestEntityDao + */ +@Stateless +@Local(TestEntityDao.class) +@TransactionAttribute(TransactionAttributeType.MANDATORY) +public class RequiresNewEjbTxTestEntityDao extends AbstractEjbTxTestEntityDao { + + @Override + public int getCount(String name) { + return super.getCountInternal(name); + } + + @TransactionAttribute(TransactionAttributeType.REQUIRES_NEW) + @Override + public int incrementCount(String name) { + return super.incrementCountInternal(name); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/TestEntityDao.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/TestEntityDao.java new file mode 100644 index 0000000000000000000000000000000000000000..86e21148d672e4ef71f901695d1a62e9290fae83 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/dao/TestEntityDao.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb.dao; + +/** + * Test DAO for EJB transaction support in the TestContext framework. + * + * @author Xavier Detant + * @author Sam Brannen + * @since 4.0.1 + */ +public interface TestEntityDao { + + int getCount(String name); + + int incrementCount(String name); + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/model/TestEntity.java b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/model/TestEntity.java new file mode 100644 index 0000000000000000000000000000000000000000..9f1cd44840bb35369921901977c18ed6f9301e0b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/ejb/model/TestEntity.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.ejb.model; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.Id; +import javax.persistence.Table; + +/** + * Test entity for EJB transaction support in the TestContext framework. + * + * @author Xavier Detant + * @author Sam Brannen + * @since 4.0.1 + */ +@Entity +@Table(name = TestEntity.TABLE_NAME) +public class TestEntity { + + public static final String TABLE_NAME = "TEST_ENTITY"; + + @Id + @Column(name = "TE_NAME", nullable = false) + private String name; + + @Column(name = "TE_COUNT", nullable = false) + private int count; + + + public TestEntity() { + } + + public TestEntity(String name, int count) { + this.name = name; + this.count = count; + } + + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + + public int getCount() { + return this.count; + } + + public void setCount(int count) { + this.count = count; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/programmatic/ProgrammaticTxMgmtTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/programmatic/ProgrammaticTxMgmtTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d1b99f9681f753f93e50bd8c092d0c764bc2974f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/programmatic/ProgrammaticTxMgmtTests.java @@ -0,0 +1,305 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.transaction.programmatic; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import javax.sql.DataSource; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.Resource; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator; +import org.springframework.test.annotation.Commit; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.transaction.AfterTransaction; +import org.springframework.test.context.transaction.BeforeTransaction; +import org.springframework.test.context.transaction.TestTransaction; +import org.springframework.test.jdbc.JdbcTestUtils; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; +import static org.springframework.test.transaction.TransactionTestUtils.*; + +/** + * JUnit-based integration tests that verify support for programmatic transaction + * management within the Spring TestContext Framework. + * + * @author Sam Brannen + * @since 4.1 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@Transactional +public class ProgrammaticTxMgmtTests { + + private String sqlScriptEncoding; + + protected JdbcTemplate jdbcTemplate; + + @Autowired + protected ApplicationContext applicationContext; + + @Rule + public TestName testName = new TestName(); + + + @Autowired + public void setDataSource(DataSource dataSource) { + this.jdbcTemplate = new JdbcTemplate(dataSource); + } + + + @BeforeTransaction + public void beforeTransaction() { + deleteFromTables("user"); + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data.sql", false); + } + + @AfterTransaction + public void afterTransaction() { + String method = testName.getMethodName(); + switch (method) { + case "commitTxAndStartNewTx": + case "commitTxButDoNotStartNewTx": { + assertUsers("Dogbert"); + break; + } + case "rollbackTxAndStartNewTx": + case "rollbackTxButDoNotStartNewTx": + case "startTxWithExistingTransaction": { + assertUsers("Dilbert"); + break; + } + case "rollbackTxAndStartNewTxWithDefaultCommitSemantics": { + assertUsers("Dilbert", "Dogbert"); + break; + } + default: { + fail("missing 'after transaction' assertion for test method: " + method); + } + } + } + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void isActiveWithNonExistentTransactionContext() { + assertFalse(TestTransaction.isActive()); + } + + @Test(expected = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void flagForRollbackWithNonExistentTransactionContext() { + TestTransaction.flagForRollback(); + } + + @Test(expected = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void flagForCommitWithNonExistentTransactionContext() { + TestTransaction.flagForCommit(); + } + + @Test(expected = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void isFlaggedForRollbackWithNonExistentTransactionContext() { + TestTransaction.isFlaggedForRollback(); + } + + @Test(expected = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void startTxWithNonExistentTransactionContext() { + TestTransaction.start(); + } + + @Test(expected = IllegalStateException.class) + public void startTxWithExistingTransaction() { + TestTransaction.start(); + } + + @Test(expected = IllegalStateException.class) + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void endTxWithNonExistentTransactionContext() { + TestTransaction.end(); + } + + @Test + public void commitTxAndStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Commit + TestTransaction.flagForCommit(); + assertFalse(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertInTransaction(false); + assertFalse(TestTransaction.isActive()); + assertUsers(); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dogbert"); + + TestTransaction.start(); + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + } + + @Test + public void commitTxButDoNotStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Commit + TestTransaction.flagForCommit(); + assertFalse(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers(); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dogbert"); + } + + @Test + public void rollbackTxAndStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Rollback (automatically) + assertTrue(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers("Dilbert"); + + // Start new transaction with default rollback semantics + TestTransaction.start(); + assertInTransaction(true); + assertTrue(TestTransaction.isFlaggedForRollback()); + assertTrue(TestTransaction.isActive()); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dilbert", "Dogbert"); + } + + @Test + public void rollbackTxButDoNotStartNewTx() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Rollback (automatically) + assertTrue(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers("Dilbert"); + } + + @Test + @Commit + public void rollbackTxAndStartNewTxWithDefaultCommitSemantics() { + assertInTransaction(true); + assertTrue(TestTransaction.isActive()); + assertUsers("Dilbert"); + deleteFromTables("user"); + assertUsers(); + + // Rollback + TestTransaction.flagForRollback(); + assertTrue(TestTransaction.isFlaggedForRollback()); + TestTransaction.end(); + assertFalse(TestTransaction.isActive()); + assertInTransaction(false); + assertUsers("Dilbert"); + + // Start new transaction with default commit semantics + TestTransaction.start(); + assertInTransaction(true); + assertFalse(TestTransaction.isFlaggedForRollback()); + assertTrue(TestTransaction.isActive()); + + executeSqlScript("classpath:/org/springframework/test/context/jdbc/data-add-dogbert.sql", false); + assertUsers("Dilbert", "Dogbert"); + } + + // ------------------------------------------------------------------------- + + protected int deleteFromTables(String... names) { + return JdbcTestUtils.deleteFromTables(this.jdbcTemplate, names); + } + + protected void executeSqlScript(String sqlResourcePath, boolean continueOnError) throws DataAccessException { + Resource resource = this.applicationContext.getResource(sqlResourcePath); + new ResourceDatabasePopulator(continueOnError, false, this.sqlScriptEncoding, resource).execute(jdbcTemplate.getDataSource()); + } + + private void assertUsers(String... users) { + List expected = Arrays.asList(users); + Collections.sort(expected); + List actual = jdbcTemplate.queryForList("select name from user", String.class); + Collections.sort(actual); + assertEquals("Users in database;", expected, actual); + } + + // ------------------------------------------------------------------------- + + @Configuration + static class Config { + + @Bean + public PlatformTransactionManager transactionManager() { + return new DataSourceTransactionManager(dataSource()); + } + + @Bean + public DataSource dataSource() { + return new EmbeddedDatabaseBuilder()// + .generateUniqueName(true)// + .addScript("classpath:/org/springframework/test/context/jdbc/schema.sql") // + .build(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/AbstractBasicWacTests.java b/spring-test/src/test/java/org/springframework/test/context/web/AbstractBasicWacTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9df5450d593ab8baa7e277caa4afff02ab245e2e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/AbstractBasicWacTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import java.io.File; + +import javax.servlet.ServletContext; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.web.context.ServletContextAware; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration +public abstract class AbstractBasicWacTests implements ServletContextAware { + + protected ServletContext servletContext; + + @Autowired + protected WebApplicationContext wac; + + @Autowired + protected MockServletContext mockServletContext; + + @Autowired + protected MockHttpServletRequest request; + + @Autowired + protected MockHttpServletResponse response; + + @Autowired + protected MockHttpSession session; + + @Autowired + protected ServletWebRequest webRequest; + + @Autowired + protected String foo; + + + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + + @Test + public void basicWacFeatures() throws Exception { + assertNotNull("ServletContext should be set in the WAC.", wac.getServletContext()); + + assertNotNull("ServletContext should have been set via ServletContextAware.", servletContext); + + assertNotNull("ServletContext should have been autowired from the WAC.", mockServletContext); + assertNotNull("MockHttpServletRequest should have been autowired from the WAC.", request); + assertNotNull("MockHttpServletResponse should have been autowired from the WAC.", response); + assertNotNull("MockHttpSession should have been autowired from the WAC.", session); + assertNotNull("ServletWebRequest should have been autowired from the WAC.", webRequest); + + Object rootWac = mockServletContext.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + assertNotNull("Root WAC must be stored in the ServletContext as: " + + WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, rootWac); + assertSame("test WAC and Root WAC in ServletContext must be the same object.", wac, rootWac); + assertSame("ServletContext instances must be the same object.", mockServletContext, wac.getServletContext()); + assertSame("ServletContext in the WAC and in the mock request", mockServletContext, request.getServletContext()); + + assertEquals("Getting real path for ServletContext resource.", + new File("src/main/webapp/index.jsp").getCanonicalPath(), mockServletContext.getRealPath("index.jsp")); + + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/AnnotationConfigWebContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/web/AnnotationConfigWebContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6a8a1d90197e5628d4d6200055feef711b974fdc --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/AnnotationConfigWebContextLoaderTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static org.hamcrest.CoreMatchers.*; + +/** + * Unit tests for {@link AnnotationConfigWebContextLoader}. + * + * @author Sam Brannen + * @since 4.0.4 + */ +public class AnnotationConfigWebContextLoaderTests { + + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + private static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + + @Test + public void configMustNotContainLocations() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(containsString("does not support resource locations")); + + AnnotationConfigWebContextLoader loader = new AnnotationConfigWebContextLoader(); + WebMergedContextConfiguration mergedConfig = new WebMergedContextConfiguration(getClass(), + new String[] { "config.xml" }, EMPTY_CLASS_ARRAY, null, EMPTY_STRING_ARRAY, EMPTY_STRING_ARRAY, + EMPTY_STRING_ARRAY, "resource/path", loader, null, null); + loader.loadContext(mergedConfig); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/BasicAnnotationConfigWacTests.java b/spring-test/src/test/java/org/springframework/test/context/web/BasicAnnotationConfigWacTests.java new file mode 100644 index 0000000000000000000000000000000000000000..96ee067b598596b4a23b189e6f1136b788736731 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/BasicAnnotationConfigWacTests.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 3.2 + */ +@ContextConfiguration +public class BasicAnnotationConfigWacTests extends AbstractBasicWacTests { + + @Configuration + static class Config { + + @Bean + public String foo() { + return "enigma"; + } + + @Bean + public ServletContextAwareBean servletContextAwareBean() { + return new ServletContextAwareBean(); + } + } + + @Autowired + protected ServletContextAwareBean servletContextAwareBean; + + @Test + public void fooEnigmaAutowired() { + assertEquals("enigma", foo); + } + + @Test + public void servletContextAwareBeanProcessed() { + assertNotNull(servletContextAwareBean); + assertNotNull(servletContextAwareBean.servletContext); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/BasicGroovyWacTests.java b/spring-test/src/test/java/org/springframework/test/context/web/BasicGroovyWacTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0b5d5b57a08d81e5f9d913950629c082c9f02a8d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/BasicGroovyWacTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Test; + +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 4.1 + * @see BasicXmlWacTests + */ +// Config loaded from BasicGroovyWacTestsContext.groovy +@ContextConfiguration +public class BasicGroovyWacTests extends AbstractBasicWacTests { + + @Test + public void groovyFooAutowired() { + assertEquals("Groovy Foo", foo); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/BasicXmlWacTests.java b/spring-test/src/test/java/org/springframework/test/context/web/BasicXmlWacTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ddde8acb6301c8317cf71033f03003cb7deadd39 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/BasicXmlWacTests.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Test; + +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; + +/** + * @author Sam Brannen + * @since 3.2 + */ +@ContextConfiguration +public class BasicXmlWacTests extends AbstractBasicWacTests { + + @Test + public void fooBarAutowired() { + assertEquals("bar", foo); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/GenericXmlWebContextLoaderTests.java b/spring-test/src/test/java/org/springframework/test/context/web/GenericXmlWebContextLoaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..30d320e11e06d533835f0e456c870e69051504c9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/GenericXmlWebContextLoaderTests.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static org.hamcrest.CoreMatchers.*; + +/** + * Unit tests for {@link GenericXmlWebContextLoader}. + * + * @author Sam Brannen + * @since 4.0.4 + */ +public class GenericXmlWebContextLoaderTests { + + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + + @Test + public void configMustNotContainAnnotatedClasses() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage(containsString("does not support annotated classes")); + + GenericXmlWebContextLoader loader = new GenericXmlWebContextLoader(); + WebMergedContextConfiguration mergedConfig = new WebMergedContextConfiguration(getClass(), EMPTY_STRING_ARRAY, + new Class[] { getClass() }, null, EMPTY_STRING_ARRAY, EMPTY_STRING_ARRAY, EMPTY_STRING_ARRAY, + "resource/path", loader, null, null); + loader.loadContext(mergedConfig); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/JUnit4SpringContextWebTests.java b/spring-test/src/test/java/org/springframework/test/context/web/JUnit4SpringContextWebTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f314d13b4763bdb25ca34b9359b18a6965cd9428 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/JUnit4SpringContextWebTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import java.io.File; + +import javax.servlet.ServletContext; + +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.AbstractJUnit4SpringContextTests; +import org.springframework.web.context.ServletContextAware; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * JUnit-based integration tests that verify support for loading a + * {@link WebApplicationContext} when extending {@link AbstractJUnit4SpringContextTests}. + * + * @author Sam Brannen + * @since 3.2.7 + */ +@ContextConfiguration +@WebAppConfiguration +public class JUnit4SpringContextWebTests extends AbstractJUnit4SpringContextTests implements ServletContextAware { + + @Configuration + static class Config { + + @Bean + public String foo() { + return "enigma"; + } + } + + + protected ServletContext servletContext; + + @Autowired + protected WebApplicationContext wac; + + @Autowired + protected MockServletContext mockServletContext; + + @Autowired + protected MockHttpServletRequest request; + + @Autowired + protected MockHttpServletResponse response; + + @Autowired + protected MockHttpSession session; + + @Autowired + protected ServletWebRequest webRequest; + + @Autowired + protected String foo; + + + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + + @Test + public void basicWacFeatures() throws Exception { + assertNotNull("ServletContext should be set in the WAC.", wac.getServletContext()); + + assertNotNull("ServletContext should have been set via ServletContextAware.", servletContext); + + assertNotNull("ServletContext should have been autowired from the WAC.", mockServletContext); + assertNotNull("MockHttpServletRequest should have been autowired from the WAC.", request); + assertNotNull("MockHttpServletResponse should have been autowired from the WAC.", response); + assertNotNull("MockHttpSession should have been autowired from the WAC.", session); + assertNotNull("ServletWebRequest should have been autowired from the WAC.", webRequest); + + Object rootWac = mockServletContext.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + assertNotNull("Root WAC must be stored in the ServletContext as: " + + WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, rootWac); + assertSame("test WAC and Root WAC in ServletContext must be the same object.", wac, rootWac); + assertSame("ServletContext instances must be the same object.", mockServletContext, wac.getServletContext()); + assertSame("ServletContext in the WAC and in the mock request", mockServletContext, request.getServletContext()); + + assertEquals("Getting real path for ServletContext resource.", + new File("src/main/webapp/index.jsp").getCanonicalPath(), mockServletContext.getRealPath("index.jsp")); + + } + + @Test + public void fooEnigmaAutowired() { + assertEquals("enigma", foo); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/MetaAnnotationConfigWacTests.java b/spring-test/src/test/java/org/springframework/test/context/web/MetaAnnotationConfigWacTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6066b2ab929d6c65d731ba98d25b7fd83b127967 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/MetaAnnotationConfigWacTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import java.io.File; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.web.context.WebApplicationContext; + +import static org.junit.Assert.*; + +/** + * Integration test that verifies meta-annotation support for {@link WebAppConfiguration} + * and {@link ContextConfiguration}. + * + * @author Sam Brannen + * @since 4.0 + * @see WebTestConfiguration + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebTestConfiguration +public class MetaAnnotationConfigWacTests { + + @Autowired + protected WebApplicationContext wac; + + @Autowired + protected MockServletContext mockServletContext; + + @Autowired + protected String foo; + + + @Test + public void fooEnigmaAutowired() { + assertEquals("enigma", foo); + } + + @Test + public void basicWacFeatures() throws Exception { + assertNotNull("ServletContext should be set in the WAC.", wac.getServletContext()); + + assertNotNull("ServletContext should have been autowired from the WAC.", mockServletContext); + + Object rootWac = mockServletContext.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + assertNotNull("Root WAC must be stored in the ServletContext as: " + + WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, rootWac); + assertSame("test WAC and Root WAC in ServletContext must be the same object.", wac, rootWac); + assertSame("ServletContext instances must be the same object.", mockServletContext, wac.getServletContext()); + + assertEquals("Getting real path for ServletContext resource.", + new File("src/main/webapp/index.jsp").getCanonicalPath(), mockServletContext.getRealPath("index.jsp")); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/RequestAndSessionScopedBeansWacTests.java b/spring-test/src/test/java/org/springframework/test/context/web/RequestAndSessionScopedBeansWacTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d23da4884a274056d91cbaecf27311218f2267ab --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/RequestAndSessionScopedBeansWacTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.web.context.WebApplicationContext; + +import static org.junit.Assert.*; + +/** + * Integration tests that verify support for request and session scoped beans + * in conjunction with the TestContext Framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@WebAppConfiguration +public class RequestAndSessionScopedBeansWacTests { + + @Autowired + private WebApplicationContext wac; + + @Autowired + private MockHttpServletRequest request; + + @Autowired + private MockHttpSession session; + + + @Test + public void requestScope() throws Exception { + final String beanName = "requestScopedTestBean"; + final String contextPath = "/path"; + + assertNull(request.getAttribute(beanName)); + + request.setContextPath(contextPath); + TestBean testBean = wac.getBean(beanName, TestBean.class); + + assertEquals(contextPath, testBean.getName()); + assertSame(testBean, request.getAttribute(beanName)); + assertSame(testBean, wac.getBean(beanName, TestBean.class)); + } + + @Test + public void sessionScope() throws Exception { + final String beanName = "sessionScopedTestBean"; + + assertNull(session.getAttribute(beanName)); + + TestBean testBean = wac.getBean(beanName, TestBean.class); + + assertSame(testBean, session.getAttribute(beanName)); + assertSame(testBean, wac.getBean(beanName, TestBean.class)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/ServletContextAwareBean.java b/spring-test/src/test/java/org/springframework/test/context/web/ServletContextAwareBean.java new file mode 100644 index 0000000000000000000000000000000000000000..3e0e84fc9c409f31e6dbe912249f241ad8f68f19 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/ServletContextAwareBean.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import javax.servlet.ServletContext; + +import org.springframework.web.context.ServletContextAware; + +/** + * Introduced to investigate claims in SPR-11145. + * + * @author Sam Brannen + * @since 4.0.2 + */ +public class ServletContextAwareBean implements ServletContextAware { + + protected ServletContext servletContext; + + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/ServletContextAwareBeanWacTests.java b/spring-test/src/test/java/org/springframework/test/context/web/ServletContextAwareBeanWacTests.java new file mode 100644 index 0000000000000000000000000000000000000000..790a03b027487ffaced7f64ace0d31f5f5d1fc31 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/ServletContextAwareBeanWacTests.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Test; + +import static org.springframework.test.context.junit4.JUnitTestingUtils.*; + +/** + * Introduced to investigate claims in SPR-11145. + * + *

Yes, this test class does in fact use JUnit to run JUnit. ;) + * + * @author Sam Brannen + * @since 4.0.2 + */ +public class ServletContextAwareBeanWacTests { + + @Test + public void ensureServletContextAwareBeanIsProcessedProperlyWhenExecutingJUnitManually() throws Exception { + runTestsAndAssertCounters(BasicAnnotationConfigWacTests.class, 3, 0, 3, 0, 0); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/ServletTestExecutionListenerJUnitIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/web/ServletTestExecutionListenerJUnitIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0738130d1bc99cf1cd7218566219192c3cbd83fe --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/ServletTestExecutionListenerJUnitIntegrationTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import static org.junit.Assert.*; + +/** + * JUnit-based integration tests for {@link ServletTestExecutionListener}. + * + * @author Sam Brannen + * @since 3.2.9 + * @see org.springframework.test.context.testng.web.ServletTestExecutionListenerTestNGIntegrationTests + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@WebAppConfiguration +public class ServletTestExecutionListenerJUnitIntegrationTests { + + @Configuration + static class Config { + /* no beans required for this test */ + } + + + @Autowired + private MockHttpServletRequest servletRequest; + + + /** + * Verifies bug fix for SPR-11626. + * + * @see #ensureMocksAreReinjectedBetweenTests_2 + */ + @Test + public void ensureMocksAreReinjectedBetweenTests_1() { + assertInjectedServletRequestEqualsRequestInRequestContextHolder(); + } + + /** + * Verifies bug fix for SPR-11626. + * + * @see #ensureMocksAreReinjectedBetweenTests_1 + */ + @Test + public void ensureMocksAreReinjectedBetweenTests_2() { + assertInjectedServletRequestEqualsRequestInRequestContextHolder(); + } + + private void assertInjectedServletRequestEqualsRequestInRequestContextHolder() { + assertEquals("Injected ServletRequest must be stored in the RequestContextHolder", servletRequest, + ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/ServletTestExecutionListenerTests.java b/spring-test/src/test/java/org/springframework/test/context/web/ServletTestExecutionListenerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..84ef5222016217496730025300e5e11a92c0cd45 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/ServletTestExecutionListenerTests.java @@ -0,0 +1,225 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.BDDMockito; + +import org.springframework.context.ApplicationContext; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.context.TestContext; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; +import static org.springframework.test.context.web.ServletTestExecutionListener.*; + +/** + * Unit tests for {@link ServletTestExecutionListener}. + * + * @author Sam Brannen + * @author Phillip Webb + * @since 3.2.6 + */ +public class ServletTestExecutionListenerTests { + + private static final String SET_UP_OUTSIDE_OF_STEL = "setUpOutsideOfStel"; + + private final WebApplicationContext wac = mock(WebApplicationContext.class); + private final MockServletContext mockServletContext = new MockServletContext(); + private final TestContext testContext = mock(TestContext.class); + private final ServletTestExecutionListener listener = new ServletTestExecutionListener(); + + + @Before + public void setUp() { + given(wac.getServletContext()).willReturn(mockServletContext); + given(testContext.getApplicationContext()).willReturn(wac); + + MockHttpServletRequest request = new MockHttpServletRequest(mockServletContext); + MockHttpServletResponse response = new MockHttpServletResponse(); + ServletWebRequest servletWebRequest = new ServletWebRequest(request, response); + + request.setAttribute(SET_UP_OUTSIDE_OF_STEL, "true"); + + RequestContextHolder.setRequestAttributes(servletWebRequest); + assertSetUpOutsideOfStelAttributeExists(); + } + + @Test + public void standardApplicationContext() throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(getClass()); + given(testContext.getApplicationContext()).willReturn(mock(ApplicationContext.class)); + + listener.beforeTestClass(testContext); + assertSetUpOutsideOfStelAttributeExists(); + + listener.prepareTestInstance(testContext); + assertSetUpOutsideOfStelAttributeExists(); + + listener.beforeTestMethod(testContext); + assertSetUpOutsideOfStelAttributeExists(); + + listener.afterTestMethod(testContext); + assertSetUpOutsideOfStelAttributeExists(); + } + + @Test + public void legacyWebTestCaseWithoutExistingRequestAttributes() throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(LegacyWebTestCase.class); + + RequestContextHolder.resetRequestAttributes(); + assertRequestAttributesDoNotExist(); + + listener.beforeTestClass(testContext); + + listener.prepareTestInstance(testContext); + assertRequestAttributesDoNotExist(); + verify(testContext, times(0)).setAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + given(testContext.getAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE)).willReturn(null); + + listener.beforeTestMethod(testContext); + assertRequestAttributesDoNotExist(); + verify(testContext, times(0)).setAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + + listener.afterTestMethod(testContext); + verify(testContext, times(1)).removeAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE); + assertRequestAttributesDoNotExist(); + } + + @Test + public void legacyWebTestCaseWithPresetRequestAttributes() throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(LegacyWebTestCase.class); + + listener.beforeTestClass(testContext); + assertSetUpOutsideOfStelAttributeExists(); + + listener.prepareTestInstance(testContext); + assertSetUpOutsideOfStelAttributeExists(); + verify(testContext, times(0)).setAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + given(testContext.getAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE)).willReturn(null); + + listener.beforeTestMethod(testContext); + assertSetUpOutsideOfStelAttributeExists(); + verify(testContext, times(0)).setAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + given(testContext.getAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE)).willReturn(null); + + listener.afterTestMethod(testContext); + verify(testContext, times(1)).removeAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE); + assertSetUpOutsideOfStelAttributeExists(); + } + + @Test + public void atWebAppConfigTestCaseWithoutExistingRequestAttributes() throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(AtWebAppConfigWebTestCase.class); + + RequestContextHolder.resetRequestAttributes(); + listener.beforeTestClass(testContext); + assertRequestAttributesDoNotExist(); + + assertWebAppConfigTestCase(); + } + + @Test + public void atWebAppConfigTestCaseWithPresetRequestAttributes() throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(AtWebAppConfigWebTestCase.class); + + listener.beforeTestClass(testContext); + assertRequestAttributesExist(); + + assertWebAppConfigTestCase(); + } + + /** + * @since 4.3 + */ + @Test + public void activateListenerWithoutExistingRequestAttributes() throws Exception { + BDDMockito.> given(testContext.getTestClass()).willReturn(NoAtWebAppConfigWebTestCase.class); + given(testContext.getAttribute(ServletTestExecutionListener.ACTIVATE_LISTENER)).willReturn(true); + + RequestContextHolder.resetRequestAttributes(); + listener.beforeTestClass(testContext); + assertRequestAttributesDoNotExist(); + + assertWebAppConfigTestCase(); + } + + + private RequestAttributes assertRequestAttributesExist() { + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + assertNotNull("request attributes should exist", requestAttributes); + return requestAttributes; + } + + private void assertRequestAttributesDoNotExist() { + assertNull("request attributes should not exist", RequestContextHolder.getRequestAttributes()); + } + + private void assertSetUpOutsideOfStelAttributeExists() { + RequestAttributes requestAttributes = assertRequestAttributesExist(); + Object setUpOutsideOfStel = requestAttributes.getAttribute(SET_UP_OUTSIDE_OF_STEL, + RequestAttributes.SCOPE_REQUEST); + assertNotNull(SET_UP_OUTSIDE_OF_STEL + " should exist as a request attribute", setUpOutsideOfStel); + } + + private void assertSetUpOutsideOfStelAttributeDoesNotExist() { + RequestAttributes requestAttributes = assertRequestAttributesExist(); + Object setUpOutsideOfStel = requestAttributes.getAttribute(SET_UP_OUTSIDE_OF_STEL, + RequestAttributes.SCOPE_REQUEST); + assertNull(SET_UP_OUTSIDE_OF_STEL + " should NOT exist as a request attribute", setUpOutsideOfStel); + } + + private void assertWebAppConfigTestCase() throws Exception { + listener.prepareTestInstance(testContext); + assertRequestAttributesExist(); + assertSetUpOutsideOfStelAttributeDoesNotExist(); + verify(testContext, times(1)).setAttribute(POPULATED_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + verify(testContext, times(1)).setAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + given(testContext.getAttribute(POPULATED_REQUEST_CONTEXT_HOLDER_ATTRIBUTE)).willReturn(Boolean.TRUE); + given(testContext.getAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE)).willReturn(Boolean.TRUE); + + listener.beforeTestMethod(testContext); + assertRequestAttributesExist(); + assertSetUpOutsideOfStelAttributeDoesNotExist(); + verify(testContext, times(1)).setAttribute(POPULATED_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + verify(testContext, times(1)).setAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE, Boolean.TRUE); + + listener.afterTestMethod(testContext); + verify(testContext).removeAttribute(POPULATED_REQUEST_CONTEXT_HOLDER_ATTRIBUTE); + verify(testContext).removeAttribute(RESET_REQUEST_CONTEXT_HOLDER_ATTRIBUTE); + assertRequestAttributesDoNotExist(); + } + + + static class LegacyWebTestCase { + } + + @WebAppConfiguration + static class AtWebAppConfigWebTestCase { + } + + static class NoAtWebAppConfigWebTestCase { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/WebAppConfigurationBootstrapWithTests.java b/spring-test/src/test/java/org/springframework/test/context/web/WebAppConfigurationBootstrapWithTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e61aaa9588686c3dd1e446229310e91dfbe08c4b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/WebAppConfigurationBootstrapWithTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.Resource; +import org.springframework.test.context.BootstrapWith; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.web.WebAppConfigurationBootstrapWithTests.CustomWebTestContextBootstrapper; +import org.springframework.web.context.WebApplicationContext; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * JUnit-based integration tests that verify support for loading a + * {@link WebApplicationContext} with a custom {@link WebTestContextBootstrapper}. + * + * @author Sam Brannen + * @author Phillip Webb + * @since 4.3 + */ +@RunWith(SpringRunner.class) +@ContextConfiguration +@WebAppConfiguration +@BootstrapWith(CustomWebTestContextBootstrapper.class) +public class WebAppConfigurationBootstrapWithTests { + + @Autowired + WebApplicationContext wac; + + + @Test + public void webApplicationContextIsLoaded() { + // from: src/test/webapp/resources/Spring.js + Resource resource = wac.getResource("/resources/Spring.js"); + assertNotNull(resource); + assertTrue(resource.exists()); + } + + + @Configuration + static class Config { + } + + /** + * Custom {@link WebTestContextBootstrapper} that requires {@code @WebAppConfiguration} + * but hard codes the resource base path. + */ + static class CustomWebTestContextBootstrapper extends WebTestContextBootstrapper { + + @Override + protected MergedContextConfiguration processMergedContextConfiguration(MergedContextConfiguration mergedConfig) { + return new WebMergedContextConfiguration(mergedConfig, "src/test/webapp"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/WebContextLoaderTestSuite.java b/spring-test/src/test/java/org/springframework/test/context/web/WebContextLoaderTestSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..8a57a6d1c0ba52a88250c14d7828181a67b78765 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/WebContextLoaderTestSuite.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import org.springframework.test.context.ContextLoader; +import org.springframework.web.context.WebApplicationContext; + +/** + * Convenience test suite for integration tests that verify support for + * {@link WebApplicationContext} {@linkplain ContextLoader context loaders} + * in the TestContext framework. + * + * @author Sam Brannen + * @since 3.2 + */ +@RunWith(Suite.class) +// Note: the following 'multi-line' layout is for enhanced code readability. +@SuiteClasses({// +BasicXmlWacTests.class,// + BasicAnnotationConfigWacTests.class,// + RequestAndSessionScopedBeansWacTests.class // +}) +public class WebContextLoaderTestSuite { +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/WebTestConfiguration.java b/spring-test/src/test/java/org/springframework/test/context/web/WebTestConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..1aab367a589196761f1ec7b2442bfa113cb12351 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/WebTestConfiguration.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; + +/** + * Custom composed annotation combining {@link WebAppConfiguration} and + * {@link ContextConfiguration} as meta-annotations. + * + * @author Sam Brannen + * @since 4.0 + */ +@WebAppConfiguration +@ContextConfiguration(classes = FooConfig.class) +@Retention(RetentionPolicy.RUNTIME) +public @interface WebTestConfiguration { +} + +@Configuration +class FooConfig { + + @Bean + public String foo() { + return "enigma"; + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/web/socket/WebSocketServletServerContainerFactoryBeanTests.java b/spring-test/src/test/java/org/springframework/test/context/web/socket/WebSocketServletServerContainerFactoryBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0dc43b20cf797fbe5a82f9fcbeea513be95ac460 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/web/socket/WebSocketServletServerContainerFactoryBeanTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.web.socket; + +import javax.websocket.server.ServerContainer; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; + +import static org.junit.Assert.*; + +/** + * Integration tests that validate support for {@link ServletServerContainerFactoryBean} + * in conjunction with {@link WebAppConfiguration @WebAppConfiguration} and the + * Spring TestContext Framework. + * + * @author Sam Brannen + * @since 4.3.1 + */ +@RunWith(SpringRunner.class) +@WebAppConfiguration +public class WebSocketServletServerContainerFactoryBeanTests { + + @Autowired + ServerContainer serverContainer; + + + @Test + public void servletServerContainerFactoryBeanSupport() { + assertEquals(42, serverContainer.getDefaultMaxTextMessageBufferSize()); + } + + + @Configuration + @EnableWebSocket + static class WebSocketConfig { + + @Bean + ServletServerContainerFactoryBean createWebSocketContainer() { + ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean(); + container.setMaxTextMessageBufferSize(42); + return container; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/jdbc/JdbcTestUtilsTests.java b/spring-test/src/test/java/org/springframework/test/jdbc/JdbcTestUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f5da8d01d5fa9785d235103cc950c92a62509307 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/jdbc/JdbcTestUtilsTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.jdbc; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import org.springframework.jdbc.core.JdbcTemplate; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Unit tests for {@link JdbcTestUtils}. + * + * @author Phillip Webb + * @since 2.5.4 + * @see JdbcTestUtilsIntegrationTests + */ +@RunWith(MockitoJUnitRunner.class) +public class JdbcTestUtilsTests { + + @Mock + private JdbcTemplate jdbcTemplate; + + + @Test + public void deleteWithoutWhereClause() throws Exception { + given(jdbcTemplate.update("DELETE FROM person")).willReturn(10); + int deleted = JdbcTestUtils.deleteFromTableWhere(jdbcTemplate, "person", null); + assertThat(deleted, equalTo(10)); + } + + @Test + public void deleteWithWhereClause() throws Exception { + given(jdbcTemplate.update("DELETE FROM person WHERE name = 'Bob' and age > 25")).willReturn(10); + int deleted = JdbcTestUtils.deleteFromTableWhere(jdbcTemplate, "person", "name = 'Bob' and age > 25"); + assertThat(deleted, equalTo(10)); + } + + @Test + public void deleteWithWhereClauseAndArguments() throws Exception { + given(jdbcTemplate.update("DELETE FROM person WHERE name = ? and age > ?", "Bob", 25)).willReturn(10); + int deleted = JdbcTestUtils.deleteFromTableWhere(jdbcTemplate, "person", "name = ? and age > ?", "Bob", 25); + assertThat(deleted, equalTo(10)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/transaction/TransactionTestUtils.java b/spring-test/src/test/java/org/springframework/test/transaction/TransactionTestUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..02ef2b9250983dc6e45d0bec4392329cd1564f7b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/transaction/TransactionTestUtils.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.transaction; + +import org.springframework.transaction.support.TransactionSynchronizationManager; + +/** + * Collection of JDK 1.4+ utilities for tests involving transactions. Intended + * for internal use within the Spring testing suite. + * + *

All {@code assert*()} methods throw {@link AssertionError}s. + * + * @author Sam Brannen + * @since 2.5 + */ +public abstract class TransactionTestUtils { + + /** + * Convenience method for determining if a transaction is active for the + * current {@link Thread}. + * @return {@code true} if a transaction is currently active + */ + public static boolean inTransaction() { + return TransactionSynchronizationManager.isActualTransactionActive(); + } + + /** + * Asserts whether or not a transaction is active for the current + * {@link Thread}. + * @param transactionExpected whether or not a transaction is expected + * @throws AssertionError if the supplied assertion fails + * @see #inTransaction() + */ + public static void assertInTransaction(boolean transactionExpected) { + if (transactionExpected) { + assertCondition(inTransaction(), "The current thread should be associated with a transaction."); + } + else { + assertCondition(!inTransaction(), "The current thread should not be associated with a transaction"); + } + } + + /** + * Fails by throwing an {@code AssertionError} with the supplied + * {@code message}. + * @param message the exception message to use + * @see #assertCondition(boolean, String) + */ + private static void fail(String message) throws AssertionError { + throw new AssertionError(message); + } + + /** + * Assert the provided boolean {@code condition}, throwing + * {@code AssertionError} with the supplied {@code message} if + * the test result is {@code false}. + * @param condition a boolean expression + * @param message the exception message to use if the assertion fails + * @throws AssertionError if condition is {@code false} + * @see #fail(String) + */ + private static void assertCondition(boolean condition, String message) throws AssertionError { + if (!condition) { + fail(message); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/AopTestUtilsTests.java b/spring-test/src/test/java/org/springframework/test/util/AopTestUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..93f57a5b325984977547959cb98c287e65f3b525 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/AopTestUtilsTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util; + +import org.junit.Test; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.aop.support.AopUtils; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.test.util.AopTestUtils.*; + +/** + * Unit tests for {@link AopTestUtils}. + * + * @author Sam Brannen + * @since 4.2 + */ +public class AopTestUtilsTests { + + private final FooImpl foo = new FooImpl(); + + + @Test(expected = IllegalArgumentException.class) + public void getTargetObjectForNull() { + getTargetObject(null); + } + + @Test + public void getTargetObjectForNonProxiedObject() { + Foo target = getTargetObject(foo); + assertSame(foo, target); + } + + @Test + public void getTargetObjectWrappedInSingleJdkDynamicProxy() { + Foo target = getTargetObject(jdkProxy(foo)); + assertSame(foo, target); + } + + @Test + public void getTargetObjectWrappedInSingleCglibProxy() { + Foo target = getTargetObject(cglibProxy(foo)); + assertSame(foo, target); + } + + @Test + public void getTargetObjectWrappedInDoubleJdkDynamicProxy() { + Foo target = getTargetObject(jdkProxy(jdkProxy(foo))); + assertNotSame(foo, target); + } + + @Test + public void getTargetObjectWrappedInDoubleCglibProxy() { + Foo target = getTargetObject(cglibProxy(cglibProxy(foo))); + assertNotSame(foo, target); + } + + @Test(expected = IllegalArgumentException.class) + public void getUltimateTargetObjectForNull() { + getUltimateTargetObject(null); + } + + @Test + public void getUltimateTargetObjectForNonProxiedObject() { + Foo target = getUltimateTargetObject(foo); + assertSame(foo, target); + } + + @Test + public void getUltimateTargetObjectWrappedInSingleJdkDynamicProxy() { + Foo target = getUltimateTargetObject(jdkProxy(foo)); + assertSame(foo, target); + } + + @Test + public void getUltimateTargetObjectWrappedInSingleCglibProxy() { + Foo target = getUltimateTargetObject(cglibProxy(foo)); + assertSame(foo, target); + } + + @Test + public void getUltimateTargetObjectWrappedInDoubleJdkDynamicProxy() { + Foo target = getUltimateTargetObject(jdkProxy(jdkProxy(foo))); + assertSame(foo, target); + } + + @Test + public void getUltimateTargetObjectWrappedInDoubleCglibProxy() { + Foo target = getUltimateTargetObject(cglibProxy(cglibProxy(foo))); + assertSame(foo, target); + } + + @Test + public void getUltimateTargetObjectWrappedInCglibProxyWrappedInJdkDynamicProxy() { + Foo target = getUltimateTargetObject(jdkProxy(cglibProxy(foo))); + assertSame(foo, target); + } + + @Test + public void getUltimateTargetObjectWrappedInCglibProxyWrappedInDoubleJdkDynamicProxy() { + Foo target = getUltimateTargetObject(jdkProxy(jdkProxy(cglibProxy(foo)))); + assertSame(foo, target); + } + + private Foo jdkProxy(Foo foo) { + ProxyFactory pf = new ProxyFactory(); + pf.setTarget(foo); + pf.addInterface(Foo.class); + Foo proxy = (Foo) pf.getProxy(); + assertTrue("Proxy is a JDK dynamic proxy", AopUtils.isJdkDynamicProxy(proxy)); + assertThat(proxy, instanceOf(Foo.class)); + return proxy; + } + + private Foo cglibProxy(Foo foo) { + ProxyFactory pf = new ProxyFactory(); + pf.setTarget(foo); + pf.setProxyTargetClass(true); + Foo proxy = (Foo) pf.getProxy(); + assertTrue("Proxy is a CGLIB proxy", AopUtils.isCglibProxy(proxy)); + assertThat(proxy, instanceOf(FooImpl.class)); + return proxy; + } + + + static interface Foo { + } + + static class FooImpl implements Foo { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/JsonPathExpectationsHelperTests.java b/spring-test/src/test/java/org/springframework/test/util/JsonPathExpectationsHelperTests.java new file mode 100644 index 0000000000000000000000000000000000000000..161b71f0592ea2238c42e66026c404574da669db --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/JsonPathExpectationsHelperTests.java @@ -0,0 +1,363 @@ +/* + * Copyright 2004-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static org.hamcrest.core.Is.is; + +/** + * Unit tests for {@link JsonPathExpectationsHelper}. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @since 3.2 + */ +public class JsonPathExpectationsHelperTests { + + private static final String CONTENT = "{" + // + "'str': 'foo', " + // + "'num': 5, " + // + "'bool': true, " + // + "'arr': [42], " + // + "'colorMap': {'red': 'rojo'}, " + // + "'whitespace': ' ', " + // + "'emptyString': '', " + // + "'emptyArray': [], " + // + "'emptyMap': {} " + // + "}"; + + private static final String SIMPSONS = "{ 'familyMembers': [ " + // + "{'name': 'Homer' }, " + // + "{'name': 'Marge' }, " + // + "{'name': 'Bart' }, " + // + "{'name': 'Lisa' }, " + // + "{'name': 'Maggie'} " + // + " ] }"; + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void exists() throws Exception { + new JsonPathExpectationsHelper("$.str").exists(CONTENT); + } + + @Test + public void existsForAnEmptyArray() throws Exception { + new JsonPathExpectationsHelper("$.emptyArray").exists(CONTENT); + } + + @Test + public void existsForAnEmptyMap() throws Exception { + new JsonPathExpectationsHelper("$.emptyMap").exists(CONTENT); + } + + @Test + public void existsForIndefinatePathWithResults() throws Exception { + new JsonPathExpectationsHelper("$.familyMembers[?(@.name == 'Bart')]").exists(SIMPSONS); + } + + @Test + public void existsForIndefinatePathWithEmptyResults() throws Exception { + String expression = "$.familyMembers[?(@.name == 'Dilbert')]"; + exception.expect(AssertionError.class); + exception.expectMessage("No value at JSON path \"" + expression + "\""); + new JsonPathExpectationsHelper(expression).exists(SIMPSONS); + } + + @Test + public void doesNotExist() throws Exception { + new JsonPathExpectationsHelper("$.bogus").doesNotExist(CONTENT); + } + + @Test + public void doesNotExistForAnEmptyArray() throws Exception { + String expression = "$.emptyArray"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected no value at JSON path \"" + expression + "\" but found: []"); + new JsonPathExpectationsHelper(expression).doesNotExist(CONTENT); + } + + @Test + public void doesNotExistForAnEmptyMap() throws Exception { + String expression = "$.emptyMap"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected no value at JSON path \"" + expression + "\" but found: {}"); + new JsonPathExpectationsHelper(expression).doesNotExist(CONTENT); + } + + @Test + public void doesNotExistForIndefinatePathWithResults() throws Exception { + String expression = "$.familyMembers[?(@.name == 'Bart')]"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected no value at JSON path \"" + expression + + "\" but found: [{\"name\":\"Bart\"}]"); + new JsonPathExpectationsHelper(expression).doesNotExist(SIMPSONS); + } + + @Test + public void doesNotExistForIndefinatePathWithEmptyResults() throws Exception { + new JsonPathExpectationsHelper("$.familyMembers[?(@.name == 'Dilbert')]").doesNotExist(SIMPSONS); + } + + @Test + public void assertValueIsEmptyForAnEmptyString() throws Exception { + new JsonPathExpectationsHelper("$.emptyString").assertValueIsEmpty(CONTENT); + } + + @Test + public void assertValueIsEmptyForAnEmptyArray() throws Exception { + new JsonPathExpectationsHelper("$.emptyArray").assertValueIsEmpty(CONTENT); + } + + @Test + public void assertValueIsEmptyForAnEmptyMap() throws Exception { + new JsonPathExpectationsHelper("$.emptyMap").assertValueIsEmpty(CONTENT); + } + + @Test + public void assertValueIsEmptyForIndefinatePathWithEmptyResults() throws Exception { + new JsonPathExpectationsHelper("$.familyMembers[?(@.name == 'Dilbert')]").assertValueIsEmpty(SIMPSONS); + } + + @Test + public void assertValueIsEmptyForIndefinatePathWithResults() throws Exception { + String expression = "$.familyMembers[?(@.name == 'Bart')]"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected an empty value at JSON path \"" + expression + + "\" but found: [{\"name\":\"Bart\"}]"); + new JsonPathExpectationsHelper(expression).assertValueIsEmpty(SIMPSONS); + } + + @Test + public void assertValueIsEmptyForWhitespace() throws Exception { + String expression = "$.whitespace"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected an empty value at JSON path \"" + expression + "\" but found: ' '"); + new JsonPathExpectationsHelper(expression).assertValueIsEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForString() throws Exception { + new JsonPathExpectationsHelper("$.str").assertValueIsNotEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForNumber() throws Exception { + new JsonPathExpectationsHelper("$.num").assertValueIsNotEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForBoolean() throws Exception { + new JsonPathExpectationsHelper("$.bool").assertValueIsNotEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForArray() throws Exception { + new JsonPathExpectationsHelper("$.arr").assertValueIsNotEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForMap() throws Exception { + new JsonPathExpectationsHelper("$.colorMap").assertValueIsNotEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForIndefinatePathWithResults() throws Exception { + new JsonPathExpectationsHelper("$.familyMembers[?(@.name == 'Bart')]").assertValueIsNotEmpty(SIMPSONS); + } + + @Test + public void assertValueIsNotEmptyForIndefinatePathWithEmptyResults() throws Exception { + String expression = "$.familyMembers[?(@.name == 'Dilbert')]"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a non-empty value at JSON path \"" + expression + "\" but found: []"); + new JsonPathExpectationsHelper(expression).assertValueIsNotEmpty(SIMPSONS); + } + + @Test + public void assertValueIsNotEmptyForAnEmptyString() throws Exception { + String expression = "$.emptyString"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a non-empty value at JSON path \"" + expression + "\" but found: ''"); + new JsonPathExpectationsHelper(expression).assertValueIsNotEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForAnEmptyArray() throws Exception { + String expression = "$.emptyArray"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a non-empty value at JSON path \"" + expression + "\" but found: []"); + new JsonPathExpectationsHelper(expression).assertValueIsNotEmpty(CONTENT); + } + + @Test + public void assertValueIsNotEmptyForAnEmptyMap() throws Exception { + String expression = "$.emptyMap"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a non-empty value at JSON path \"" + expression + "\" but found: {}"); + new JsonPathExpectationsHelper(expression).assertValueIsNotEmpty(CONTENT); + } + + @Test + public void hasJsonPath() { + new JsonPathExpectationsHelper("$.abc").hasJsonPath("{\"abc\": \"123\"}"); + } + + @Test + public void hasJsonPathWithNull() { + new JsonPathExpectationsHelper("$.abc").hasJsonPath("{\"abc\": null}"); + } + + @Test + public void hasJsonPathForIndefinatePathWithResults() { + new JsonPathExpectationsHelper("$.familyMembers[?(@.name == 'Bart')]").hasJsonPath(SIMPSONS); + } + + @Test + public void hasJsonPathForIndefinatePathWithEmptyResults() { + String expression = "$.familyMembers[?(@.name == 'Dilbert')]"; + exception.expect(AssertionError.class); + exception.expectMessage("No values for JSON path \"" + expression + "\""); + new JsonPathExpectationsHelper(expression).hasJsonPath(SIMPSONS); + } + + @Test // SPR-16339 + public void doesNotHaveJsonPath() { + new JsonPathExpectationsHelper("$.abc").doesNotHaveJsonPath("{}"); + } + + @Test // SPR-16339 + public void doesNotHaveJsonPathWithNull() { + exception.expect(AssertionError.class); + new JsonPathExpectationsHelper("$.abc").doesNotHaveJsonPath("{\"abc\": null}"); + } + + @Test + public void doesNotHaveJsonPathForIndefinatePathWithEmptyResults() { + new JsonPathExpectationsHelper("$.familyMembers[?(@.name == 'Dilbert')]").doesNotHaveJsonPath(SIMPSONS); + } + + @Test + public void doesNotHaveEmptyPathForIndefinatePathWithResults() { + String expression = "$.familyMembers[?(@.name == 'Bart')]"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected no values at JSON path \"" + expression + "\" " + + "but found: [{\"name\":\"Bart\"}]"); + new JsonPathExpectationsHelper(expression).doesNotHaveJsonPath(SIMPSONS); + } + + @Test + public void assertValue() throws Exception { + new JsonPathExpectationsHelper("$.num").assertValue(CONTENT, 5); + } + + @Test // SPR-14498 + public void assertValueWithNumberConversion() throws Exception { + new JsonPathExpectationsHelper("$.num").assertValue(CONTENT, 5.0); + } + + @Test // SPR-14498 + public void assertValueWithNumberConversionAndMatcher() throws Exception { + new JsonPathExpectationsHelper("$.num").assertValue(CONTENT, is(5.0), Double.class); + } + + @Test + public void assertValueIsString() throws Exception { + new JsonPathExpectationsHelper("$.str").assertValueIsString(CONTENT); + } + + @Test + public void assertValueIsStringForAnEmptyString() throws Exception { + new JsonPathExpectationsHelper("$.emptyString").assertValueIsString(CONTENT); + } + + @Test + public void assertValueIsStringForNonString() throws Exception { + String expression = "$.bool"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a string at JSON path \"" + expression + "\" but found: true"); + new JsonPathExpectationsHelper(expression).assertValueIsString(CONTENT); + } + + @Test + public void assertValueIsNumber() throws Exception { + new JsonPathExpectationsHelper("$.num").assertValueIsNumber(CONTENT); + } + + @Test + public void assertValueIsNumberForNonNumber() throws Exception { + String expression = "$.bool"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a number at JSON path \"" + expression + "\" but found: true"); + new JsonPathExpectationsHelper(expression).assertValueIsNumber(CONTENT); + } + + @Test + public void assertValueIsBoolean() throws Exception { + new JsonPathExpectationsHelper("$.bool").assertValueIsBoolean(CONTENT); + } + + @Test + public void assertValueIsBooleanForNonBoolean() throws Exception { + String expression = "$.num"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a boolean at JSON path \"" + expression + "\" but found: 5"); + new JsonPathExpectationsHelper(expression).assertValueIsBoolean(CONTENT); + } + + @Test + public void assertValueIsArray() throws Exception { + new JsonPathExpectationsHelper("$.arr").assertValueIsArray(CONTENT); + } + + @Test + public void assertValueIsArrayForAnEmptyArray() throws Exception { + new JsonPathExpectationsHelper("$.emptyArray").assertValueIsArray(CONTENT); + } + + @Test + public void assertValueIsArrayForNonArray() throws Exception { + String expression = "$.str"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected an array at JSON path \"" + expression + "\" but found: 'foo'"); + new JsonPathExpectationsHelper(expression).assertValueIsArray(CONTENT); + } + + @Test + public void assertValueIsMap() throws Exception { + new JsonPathExpectationsHelper("$.colorMap").assertValueIsMap(CONTENT); + } + + @Test + public void assertValueIsMapForAnEmptyMap() throws Exception { + new JsonPathExpectationsHelper("$.emptyMap").assertValueIsMap(CONTENT); + } + + @Test + public void assertValueIsMapForNonMap() throws Exception { + String expression = "$.str"; + exception.expect(AssertionError.class); + exception.expectMessage("Expected a map at JSON path \"" + expression + "\" but found: 'foo'"); + new JsonPathExpectationsHelper(expression).assertValueIsMap(CONTENT); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/MetaAnnotationUtilsTests.java b/spring-test/src/test/java/org/springframework/test/util/MetaAnnotationUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8a4853b64e51ae08582966d345b898b6dd525bf2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/MetaAnnotationUtilsTests.java @@ -0,0 +1,647 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.junit.Test; + +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import org.springframework.stereotype.Service; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.transaction.annotation.Transactional; + +import static org.junit.Assert.*; +import static org.springframework.test.util.MetaAnnotationUtils.*; + +/** + * Unit tests for {@link MetaAnnotationUtils}. + * + * @author Sam Brannen + * @since 4.0 + * @see OverriddenMetaAnnotationAttributesTests + */ +public class MetaAnnotationUtilsTests { + + private void assertAtComponentOnComposedAnnotation( + Class rootDeclaringClass, String name, Class composedAnnotationType) { + + assertAtComponentOnComposedAnnotation(rootDeclaringClass, rootDeclaringClass, name, composedAnnotationType); + } + + private void assertAtComponentOnComposedAnnotation( + Class startClass, Class rootDeclaringClass, String name, Class composedAnnotationType) { + + assertAtComponentOnComposedAnnotation(startClass, rootDeclaringClass, composedAnnotationType, name, composedAnnotationType); + } + + private void assertAtComponentOnComposedAnnotation(Class startClass, Class rootDeclaringClass, + Class declaringClass, String name, Class composedAnnotationType) { + + AnnotationDescriptor descriptor = findAnnotationDescriptor(startClass, Component.class); + assertNotNull("AnnotationDescriptor should not be null", descriptor); + assertEquals("rootDeclaringClass", rootDeclaringClass, descriptor.getRootDeclaringClass()); + assertEquals("declaringClass", declaringClass, descriptor.getDeclaringClass()); + assertEquals("annotationType", Component.class, descriptor.getAnnotationType()); + assertEquals("component name", name, descriptor.getAnnotation().value()); + assertNotNull("composedAnnotation should not be null", descriptor.getComposedAnnotation()); + assertEquals("composedAnnotationType", composedAnnotationType, descriptor.getComposedAnnotationType()); + } + + private void assertAtComponentOnComposedAnnotationForMultipleCandidateTypes( + Class startClass, String name, Class composedAnnotationType) { + + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes( + startClass, startClass, name, composedAnnotationType); + } + + private void assertAtComponentOnComposedAnnotationForMultipleCandidateTypes(Class startClass, + Class rootDeclaringClass, String name, Class composedAnnotationType) { + + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes( + startClass, rootDeclaringClass, composedAnnotationType, name, composedAnnotationType); + } + + @SuppressWarnings("unchecked") + private void assertAtComponentOnComposedAnnotationForMultipleCandidateTypes(Class startClass, + Class rootDeclaringClass, Class declaringClass, String name, + Class composedAnnotationType) { + + Class annotationType = Component.class; + UntypedAnnotationDescriptor descriptor = findAnnotationDescriptorForTypes( + startClass, Service.class, annotationType, Order.class, Transactional.class); + + assertNotNull("UntypedAnnotationDescriptor should not be null", descriptor); + assertEquals("rootDeclaringClass", rootDeclaringClass, descriptor.getRootDeclaringClass()); + assertEquals("declaringClass", declaringClass, descriptor.getDeclaringClass()); + assertEquals("annotationType", annotationType, descriptor.getAnnotationType()); + assertEquals("component name", name, ((Component) descriptor.getAnnotation()).value()); + assertNotNull("composedAnnotation should not be null", descriptor.getComposedAnnotation()); + assertEquals("composedAnnotationType", composedAnnotationType, descriptor.getComposedAnnotationType()); + } + + @Test + public void findAnnotationDescriptorWithNoAnnotationPresent() { + assertNull(findAnnotationDescriptor(NonAnnotatedInterface.class, Transactional.class)); + assertNull(findAnnotationDescriptor(NonAnnotatedClass.class, Transactional.class)); + } + + @Test + public void findAnnotationDescriptorWithInheritedAnnotationOnClass() { + // Note: @Transactional is inherited + assertEquals(InheritedAnnotationClass.class, + findAnnotationDescriptor(InheritedAnnotationClass.class, Transactional.class).getRootDeclaringClass()); + assertEquals(InheritedAnnotationClass.class, + findAnnotationDescriptor(SubInheritedAnnotationClass.class, Transactional.class).getRootDeclaringClass()); + } + + @Test + public void findAnnotationDescriptorWithInheritedAnnotationOnInterface() { + // Note: @Transactional is inherited + Transactional rawAnnotation = InheritedAnnotationInterface.class.getAnnotation(Transactional.class); + + AnnotationDescriptor descriptor = + findAnnotationDescriptor(InheritedAnnotationInterface.class, Transactional.class); + assertNotNull(descriptor); + assertEquals(InheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(InheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + + descriptor = findAnnotationDescriptor(SubInheritedAnnotationInterface.class, Transactional.class); + assertNotNull(descriptor); + assertEquals(SubInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(InheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + + descriptor = findAnnotationDescriptor(SubSubInheritedAnnotationInterface.class, Transactional.class); + assertNotNull(descriptor); + assertEquals(SubSubInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(InheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + } + + @Test + public void findAnnotationDescriptorForNonInheritedAnnotationOnClass() { + // Note: @Order is not inherited. + assertEquals(NonInheritedAnnotationClass.class, + findAnnotationDescriptor(NonInheritedAnnotationClass.class, Order.class).getRootDeclaringClass()); + assertEquals(NonInheritedAnnotationClass.class, + findAnnotationDescriptor(SubNonInheritedAnnotationClass.class, Order.class).getRootDeclaringClass()); + } + + @Test + public void findAnnotationDescriptorForNonInheritedAnnotationOnInterface() { + // Note: @Order is not inherited. + Order rawAnnotation = NonInheritedAnnotationInterface.class.getAnnotation(Order.class); + + AnnotationDescriptor descriptor = + findAnnotationDescriptor(NonInheritedAnnotationInterface.class, Order.class); + assertNotNull(descriptor); + assertEquals(NonInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(NonInheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + + descriptor = findAnnotationDescriptor(SubNonInheritedAnnotationInterface.class, Order.class); + assertNotNull(descriptor); + assertEquals(SubNonInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(NonInheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + } + + @Test + public void findAnnotationDescriptorWithMetaComponentAnnotation() { + assertAtComponentOnComposedAnnotation(HasMetaComponentAnnotation.class, "meta1", Meta1.class); + } + + @Test + public void findAnnotationDescriptorWithLocalAndMetaComponentAnnotation() { + Class annotationType = Component.class; + AnnotationDescriptor descriptor = findAnnotationDescriptor( + HasLocalAndMetaComponentAnnotation.class, annotationType); + + assertEquals(HasLocalAndMetaComponentAnnotation.class, descriptor.getRootDeclaringClass()); + assertEquals(annotationType, descriptor.getAnnotationType()); + assertNull(descriptor.getComposedAnnotation()); + assertNull(descriptor.getComposedAnnotationType()); + } + + @Test + public void findAnnotationDescriptorForInterfaceWithMetaAnnotation() { + assertAtComponentOnComposedAnnotation(InterfaceWithMetaAnnotation.class, "meta1", Meta1.class); + } + + @Test + public void findAnnotationDescriptorForClassWithMetaAnnotatedInterface() { + Component rawAnnotation = AnnotationUtils.findAnnotation(ClassWithMetaAnnotatedInterface.class, Component.class); + AnnotationDescriptor descriptor = + findAnnotationDescriptor(ClassWithMetaAnnotatedInterface.class, Component.class); + + assertNotNull(descriptor); + assertEquals(ClassWithMetaAnnotatedInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(Meta1.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + assertEquals(Meta1.class, descriptor.getComposedAnnotation().annotationType()); + } + + @Test + public void findAnnotationDescriptorForClassWithLocalMetaAnnotationAndAnnotatedSuperclass() { + AnnotationDescriptor descriptor = findAnnotationDescriptor( + MetaAnnotatedAndSuperAnnotatedContextConfigClass.class, ContextConfiguration.class); + + assertNotNull("AnnotationDescriptor should not be null", descriptor); + assertEquals("rootDeclaringClass", MetaAnnotatedAndSuperAnnotatedContextConfigClass.class, descriptor.getRootDeclaringClass()); + assertEquals("declaringClass", MetaConfig.class, descriptor.getDeclaringClass()); + assertEquals("annotationType", ContextConfiguration.class, descriptor.getAnnotationType()); + assertNotNull("composedAnnotation should not be null", descriptor.getComposedAnnotation()); + assertEquals("composedAnnotationType", MetaConfig.class, descriptor.getComposedAnnotationType()); + + assertArrayEquals("configured classes", new Class[] {String.class}, + descriptor.getAnnotationAttributes().getClassArray("classes")); + } + + @Test + public void findAnnotationDescriptorForClassWithLocalMetaAnnotationAndMetaAnnotatedInterface() { + assertAtComponentOnComposedAnnotation(ClassWithLocalMetaAnnotationAndMetaAnnotatedInterface.class, "meta2", Meta2.class); + } + + @Test + public void findAnnotationDescriptorForSubClassWithLocalMetaAnnotationAndMetaAnnotatedInterface() { + assertAtComponentOnComposedAnnotation(SubClassWithLocalMetaAnnotationAndMetaAnnotatedInterface.class, + ClassWithLocalMetaAnnotationAndMetaAnnotatedInterface.class, "meta2", Meta2.class); + } + + /** + * @since 4.0.3 + */ + @Test + public void findAnnotationDescriptorOnMetaMetaAnnotatedClass() { + Class startClass = MetaMetaAnnotatedClass.class; + assertAtComponentOnComposedAnnotation(startClass, startClass, Meta2.class, "meta2", MetaMeta.class); + } + + /** + * @since 4.0.3 + */ + @Test + public void findAnnotationDescriptorOnMetaMetaMetaAnnotatedClass() { + Class startClass = MetaMetaMetaAnnotatedClass.class; + assertAtComponentOnComposedAnnotation(startClass, startClass, Meta2.class, "meta2", MetaMetaMeta.class); + } + + /** + * @since 4.0.3 + */ + @Test + public void findAnnotationDescriptorOnAnnotatedClassWithMissingTargetMetaAnnotation() { + // InheritedAnnotationClass is NOT annotated or meta-annotated with @Component + AnnotationDescriptor descriptor = findAnnotationDescriptor( + InheritedAnnotationClass.class, Component.class); + assertNull("Should not find @Component on InheritedAnnotationClass", descriptor); + } + + /** + * @since 4.0.3 + */ + @Test + public void findAnnotationDescriptorOnMetaCycleAnnotatedClassWithMissingTargetMetaAnnotation() { + AnnotationDescriptor descriptor = findAnnotationDescriptor( + MetaCycleAnnotatedClass.class, Component.class); + assertNull("Should not find @Component on MetaCycleAnnotatedClass", descriptor); + } + + // ------------------------------------------------------------------------- + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesWithNoAnnotationPresent() { + assertNull(findAnnotationDescriptorForTypes(NonAnnotatedInterface.class, Transactional.class, Component.class)); + assertNull(findAnnotationDescriptorForTypes(NonAnnotatedClass.class, Transactional.class, Order.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesWithInheritedAnnotationOnClass() { + // Note: @Transactional is inherited + assertEquals(InheritedAnnotationClass.class, + findAnnotationDescriptorForTypes(InheritedAnnotationClass.class, Transactional.class).getRootDeclaringClass()); + assertEquals( + InheritedAnnotationClass.class, + findAnnotationDescriptorForTypes(SubInheritedAnnotationClass.class, Transactional.class).getRootDeclaringClass()); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesWithInheritedAnnotationOnInterface() { + // Note: @Transactional is inherited + Transactional rawAnnotation = InheritedAnnotationInterface.class.getAnnotation(Transactional.class); + + UntypedAnnotationDescriptor descriptor = + findAnnotationDescriptorForTypes(InheritedAnnotationInterface.class, Transactional.class); + assertNotNull(descriptor); + assertEquals(InheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(InheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + + descriptor = findAnnotationDescriptorForTypes(SubInheritedAnnotationInterface.class, Transactional.class); + assertNotNull(descriptor); + assertEquals(SubInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(InheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + + descriptor = findAnnotationDescriptorForTypes(SubSubInheritedAnnotationInterface.class, Transactional.class); + assertNotNull(descriptor); + assertEquals(SubSubInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(InheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesForNonInheritedAnnotationOnClass() { + // Note: @Order is not inherited. + assertEquals(NonInheritedAnnotationClass.class, + findAnnotationDescriptorForTypes(NonInheritedAnnotationClass.class, Order.class).getRootDeclaringClass()); + assertEquals(NonInheritedAnnotationClass.class, + findAnnotationDescriptorForTypes(SubNonInheritedAnnotationClass.class, Order.class).getRootDeclaringClass()); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesForNonInheritedAnnotationOnInterface() { + // Note: @Order is not inherited. + Order rawAnnotation = NonInheritedAnnotationInterface.class.getAnnotation(Order.class); + + UntypedAnnotationDescriptor descriptor = + findAnnotationDescriptorForTypes(NonInheritedAnnotationInterface.class, Order.class); + assertNotNull(descriptor); + assertEquals(NonInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(NonInheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + + descriptor = findAnnotationDescriptorForTypes(SubNonInheritedAnnotationInterface.class, Order.class); + assertNotNull(descriptor); + assertEquals(SubNonInheritedAnnotationInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(NonInheritedAnnotationInterface.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesWithLocalAndMetaComponentAnnotation() { + Class annotationType = Component.class; + UntypedAnnotationDescriptor descriptor = findAnnotationDescriptorForTypes( + HasLocalAndMetaComponentAnnotation.class, Transactional.class, annotationType, Order.class); + assertEquals(HasLocalAndMetaComponentAnnotation.class, descriptor.getRootDeclaringClass()); + assertEquals(annotationType, descriptor.getAnnotationType()); + assertNull(descriptor.getComposedAnnotation()); + assertNull(descriptor.getComposedAnnotationType()); + } + + @Test + public void findAnnotationDescriptorForTypesWithMetaComponentAnnotation() { + Class startClass = HasMetaComponentAnnotation.class; + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes(startClass, "meta1", Meta1.class); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesWithMetaAnnotationWithDefaultAttributes() { + Class startClass = MetaConfigWithDefaultAttributesTestCase.class; + Class annotationType = ContextConfiguration.class; + + UntypedAnnotationDescriptor descriptor = findAnnotationDescriptorForTypes(startClass, + Service.class, ContextConfiguration.class, Order.class, Transactional.class); + + assertNotNull(descriptor); + assertEquals(startClass, descriptor.getRootDeclaringClass()); + assertEquals(annotationType, descriptor.getAnnotationType()); + assertArrayEquals(new Class[] {}, ((ContextConfiguration) descriptor.getAnnotation()).value()); + assertArrayEquals(new Class[] {MetaConfig.DevConfig.class, MetaConfig.ProductionConfig.class}, + descriptor.getAnnotationAttributes().getClassArray("classes")); + assertNotNull(descriptor.getComposedAnnotation()); + assertEquals(MetaConfig.class, descriptor.getComposedAnnotationType()); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesWithMetaAnnotationWithOverriddenAttributes() { + Class startClass = MetaConfigWithOverriddenAttributesTestCase.class; + Class annotationType = ContextConfiguration.class; + + UntypedAnnotationDescriptor descriptor = findAnnotationDescriptorForTypes( + startClass, Service.class, ContextConfiguration.class, Order.class, Transactional.class); + + assertNotNull(descriptor); + assertEquals(startClass, descriptor.getRootDeclaringClass()); + assertEquals(annotationType, descriptor.getAnnotationType()); + assertArrayEquals(new Class[] {}, ((ContextConfiguration) descriptor.getAnnotation()).value()); + assertArrayEquals(new Class[] {MetaAnnotationUtilsTests.class}, + descriptor.getAnnotationAttributes().getClassArray("classes")); + assertNotNull(descriptor.getComposedAnnotation()); + assertEquals(MetaConfig.class, descriptor.getComposedAnnotationType()); + } + + @Test + public void findAnnotationDescriptorForTypesForInterfaceWithMetaAnnotation() { + Class startClass = InterfaceWithMetaAnnotation.class; + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes(startClass, "meta1", Meta1.class); + } + + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesForClassWithMetaAnnotatedInterface() { + Component rawAnnotation = AnnotationUtils.findAnnotation(ClassWithMetaAnnotatedInterface.class, Component.class); + + UntypedAnnotationDescriptor descriptor = findAnnotationDescriptorForTypes( + ClassWithMetaAnnotatedInterface.class, Service.class, Component.class, Order.class, Transactional.class); + + assertNotNull(descriptor); + assertEquals(ClassWithMetaAnnotatedInterface.class, descriptor.getRootDeclaringClass()); + assertEquals(Meta1.class, descriptor.getDeclaringClass()); + assertEquals(rawAnnotation, descriptor.getAnnotation()); + assertEquals(Meta1.class, descriptor.getComposedAnnotation().annotationType()); + } + + @Test + public void findAnnotationDescriptorForTypesForClassWithLocalMetaAnnotationAndMetaAnnotatedInterface() { + Class startClass = ClassWithLocalMetaAnnotationAndMetaAnnotatedInterface.class; + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes(startClass, "meta2", Meta2.class); + } + + @Test + public void findAnnotationDescriptorForTypesForSubClassWithLocalMetaAnnotationAndMetaAnnotatedInterface() { + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes( + SubClassWithLocalMetaAnnotationAndMetaAnnotatedInterface.class, + ClassWithLocalMetaAnnotationAndMetaAnnotatedInterface.class, "meta2", Meta2.class); + } + + /** + * @since 4.0.3 + */ + @Test + public void findAnnotationDescriptorForTypesOnMetaMetaAnnotatedClass() { + Class startClass = MetaMetaAnnotatedClass.class; + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes( + startClass, startClass, Meta2.class, "meta2", MetaMeta.class); + } + + /** + * @since 4.0.3 + */ + @Test + public void findAnnotationDescriptorForTypesOnMetaMetaMetaAnnotatedClass() { + Class startClass = MetaMetaMetaAnnotatedClass.class; + assertAtComponentOnComposedAnnotationForMultipleCandidateTypes( + startClass, startClass, Meta2.class, "meta2", MetaMetaMeta.class); + } + + /** + * @since 4.0.3 + */ + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesOnAnnotatedClassWithMissingTargetMetaAnnotation() { + // InheritedAnnotationClass is NOT annotated or meta-annotated with @Component, + // @Service, or @Order, but it is annotated with @Transactional. + UntypedAnnotationDescriptor descriptor = findAnnotationDescriptorForTypes( + InheritedAnnotationClass.class, Service.class, Component.class, Order.class); + assertNull("Should not find @Component on InheritedAnnotationClass", descriptor); + } + + /** + * @since 4.0.3 + */ + @Test + @SuppressWarnings("unchecked") + public void findAnnotationDescriptorForTypesOnMetaCycleAnnotatedClassWithMissingTargetMetaAnnotation() { + UntypedAnnotationDescriptor descriptor = findAnnotationDescriptorForTypes( + MetaCycleAnnotatedClass.class, Service.class, Component.class, Order.class); + assertNull("Should not find @Component on MetaCycleAnnotatedClass", descriptor); + } + + + // ------------------------------------------------------------------------- + + @Component(value = "meta1") + @Order + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + @Documented + static @interface Meta1 { + } + + @Component(value = "meta2") + @Transactional + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + @Documented + static @interface Meta2 { + } + + @Meta2 + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + @Documented + @interface MetaMeta { + } + + @MetaMeta + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + @Documented + @interface MetaMetaMeta { + } + + @MetaCycle3 + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.ANNOTATION_TYPE) + @Documented + @interface MetaCycle1 { + } + + @MetaCycle1 + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.ANNOTATION_TYPE) + @Documented + @interface MetaCycle2 { + } + + @MetaCycle2 + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + @Documented + @interface MetaCycle3 { + } + + @ContextConfiguration + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + @Documented + static @interface MetaConfig { + + static class DevConfig { + } + + static class ProductionConfig { + } + + + Class[] classes() default { DevConfig.class, ProductionConfig.class }; + } + + // ------------------------------------------------------------------------- + + @Meta1 + static class HasMetaComponentAnnotation { + } + + @Meta1 + @Component(value = "local") + @Meta2 + static class HasLocalAndMetaComponentAnnotation { + } + + @Meta1 + static interface InterfaceWithMetaAnnotation { + } + + static class ClassWithMetaAnnotatedInterface implements InterfaceWithMetaAnnotation { + } + + @Meta2 + static class ClassWithLocalMetaAnnotationAndMetaAnnotatedInterface implements InterfaceWithMetaAnnotation { + } + + static class SubClassWithLocalMetaAnnotationAndMetaAnnotatedInterface extends + ClassWithLocalMetaAnnotationAndMetaAnnotatedInterface { + } + + @MetaMeta + static class MetaMetaAnnotatedClass { + } + + @MetaMetaMeta + static class MetaMetaMetaAnnotatedClass { + } + + @MetaCycle3 + static class MetaCycleAnnotatedClass { + } + + @MetaConfig + public class MetaConfigWithDefaultAttributesTestCase { + } + + @MetaConfig(classes = MetaAnnotationUtilsTests.class) + public class MetaConfigWithOverriddenAttributesTestCase { + } + + // ------------------------------------------------------------------------- + + @Transactional + static interface InheritedAnnotationInterface { + } + + static interface SubInheritedAnnotationInterface extends InheritedAnnotationInterface { + } + + static interface SubSubInheritedAnnotationInterface extends SubInheritedAnnotationInterface { + } + + @Order + static interface NonInheritedAnnotationInterface { + } + + static interface SubNonInheritedAnnotationInterface extends NonInheritedAnnotationInterface { + } + + static class NonAnnotatedClass { + } + + static interface NonAnnotatedInterface { + } + + @Transactional + static class InheritedAnnotationClass { + } + + static class SubInheritedAnnotationClass extends InheritedAnnotationClass { + } + + @Order + static class NonInheritedAnnotationClass { + } + + static class SubNonInheritedAnnotationClass extends NonInheritedAnnotationClass { + } + + @ContextConfiguration(classes = Number.class) + static class AnnotatedContextConfigClass { + } + + @MetaConfig(classes = String.class) + static class MetaAnnotatedAndSuperAnnotatedContextConfigClass extends AnnotatedContextConfigClass { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/OverriddenMetaAnnotationAttributesTests.java b/spring-test/src/test/java/org/springframework/test/util/OverriddenMetaAnnotationAttributesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4765c1ed8647d789d3b3920f38f1a5281b609f85 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/OverriddenMetaAnnotationAttributesTests.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.Test; + +import org.springframework.core.annotation.AnnotationAttributes; +import org.springframework.test.context.ContextConfiguration; + +import static org.junit.Assert.*; +import static org.springframework.test.util.MetaAnnotationUtils.*; + +/** + * Unit tests for {@link MetaAnnotationUtils} that verify support for overridden + * meta-annotation attributes. + * + *

See SPR-10181. + * + * @author Sam Brannen + * @since 4.0 + * @see MetaAnnotationUtilsTests + */ +public class OverriddenMetaAnnotationAttributesTests { + + @Test + public void contextConfigurationValue() throws Exception { + Class declaringClass = MetaValueConfigTestCase.class; + AnnotationDescriptor descriptor = findAnnotationDescriptor(declaringClass, + ContextConfiguration.class); + assertNotNull(descriptor); + assertEquals(declaringClass, descriptor.getRootDeclaringClass()); + assertEquals(MetaValueConfig.class, descriptor.getComposedAnnotationType()); + assertEquals(ContextConfiguration.class, descriptor.getAnnotationType()); + assertNotNull(descriptor.getComposedAnnotation()); + assertEquals(MetaValueConfig.class, descriptor.getComposedAnnotationType()); + + // direct access to annotation value: + assertArrayEquals(new String[] { "foo.xml" }, descriptor.getAnnotation().value()); + } + + @Test + public void overriddenContextConfigurationValue() throws Exception { + Class declaringClass = OverriddenMetaValueConfigTestCase.class; + AnnotationDescriptor descriptor = findAnnotationDescriptor(declaringClass, + ContextConfiguration.class); + assertNotNull(descriptor); + assertEquals(declaringClass, descriptor.getRootDeclaringClass()); + assertEquals(MetaValueConfig.class, descriptor.getComposedAnnotationType()); + assertEquals(ContextConfiguration.class, descriptor.getAnnotationType()); + assertNotNull(descriptor.getComposedAnnotation()); + assertEquals(MetaValueConfig.class, descriptor.getComposedAnnotationType()); + + // direct access to annotation value: + assertArrayEquals(new String[] { "foo.xml" }, descriptor.getAnnotation().value()); + + // overridden attribute: + AnnotationAttributes attributes = descriptor.getAnnotationAttributes(); + + // NOTE: we would like to be able to override the 'value' attribute; however, + // Spring currently does not allow overrides for the 'value' attribute. + // See SPR-11393 for related discussions. + assertArrayEquals(new String[] { "foo.xml" }, attributes.getStringArray("value")); + } + + @Test + public void contextConfigurationLocationsAndInheritLocations() throws Exception { + Class declaringClass = MetaLocationsConfigTestCase.class; + AnnotationDescriptor descriptor = findAnnotationDescriptor(declaringClass, + ContextConfiguration.class); + assertNotNull(descriptor); + assertEquals(declaringClass, descriptor.getRootDeclaringClass()); + assertEquals(MetaLocationsConfig.class, descriptor.getComposedAnnotationType()); + assertEquals(ContextConfiguration.class, descriptor.getAnnotationType()); + assertNotNull(descriptor.getComposedAnnotation()); + assertEquals(MetaLocationsConfig.class, descriptor.getComposedAnnotationType()); + + // direct access to annotation attributes: + assertArrayEquals(new String[] { "foo.xml" }, descriptor.getAnnotation().locations()); + assertFalse(descriptor.getAnnotation().inheritLocations()); + } + + @Test + public void overriddenContextConfigurationLocationsAndInheritLocations() throws Exception { + Class declaringClass = OverriddenMetaLocationsConfigTestCase.class; + AnnotationDescriptor descriptor = findAnnotationDescriptor(declaringClass, + ContextConfiguration.class); + assertNotNull(descriptor); + assertEquals(declaringClass, descriptor.getRootDeclaringClass()); + assertEquals(MetaLocationsConfig.class, descriptor.getComposedAnnotationType()); + assertEquals(ContextConfiguration.class, descriptor.getAnnotationType()); + assertNotNull(descriptor.getComposedAnnotation()); + assertEquals(MetaLocationsConfig.class, descriptor.getComposedAnnotationType()); + + // direct access to annotation attributes: + assertArrayEquals(new String[] { "foo.xml" }, descriptor.getAnnotation().locations()); + assertFalse(descriptor.getAnnotation().inheritLocations()); + + // overridden attributes: + AnnotationAttributes attributes = descriptor.getAnnotationAttributes(); + assertArrayEquals(new String[] { "bar.xml" }, attributes.getStringArray("locations")); + assertTrue(attributes.getBoolean("inheritLocations")); + } + + + // ------------------------------------------------------------------------- + + @ContextConfiguration("foo.xml") + @Retention(RetentionPolicy.RUNTIME) + static @interface MetaValueConfig { + + String[] value() default {}; + } + + @MetaValueConfig + public static class MetaValueConfigTestCase { + } + + @MetaValueConfig("bar.xml") + public static class OverriddenMetaValueConfigTestCase { + } + + @ContextConfiguration(locations = "foo.xml", inheritLocations = false) + @Retention(RetentionPolicy.RUNTIME) + static @interface MetaLocationsConfig { + + String[] locations() default {}; + + boolean inheritLocations(); + } + + @MetaLocationsConfig(inheritLocations = true) + static class MetaLocationsConfigTestCase { + } + + @MetaLocationsConfig(locations = "bar.xml", inheritLocations = true) + static class OverriddenMetaLocationsConfigTestCase { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/ReflectionTestUtilsTests.java b/spring-test/src/test/java/org/springframework/test/util/ReflectionTestUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..82934530c5b2d2912c37d231d6eb9175fb3b69ed --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/ReflectionTestUtilsTests.java @@ -0,0 +1,425 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.aop.support.AopUtils; +import org.springframework.test.util.subpackage.Component; +import org.springframework.test.util.subpackage.LegacyEntity; +import org.springframework.test.util.subpackage.Person; +import org.springframework.test.util.subpackage.PersonEntity; +import org.springframework.test.util.subpackage.StaticFields; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.test.util.ReflectionTestUtils.*; + +/** + * Unit tests for {@link ReflectionTestUtils}. + * + * @author Sam Brannen + * @author Juergen Hoeller + */ +public class ReflectionTestUtilsTests { + + private static final Float PI = Float.valueOf((float) 22 / 7); + + private final Person person = new PersonEntity(); + + private final Component component = new Component(); + + private final LegacyEntity entity = new LegacyEntity(); + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Before + public void resetStaticFields() { + StaticFields.reset(); + } + + @Test + public void setFieldWithNullTargetObject() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(startsWith("Either targetObject or targetClass")); + setField((Object) null, "id", Long.valueOf(99)); + } + + @Test + public void getFieldWithNullTargetObject() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(startsWith("Either targetObject or targetClass")); + getField((Object) null, "id"); + } + + @Test + public void setFieldWithNullTargetClass() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(startsWith("Either targetObject or targetClass")); + setField((Class) null, "id", Long.valueOf(99)); + } + + @Test + public void getFieldWithNullTargetClass() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(startsWith("Either targetObject or targetClass")); + getField((Class) null, "id"); + } + + @Test + public void setFieldWithNullNameAndNullType() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(startsWith("Either name or type")); + setField(person, null, Long.valueOf(99), null); + } + + @Test + public void setFieldWithBogusName() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(startsWith("Could not find field 'bogus'")); + setField(person, "bogus", Long.valueOf(99), long.class); + } + + @Test + public void setFieldWithWrongType() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(startsWith("Could not find field")); + setField(person, "id", Long.valueOf(99), String.class); + } + + @Test + public void setFieldAndGetFieldForStandardUseCases() throws Exception { + assertSetFieldAndGetFieldBehavior(this.person); + } + + @Test + public void setFieldAndGetFieldViaJdkDynamicProxy() throws Exception { + ProxyFactory pf = new ProxyFactory(this.person); + pf.addInterface(Person.class); + Person proxy = (Person) pf.getProxy(); + assertTrue("Proxy is a JDK dynamic proxy", AopUtils.isJdkDynamicProxy(proxy)); + assertSetFieldAndGetFieldBehaviorForProxy(proxy, this.person); + } + + @Test + public void setFieldAndGetFieldViaCglibProxy() throws Exception { + ProxyFactory pf = new ProxyFactory(this.person); + pf.setProxyTargetClass(true); + Person proxy = (Person) pf.getProxy(); + assertTrue("Proxy is a CGLIB proxy", AopUtils.isCglibProxy(proxy)); + assertSetFieldAndGetFieldBehaviorForProxy(proxy, this.person); + } + + private static void assertSetFieldAndGetFieldBehavior(Person person) { + // Set reflectively + setField(person, "id", Long.valueOf(99), long.class); + setField(person, "name", "Tom"); + setField(person, "age", Integer.valueOf(42)); + setField(person, "eyeColor", "blue", String.class); + setField(person, "likesPets", Boolean.TRUE); + setField(person, "favoriteNumber", PI, Number.class); + + // Get reflectively + assertEquals(Long.valueOf(99), getField(person, "id")); + assertEquals("Tom", getField(person, "name")); + assertEquals(Integer.valueOf(42), getField(person, "age")); + assertEquals("blue", getField(person, "eyeColor")); + assertEquals(Boolean.TRUE, getField(person, "likesPets")); + assertEquals(PI, getField(person, "favoriteNumber")); + + // Get directly + assertEquals("ID (private field in a superclass)", 99, person.getId()); + assertEquals("name (protected field)", "Tom", person.getName()); + assertEquals("age (private field)", 42, person.getAge()); + assertEquals("eye color (package private field)", "blue", person.getEyeColor()); + assertEquals("'likes pets' flag (package private boolean field)", true, person.likesPets()); + assertEquals("'favorite number' (package field)", PI, person.getFavoriteNumber()); + } + + private static void assertSetFieldAndGetFieldBehaviorForProxy(Person proxy, Person target) { + assertSetFieldAndGetFieldBehavior(proxy); + + // Get directly from Target + assertEquals("ID (private field in a superclass)", 99, target.getId()); + assertEquals("name (protected field)", "Tom", target.getName()); + assertEquals("age (private field)", 42, target.getAge()); + assertEquals("eye color (package private field)", "blue", target.getEyeColor()); + assertEquals("'likes pets' flag (package private boolean field)", true, target.likesPets()); + assertEquals("'favorite number' (package field)", PI, target.getFavoriteNumber()); + } + + @Test + public void setFieldWithNullValuesForNonPrimitives() throws Exception { + // Fields must be non-null to start with + setField(person, "name", "Tom"); + setField(person, "eyeColor", "blue", String.class); + setField(person, "favoriteNumber", PI, Number.class); + assertNotNull(person.getName()); + assertNotNull(person.getEyeColor()); + assertNotNull(person.getFavoriteNumber()); + + // Set to null + setField(person, "name", null, String.class); + setField(person, "eyeColor", null, String.class); + setField(person, "favoriteNumber", null, Number.class); + + assertNull("name (protected field)", person.getName()); + assertNull("eye color (package private field)", person.getEyeColor()); + assertNull("'favorite number' (package field)", person.getFavoriteNumber()); + } + + @Test(expected = IllegalArgumentException.class) + public void setFieldWithNullValueForPrimitiveLong() throws Exception { + setField(person, "id", null, long.class); + } + + @Test(expected = IllegalArgumentException.class) + public void setFieldWithNullValueForPrimitiveInt() throws Exception { + setField(person, "age", null, int.class); + } + + @Test(expected = IllegalArgumentException.class) + public void setFieldWithNullValueForPrimitiveBoolean() throws Exception { + setField(person, "likesPets", null, boolean.class); + } + + @Test + public void setStaticFieldViaClass() throws Exception { + setField(StaticFields.class, "publicField", "xxx"); + setField(StaticFields.class, "privateField", "yyy"); + + assertEquals("public static field", "xxx", StaticFields.publicField); + assertEquals("private static field", "yyy", StaticFields.getPrivateField()); + } + + @Test + public void setStaticFieldViaClassWithExplicitType() throws Exception { + setField(StaticFields.class, "publicField", "xxx", String.class); + setField(StaticFields.class, "privateField", "yyy", String.class); + + assertEquals("public static field", "xxx", StaticFields.publicField); + assertEquals("private static field", "yyy", StaticFields.getPrivateField()); + } + + @Test + public void setStaticFieldViaInstance() throws Exception { + StaticFields staticFields = new StaticFields(); + setField(staticFields, null, "publicField", "xxx", null); + setField(staticFields, null, "privateField", "yyy", null); + + assertEquals("public static field", "xxx", StaticFields.publicField); + assertEquals("private static field", "yyy", StaticFields.getPrivateField()); + } + + @Test + public void getStaticFieldViaClass() throws Exception { + assertEquals("public static field", "public", getField(StaticFields.class, "publicField")); + assertEquals("private static field", "private", getField(StaticFields.class, "privateField")); + } + + @Test + public void getStaticFieldViaInstance() throws Exception { + StaticFields staticFields = new StaticFields(); + assertEquals("public static field", "public", getField(staticFields, "publicField")); + assertEquals("private static field", "private", getField(staticFields, "privateField")); + } + + @Test + public void invokeSetterMethodAndInvokeGetterMethodWithExplicitMethodNames() throws Exception { + invokeSetterMethod(person, "setId", Long.valueOf(1), long.class); + invokeSetterMethod(person, "setName", "Jerry", String.class); + invokeSetterMethod(person, "setAge", Integer.valueOf(33), int.class); + invokeSetterMethod(person, "setEyeColor", "green", String.class); + invokeSetterMethod(person, "setLikesPets", Boolean.FALSE, boolean.class); + invokeSetterMethod(person, "setFavoriteNumber", Integer.valueOf(42), Number.class); + + assertEquals("ID (protected method in a superclass)", 1, person.getId()); + assertEquals("name (private method)", "Jerry", person.getName()); + assertEquals("age (protected method)", 33, person.getAge()); + assertEquals("eye color (package private method)", "green", person.getEyeColor()); + assertEquals("'likes pets' flag (protected method for a boolean)", false, person.likesPets()); + assertEquals("'favorite number' (protected method for a Number)", Integer.valueOf(42), person.getFavoriteNumber()); + + assertEquals(Long.valueOf(1), invokeGetterMethod(person, "getId")); + assertEquals("Jerry", invokeGetterMethod(person, "getName")); + assertEquals(Integer.valueOf(33), invokeGetterMethod(person, "getAge")); + assertEquals("green", invokeGetterMethod(person, "getEyeColor")); + assertEquals(Boolean.FALSE, invokeGetterMethod(person, "likesPets")); + assertEquals(Integer.valueOf(42), invokeGetterMethod(person, "getFavoriteNumber")); + } + + @Test + public void invokeSetterMethodAndInvokeGetterMethodWithJavaBeanPropertyNames() throws Exception { + invokeSetterMethod(person, "id", Long.valueOf(99), long.class); + invokeSetterMethod(person, "name", "Tom"); + invokeSetterMethod(person, "age", Integer.valueOf(42)); + invokeSetterMethod(person, "eyeColor", "blue", String.class); + invokeSetterMethod(person, "likesPets", Boolean.TRUE); + invokeSetterMethod(person, "favoriteNumber", PI, Number.class); + + assertEquals("ID (protected method in a superclass)", 99, person.getId()); + assertEquals("name (private method)", "Tom", person.getName()); + assertEquals("age (protected method)", 42, person.getAge()); + assertEquals("eye color (package private method)", "blue", person.getEyeColor()); + assertEquals("'likes pets' flag (protected method for a boolean)", true, person.likesPets()); + assertEquals("'favorite number' (protected method for a Number)", PI, person.getFavoriteNumber()); + + assertEquals(Long.valueOf(99), invokeGetterMethod(person, "id")); + assertEquals("Tom", invokeGetterMethod(person, "name")); + assertEquals(Integer.valueOf(42), invokeGetterMethod(person, "age")); + assertEquals("blue", invokeGetterMethod(person, "eyeColor")); + assertEquals(Boolean.TRUE, invokeGetterMethod(person, "likesPets")); + assertEquals(PI, invokeGetterMethod(person, "favoriteNumber")); + } + + @Test + public void invokeSetterMethodWithNullValuesForNonPrimitives() throws Exception { + invokeSetterMethod(person, "name", null, String.class); + invokeSetterMethod(person, "eyeColor", null, String.class); + invokeSetterMethod(person, "favoriteNumber", null, Number.class); + + assertNull("name (private method)", person.getName()); + assertNull("eye color (package private method)", person.getEyeColor()); + assertNull("'favorite number' (protected method for a Number)", person.getFavoriteNumber()); + } + + @Test(expected = IllegalArgumentException.class) + public void invokeSetterMethodWithNullValueForPrimitiveLong() throws Exception { + invokeSetterMethod(person, "id", null, long.class); + } + + @Test(expected = IllegalArgumentException.class) + public void invokeSetterMethodWithNullValueForPrimitiveInt() throws Exception { + invokeSetterMethod(person, "age", null, int.class); + } + + @Test(expected = IllegalArgumentException.class) + public void invokeSetterMethodWithNullValueForPrimitiveBoolean() throws Exception { + invokeSetterMethod(person, "likesPets", null, boolean.class); + } + + @Test + public void invokeMethodWithAutoboxingAndUnboxing() { + // IntelliJ IDEA 11 won't accept int assignment here + Integer difference = invokeMethod(component, "subtract", 5, 2); + assertEquals("subtract(5, 2)", 3, difference.intValue()); + } + + @Test + @Ignore("[SPR-8644] findMethod() does not currently support var-args") + public void invokeMethodWithPrimitiveVarArgs() { + // IntelliJ IDEA 11 won't accept int assignment here + Integer sum = invokeMethod(component, "add", 1, 2, 3, 4); + assertEquals("add(1,2,3,4)", 10, sum.intValue()); + } + + @Test + public void invokeMethodWithPrimitiveVarArgsAsSingleArgument() { + // IntelliJ IDEA 11 won't accept int assignment here + Integer sum = invokeMethod(component, "add", new int[] { 1, 2, 3, 4 }); + assertEquals("add(1,2,3,4)", 10, sum.intValue()); + } + + @Test + public void invokeMethodSimulatingLifecycleEvents() { + assertNull("number", component.getNumber()); + assertNull("text", component.getText()); + + // Simulate autowiring a configuration method + invokeMethod(component, "configure", Integer.valueOf(42), "enigma"); + assertEquals("number should have been configured", Integer.valueOf(42), component.getNumber()); + assertEquals("text should have been configured", "enigma", component.getText()); + + // Simulate @PostConstruct life-cycle event + invokeMethod(component, "init"); + // assertions in init() should succeed + + // Simulate @PreDestroy life-cycle event + invokeMethod(component, "destroy"); + assertNull("number", component.getNumber()); + assertNull("text", component.getText()); + } + + @Test + public void invokeInitMethodBeforeAutowiring() { + exception.expect(IllegalStateException.class); + exception.expectMessage(equalTo("number must not be null")); + invokeMethod(component, "init"); + } + + @Test + public void invokeMethodWithIncompatibleArgumentTypes() { + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Method not found")); + invokeMethod(component, "subtract", "foo", 2.0); + } + + @Test + public void invokeMethodWithTooFewArguments() { + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Method not found")); + invokeMethod(component, "configure", Integer.valueOf(42)); + } + + @Test + public void invokeMethodWithTooManyArguments() { + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Method not found")); + invokeMethod(component, "configure", Integer.valueOf(42), "enigma", "baz", "quux"); + } + + @Test // SPR-14363 + public void getFieldOnLegacyEntityWithSideEffectsInToString() { + Object collaborator = getField(entity, "collaborator"); + assertNotNull(collaborator); + } + + @Test // SPR-9571 and SPR-14363 + public void setFieldOnLegacyEntityWithSideEffectsInToString() { + String testCollaborator = "test collaborator"; + setField(entity, "collaborator", testCollaborator, Object.class); + assertTrue(entity.toString().contains(testCollaborator)); + } + + @Test // SPR-14363 + public void invokeMethodOnLegacyEntityWithSideEffectsInToString() { + invokeMethod(entity, "configure", Integer.valueOf(42), "enigma"); + assertEquals("number should have been configured", Integer.valueOf(42), entity.getNumber()); + assertEquals("text should have been configured", "enigma", entity.getText()); + } + + @Test // SPR-14363 + public void invokeGetterMethodOnLegacyEntityWithSideEffectsInToString() { + Object collaborator = invokeGetterMethod(entity, "collaborator"); + assertNotNull(collaborator); + } + + @Test // SPR-14363 + public void invokeSetterMethodOnLegacyEntityWithSideEffectsInToString() { + String testCollaborator = "test collaborator"; + invokeSetterMethod(entity, "collaborator", testCollaborator); + assertTrue(entity.toString().contains(testCollaborator)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/XmlExpectationsHelperTests.java b/spring-test/src/test/java/org/springframework/test/util/XmlExpectationsHelperTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8f4690edca7b80e6574f19746707789cb5c8d947 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/XmlExpectationsHelperTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** + * Unit tests for {@link XmlExpectationsHelper}. + * + * @author Matthew Depue + */ +public class XmlExpectationsHelperTests { + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void assertXmlEqualForEqual() throws Exception { + final String control = "f1f2"; + final String test = "f1f2"; + + final XmlExpectationsHelper xmlHelper = new XmlExpectationsHelper(); + xmlHelper.assertXmlEqual(control, test); + } + + @Test + public void assertXmlEqualExceptionForIncorrectValue() throws Exception { + final String control = "f1f2"; + final String test = "notf1f2"; + + exception.expect(AssertionError.class); + exception.expectMessage(Matchers.startsWith("Body content Expected child 'field1'")); + + final XmlExpectationsHelper xmlHelper = new XmlExpectationsHelper(); + xmlHelper.assertXmlEqual(control, test); + } + + @Test + public void assertXmlEqualForOutOfOrder() throws Exception { + final String control = "f1f2"; + final String test = "f2f1"; + + final XmlExpectationsHelper xmlHelper = new XmlExpectationsHelper(); + xmlHelper.assertXmlEqual(control, test); + } + + @Test + public void assertXmlEqualExceptionForMoreEntries() throws Exception { + final String control = "f1f2"; + final String test = "f1f2f3"; + + exception.expect(AssertionError.class); + exception.expectMessage(Matchers.containsString("Expected child nodelist length '2' but was '3'")); + + final XmlExpectationsHelper xmlHelper = new XmlExpectationsHelper(); + xmlHelper.assertXmlEqual(control, test); + } + + @Test + public void assertXmlEqualExceptionForLessEntries() throws Exception { + final String control = "f1f2f3"; + final String test = "f1f2"; + + exception.expect(AssertionError.class); + exception.expectMessage(Matchers.containsString("Expected child nodelist length '3' but was '2'")); + + final XmlExpectationsHelper xmlHelper = new XmlExpectationsHelper(); + xmlHelper.assertXmlEqual(control, test); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/subpackage/Component.java b/spring-test/src/test/java/org/springframework/test/util/subpackage/Component.java new file mode 100644 index 0000000000000000000000000000000000000000..f0b41c658ac9ea6264e2b1bcc9fe37d831cd3363 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/subpackage/Component.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util.subpackage; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Simple POJO representing a component; intended for use in + * unit tests. + * + * @author Sam Brannen + * @since 3.1 + */ +public class Component { + + private Integer number; + private String text; + + + public Integer getNumber() { + return this.number; + } + + public String getText() { + return this.text; + } + + @Autowired + protected void configure(Integer number, String text) { + this.number = number; + this.text = text; + } + + @PostConstruct + protected void init() { + Assert.state(number != null, "number must not be null"); + Assert.state(StringUtils.hasText(text), "text must not be empty"); + } + + @PreDestroy + protected void destroy() { + this.number = null; + this.text = null; + } + + int subtract(int a, int b) { + return a - b; + } + + int add(int... args) { + int sum = 0; + for (int i = 0; i < args.length; i++) { + sum += args[i]; + } + return sum; + } + + int multiply(Integer... args) { + int product = 1; + for (int i = 0; i < args.length; i++) { + product *= args[i]; + } + return product; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/subpackage/LegacyEntity.java b/spring-test/src/test/java/org/springframework/test/util/subpackage/LegacyEntity.java new file mode 100644 index 0000000000000000000000000000000000000000..422ae7d9fb013d0752a714814fe05a1f548d0944 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/subpackage/LegacyEntity.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util.subpackage; + +import org.springframework.core.style.ToStringCreator; + +/** + * A legacy entity whose {@link #toString()} method has side effects; + * intended for use in unit tests. + * + * @author Sam Brannen + * @since 3.2 + */ +public class LegacyEntity { + + private Object collaborator = new Object() { + + @Override + public String toString() { + throw new LegacyEntityException( + "Invoking toString() on the default collaborator causes an undesirable side effect"); + } + }; + + private Integer number; + private String text; + + + public void configure(Integer number, String text) { + this.number = number; + this.text = text; + } + + public Integer getNumber() { + return this.number; + } + + public String getText() { + return this.text; + } + + public Object getCollaborator() { + return this.collaborator; + } + + public void setCollaborator(Object collaborator) { + this.collaborator = collaborator; + } + + @Override + public String toString() { + return new ToStringCreator(this)// + .append("collaborator", this.collaborator)// + .toString(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/subpackage/LegacyEntityException.java b/spring-test/src/test/java/org/springframework/test/util/subpackage/LegacyEntityException.java new file mode 100644 index 0000000000000000000000000000000000000000..a2504fb447a81f3fdb5f23b5d5ca3318868049d3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/subpackage/LegacyEntityException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util.subpackage; + +/** + * Exception thrown by a {@link LegacyEntity}. + * + * @author Sam Brannen + * @since 4.3.1 + */ +@SuppressWarnings("serial") +public class LegacyEntityException extends RuntimeException { + + public LegacyEntityException(String message) { + super(message); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/subpackage/PersistentEntity.java b/spring-test/src/test/java/org/springframework/test/util/subpackage/PersistentEntity.java new file mode 100644 index 0000000000000000000000000000000000000000..8e4e45636ac66366b9a5de1b4a00587a05b25de2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/subpackage/PersistentEntity.java @@ -0,0 +1,39 @@ +/* + * Copyright 2007-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util.subpackage; + +/** + * Abstract base class for persistent entities; intended for use in + * unit tests. + * + * @author Sam Brannen + * @since 2.5 + */ +public abstract class PersistentEntity { + + private long id; + + + public long getId() { + return this.id; + } + + protected void setId(long id) { + this.id = id; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/subpackage/Person.java b/spring-test/src/test/java/org/springframework/test/util/subpackage/Person.java new file mode 100644 index 0000000000000000000000000000000000000000..399e7ef3e55cc3c3c050e7798eea6bb75f61ce33 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/subpackage/Person.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util.subpackage; + +/** + * Interface representing a person entity; intended for use in unit tests. + * + *

The introduction of an interface is necessary in order to test support for + * JDK dynamic proxies. + * + * @author Sam Brannen + * @since 4.3 + */ +public interface Person { + + long getId(); + + String getName(); + + int getAge(); + + String getEyeColor(); + + boolean likesPets(); + + Number getFavoriteNumber(); + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/subpackage/PersonEntity.java b/spring-test/src/test/java/org/springframework/test/util/subpackage/PersonEntity.java new file mode 100644 index 0000000000000000000000000000000000000000..df45b57f06e515d19a630580f11eb40f0b16c856 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/subpackage/PersonEntity.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util.subpackage; + +import org.springframework.core.style.ToStringCreator; + +/** + * Concrete subclass of {@link PersistentEntity} representing a person + * entity; intended for use in unit tests. + * + * @author Sam Brannen + * @since 2.5 + */ +public class PersonEntity extends PersistentEntity implements Person { + + protected String name; + + private int age; + + String eyeColor; + + boolean likesPets = false; + + private Number favoriteNumber; + + + public String getName() { + return this.name; + } + + @SuppressWarnings("unused") + private void setName(final String name) { + this.name = name; + } + + public int getAge() { + return this.age; + } + + protected void setAge(final int age) { + this.age = age; + } + + public String getEyeColor() { + return this.eyeColor; + } + + void setEyeColor(final String eyeColor) { + this.eyeColor = eyeColor; + } + + public boolean likesPets() { + return this.likesPets; + } + + protected void setLikesPets(final boolean likesPets) { + this.likesPets = likesPets; + } + + public Number getFavoriteNumber() { + return this.favoriteNumber; + } + + protected void setFavoriteNumber(Number favoriteNumber) { + this.favoriteNumber = favoriteNumber; + } + + @Override + public String toString() { + // @formatter:off + return new ToStringCreator(this) + .append("id", this.getId()) + .append("name", this.name) + .append("age", this.age) + .append("eyeColor", this.eyeColor) + .append("likesPets", this.likesPets) + .append("favoriteNumber", this.favoriteNumber) + .toString(); + // @formatter:on + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/util/subpackage/StaticFields.java b/spring-test/src/test/java/org/springframework/test/util/subpackage/StaticFields.java new file mode 100644 index 0000000000000000000000000000000000000000..5a4f2cac0b6e8baa92dd9633be3f68bec20c0727 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/util/subpackage/StaticFields.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.util.subpackage; + +/** + * Simple class with static fields; intended for use in unit tests. + * + * @author Sam Brannen + * @since 4.2 + */ +public class StaticFields { + + public static String publicField = "public"; + + private static String privateField = "private"; + + + public static void reset() { + publicField = "public"; + privateField = "private"; + } + + public static String getPrivateField() { + return privateField; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/Person.java b/spring-test/src/test/java/org/springframework/test/web/Person.java new file mode 100644 index 0000000000000000000000000000000000000000..61e83c129be9451dbcceebc13ea264bc1b4e1b0d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/Person.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web; + +import javax.validation.constraints.NotNull; +import javax.xml.bind.annotation.XmlRootElement; + +import org.springframework.util.ObjectUtils; + +@XmlRootElement +public class Person { + + @NotNull + private String name; + + private double someDouble; + + private boolean someBoolean; + + public Person() { + } + + public Person(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public Person setName(String name) { + this.name = name; + return this; + } + + public double getSomeDouble() { + return someDouble; + } + + public Person setSomeDouble(double someDouble) { + this.someDouble = someDouble; + return this; + } + + public boolean isSomeBoolean() { + return someBoolean; + } + + public Person setSomeBoolean(boolean someBoolean) { + this.someBoolean = someBoolean; + return this; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof Person)) { + return false; + } + Person otherPerson = (Person) other; + return (ObjectUtils.nullSafeEquals(this.name, otherPerson.name) && + ObjectUtils.nullSafeEquals(this.someDouble, otherPerson.someDouble) && + ObjectUtils.nullSafeEquals(this.someBoolean, otherPerson.someBoolean)); + } + + @Override + public int hashCode() { + return Person.class.hashCode(); + } + + @Override + public String toString() { + return "Person [name=" + this.name + ", someDouble=" + this.someDouble + + ", someBoolean=" + this.someBoolean + "]"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/DefaultRequestExpectationTests.java b/spring-test/src/test/java/org/springframework/test/web/client/DefaultRequestExpectationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..002f31d4f5b507c84e98c40802e352acb29456af --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/DefaultRequestExpectationTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client; + +import java.net.URI; +import java.net.URISyntaxException; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; + +import static org.junit.Assert.*; +import static org.springframework.http.HttpMethod.*; +import static org.springframework.test.web.client.ExpectedCount.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Unit tests for {@link DefaultRequestExpectation}. + * @author Rossen Stoyanchev + */ +public class DefaultRequestExpectationTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Test + public void match() throws Exception { + RequestExpectation expectation = new DefaultRequestExpectation(once(), requestTo("/foo")); + expectation.match(createRequest(GET, "/foo")); + } + + @Test + public void matchWithFailedExpectation() throws Exception { + RequestExpectation expectation = new DefaultRequestExpectation(once(), requestTo("/foo")); + expectation.andExpect(method(POST)); + + this.thrown.expectMessage("Unexpected HttpMethod expected: but was:"); + expectation.match(createRequest(GET, "/foo")); + } + + @Test + public void hasRemainingCount() { + RequestExpectation expectation = new DefaultRequestExpectation(twice(), requestTo("/foo")); + expectation.andRespond(withSuccess()); + + expectation.incrementAndValidate(); + assertTrue(expectation.hasRemainingCount()); + + expectation.incrementAndValidate(); + assertFalse(expectation.hasRemainingCount()); + } + + @Test + public void isSatisfied() { + RequestExpectation expectation = new DefaultRequestExpectation(twice(), requestTo("/foo")); + expectation.andRespond(withSuccess()); + + expectation.incrementAndValidate(); + assertFalse(expectation.isSatisfied()); + + expectation.incrementAndValidate(); + assertTrue(expectation.isSatisfied()); + } + + + @SuppressWarnings("deprecation") + private ClientHttpRequest createRequest(HttpMethod method, String url) { + try { + return new org.springframework.mock.http.client.MockAsyncClientHttpRequest(method, new URI(url)); + } + catch (URISyntaxException ex) { + throw new IllegalStateException(ex); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/MockRestServiceServerTests.java b/spring-test/src/test/java/org/springframework/test/web/client/MockRestServiceServerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cae4a088bd41fd633db604172c486f9c2e4d508c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/MockRestServiceServerTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client; + +import java.net.SocketException; + +import org.junit.Test; + +import org.springframework.test.web.client.MockRestServiceServer.MockRestServiceServerBuilder; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; +import static org.springframework.http.HttpMethod.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Unit tests for {@link MockRestServiceServer}. + * + * @author Rossen Stoyanchev + */ +public class MockRestServiceServerTests { + + private final RestTemplate restTemplate = new RestTemplate(); + + + @Test + public void buildMultipleTimes() { + MockRestServiceServerBuilder builder = MockRestServiceServer.bindTo(this.restTemplate); + + MockRestServiceServer server = builder.build(); + server.expect(requestTo("/foo")).andRespond(withSuccess()); + this.restTemplate.getForObject("/foo", Void.class); + server.verify(); + + server = builder.ignoreExpectOrder(true).build(); + server.expect(requestTo("/foo")).andRespond(withSuccess()); + server.expect(requestTo("/bar")).andRespond(withSuccess()); + this.restTemplate.getForObject("/bar", Void.class); + this.restTemplate.getForObject("/foo", Void.class); + server.verify(); + + server = builder.build(); + server.expect(requestTo("/bar")).andRespond(withSuccess()); + this.restTemplate.getForObject("/bar", Void.class); + server.verify(); + } + + @Test(expected = AssertionError.class) + public void exactExpectOrder() { + MockRestServiceServer server = MockRestServiceServer.bindTo(this.restTemplate) + .ignoreExpectOrder(false).build(); + + server.expect(requestTo("/foo")).andRespond(withSuccess()); + server.expect(requestTo("/bar")).andRespond(withSuccess()); + this.restTemplate.getForObject("/bar", Void.class); + } + + @Test + public void ignoreExpectOrder() { + MockRestServiceServer server = MockRestServiceServer.bindTo(this.restTemplate) + .ignoreExpectOrder(true).build(); + + server.expect(requestTo("/foo")).andRespond(withSuccess()); + server.expect(requestTo("/bar")).andRespond(withSuccess()); + this.restTemplate.getForObject("/bar", Void.class); + this.restTemplate.getForObject("/foo", Void.class); + server.verify(); + } + + @Test + public void resetAndReuseServer() { + MockRestServiceServer server = MockRestServiceServer.bindTo(this.restTemplate).build(); + + server.expect(requestTo("/foo")).andRespond(withSuccess()); + this.restTemplate.getForObject("/foo", Void.class); + server.verify(); + server.reset(); + + server.expect(requestTo("/bar")).andRespond(withSuccess()); + this.restTemplate.getForObject("/bar", Void.class); + server.verify(); + } + + @Test + public void resetAndReuseServerWithUnorderedExpectationManager() { + MockRestServiceServer server = MockRestServiceServer.bindTo(this.restTemplate) + .ignoreExpectOrder(true).build(); + + server.expect(requestTo("/foo")).andRespond(withSuccess()); + this.restTemplate.getForObject("/foo", Void.class); + server.verify(); + server.reset(); + + server.expect(requestTo("/foo")).andRespond(withSuccess()); + server.expect(requestTo("/bar")).andRespond(withSuccess()); + this.restTemplate.getForObject("/bar", Void.class); + this.restTemplate.getForObject("/foo", Void.class); + server.verify(); + } + + @Test // SPR-16132 + public void followUpRequestAfterFailure() { + MockRestServiceServer server = MockRestServiceServer.bindTo(this.restTemplate).build(); + + server.expect(requestTo("/some-service/some-endpoint")) + .andRespond(request -> { throw new SocketException("pseudo network error"); }); + + server.expect(requestTo("/reporting-service/report-error")) + .andExpect(method(POST)).andRespond(withSuccess()); + + try { + this.restTemplate.getForEntity("/some-service/some-endpoint", String.class); + fail("Expected exception"); + } + catch (Exception ex) { + this.restTemplate.postForEntity("/reporting-service/report-error", ex.toString(), String.class); + } + + server.verify(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/SimpleRequestExpectationManagerTests.java b/spring-test/src/test/java/org/springframework/test/web/client/SimpleRequestExpectationManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a27f42bc83395e43b120b202132e717cc9edced9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/SimpleRequestExpectationManagerTests.java @@ -0,0 +1,211 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client; + +import java.net.SocketException; +import java.net.URI; +import java.net.URISyntaxException; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.mock.http.client.MockClientHttpRequest; + +import static org.junit.Assert.*; +import static org.springframework.http.HttpMethod.*; +import static org.springframework.test.web.client.ExpectedCount.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Unit tests for {@link SimpleRequestExpectationManager}. + * + * @author Rossen Stoyanchev + */ +public class SimpleRequestExpectationManagerTests { + + private final SimpleRequestExpectationManager manager = new SimpleRequestExpectationManager(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Test + public void unexpectedRequest() throws Exception { + try { + this.manager.validateRequest(createRequest(GET, "/foo")); + } + catch (AssertionError error) { + assertEquals("No further requests expected: HTTP GET /foo\n" + + "0 request(s) executed.\n", error.getMessage()); + } + } + + @Test + public void zeroExpectedRequests() throws Exception { + this.manager.verify(); + } + + @Test + public void sequentialRequests() throws Exception { + this.manager.expectRequest(once(), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(once(), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.verify(); + } + + @Test + public void sequentialRequestsTooMany() throws Exception { + this.manager.expectRequest(max(1), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(max(1), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.thrown.expectMessage("No further requests expected: HTTP GET /baz\n" + + "2 request(s) executed:\n" + + "GET /foo\n" + + "GET /bar\n"); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/baz")); + } + + @Test + public void sequentialRequestsTooFew() throws Exception { + this.manager.expectRequest(min(1), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(min(1), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.thrown.expectMessage("Further request(s) expected leaving 1 unsatisfied expectation(s).\n" + + "1 request(s) executed:\nGET /foo\n"); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.verify(); + } + + @Test + public void repeatedRequests() throws Exception { + this.manager.expectRequest(times(3), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(times(3), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.verify(); + } + + @Test + public void repeatedRequestsTooMany() throws Exception { + this.manager.expectRequest(max(2), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(max(2), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.thrown.expectMessage("No further requests expected: HTTP GET /foo\n" + + "4 request(s) executed:\n" + + "GET /foo\n" + + "GET /bar\n" + + "GET /foo\n" + + "GET /bar\n"); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + } + + @Test + public void repeatedRequestsTooFew() throws Exception { + this.manager.expectRequest(min(2), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(min(2), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.thrown.expectMessage("3 request(s) executed:\n" + + "GET /foo\n" + + "GET /bar\n" + + "GET /foo\n"); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.verify(); + } + + @Test + public void repeatedRequestsNotInOrder() throws Exception { + this.manager.expectRequest(twice(), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(twice(), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(twice(), requestTo("/baz")).andExpect(method(GET)).andRespond(withSuccess()); + + this.thrown.expectMessage("Unexpected HttpMethod expected: but was:"); + this.manager.validateRequest(createRequest(POST, "/foo")); + } + + @Test // SPR-15672 + public void sequentialRequestsWithDifferentCount() throws Exception { + this.manager.expectRequest(times(2), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(once(), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + } + + @Test // SPR-15719 + public void repeatedRequestsInSequentialOrder() throws Exception { + this.manager.expectRequest(times(2), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(times(2), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/bar")); + } + + @Test // SPR-16132 + public void sequentialRequestsWithFirstFailing() throws Exception { + this.manager.expectRequest(once(), requestTo("/foo")). + andExpect(method(GET)).andRespond(request -> { throw new SocketException("pseudo network error"); }); + this.manager.expectRequest(once(), requestTo("/handle-error")). + andExpect(method(POST)).andRespond(withSuccess()); + + try { + this.manager.validateRequest(createRequest(GET, "/foo")); + fail("Expected SocketException"); + } + catch (SocketException ex) { + // expected + } + this.manager.validateRequest(createRequest(POST, "/handle-error")); + this.manager.verify(); + } + + + private ClientHttpRequest createRequest(HttpMethod method, String url) { + try { + return new MockClientHttpRequest(method, new URI(url)); + } + catch (URISyntaxException ex) { + throw new IllegalStateException(ex); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/UnorderedRequestExpectationManagerTests.java b/spring-test/src/test/java/org/springframework/test/web/client/UnorderedRequestExpectationManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..60619db3917536de2edb58e96150e1e00ccf82b0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/UnorderedRequestExpectationManagerTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client; + +import java.net.URI; +import java.net.URISyntaxException; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; + +import static org.junit.Assert.*; +import static org.springframework.http.HttpMethod.*; +import static org.springframework.test.web.client.ExpectedCount.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Unit tests for {@link UnorderedRequestExpectationManager}. + * + * @author Rossen Stoyanchev + */ +public class UnorderedRequestExpectationManagerTests { + + private UnorderedRequestExpectationManager manager = new UnorderedRequestExpectationManager(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Test + public void unexpectedRequest() throws Exception { + try { + this.manager.validateRequest(createRequest(GET, "/foo")); + } + catch (AssertionError error) { + assertEquals("No further requests expected: HTTP GET /foo\n" + + "0 request(s) executed.\n", error.getMessage()); + } + } + + @Test + public void zeroExpectedRequests() throws Exception { + this.manager.verify(); + } + + @Test + public void multipleRequests() throws Exception { + this.manager.expectRequest(once(), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(once(), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.verify(); + } + + @Test + public void repeatedRequests() throws Exception { + this.manager.expectRequest(twice(), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(twice(), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.verify(); + } + + @Test + public void repeatedRequestsTooMany() throws Exception { + this.manager.expectRequest(max(2), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(max(2), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.thrown.expectMessage("No further requests expected: HTTP GET /foo\n" + + "4 request(s) executed:\n" + + "GET /bar\n" + + "GET /foo\n" + + "GET /bar\n" + + "GET /foo\n"); + + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/foo")); + } + + @Test + public void repeatedRequestsTooFew() throws Exception { + this.manager.expectRequest(min(2), requestTo("/foo")).andExpect(method(GET)).andRespond(withSuccess()); + this.manager.expectRequest(min(2), requestTo("/bar")).andExpect(method(GET)).andRespond(withSuccess()); + + this.thrown.expectMessage("3 request(s) executed:\n" + + "GET /bar\n" + + "GET /foo\n" + + "GET /foo\n"); + + this.manager.validateRequest(createRequest(GET, "/bar")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.validateRequest(createRequest(GET, "/foo")); + this.manager.verify(); + } + + + @SuppressWarnings("deprecation") + private ClientHttpRequest createRequest(HttpMethod method, String url) { + try { + return new org.springframework.mock.http.client.MockAsyncClientHttpRequest(method, new URI(url)); + } + catch (URISyntaxException ex) { + throw new IllegalStateException(ex); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/match/ContentRequestMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/client/match/ContentRequestMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7a75928314f10a9c992d797bb8432ee5b92a7f6b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/match/ContentRequestMatchersTests.java @@ -0,0 +1,193 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.match; + +import java.nio.charset.StandardCharsets; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.mock.http.client.MockClientHttpRequest; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.hamcrest.Matchers.*; + +/** + * Unit tests for {@link ContentRequestMatchers}. + * + * @author Rossen Stoyanchev + */ +public class ContentRequestMatchersTests { + + private MockClientHttpRequest request; + + + @Before + public void setUp() { + this.request = new MockClientHttpRequest(); + } + + + @Test + public void testContentType() throws Exception { + this.request.getHeaders().setContentType(MediaType.APPLICATION_JSON); + + MockRestRequestMatchers.content().contentType("application/json").match(this.request); + MockRestRequestMatchers.content().contentType(MediaType.APPLICATION_JSON).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testContentTypeNoMatch1() throws Exception { + this.request.getHeaders().setContentType(MediaType.APPLICATION_JSON); + + MockRestRequestMatchers.content().contentType("application/xml").match(this.request); + } + + @Test(expected = AssertionError.class) + public void testContentTypeNoMatch2() throws Exception { + this.request.getHeaders().setContentType(MediaType.APPLICATION_JSON); + + MockRestRequestMatchers.content().contentType(MediaType.APPLICATION_ATOM_XML).match(this.request); + } + + @Test + public void testString() throws Exception { + this.request.getBody().write("test".getBytes()); + + MockRestRequestMatchers.content().string("test").match(this.request); + } + + @Test(expected = AssertionError.class) + public void testStringNoMatch() throws Exception { + this.request.getBody().write("test".getBytes()); + + MockRestRequestMatchers.content().string("Test").match(this.request); + } + + @Test + public void testBytes() throws Exception { + byte[] content = "test".getBytes(); + this.request.getBody().write(content); + + MockRestRequestMatchers.content().bytes(content).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testBytesNoMatch() throws Exception { + this.request.getBody().write("test".getBytes()); + + MockRestRequestMatchers.content().bytes("Test".getBytes()).match(this.request); + } + + @Test + public void testFormData() throws Exception { + String contentType = "application/x-www-form-urlencoded;charset=UTF-8"; + String body = "name+1=value+1&name+2=value+A&name+2=value+B&name+3"; + + this.request.getHeaders().setContentType(MediaType.parseMediaType(contentType)); + this.request.getBody().write(body.getBytes(StandardCharsets.UTF_8)); + + MultiValueMap map = new LinkedMultiValueMap<>(); + map.add("name 1", "value 1"); + map.add("name 2", "value A"); + map.add("name 2", "value B"); + map.add("name 3", null); + MockRestRequestMatchers.content().formData(map).match(this.request); + } + + @Test + public void testXml() throws Exception { + String content = "bazbazz"; + this.request.getBody().write(content.getBytes()); + + MockRestRequestMatchers.content().xml(content).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testXmlNoMatch() throws Exception { + this.request.getBody().write("11".getBytes()); + + MockRestRequestMatchers.content().xml("22").match(this.request); + } + + @Test + public void testNodeMatcher() throws Exception { + String content = "baz"; + this.request.getBody().write(content.getBytes()); + + MockRestRequestMatchers.content().node(hasXPath("/foo/bar")).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testNodeMatcherNoMatch() throws Exception { + String content = "baz"; + this.request.getBody().write(content.getBytes()); + + MockRestRequestMatchers.content().node(hasXPath("/foo/bar/bar")).match(this.request); + } + + @Test + public void testJsonLenientMatch() throws Exception { + String content = "{\n \"foo array\":[\"first\",\"second\"] , \"someExtraProperty\": \"which is allowed\" \n}"; + this.request.getBody().write(content.getBytes()); + + MockRestRequestMatchers.content().json("{\n \"foo array\":[\"second\",\"first\"] \n}") + .match(this.request); + MockRestRequestMatchers.content().json("{\n \"foo array\":[\"second\",\"first\"] \n}", false) + .match(this.request); + } + + @Test + public void testJsonStrictMatch() throws Exception { + String content = "{\n \"foo\": \"bar\", \"foo array\":[\"first\",\"second\"] \n}"; + this.request.getBody().write(content.getBytes()); + + MockRestRequestMatchers + .content() + .json("{\n \"foo array\":[\"first\",\"second\"] , \"foo\": \"bar\" \n}", true) + .match(this.request); + } + + @Test(expected = AssertionError.class) + public void testJsonLenientNoMatch() throws Exception { + String content = "{\n \"bar\" : \"foo\" \n}"; + this.request.getBody().write(content.getBytes()); + + MockRestRequestMatchers + .content() + .json("{\n \"foo\" : \"bar\" \n}") + .match(this.request); + MockRestRequestMatchers + .content() + .json("{\n \"foo\" : \"bar\" \n}", false) + .match(this.request); + } + + @Test(expected = AssertionError.class) + public void testJsonStrictNoMatch() throws Exception { + String content = "{\n \"foo array\":[\"first\",\"second\"] , \"someExtraProperty\": \"which is NOT allowed\" \n}"; + this.request.getBody().write(content.getBytes()); + + MockRestRequestMatchers + .content() + .json("{\n \"foo array\":[\"second\",\"first\"] \n}", true) + .match(this.request); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/match/JsonPathRequestMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/client/match/JsonPathRequestMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f51019ec751f70c198552c2a9ead6f95cc4f10fe --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/match/JsonPathRequestMatchersTests.java @@ -0,0 +1,243 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.match; + +import java.io.IOException; + +import org.junit.Test; + +import org.springframework.mock.http.client.MockClientHttpRequest; + +import static org.hamcrest.CoreMatchers.*; + +/** + * Unit tests for {@link JsonPathRequestMatchers}. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class JsonPathRequestMatchersTests { + + private static final String REQUEST_CONTENT = "{" + // + "'str': 'foo', " + // + "'num': 5, " + // + "'bool': true, " + // + "'arr': [42], " + // + "'colorMap': {'red': 'rojo'}, " + // + "'emptyString': '', " + // + "'emptyArray': [], " + // + "'emptyMap': {} " + // + "}"; + + private static final MockClientHttpRequest request = new MockClientHttpRequest(); + + static { + try { + request.getBody().write(REQUEST_CONTENT.getBytes()); + } + catch (IOException e) { + throw new IllegalStateException(e); + } + } + + + @Test(expected = AssertionError.class) + public void valueWithMismatch() throws Exception { + new JsonPathRequestMatchers("$.str").value("bogus").match(request); + } + + @Test + public void valueWithDirectMatch() throws Exception { + new JsonPathRequestMatchers("$.str").value("foo").match(request); + } + + @Test // SPR-14498 + public void valueWithNumberConversion() throws Exception { + new JsonPathRequestMatchers("$.num").value(5.0f).match(request); + } + + @Test + public void valueWithMatcher() throws Exception { + new JsonPathRequestMatchers("$.str").value(equalTo("foo")).match(request); + } + + @Test // SPR-14498 + public void valueWithMatcherAndNumberConversion() throws Exception { + new JsonPathRequestMatchers("$.num").value(equalTo(5.0f), Float.class).match(request); + } + + @Test(expected = AssertionError.class) + public void valueWithMatcherAndMismatch() throws Exception { + new JsonPathRequestMatchers("$.str").value(equalTo("bogus")).match(request); + } + + @Test + public void exists() throws Exception { + new JsonPathRequestMatchers("$.str").exists().match(request); + } + + @Test + public void existsForAnEmptyArray() throws Exception { + new JsonPathRequestMatchers("$.emptyArray").exists().match(request); + } + + @Test + public void existsForAnEmptyMap() throws Exception { + new JsonPathRequestMatchers("$.emptyMap").exists().match(request); + } + + @Test(expected = AssertionError.class) + public void existsNoMatch() throws Exception { + new JsonPathRequestMatchers("$.bogus").exists().match(request); + } + + @Test + public void doesNotExist() throws Exception { + new JsonPathRequestMatchers("$.bogus").doesNotExist().match(request); + } + + @Test(expected = AssertionError.class) + public void doesNotExistNoMatch() throws Exception { + new JsonPathRequestMatchers("$.str").doesNotExist().match(request); + } + + @Test(expected = AssertionError.class) + public void doesNotExistForAnEmptyArray() throws Exception { + new JsonPathRequestMatchers("$.emptyArray").doesNotExist().match(request); + } + + @Test(expected = AssertionError.class) + public void doesNotExistForAnEmptyMap() throws Exception { + new JsonPathRequestMatchers("$.emptyMap").doesNotExist().match(request); + } + + @Test + public void isEmptyForAnEmptyString() throws Exception { + new JsonPathRequestMatchers("$.emptyString").isEmpty().match(request); + } + + @Test + public void isEmptyForAnEmptyArray() throws Exception { + new JsonPathRequestMatchers("$.emptyArray").isEmpty().match(request); + } + + @Test + public void isEmptyForAnEmptyMap() throws Exception { + new JsonPathRequestMatchers("$.emptyMap").isEmpty().match(request); + } + + @Test + public void isNotEmptyForString() throws Exception { + new JsonPathRequestMatchers("$.str").isNotEmpty().match(request); + } + + @Test + public void isNotEmptyForNumber() throws Exception { + new JsonPathRequestMatchers("$.num").isNotEmpty().match(request); + } + + @Test + public void isNotEmptyForBoolean() throws Exception { + new JsonPathRequestMatchers("$.bool").isNotEmpty().match(request); + } + + @Test + public void isNotEmptyForArray() throws Exception { + new JsonPathRequestMatchers("$.arr").isNotEmpty().match(request); + } + + @Test + public void isNotEmptyForMap() throws Exception { + new JsonPathRequestMatchers("$.colorMap").isNotEmpty().match(request); + } + + @Test(expected = AssertionError.class) + public void isNotEmptyForAnEmptyString() throws Exception { + new JsonPathRequestMatchers("$.emptyString").isNotEmpty().match(request); + } + + @Test(expected = AssertionError.class) + public void isNotEmptyForAnEmptyArray() throws Exception { + new JsonPathRequestMatchers("$.emptyArray").isNotEmpty().match(request); + } + + @Test(expected = AssertionError.class) + public void isNotEmptyForAnEmptyMap() throws Exception { + new JsonPathRequestMatchers("$.emptyMap").isNotEmpty().match(request); + } + + @Test + public void isArray() throws Exception { + new JsonPathRequestMatchers("$.arr").isArray().match(request); + } + + @Test + public void isArrayForAnEmptyArray() throws Exception { + new JsonPathRequestMatchers("$.emptyArray").isArray().match(request); + } + + @Test(expected = AssertionError.class) + public void isArrayNoMatch() throws Exception { + new JsonPathRequestMatchers("$.str").isArray().match(request); + } + + @Test + public void isMap() throws Exception { + new JsonPathRequestMatchers("$.colorMap").isMap().match(request); + } + + @Test + public void isMapForAnEmptyMap() throws Exception { + new JsonPathRequestMatchers("$.emptyMap").isMap().match(request); + } + + @Test(expected = AssertionError.class) + public void isMapNoMatch() throws Exception { + new JsonPathRequestMatchers("$.str").isMap().match(request); + } + + @Test + public void isBoolean() throws Exception { + new JsonPathRequestMatchers("$.bool").isBoolean().match(request); + } + + @Test(expected = AssertionError.class) + public void isBooleanNoMatch() throws Exception { + new JsonPathRequestMatchers("$.str").isBoolean().match(request); + } + + @Test + public void isNumber() throws Exception { + new JsonPathRequestMatchers("$.num").isNumber().match(request); + } + + @Test(expected = AssertionError.class) + public void isNumberNoMatch() throws Exception { + new JsonPathRequestMatchers("$.str").isNumber().match(request); + } + + @Test + public void isString() throws Exception { + new JsonPathRequestMatchers("$.str").isString().match(request); + } + + @Test(expected = AssertionError.class) + public void isStringNoMatch() throws Exception { + new JsonPathRequestMatchers("$.arr").isString().match(request); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/match/MockRestRequestMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/client/match/MockRestRequestMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4f2d584f7b2d4bba550b9abafb453ca5b2606f75 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/match/MockRestRequestMatchersTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.match; + +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.http.client.MockClientHttpRequest; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Unit tests for {@link MockRestRequestMatchers}. + * + * @author Craig Walls + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class MockRestRequestMatchersTests { + + private final MockClientHttpRequest request = new MockClientHttpRequest(); + + + @Test + public void requestTo() throws Exception { + this.request.setURI(new URI("http://www.foo.com/bar")); + + MockRestRequestMatchers.requestTo("http://www.foo.com/bar").match(this.request); + } + + @Test // SPR-15819 + public void requestToUriTemplate() throws Exception { + this.request.setURI(new URI("http://www.foo.com/bar")); + + MockRestRequestMatchers.requestToUriTemplate("http://www.foo.com/{bar}", "bar").match(this.request); + } + + @Test + public void requestToNoMatch() throws Exception { + this.request.setURI(new URI("http://www.foo.com/bar")); + + assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.requestTo("http://www.foo.com/wrong").match(this.request)); + } + + @Test + public void requestToContains() throws Exception { + this.request.setURI(new URI("http://www.foo.com/bar")); + + MockRestRequestMatchers.requestTo(containsString("bar")).match(this.request); + } + + @Test + public void method() throws Exception { + this.request.setMethod(HttpMethod.GET); + + MockRestRequestMatchers.method(HttpMethod.GET).match(this.request); + } + + @Test + public void methodNoMatch() throws Exception { + this.request.setMethod(HttpMethod.POST); + + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.method(HttpMethod.GET).match(this.request)); + assertThat(error.getMessage(), containsString("expected: but was:")); + } + + @Test + public void header() throws Exception { + this.request.getHeaders().put("foo", Arrays.asList("bar", "baz")); + + MockRestRequestMatchers.header("foo", "bar", "baz").match(this.request); + } + + @Test + public void headerMissing() throws Exception { + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.header("foo", "bar").match(this.request)); + assertThat(error.getMessage(), containsString("was null")); + } + + @Test + public void headerMissingValue() throws Exception { + this.request.getHeaders().put("foo", Arrays.asList("bar", "baz")); + + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.header("foo", "bad").match(this.request)); + assertThat(error.getMessage(), containsString("expected: but was:")); + } + + @Test + public void headerContains() throws Exception { + this.request.getHeaders().put("foo", Arrays.asList("bar", "baz")); + + MockRestRequestMatchers.header("foo", containsString("ba")).match(this.request); + } + + @Test + public void headerContainsWithMissingHeader() throws Exception { + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.header("foo", containsString("baz")).match(this.request)); + assertThat(error.getMessage(), containsString("but was null")); + } + + @Test + public void headerContainsWithMissingValue() throws Exception { + this.request.getHeaders().put("foo", Arrays.asList("bar", "baz")); + + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.header("foo", containsString("bx")).match(this.request)); + assertThat(error.getMessage(), containsString("was \"bar\"")); + } + + @Test + public void headers() throws Exception { + this.request.getHeaders().put("foo", Arrays.asList("bar", "baz")); + + MockRestRequestMatchers.header("foo", "bar", "baz").match(this.request); + } + + @Test + public void headersWithMissingHeader() throws Exception { + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.header("foo", "bar").match(this.request)); + assertThat(error.getMessage(), containsString("but was null")); + } + + @Test + public void headersWithMissingValue() throws Exception { + this.request.getHeaders().put("foo", Collections.singletonList("bar")); + + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.header("foo", "bar", "baz").match(this.request)); + assertThat(error.getMessage(), containsString("to have at least <2> values")); + } + + @Test + public void queryParam() throws Exception { + this.request.setURI(new URI("http://www.foo.com/a?foo=bar&foo=baz")); + + MockRestRequestMatchers.queryParam("foo", "bar", "baz").match(this.request); + } + + @Test + public void queryParamMissing() throws Exception { + this.request.setURI(new URI("http://www.foo.com/a")); + + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.queryParam("foo", "bar").match(this.request)); + assertThat(error.getMessage(), containsString("but was null")); + } + + @Test + public void queryParamMissingValue() throws Exception { + this.request.setURI(new URI("http://www.foo.com/a?foo=bar&foo=baz")); + + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.queryParam("foo", "bad").match(this.request)); + assertThat(error.getMessage(), containsString("expected: but was:")); + } + + @Test + public void queryParamContains() throws Exception { + this.request.setURI(new URI("http://www.foo.com/a?foo=bar&foo=baz")); + + MockRestRequestMatchers.queryParam("foo", containsString("ba")).match(this.request); + } + + @Test + public void queryParamContainsWithMissingValue() throws Exception { + this.request.setURI(new URI("http://www.foo.com/a?foo=bar&foo=baz")); + + AssertionError error = assertThrows(AssertionError.class, + () -> MockRestRequestMatchers.queryParam("foo", containsString("bx")).match(this.request)); + assertThat(error.getMessage(), containsString("was \"bar\"")); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/match/XpathRequestMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/client/match/XpathRequestMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4273d2d11cfa29b5a95326aaaa425a3676c59254 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/match/XpathRequestMatchersTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.match; + +import java.io.IOException; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.http.client.MockClientHttpRequest; + +/** + * Unit tests for {@link XpathRequestMatchers}. + * + * @author Rossen Stoyanchev + */ +public class XpathRequestMatchersTests { + + private static final String RESPONSE_CONTENT = "111true"; + + private MockClientHttpRequest request; + + + @Before + public void setUp() throws IOException { + this.request = new MockClientHttpRequest(); + this.request.getBody().write(RESPONSE_CONTENT.getBytes()); + } + + + @Test + public void testNodeMatcher() throws Exception { + new XpathRequestMatchers("/foo/bar", null).node(Matchers.notNullValue()).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testNodeMatcherNoMatch() throws Exception { + new XpathRequestMatchers("/foo/bar", null).node(Matchers.nullValue()).match(this.request); + } + + @Test + public void testExists() throws Exception { + new XpathRequestMatchers("/foo/bar", null).exists().match(this.request); + } + + @Test(expected = AssertionError.class) + public void testExistsNoMatch() throws Exception { + new XpathRequestMatchers("/foo/Bar", null).exists().match(this.request); + } + + @Test + public void testDoesNotExist() throws Exception { + new XpathRequestMatchers("/foo/Bar", null).doesNotExist().match(this.request); + } + + @Test(expected = AssertionError.class) + public void testDoesNotExistNoMatch() throws Exception { + new XpathRequestMatchers("/foo/bar", null).doesNotExist().match(this.request); + } + + @Test + public void testNodeCount() throws Exception { + new XpathRequestMatchers("/foo/bar", null).nodeCount(2).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testNodeCountNoMatch() throws Exception { + new XpathRequestMatchers("/foo/bar", null).nodeCount(1).match(this.request); + } + + @Test + public void testString() throws Exception { + new XpathRequestMatchers("/foo/bar[1]", null).string("111").match(this.request); + } + + @Test(expected = AssertionError.class) + public void testStringNoMatch() throws Exception { + new XpathRequestMatchers("/foo/bar[1]", null).string("112").match(this.request); + } + + @Test + public void testNumber() throws Exception { + new XpathRequestMatchers("/foo/bar[1]", null).number(111.0).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testNumberNoMatch() throws Exception { + new XpathRequestMatchers("/foo/bar[1]", null).number(111.1).match(this.request); + } + + @Test + public void testBoolean() throws Exception { + new XpathRequestMatchers("/foo/bar[2]", null).booleanValue(true).match(this.request); + } + + @Test(expected = AssertionError.class) + public void testBooleanNoMatch() throws Exception { + new XpathRequestMatchers("/foo/bar[2]", null).booleanValue(false).match(this.request); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/response/ResponseCreatorsTests.java b/spring-test/src/test/java/org/springframework/test/web/client/response/ResponseCreatorsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d96f6ca64ffc070da288d8372941e59bbabb4ed5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/response/ResponseCreatorsTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.response; + +import java.net.URI; + +import org.junit.Test; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.util.StreamUtils; + +import static org.junit.Assert.*; + +/** + * Tests for the {@link MockRestResponseCreators} static factory methods. + * + * @author Rossen Stoyanchev + */ +public class ResponseCreatorsTests { + + @Test + public void success() throws Exception { + MockClientHttpResponse response = (MockClientHttpResponse) MockRestResponseCreators.withSuccess().createResponse(null); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertTrue(response.getHeaders().isEmpty()); + assertEquals(0, StreamUtils.copyToByteArray(response.getBody()).length); + } + + @Test + public void successWithContent() throws Exception { + DefaultResponseCreator responseCreator = MockRestResponseCreators.withSuccess("foo", MediaType.TEXT_PLAIN); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertEquals(MediaType.TEXT_PLAIN, response.getHeaders().getContentType()); + assertArrayEquals("foo".getBytes(), StreamUtils.copyToByteArray(response.getBody())); + } + + @Test + public void successWithContentWithoutContentType() throws Exception { + DefaultResponseCreator responseCreator = MockRestResponseCreators.withSuccess("foo", null); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertNull(response.getHeaders().getContentType()); + assertArrayEquals("foo".getBytes(), StreamUtils.copyToByteArray(response.getBody())); + } + + @Test + public void created() throws Exception { + URI location = new URI("/foo"); + DefaultResponseCreator responseCreator = MockRestResponseCreators.withCreatedEntity(location); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.CREATED, response.getStatusCode()); + assertEquals(location, response.getHeaders().getLocation()); + assertEquals(0, StreamUtils.copyToByteArray(response.getBody()).length); + } + + @Test + public void noContent() throws Exception { + DefaultResponseCreator responseCreator = MockRestResponseCreators.withNoContent(); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.NO_CONTENT, response.getStatusCode()); + assertTrue(response.getHeaders().isEmpty()); + assertEquals(0, StreamUtils.copyToByteArray(response.getBody()).length); + } + + @Test + public void badRequest() throws Exception { + DefaultResponseCreator responseCreator = MockRestResponseCreators.withBadRequest(); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode()); + assertTrue(response.getHeaders().isEmpty()); + assertEquals(0, StreamUtils.copyToByteArray(response.getBody()).length); + } + + @Test + public void unauthorized() throws Exception { + DefaultResponseCreator responseCreator = MockRestResponseCreators.withUnauthorizedRequest(); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode()); + assertTrue(response.getHeaders().isEmpty()); + assertEquals(0, StreamUtils.copyToByteArray(response.getBody()).length); + } + + @Test + public void serverError() throws Exception { + DefaultResponseCreator responseCreator = MockRestResponseCreators.withServerError(); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode()); + assertTrue(response.getHeaders().isEmpty()); + assertEquals(0, StreamUtils.copyToByteArray(response.getBody()).length); + } + + @Test + public void withStatus() throws Exception { + DefaultResponseCreator responseCreator = MockRestResponseCreators.withStatus(HttpStatus.FORBIDDEN); + MockClientHttpResponse response = (MockClientHttpResponse) responseCreator.createResponse(null); + + assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); + assertTrue(response.getHeaders().isEmpty()); + assertEquals(0, StreamUtils.copyToByteArray(response.getBody()).length); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..db2e299a2ae91936fe8219e93fe7d64d02f8b41a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.samples; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.client.MockMvcClientHttpRequestFactory; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import static org.junit.Assert.assertEquals; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Tests that use a {@link RestTemplate} configured with a + * {@link MockMvcClientHttpRequestFactory} that is in turn configured with a + * {@link MockMvc} instance that uses a {@link WebApplicationContext} loaded by + * the TestContext framework. + * + * @author Rossen Stoyanchev + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration +@ContextConfiguration +public class MockMvcClientHttpRequestFactoryTests { + + @Autowired + private WebApplicationContext wac; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).alwaysExpect(status().isOk()).build(); + } + + @Test + public void test() throws Exception { + RestTemplate template = new RestTemplate(new MockMvcClientHttpRequestFactory(this.mockMvc)); + String result = template.getForObject("/foo", String.class); + assertEquals("bar", result); + } + + @Test + @SuppressWarnings("deprecation") + public void testAsyncTemplate() throws Exception { + org.springframework.web.client.AsyncRestTemplate template = new org.springframework.web.client.AsyncRestTemplate( + new MockMvcClientHttpRequestFactory(this.mockMvc)); + ListenableFuture> entity = template.getForEntity("/foo", String.class); + assertEquals("bar", entity.get().getBody()); + } + + + @EnableWebMvc + @Configuration + @ComponentScan(basePackageClasses=MockMvcClientHttpRequestFactoryTests.class) + static class MyWebConfig implements WebMvcConfigurer { + } + + @Controller + static class MyController { + + @RequestMapping(value="/foo", method=RequestMethod.GET) + @ResponseBody + public String handle() { + return "bar"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/SampleAsyncTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/SampleAsyncTests.java new file mode 100644 index 0000000000000000000000000000000000000000..78c93cc1f9ef348d6458c7a6b4ed3f99a2c4bce1 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/SampleAsyncTests.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.samples; + +import org.junit.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.test.web.Person; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.util.concurrent.ListenableFuture; + +import static org.junit.Assert.*; +import static org.springframework.test.web.client.ExpectedCount.manyTimes; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Examples to demonstrate writing client-side REST tests with Spring MVC Test. + * While the tests in this class invoke the RestTemplate directly, in actual + * tests the RestTemplate may likely be invoked indirectly, i.e. through client + * code. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +@SuppressWarnings("deprecation") +public class SampleAsyncTests { + + private final org.springframework.web.client.AsyncRestTemplate restTemplate = new org.springframework.web.client.AsyncRestTemplate(); + + private final MockRestServiceServer mockServer = MockRestServiceServer.createServer(this.restTemplate); + + + @Test + public void performGet() throws Exception { + + String responseBody = "{\"name\" : \"Ludwig van Beethoven\", \"someDouble\" : \"1.6035\"}"; + + this.mockServer.expect(requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + @SuppressWarnings("unused") + ListenableFuture> ludwig = + this.restTemplate.getForEntity("/composers/{id}", Person.class, 42); + + // We are only validating the request. The response is mocked out. + // person.getName().equals("Ludwig van Beethoven") + // person.getDouble().equals(1.6035) + + this.mockServer.verify(); + } + + @Test + public void performGetManyTimes() throws Exception { + + String responseBody = "{\"name\" : \"Ludwig van Beethoven\", \"someDouble\" : \"1.6035\"}"; + + this.mockServer.expect(manyTimes(), requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + @SuppressWarnings("unused") + ListenableFuture> ludwig = + this.restTemplate.getForEntity("/composers/{id}", Person.class, 42); + + // We are only validating the request. The response is mocked out. + // person.getName().equals("Ludwig van Beethoven") + // person.getDouble().equals(1.6035) + + this.restTemplate.getForEntity("/composers/{id}", Person.class, 42); + this.restTemplate.getForEntity("/composers/{id}", Person.class, 42); + this.restTemplate.getForEntity("/composers/{id}", Person.class, 42); + this.restTemplate.getForEntity("/composers/{id}", Person.class, 42); + + this.mockServer.verify(); + } + + @Test + public void performGetWithResponseBodyFromFile() throws Exception { + + Resource responseBody = new ClassPathResource("ludwig.json", this.getClass()); + + this.mockServer.expect(requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + @SuppressWarnings("unused") + ListenableFuture> ludwig = + this.restTemplate.getForEntity("/composers/{id}", Person.class, 42); + + // hotel.getId() == 42 + // hotel.getName().equals("Holiday Inn") + + this.mockServer.verify(); + } + + @Test + public void verify() { + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("1", MediaType.TEXT_PLAIN)); + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("2", MediaType.TEXT_PLAIN)); + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("4", MediaType.TEXT_PLAIN)); + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("8", MediaType.TEXT_PLAIN)); + + @SuppressWarnings("unused") + ListenableFuture> result = this.restTemplate.getForEntity("/number", String.class); + // result == "1" + + result = this.restTemplate.getForEntity("/number", String.class); + // result == "2" + + try { + this.mockServer.verify(); + } + catch (AssertionError error) { + assertTrue(error.getMessage(), error.getMessage().contains("2 unsatisfied expectation(s)")); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/SampleTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/SampleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..32d955eb9a4d6bb28cdd7dc2e80835bc0d8383f4 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/SampleTests.java @@ -0,0 +1,225 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.client.samples; + +import java.io.IOException; +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpRequestExecution; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.test.web.Person; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.util.FileCopyUtils; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; +import static org.springframework.test.web.client.ExpectedCount.manyTimes; +import static org.springframework.test.web.client.ExpectedCount.never; +import static org.springframework.test.web.client.ExpectedCount.once; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + +/** + * Examples to demonstrate writing client-side REST tests with Spring MVC Test. + * While the tests in this class invoke the RestTemplate directly, in actual + * tests the RestTemplate may likely be invoked indirectly, i.e. through client + * code. + * + * @author Rossen Stoyanchev + */ +public class SampleTests { + + private MockRestServiceServer mockServer; + + private RestTemplate restTemplate; + + @Before + public void setup() { + this.restTemplate = new RestTemplate(); + this.mockServer = MockRestServiceServer.bindTo(this.restTemplate).ignoreExpectOrder(true).build(); + } + + @Test + public void performGet() { + + String responseBody = "{\"name\" : \"Ludwig van Beethoven\", \"someDouble\" : \"1.6035\"}"; + + this.mockServer.expect(requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + @SuppressWarnings("unused") + Person ludwig = this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + + // We are only validating the request. The response is mocked out. + // hotel.getId() == 42 + // hotel.getName().equals("Holiday Inn") + + this.mockServer.verify(); + } + + @Test + public void performGetManyTimes() { + + String responseBody = "{\"name\" : \"Ludwig van Beethoven\", \"someDouble\" : \"1.6035\"}"; + + this.mockServer.expect(manyTimes(), requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + @SuppressWarnings("unused") + Person ludwig = this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + + // We are only validating the request. The response is mocked out. + // hotel.getId() == 42 + // hotel.getName().equals("Holiday Inn") + + this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + + this.mockServer.verify(); + } + + @Test + public void expectNever() { + + String responseBody = "{\"name\" : \"Ludwig van Beethoven\", \"someDouble\" : \"1.6035\"}"; + + this.mockServer.expect(once(), requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + this.mockServer.expect(never(), requestTo("/composers/43")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + + this.mockServer.verify(); + } + + @Test(expected = AssertionError.class) + public void expectNeverViolated() { + + String responseBody = "{\"name\" : \"Ludwig van Beethoven\", \"someDouble\" : \"1.6035\"}"; + + this.mockServer.expect(once(), requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + this.mockServer.expect(never(), requestTo("/composers/43")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + this.restTemplate.getForObject("/composers/{id}", Person.class, 43); + } + + @Test + public void performGetWithResponseBodyFromFile() { + + Resource responseBody = new ClassPathResource("ludwig.json", this.getClass()); + + this.mockServer.expect(requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(responseBody, MediaType.APPLICATION_JSON)); + + @SuppressWarnings("unused") + Person ludwig = this.restTemplate.getForObject("/composers/{id}", Person.class, 42); + + // hotel.getId() == 42 + // hotel.getName().equals("Holiday Inn") + + this.mockServer.verify(); + } + + @Test + public void verify() { + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("1", MediaType.TEXT_PLAIN)); + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("2", MediaType.TEXT_PLAIN)); + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("4", MediaType.TEXT_PLAIN)); + + this.mockServer.expect(requestTo("/number")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess("8", MediaType.TEXT_PLAIN)); + + @SuppressWarnings("unused") + String result1 = this.restTemplate.getForObject("/number", String.class); + // result1 == "1" + + @SuppressWarnings("unused") + String result2 = this.restTemplate.getForObject("/number", String.class); + // result == "2" + + try { + this.mockServer.verify(); + } + catch (AssertionError error) { + assertTrue(error.getMessage(), error.getMessage().contains("2 unsatisfied expectation(s)")); + } + } + + @Test // SPR-14694 + public void repeatedAccessToResponseViaResource() { + + Resource resource = new ClassPathResource("ludwig.json", this.getClass()); + + RestTemplate restTemplate = new RestTemplate(); + restTemplate.setInterceptors(Collections.singletonList(new ContentInterceptor(resource))); + + MockRestServiceServer mockServer = MockRestServiceServer.bindTo(restTemplate) + .ignoreExpectOrder(true) + .bufferContent() // enable repeated reads of response body + .build(); + + mockServer.expect(requestTo("/composers/42")).andExpect(method(HttpMethod.GET)) + .andRespond(withSuccess(resource, MediaType.APPLICATION_JSON)); + + restTemplate.getForObject("/composers/{id}", Person.class, 42); + + mockServer.verify(); + } + + + private static class ContentInterceptor implements ClientHttpRequestInterceptor { + + private final Resource resource; + + + private ContentInterceptor(Resource resource) { + this.resource = resource; + } + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, + ClientHttpRequestExecution execution) throws IOException { + + ClientHttpResponse response = execution.execute(request, body); + byte[] expected = FileCopyUtils.copyToByteArray(this.resource.getInputStream()); + byte[] actual = FileCopyUtils.copyToByteArray(response.getBody()); + assertEquals(new String(expected), new String(actual)); + return response; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/ContentRequestMatchersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/ContentRequestMatchersIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c20e81696c9939b28e20361eeb7a23be73e3d255 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/ContentRequestMatchersIntegrationTests.java @@ -0,0 +1,108 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.samples.matchers; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.test.web.Person; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestTemplate; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Examples of defining expectations on request content and content type. + * + * @author Rossen Stoyanchev + * @see JsonPathRequestMatchersIntegrationTests + * @see XmlContentRequestMatchersIntegrationTests + * @see XpathRequestMatchersIntegrationTests + */ +public class ContentRequestMatchersIntegrationTests { + + private MockRestServiceServer mockServer; + + private RestTemplate restTemplate; + + + @Before + public void setup() { + List> converters = new ArrayList<>(); + converters.add(new StringHttpMessageConverter()); + converters.add(new MappingJackson2HttpMessageConverter()); + + this.restTemplate = new RestTemplate(); + this.restTemplate.setMessageConverters(converters); + + this.mockServer = MockRestServiceServer.createServer(this.restTemplate); + } + + + @Test + public void contentType() throws Exception { + this.mockServer.expect(content().contentType("application/json;charset=UTF-8")).andRespond(withSuccess()); + executeAndVerify(new Person()); + } + + @Test + public void contentTypeNoMatch() throws Exception { + this.mockServer.expect(content().contentType("application/json;charset=UTF-8")).andRespond(withSuccess()); + try { + executeAndVerify("foo"); + } + catch (AssertionError error) { + String message = error.getMessage(); + assertTrue(message, message.startsWith("Content type expected:")); + } + } + + @Test + public void contentAsString() throws Exception { + this.mockServer.expect(content().string("foo")).andRespond(withSuccess()); + executeAndVerify("foo"); + } + + @Test + public void contentStringStartsWith() throws Exception { + this.mockServer.expect(content().string(startsWith("foo"))).andRespond(withSuccess()); + executeAndVerify("foo123"); + } + + @Test + public void contentAsBytes() throws Exception { + this.mockServer.expect(content().bytes("foo".getBytes())).andRespond(withSuccess()); + executeAndVerify("foo"); + } + + private void executeAndVerify(Object body) throws URISyntaxException { + this.restTemplate.put(new URI("/foo"), body); + this.mockServer.verify(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/HeaderRequestMatchersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/HeaderRequestMatchersIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a1d79576dbe5c99a2336d57e5498a22ae847ca54 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/HeaderRequestMatchersIntegrationTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.samples.matchers; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.test.web.Person; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestTemplate; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Examples of defining expectations on request headers. + * + * @author Rossen Stoyanchev + */ +public class HeaderRequestMatchersIntegrationTests { + + private static final String RESPONSE_BODY = "{\"name\" : \"Ludwig van Beethoven\", \"someDouble\" : \"1.6035\"}"; + + + private MockRestServiceServer mockServer; + + private RestTemplate restTemplate; + + + @Before + public void setup() { + List> converters = new ArrayList<>(); + converters.add(new StringHttpMessageConverter()); + converters.add(new MappingJackson2HttpMessageConverter()); + + this.restTemplate = new RestTemplate(); + this.restTemplate.setMessageConverters(converters); + + this.mockServer = MockRestServiceServer.createServer(this.restTemplate); + } + + + @Test + public void testString() throws Exception { + this.mockServer.expect(requestTo("/person/1")) + .andExpect(header("Accept", "application/json, application/*+json")) + .andRespond(withSuccess(RESPONSE_BODY, MediaType.APPLICATION_JSON)); + + executeAndVerify(); + } + + @Test + public void testStringContains() throws Exception { + this.mockServer.expect(requestTo("/person/1")) + .andExpect(header("Accept", containsString("json"))) + .andRespond(withSuccess(RESPONSE_BODY, MediaType.APPLICATION_JSON)); + + executeAndVerify(); + } + + private void executeAndVerify() throws URISyntaxException { + this.restTemplate.getForObject(new URI("/person/1"), Person.class); + this.mockServer.verify(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/JsonPathRequestMatchersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/JsonPathRequestMatchersIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6d4715357206cf3e7ef3647a5b6d9924cae48f00 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/JsonPathRequestMatchersIntegrationTests.java @@ -0,0 +1,179 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.samples.matchers; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Collections; + +import org.junit.Test; + +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.test.web.Person; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestTemplate; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Examples of defining expectations on JSON request content with + * JsonPath expressions. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @see org.springframework.test.web.client.match.JsonPathRequestMatchers + * @see org.springframework.test.web.client.match.JsonPathRequestMatchersTests + */ +public class JsonPathRequestMatchersIntegrationTests { + + private static final MultiValueMap people = new LinkedMultiValueMap<>(); + + static { + people.add("composers", new Person("Johann Sebastian Bach")); + people.add("composers", new Person("Johannes Brahms")); + people.add("composers", new Person("Edvard Grieg")); + people.add("composers", new Person("Robert Schumann")); + people.add("performers", new Person("Vladimir Ashkenazy")); + people.add("performers", new Person("Yehudi Menuhin")); + } + + + private final RestTemplate restTemplate = + new RestTemplate(Collections.singletonList(new MappingJackson2HttpMessageConverter())); + + private final MockRestServiceServer mockServer = MockRestServiceServer.createServer(this.restTemplate); + + + @Test + public void exists() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers[0]").exists()) + .andExpect(jsonPath("$.composers[1]").exists()) + .andExpect(jsonPath("$.composers[2]").exists()) + .andExpect(jsonPath("$.composers[3]").exists()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void doesNotExist() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers[?(@.name == 'Edvard Grieeeeeeg')]").doesNotExist()) + .andExpect(jsonPath("$.composers[?(@.name == 'Robert Schuuuuuuman')]").doesNotExist()) + .andExpect(jsonPath("$.composers[4]").doesNotExist()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void value() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers[0].name").value("Johann Sebastian Bach")) + .andExpect(jsonPath("$.performers[1].name").value("Yehudi Menuhin")) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void hamcrestMatchers() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers[0].name").value(equalTo("Johann Sebastian Bach"))) + .andExpect(jsonPath("$.performers[1].name").value(equalTo("Yehudi Menuhin"))) + .andExpect(jsonPath("$.composers[0].name", startsWith("Johann"))) + .andExpect(jsonPath("$.performers[0].name", endsWith("Ashkenazy"))) + .andExpect(jsonPath("$.performers[1].name", containsString("di Me"))) + .andExpect(jsonPath("$.composers[1].name", isIn(Arrays.asList("Johann Sebastian Bach", "Johannes Brahms")))) + .andExpect(jsonPath("$.composers[:3].name", hasItem("Johannes Brahms"))) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void hamcrestMatchersWithParameterizedJsonPaths() throws Exception { + String composerName = "$.composers[%s].name"; + String performerName = "$.performers[%s].name"; + + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath(composerName, 0).value(startsWith("Johann"))) + .andExpect(jsonPath(performerName, 0).value(endsWith("Ashkenazy"))) + .andExpect(jsonPath(performerName, 1).value(containsString("di Me"))) + .andExpect(jsonPath(composerName, 1).value(isIn(Arrays.asList("Johann Sebastian Bach", "Johannes Brahms")))) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void isArray() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers").isArray()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void isString() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers[0].name").isString()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void isNumber() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers[0].someDouble").isNumber()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void isBoolean() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.composers[0].someBoolean").isBoolean()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + private void executeAndVerify() throws URISyntaxException { + this.restTemplate.put(new URI("/composers"), people); + this.mockServer.verify(); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/XmlContentRequestMatchersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/XmlContentRequestMatchersIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3453630bf98d837b2d0afd11515f4c9b2eb9d996 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/XmlContentRequestMatchersIntegrationTests.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.samples.matchers; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.xml.bind.annotation.XmlAccessType; +import javax.xml.bind.annotation.XmlAccessorType; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; +import javax.xml.bind.annotation.XmlRootElement; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; +import org.springframework.test.web.Person; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestTemplate; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Examples of defining expectations on XML request content with XMLUnit. + * + * @author Rossen Stoyanchev + * @see ContentRequestMatchersIntegrationTests + * @see XpathRequestMatchersIntegrationTests + */ +public class XmlContentRequestMatchersIntegrationTests { + + private static final String PEOPLE_XML = + "" + + "" + + "Johann Sebastian Bachfalse21.0" + + "Johannes Brahmsfalse0.0025" + + "Edvard Griegfalse1.6035" + + "Robert SchumannfalseNaN" + + ""; + + + private MockRestServiceServer mockServer; + + private RestTemplate restTemplate; + + private PeopleWrapper people; + + + @Before + public void setup() { + List composers = Arrays.asList( + new Person("Johann Sebastian Bach").setSomeDouble(21), + new Person("Johannes Brahms").setSomeDouble(.0025), + new Person("Edvard Grieg").setSomeDouble(1.6035), + new Person("Robert Schumann").setSomeDouble(Double.NaN)); + + this.people = new PeopleWrapper(composers); + + List> converters = new ArrayList<>(); + converters.add(new Jaxb2RootElementHttpMessageConverter()); + + this.restTemplate = new RestTemplate(); + this.restTemplate.setMessageConverters(converters); + + this.mockServer = MockRestServiceServer.createServer(this.restTemplate); + } + + @Test + public void testXmlEqualTo() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(content().xml(PEOPLE_XML)) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void testHamcrestNodeMatcher() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(content().node(hasXPath("/people/composers/composer[1]"))) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + private void executeAndVerify() throws URISyntaxException { + this.restTemplate.put(new URI("/composers"), this.people); + this.mockServer.verify(); + } + + + @SuppressWarnings("unused") + @XmlRootElement(name="people") + @XmlAccessorType(XmlAccessType.FIELD) + private static class PeopleWrapper { + + @XmlElementWrapper(name="composers") + @XmlElement(name="composer") + private List composers; + + public PeopleWrapper() { + } + + public PeopleWrapper(List composers) { + this.composers = composers; + } + + public List getComposers() { + return this.composers; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/XpathRequestMatchersIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/XpathRequestMatchersIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3469cc9966793eafd830099deff4b7b0abcf1e61 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/matchers/XpathRequestMatchersIntegrationTests.java @@ -0,0 +1,225 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.client.samples.matchers; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import javax.xml.bind.annotation.XmlAccessType; +import javax.xml.bind.annotation.XmlAccessorType; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; +import javax.xml.bind.annotation.XmlRootElement; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; +import org.springframework.test.web.Person; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestTemplate; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.*; + +/** + * Examples of defining expectations on XML request content with XPath expressions. + * + * @author Rossen Stoyanchev + * @see ContentRequestMatchersIntegrationTests + * @see XmlContentRequestMatchersIntegrationTests + */ +public class XpathRequestMatchersIntegrationTests { + + private static final Map NS = + Collections.singletonMap("ns", "https://example.org/music/people"); + + + private MockRestServiceServer mockServer; + + private RestTemplate restTemplate; + + private PeopleWrapper people; + + + @Before + public void setup() { + List composers = Arrays.asList( + new Person("Johann Sebastian Bach").setSomeDouble(21), + new Person("Johannes Brahms").setSomeDouble(.0025), + new Person("Edvard Grieg").setSomeDouble(1.6035), + new Person("Robert Schumann").setSomeDouble(Double.NaN)); + + List performers = Arrays.asList( + new Person("Vladimir Ashkenazy").setSomeBoolean(false), + new Person("Yehudi Menuhin").setSomeBoolean(true)); + + this.people = new PeopleWrapper(composers, performers); + + List> converters = new ArrayList<>(); + converters.add(new Jaxb2RootElementHttpMessageConverter()); + + this.restTemplate = new RestTemplate(); + this.restTemplate.setMessageConverters(converters); + + this.mockServer = MockRestServiceServer.createServer(this.restTemplate); + } + + + @Test + public void testExists() throws Exception { + String composer = "/ns:people/composers/composer[%s]"; + String performer = "/ns:people/performers/performer[%s]"; + + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(xpath(composer, NS, 1).exists()) + .andExpect(xpath(composer, NS, 2).exists()) + .andExpect(xpath(composer, NS, 3).exists()) + .andExpect(xpath(composer, NS, 4).exists()) + .andExpect(xpath(performer, NS, 1).exists()) + .andExpect(xpath(performer, NS, 2).exists()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void testDoesNotExist() throws Exception { + String composer = "/ns:people/composers/composer[%s]"; + String performer = "/ns:people/performers/performer[%s]"; + + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(xpath(composer, NS, 0).doesNotExist()) + .andExpect(xpath(composer, NS, 5).doesNotExist()) + .andExpect(xpath(performer, NS, 0).doesNotExist()) + .andExpect(xpath(performer, NS, 3).doesNotExist()) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void testString() throws Exception { + String composerName = "/ns:people/composers/composer[%s]/name"; + String performerName = "/ns:people/performers/performer[%s]/name"; + + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(xpath(composerName, NS, 1).string("Johann Sebastian Bach")) + .andExpect(xpath(composerName, NS, 2).string("Johannes Brahms")) + .andExpect(xpath(composerName, NS, 3).string("Edvard Grieg")) + .andExpect(xpath(composerName, NS, 4).string("Robert Schumann")) + .andExpect(xpath(performerName, NS, 1).string("Vladimir Ashkenazy")) + .andExpect(xpath(performerName, NS, 2).string("Yehudi Menuhin")) + .andExpect(xpath(composerName, NS, 1).string(equalTo("Johann Sebastian Bach"))) // Hamcrest.. + .andExpect(xpath(composerName, NS, 1).string(startsWith("Johann"))) // Hamcrest.. + .andExpect(xpath(composerName, NS, 1).string(notNullValue())) // Hamcrest.. + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void testNumber() throws Exception { + String composerDouble = "/ns:people/composers/composer[%s]/someDouble"; + + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(xpath(composerDouble, NS, 1).number(21d)) + .andExpect(xpath(composerDouble, NS, 2).number(.0025)) + .andExpect(xpath(composerDouble, NS, 3).number(1.6035)) + .andExpect(xpath(composerDouble, NS, 4).number(Double.NaN)) + .andExpect(xpath(composerDouble, NS, 1).number(equalTo(21d))) // Hamcrest.. + .andExpect(xpath(composerDouble, NS, 3).number(closeTo(1.6, .01))) // Hamcrest.. + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void testBoolean() throws Exception { + + String performerBooleanValue = "/ns:people/performers/performer[%s]/someBoolean"; + + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(xpath(performerBooleanValue, NS, 1).booleanValue(false)) + .andExpect(xpath(performerBooleanValue, NS, 2).booleanValue(true)) + .andRespond(withSuccess()); + + executeAndVerify(); + } + + @Test + public void testNodeCount() throws Exception { + this.mockServer.expect(requestTo("/composers")) + .andExpect(content().contentType("application/xml")) + .andExpect(xpath("/ns:people/composers/composer", NS).nodeCount(4)) + .andExpect(xpath("/ns:people/performers/performer", NS).nodeCount(2)) + .andExpect(xpath("/ns:people/composers/composer", NS).nodeCount(equalTo(4))) // Hamcrest.. + .andExpect(xpath("/ns:people/performers/performer", NS).nodeCount(equalTo(2))) // Hamcrest.. + .andRespond(withSuccess()); + + executeAndVerify(); + } + + private void executeAndVerify() throws URISyntaxException { + this.restTemplate.put(new URI("/composers"), this.people); + this.mockServer.verify(); + } + + + @SuppressWarnings("unused") + @XmlRootElement(name="people", namespace="https://example.org/music/people") + @XmlAccessorType(XmlAccessType.FIELD) + private static class PeopleWrapper { + + @XmlElementWrapper(name="composers") + @XmlElement(name="composer") + private List composers; + + @XmlElementWrapper(name="performers") + @XmlElement(name="performer") + private List performers; + + public PeopleWrapper() { + } + + public PeopleWrapper(List composers, List performers) { + this.composers = composers; + this.performers = performers; + } + + public List getComposers() { + return this.composers; + } + + public List getPerformers() { + return this.performers; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/ApplicationContextSpecTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/ApplicationContextSpecTests.java new file mode 100644 index 0000000000000000000000000000000000000000..20dcdd649cf0fcddfe12805a62ebc58e2155a0a2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/ApplicationContextSpecTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.reactive.server; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.server.MockWebSession; +import org.springframework.web.reactive.config.EnableWebFlux; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; +import org.springframework.web.server.session.WebSessionManager; + +/** + * Unit tests with {@link ApplicationContextSpec}. + * @author Rossen Stoyanchev + */ +public class ApplicationContextSpecTests { + + + @Test // SPR-17094 + public void sessionManagerBean() { + ApplicationContext context = new AnnotationConfigApplicationContext(WebConfig.class); + ApplicationContextSpec spec = new ApplicationContextSpec(context); + WebTestClient testClient = spec.configureClient().build(); + + for (int i=0; i < 2; i++) { + testClient.get().uri("/sessionClassName") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("MockWebSession"); + } + } + + + @Configuration + @EnableWebFlux + static class WebConfig { + + @Bean + public RouterFunction handler() { + return RouterFunctions.route() + .GET("/sessionClassName", request -> + request.session().flatMap(session -> { + String className = session.getClass().getSimpleName(); + return ServerResponse.ok().syncBody(className); + })) + .build(); + } + + @Bean + public WebSessionManager webSessionManager() { + MockWebSession session = new MockWebSession(); + return exchange -> Mono.just(session); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultControllerSpecTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultControllerSpecTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3a0e6b086a22d42f41ca6cb888453289736012fe --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultControllerSpecTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server; + +import java.util.function.Consumer; + +import org.junit.Test; + +import org.springframework.format.FormatterRegistry; +import org.springframework.http.ResponseEntity; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; +import org.springframework.web.reactive.config.CorsRegistry; +import org.springframework.web.reactive.config.PathMatchConfigurer; +import org.springframework.web.reactive.config.ViewResolverRegistry; +import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link DefaultControllerSpec}. + * @author Rossen Stoyanchev + */ +public class DefaultControllerSpecTests { + + @Test + public void controller() { + new DefaultControllerSpec(new MyController()).build() + .get().uri("/") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("Success"); + } + + @Test + public void controllerAdvice() { + new DefaultControllerSpec(new MyController()) + .controllerAdvice(new MyControllerAdvice()) + .build() + .get().uri("/exception") + .exchange() + .expectStatus().isBadRequest() + .expectBody(String.class).isEqualTo("Handled exception"); + } + + @Test + public void controllerAdviceWithClassArgument() { + new DefaultControllerSpec(MyController.class) + .controllerAdvice(MyControllerAdvice.class) + .build() + .get().uri("/exception") + .exchange() + .expectStatus().isBadRequest() + .expectBody(String.class).isEqualTo("Handled exception"); + } + + @Test + public void configurerConsumers() { + TestConsumer argumentResolverConsumer = new TestConsumer<>(); + TestConsumer contenTypeResolverConsumer = new TestConsumer<>(); + TestConsumer corsRegistryConsumer = new TestConsumer<>(); + TestConsumer formatterConsumer = new TestConsumer<>(); + TestConsumer codecsConsumer = new TestConsumer<>(); + TestConsumer pathMatchingConsumer = new TestConsumer<>(); + TestConsumer viewResolverConsumer = new TestConsumer<>(); + + new DefaultControllerSpec(new MyController()) + .argumentResolvers(argumentResolverConsumer) + .contentTypeResolver(contenTypeResolverConsumer) + .corsMappings(corsRegistryConsumer) + .formatters(formatterConsumer) + .httpMessageCodecs(codecsConsumer) + .pathMatching(pathMatchingConsumer) + .viewResolvers(viewResolverConsumer) + .build(); + + assertNotNull(argumentResolverConsumer.getValue()); + assertNotNull(contenTypeResolverConsumer.getValue()); + assertNotNull(corsRegistryConsumer.getValue()); + assertNotNull(formatterConsumer.getValue()); + assertNotNull(codecsConsumer.getValue()); + assertNotNull(pathMatchingConsumer.getValue()); + assertNotNull(viewResolverConsumer.getValue()); + + } + + + @RestController + private static class MyController { + + @GetMapping("/") + public String handle() { + return "Success"; + } + + @GetMapping("/exception") + public void handleWithError() { + throw new IllegalStateException(); + } + + } + + + @ControllerAdvice + private static class MyControllerAdvice { + + @ExceptionHandler + public ResponseEntity handle(IllegalStateException ex) { + return ResponseEntity.status(400).body("Handled exception"); + } + } + + + private static class TestConsumer implements Consumer { + + private T value; + + public T getValue() { + return this.value; + } + + @Override + public void accept(T t) { + this.value = t; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultRouterFunctionSpecTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultRouterFunctionSpecTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e17f289c8f47b4382de9bd8b34db8af7e7d1ab2c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/DefaultRouterFunctionSpecTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.reactive.server; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.web.reactive.function.server.HandlerStrategies; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Unit tests for {@link DefaultRouterFunctionSpec}. + * @author Rossen Stoyanchev + */ +public class DefaultRouterFunctionSpecTests { + + @Test + public void webFilter() { + + RouterFunction routerFunction = RouterFunctions.route() + .GET("/", request -> ServerResponse.ok().build()) + .build(); + + new DefaultRouterFunctionSpec(routerFunction) + .handlerStrategies(HandlerStrategies.builder() + .webFilter((exchange, chain) -> { + exchange.getResponse().getHeaders().set("foo", "123"); + return chain.filter(exchange); + }) + .build()) + .build() + .get() + .uri("/") + .exchange() + .expectStatus().isOk() + .expectHeader().valueEquals("foo", "123"); + } + + @Test + public void exceptionHandler() { + + RouterFunction routerFunction = RouterFunctions.route() + .GET("/error", request -> Mono.error(new IllegalStateException("boo"))) + .build(); + + new DefaultRouterFunctionSpec(routerFunction) + .handlerStrategies(HandlerStrategies.builder() + .exceptionHandler((exchange, ex) -> { + exchange.getResponse().setStatusCode(HttpStatus.BAD_REQUEST); + return Mono.empty(); + }) + .build()) + .build() + .get() + .uri("/error") + .exchange() + .expectStatus().isBadRequest(); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/HeaderAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/HeaderAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..22282033df8938a42512768271ee8bebcf10b9d2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/HeaderAssertionTests.java @@ -0,0 +1,265 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server; + +import java.net.URI; +import java.time.Duration; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import reactor.core.publisher.MonoProcessor; + +import org.springframework.http.CacheControl; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.http.client.reactive.MockClientHttpRequest; +import org.springframework.mock.http.client.reactive.MockClientHttpResponse; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link HeaderAssertions}. + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class HeaderAssertionTests { + + @Test + public void valueEquals() { + HttpHeaders headers = new HttpHeaders(); + headers.add("foo", "bar"); + HeaderAssertions assertions = headerAssertions(headers); + + // Success + assertions.valueEquals("foo", "bar"); + + try { + assertions.valueEquals("what?!", "bar"); + fail("Missing header expected"); + } + catch (AssertionError error) { + // expected + } + + try { + assertions.valueEquals("foo", "what?!"); + fail("Wrong value expected"); + } + catch (AssertionError error) { + // expected + } + + try { + assertions.valueEquals("foo", "bar", "what?!"); + fail("Wrong # of values expected"); + } + catch (AssertionError error) { + // expected + } + } + + @Test + public void valueEqualsWithMultipleValues() { + HttpHeaders headers = new HttpHeaders(); + headers.add("foo", "bar"); + headers.add("foo", "baz"); + HeaderAssertions assertions = headerAssertions(headers); + + // Success + assertions.valueEquals("foo", "bar", "baz"); + + try { + assertions.valueEquals("foo", "bar", "what?!"); + fail("Wrong value expected"); + } + catch (AssertionError error) { + // expected + } + + try { + assertions.valueEquals("foo", "bar"); + fail("Too few values expected"); + } + catch (AssertionError error) { + // expected + } + + } + + @Test + public void valueMatches() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON_UTF8); + HeaderAssertions assertions = headerAssertions(headers); + + // Success + assertions.valueMatches("Content-Type", ".*UTF-8.*"); + + try { + assertions.valueMatches("Content-Type", ".*ISO-8859-1.*"); + fail("Wrong pattern expected"); + } + catch (AssertionError error) { + Throwable cause = error.getCause(); + assertNotNull(cause); + assertEquals("Response header 'Content-Type'=[application/json;charset=UTF-8] " + + "does not match [.*ISO-8859-1.*]", cause.getMessage()); + } + } + + @Test + public void valueMatcher() { + HttpHeaders headers = new HttpHeaders(); + headers.add("foo", "bar"); + HeaderAssertions assertions = headerAssertions(headers); + + assertions.value("foo", containsString("a")); + } + + @Test + public void exists() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON_UTF8); + HeaderAssertions assertions = headerAssertions(headers); + + // Success + assertions.exists("Content-Type"); + + try { + assertions.exists("Framework"); + fail("Header should not exist"); + } + catch (AssertionError error) { + Throwable cause = error.getCause(); + assertNotNull(cause); + assertEquals("Response header 'Framework' does not exist", cause.getMessage()); + } + } + + @Test + public void doesNotExist() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON_UTF8); + HeaderAssertions assertions = headerAssertions(headers); + + // Success + assertions.doesNotExist("Framework"); + + try { + assertions.doesNotExist("Content-Type"); + fail("Existing header expected"); + } + catch (AssertionError error) { + Throwable cause = error.getCause(); + assertNotNull(cause); + assertEquals("Response header 'Content-Type' exists with " + + "value=[application/json;charset=UTF-8]", cause.getMessage()); + } + } + + @Test + public void contentTypeCompatibleWith() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_XML); + HeaderAssertions assertions = headerAssertions(headers); + + // Success + assertions.contentTypeCompatibleWith(MediaType.parseMediaType("application/*")); + + try { + assertions.contentTypeCompatibleWith(MediaType.TEXT_XML); + fail("MediaTypes not compatible expected"); + } + catch (AssertionError error) { + Throwable cause = error.getCause(); + assertNotNull(cause); + assertEquals("Response header 'Content-Type'=[application/xml] " + + "is not compatible with [text/xml]", cause.getMessage()); + } + } + + @Test + public void cacheControl() { + CacheControl control = CacheControl.maxAge(1, TimeUnit.HOURS).noTransform(); + + HttpHeaders headers = new HttpHeaders(); + headers.setCacheControl(control.getHeaderValue()); + HeaderAssertions assertions = headerAssertions(headers); + + // Success + assertions.cacheControl(control); + + try { + assertions.cacheControl(CacheControl.noStore()); + fail("Wrong value expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void expires() { + HttpHeaders headers = new HttpHeaders(); + ZonedDateTime expires = ZonedDateTime.of(2018, 1, 1, 0, 0, 0, 0, ZoneId.of("UTC")); + headers.setExpires(expires); + HeaderAssertions assertions = headerAssertions(headers); + assertions.expires(expires.toInstant().toEpochMilli()); + try { + assertions.expires(expires.toInstant().toEpochMilli() + 1); + fail("Wrong value expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void lastModified() { + HttpHeaders headers = new HttpHeaders(); + ZonedDateTime lastModified = ZonedDateTime.of(2018, 1, 1, 0, 0, 0, 0, ZoneId.of("UTC")); + headers.setLastModified(lastModified.toInstant().toEpochMilli()); + HeaderAssertions assertions = headerAssertions(headers); + assertions.lastModified(lastModified.toInstant().toEpochMilli()); + try { + assertions.lastModified(lastModified.toInstant().toEpochMilli() + 1); + fail("Wrong value expected"); + } + catch (AssertionError error) { + // Expected + } + } + + private HeaderAssertions headerAssertions(HttpHeaders responseHeaders) { + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.GET, URI.create("/")); + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.getHeaders().putAll(responseHeaders); + + MonoProcessor emptyContent = MonoProcessor.create(); + emptyContent.onComplete(); + + ExchangeResult result = new ExchangeResult(request, response, emptyContent, emptyContent, Duration.ZERO, null); + return new HeaderAssertions(result, mock(WebTestClient.ResponseSpec.class)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/HttpHandlerConnectorTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/HttpHandlerConnectorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..dd687157d4dee34797db224f4daa26b62c8ed6ac --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/HttpHandlerConnectorTests.java @@ -0,0 +1,134 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server; + +import java.net.URI; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Function; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.ResponseCookie; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for {@link HttpHandlerConnector}. + * @author Rossen Stoyanchev + */ +public class HttpHandlerConnectorTests { + + + @Test + public void adaptRequest() throws Exception { + + TestHttpHandler handler = new TestHttpHandler(response -> { + response.setStatusCode(HttpStatus.OK); + return response.setComplete(); + }); + + new HttpHandlerConnector(handler).connect(HttpMethod.POST, URI.create("/custom-path"), + request -> { + request.getHeaders().put("custom-header", Arrays.asList("h0", "h1")); + request.getCookies().add("custom-cookie", new HttpCookie("custom-cookie", "c0")); + return request.writeWith(Mono.just(toDataBuffer("Custom body"))); + }).block(Duration.ofSeconds(5)); + + MockServerHttpRequest request = (MockServerHttpRequest) handler.getSavedRequest(); + assertEquals(HttpMethod.POST, request.getMethod()); + assertEquals("/custom-path", request.getURI().toString()); + + HttpHeaders headers = request.getHeaders(); + assertEquals(Arrays.asList("h0", "h1"), headers.get("custom-header")); + assertEquals(new HttpCookie("custom-cookie", "c0"), request.getCookies().getFirst("custom-cookie")); + assertEquals(Collections.singletonList("custom-cookie=c0"), headers.get(HttpHeaders.COOKIE)); + + DataBuffer buffer = request.getBody().blockFirst(Duration.ZERO); + assertEquals("Custom body", DataBufferTestUtils.dumpString(buffer, UTF_8)); + } + + @Test + public void adaptResponse() throws Exception { + + ResponseCookie cookie = ResponseCookie.from("custom-cookie", "c0").build(); + + TestHttpHandler handler = new TestHttpHandler(response -> { + response.setStatusCode(HttpStatus.OK); + response.getHeaders().put("custom-header", Arrays.asList("h0", "h1")); + response.addCookie(cookie); + return response.writeWith(Mono.just(toDataBuffer("Custom body"))); + }); + + ClientHttpResponse response = new HttpHandlerConnector(handler) + .connect(HttpMethod.GET, URI.create("/custom-path"), ReactiveHttpOutputMessage::setComplete) + .block(Duration.ofSeconds(5)); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + HttpHeaders headers = response.getHeaders(); + assertEquals(Arrays.asList("h0", "h1"), headers.get("custom-header")); + assertEquals(cookie, response.getCookies().getFirst("custom-cookie")); + assertEquals(Collections.singletonList("custom-cookie=c0"), headers.get(HttpHeaders.SET_COOKIE)); + + DataBuffer buffer = response.getBody().blockFirst(Duration.ZERO); + assertEquals("Custom body", DataBufferTestUtils.dumpString(buffer, UTF_8)); + } + + private DataBuffer toDataBuffer(String body) { + return new DefaultDataBufferFactory().wrap(body.getBytes(UTF_8)); + } + + + private static class TestHttpHandler implements HttpHandler { + + private ServerHttpRequest savedRequest; + + private final Function> responseMonoFunction; + + + public TestHttpHandler(Function> function) { + this.responseMonoFunction = function; + } + + public ServerHttpRequest getSavedRequest() { + return this.savedRequest; + } + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + this.savedRequest = request; + return this.responseMonoFunction.apply(response); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerSpecTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerSpecTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4a865f715d7bb985f23eaa69ca576d1d529b1aaa --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerSpecTests.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.reactive.server; + +import java.nio.charset.StandardCharsets; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +import static org.hamcrest.MatcherAssert.*; +import static org.hamcrest.core.StringContains.*; + +/** + * Unit tests for {@link AbstractMockServerSpec}. + * @author Rossen Stoyanchev + */ +public class MockServerSpecTests { + + private final TestMockServerSpec serverSpec = new TestMockServerSpec(); + + + @Test + public void applyFiltersAfterConfigurerAdded() { + + this.serverSpec.webFilter(new TestWebFilter("A")); + + this.serverSpec.apply(new MockServerConfigurer() { + + @Override + public void afterConfigureAdded(WebTestClient.MockServerSpec spec) { + spec.webFilter(new TestWebFilter("B")); + } + }); + + this.serverSpec.build().get().uri("/") + .exchange() + .expectBody(String.class) + .consumeWith(result -> assertThat( + result.getResponseBody(), containsString("test-attribute=:A:B"))); + } + + @Test + public void applyFiltersBeforeServerCreated() { + + this.serverSpec.webFilter(new TestWebFilter("App-A")); + this.serverSpec.webFilter(new TestWebFilter("App-B")); + + this.serverSpec.apply(new MockServerConfigurer() { + + @Override + public void beforeServerCreated(WebHttpHandlerBuilder builder) { + builder.filters(filters -> { + filters.add(0, new TestWebFilter("Fwk-A")); + filters.add(1, new TestWebFilter("Fwk-B")); + }); + } + }); + + this.serverSpec.build().get().uri("/") + .exchange() + .expectBody(String.class) + .consumeWith(result -> assertThat( + result.getResponseBody(), containsString("test-attribute=:Fwk-A:Fwk-B:App-A:App-B"))); + } + + + private static class TestMockServerSpec extends AbstractMockServerSpec { + + @Override + protected WebHttpHandlerBuilder initHttpHandlerBuilder() { + return WebHttpHandlerBuilder.webHandler(exchange -> { + DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); + String text = exchange.getAttributes().toString(); + DataBuffer buffer = factory.wrap(text.getBytes(StandardCharsets.UTF_8)); + return exchange.getResponse().writeWith(Mono.just(buffer)); + }); + } + } + + private static class TestWebFilter implements WebFilter { + + private final String name; + + TestWebFilter(String name) { + this.name = name; + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + String name = "test-attribute"; + String value = exchange.getAttributeOrDefault(name, ""); + exchange.getAttributes().put(name, value + ":" + this.name); + return chain.filter(exchange); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3c4503f2d7e3e4e0144be0846e2c4129f242e098 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerTests.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.reactive.server; + +import java.util.Arrays; + +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.server.reactive.ServerHttpResponse; + +import static java.nio.charset.StandardCharsets.*; +import static org.junit.Assert.*; + +/** + * Test scenarios involving a mock server. + * @author Rossen Stoyanchev + */ +public class MockServerTests { + + + @Test // SPR-15674 (in comments) + public void mutateDoesNotCreateNewSession() { + + WebTestClient client = WebTestClient + .bindToWebHandler(exchange -> { + if (exchange.getRequest().getURI().getPath().equals("/set")) { + return exchange.getSession() + .doOnNext(session -> session.getAttributes().put("foo", "bar")) + .then(); + } + else { + return exchange.getSession() + .map(session -> session.getAttributeOrDefault("foo", "none")) + .flatMap(value -> { + DataBuffer buffer = toDataBuffer(value); + return exchange.getResponse().writeWith(Mono.just(buffer)); + }); + } + }) + .build(); + + // Set the session attribute + EntityExchangeResult result = client.get().uri("/set").exchange() + .expectStatus().isOk().expectBody().isEmpty(); + + ResponseCookie session = result.getResponseCookies().getFirst("SESSION"); + + // Now get attribute + client.mutate().build() + .get().uri("/get") + .cookie(session.getName(), session.getValue()) + .exchange() + .expectBody(String.class).isEqualTo("bar"); + } + + @Test // SPR-16059 + public void mutateDoesCopy() { + + WebTestClient.Builder builder = WebTestClient + .bindToWebHandler(exchange -> exchange.getResponse().setComplete()) + .configureClient(); + + builder.filter((request, next) -> next.exchange(request)); + builder.defaultHeader("foo", "bar"); + builder.defaultCookie("foo", "bar"); + WebTestClient client1 = builder.build(); + + builder.filter((request, next) -> next.exchange(request)); + builder.defaultHeader("baz", "qux"); + builder.defaultCookie("baz", "qux"); + WebTestClient client2 = builder.build(); + + WebTestClient.Builder mutatedBuilder = client1.mutate(); + + mutatedBuilder.filter((request, next) -> next.exchange(request)); + mutatedBuilder.defaultHeader("baz", "qux"); + mutatedBuilder.defaultCookie("baz", "qux"); + WebTestClient clientFromMutatedBuilder = mutatedBuilder.build(); + + client1.mutate().filters(filters -> assertEquals(1, filters.size())); + client1.mutate().defaultHeaders(headers -> assertEquals(1, headers.size())); + client1.mutate().defaultCookies(cookies -> assertEquals(1, cookies.size())); + + client2.mutate().filters(filters -> assertEquals(2, filters.size())); + client2.mutate().defaultHeaders(headers -> assertEquals(2, headers.size())); + client2.mutate().defaultCookies(cookies -> assertEquals(2, cookies.size())); + + clientFromMutatedBuilder.mutate().filters(filters -> assertEquals(2, filters.size())); + clientFromMutatedBuilder.mutate().defaultHeaders(headers -> assertEquals(2, headers.size())); + clientFromMutatedBuilder.mutate().defaultCookies(cookies -> assertEquals(2, cookies.size())); + } + + @Test // SPR-16124 + public void exchangeResultHasCookieHeaders() { + + ExchangeResult result = WebTestClient + .bindToWebHandler(exchange -> { + ServerHttpResponse response = exchange.getResponse(); + if (exchange.getRequest().getURI().getPath().equals("/cookie")) { + response.addCookie(ResponseCookie.from("a", "alpha").path("/pathA").build()); + response.addCookie(ResponseCookie.from("b", "beta").path("/pathB").build()); + } + else { + response.setStatusCode(HttpStatus.NOT_FOUND); + } + return response.setComplete(); + }) + .build() + .get().uri("/cookie").cookie("a", "alpha").cookie("b", "beta") + .exchange() + .expectStatus().isOk() + .expectHeader().valueEquals(HttpHeaders.SET_COOKIE, "a=alpha; Path=/pathA", "b=beta; Path=/pathB") + .expectBody().isEmpty(); + + assertEquals(Arrays.asList("a=alpha", "b=beta"), + result.getRequestHeaders().get(HttpHeaders.COOKIE)); + } + + @Test + public void responseBodyContentWithFluxExchangeResult() { + + FluxExchangeResult result = WebTestClient + .bindToWebHandler(exchange -> { + ServerHttpResponse response = exchange.getResponse(); + response.getHeaders().setContentType(MediaType.TEXT_PLAIN); + return response.writeWith(Flux.just(toDataBuffer("body"))); + }) + .build() + .get().uri("/") + .exchange() + .expectStatus().isOk() + .returnResult(String.class); + + // Get the raw content without consuming the response body flux.. + byte[] bytes = result.getResponseBodyContent(); + + assertNotNull(bytes); + assertEquals("body", new String(bytes, UTF_8)); + } + + + private DataBuffer toDataBuffer(String value) { + byte[] bytes = value.getBytes(UTF_8); + return new DefaultDataBufferFactory().wrap(bytes); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/StatusAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/StatusAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5b6f1c459a34df0dbc7adb29f7f1b17c81bb24cd --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/StatusAssertionTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server; + +import java.net.URI; +import java.time.Duration; + +import org.junit.Test; +import reactor.core.publisher.MonoProcessor; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.client.reactive.MockClientHttpRequest; +import org.springframework.mock.http.client.reactive.MockClientHttpResponse; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link StatusAssertions}. + * @author Rossen Stoyanchev + */ +public class StatusAssertionTests { + + @Test + public void isEqualTo() { + StatusAssertions assertions = statusAssertions(HttpStatus.CONFLICT); + + // Success + assertions.isEqualTo(HttpStatus.CONFLICT); + assertions.isEqualTo(409); + + try { + assertions.isEqualTo(HttpStatus.REQUEST_TIMEOUT); + fail("Wrong status expected"); + } + catch (AssertionError error) { + // Expected + } + + try { + assertions.isEqualTo(408); + fail("Wrong status value expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test // gh-23630 + public void isEqualToWithCustomStatus() { + statusAssertions(600).isEqualTo(600); + } + + @Test + public void reasonEquals() { + StatusAssertions assertions = statusAssertions(HttpStatus.CONFLICT); + + // Success + assertions.reasonEquals("Conflict"); + + try { + assertions.reasonEquals("Request Timeout"); + fail("Wrong reason expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void statusSerius1xx() { + StatusAssertions assertions = statusAssertions(HttpStatus.CONTINUE); + + // Success + assertions.is1xxInformational(); + + try { + assertions.is2xxSuccessful(); + fail("Wrong series expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void statusSerius2xx() { + StatusAssertions assertions = statusAssertions(HttpStatus.OK); + + // Success + assertions.is2xxSuccessful(); + + try { + assertions.is5xxServerError(); + fail("Wrong series expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void statusSerius3xx() { + StatusAssertions assertions = statusAssertions(HttpStatus.PERMANENT_REDIRECT); + + // Success + assertions.is3xxRedirection(); + + try { + assertions.is2xxSuccessful(); + fail("Wrong series expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void statusSerius4xx() { + StatusAssertions assertions = statusAssertions(HttpStatus.BAD_REQUEST); + + // Success + assertions.is4xxClientError(); + + try { + assertions.is2xxSuccessful(); + fail("Wrong series expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void statusSerius5xx() { + StatusAssertions assertions = statusAssertions(HttpStatus.INTERNAL_SERVER_ERROR); + + // Success + assertions.is5xxServerError(); + + try { + assertions.is2xxSuccessful(); + fail("Wrong series expected"); + } + catch (AssertionError error) { + // Expected + } + } + + @Test + public void matches() { + StatusAssertions assertions = statusAssertions(HttpStatus.CONFLICT); + + // Success + assertions.value(equalTo(409)); + assertions.value(greaterThan(400)); + + try { + assertions.value(equalTo(200)); + fail("Wrong status expected"); + } + catch (AssertionError error) { + // Expected + } + } + + + private StatusAssertions statusAssertions(HttpStatus status) { + return statusAssertions(status.value()); + } + + private StatusAssertions statusAssertions(int status) { + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.GET, URI.create("/")); + MockClientHttpResponse response = new MockClientHttpResponse(status); + + MonoProcessor emptyContent = MonoProcessor.create(); + emptyContent.onComplete(); + + ExchangeResult result = new ExchangeResult(request, response, emptyContent, emptyContent, Duration.ZERO, null); + return new StatusAssertions(result, mock(WebTestClient.ResponseSpec.class)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5c84c1db87afa9ad9920cbc7e3250b48f661f78c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server; + +import java.net.URI; +import java.time.Duration; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.reactive.ClientHttpConnector; +import org.springframework.http.client.reactive.ClientHttpRequest; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.mock.http.client.reactive.MockClientHttpRequest; +import org.springframework.mock.http.client.reactive.MockClientHttpResponse; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.reactive.function.client.ExchangeFunctions; + +import static java.time.Duration.ofMillis; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link WiretapConnector}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class WiretapConnectorTests { + + @Test + public void captureAndClaim() { + ClientHttpRequest request = new MockClientHttpRequest(HttpMethod.GET, "/test"); + ClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + ClientHttpConnector connector = (method, uri, fn) -> fn.apply(request).then(Mono.just(response)); + + ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("/test")) + .header(WebTestClient.WEBTESTCLIENT_REQUEST_ID, "1").build(); + + WiretapConnector wiretapConnector = new WiretapConnector(connector); + ExchangeFunction function = ExchangeFunctions.create(wiretapConnector); + function.exchange(clientRequest).block(ofMillis(0)); + + WiretapConnector.Info actual = wiretapConnector.claimRequest("1"); + ExchangeResult result = actual.createExchangeResult(Duration.ZERO, null); + assertEquals(HttpMethod.GET, result.getMethod()); + assertEquals("/test", result.getUrl().toString()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ErrorTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ErrorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fef80e1c32b31d3e2a72417859fa58f20b5cd45e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ErrorTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples; + +import java.nio.charset.StandardCharsets; + +import org.junit.Test; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.test.web.reactive.server.EntityExchangeResult; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; + +import static org.junit.Assert.*; + +/** + * Tests with error status codes or error conditions. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ErrorTests { + + private final WebTestClient client = WebTestClient.bindToController(new TestController()).build(); + + + @Test + public void notFound(){ + this.client.get().uri("/invalid") + .exchange() + .expectStatus().isNotFound() + .expectBody(Void.class); + } + + @Test + public void serverException() { + this.client.get().uri("/server-error") + .exchange() + .expectStatus().isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR) + .expectBody(Void.class); + } + + @Test // SPR-17363 + public void badRequestBeforeRequestBodyConsumed() { + EntityExchangeResult result = this.client.post() + .uri("/post") + .contentType(MediaType.APPLICATION_JSON_UTF8) + .syncBody(new Person("Dan")) + .exchange() + .expectStatus().isBadRequest() + .expectBody().isEmpty(); + + byte[] content = result.getRequestBodyContent(); + assertNotNull(content); + assertEquals("{\"name\":\"Dan\"}", new String(content, StandardCharsets.UTF_8)); + } + + + @RestController + static class TestController { + + @GetMapping("/server-error") + void handleAndThrowException() { + throw new IllegalStateException("server error"); + } + + @PostMapping(path = "/post", params = "p") + void handlePost(@RequestBody Person person) { + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ExchangeMutatorTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ExchangeMutatorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ce3a85f92b7e287c77ee1cf02f8ac65bdc8ad5fc --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ExchangeMutatorTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.reactive.server.samples; + +import java.security.Principal; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.client.reactive.ClientHttpConnector; +import org.springframework.lang.Nullable; +import org.springframework.test.web.reactive.server.MockServerConfigurer; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.test.web.reactive.server.WebTestClientConfigurer; +import org.springframework.util.Assert; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +/** + * Samples tests that demonstrate applying ServerWebExchange initialization. + * @author Rossen Stoyanchev + */ +public class ExchangeMutatorTests { + + private WebTestClient webTestClient; + + + @Before + public void setUp() throws Exception { + + this.webTestClient = WebTestClient.bindToController(new TestController()) + .apply(identity("Pablo")) + .build(); + } + + @Test + public void useGloballyConfiguredIdentity() throws Exception { + this.webTestClient.get().uri("/userIdentity") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("Hello Pablo!"); + } + + @Test + public void useLocallyConfiguredIdentity() throws Exception { + + this.webTestClient + .mutateWith(identity("Giovanni")) + .get().uri("/userIdentity") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("Hello Giovanni!"); + } + + + private static IdentityConfigurer identity(String userName) { + return new IdentityConfigurer(userName); + } + + + @RestController + static class TestController { + + @GetMapping("/userIdentity") + public String handle(Principal principal) { + return "Hello " + principal.getName() + "!"; + } + } + + private static class TestUser implements Principal { + + private final String name; + + TestUser(String name) { + this.name = name; + } + + @Override + public String getName() { + return this.name; + } + } + + private static class IdentityConfigurer implements MockServerConfigurer, WebTestClientConfigurer { + + private final IdentityFilter filter; + + + public IdentityConfigurer(String userName) { + this.filter = new IdentityFilter(userName); + } + + @Override + public void beforeServerCreated(WebHttpHandlerBuilder builder) { + builder.filters(filters -> filters.add(0, this.filter)); + } + + @Override + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, + @Nullable ClientHttpConnector connector) { + + Assert.notNull(httpHandlerBuilder, "Not a mock server"); + httpHandlerBuilder.filters(filters -> { + filters.removeIf(filter -> filter instanceof IdentityFilter); + filters.add(0, this.filter); + }); + } + } + + private static class IdentityFilter implements WebFilter { + + private final Mono userMono; + + + IdentityFilter(String userName) { + this.userMono = Mono.just(new TestUser(userName)); + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + exchange = exchange.mutate().principal(this.userMono).build(); + return chain.filter(exchange); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/HeaderAndCookieTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/HeaderAndCookieTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9e6397fafc2d19ed06ab4f217aa0d52476913826 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/HeaderAndCookieTests.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.CookieValue; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.annotation.RestController; + +/** + * Tests with headers and cookies. + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class HeaderAndCookieTests { + + private final WebTestClient client = WebTestClient.bindToController(new TestController()).build(); + + + @Test + public void requestResponseHeaderPair() throws Exception { + this.client.get().uri("/header-echo").header("h1", "in") + .exchange() + .expectStatus().isOk() + .expectHeader().valueEquals("h1", "in-out"); + } + + @Test + public void headerMultipleValues() throws Exception { + this.client.get().uri("/header-multi-value") + .exchange() + .expectStatus().isOk() + .expectHeader().valueEquals("h1", "v1", "v2", "v3"); + } + + @Test + public void setCookies() { + this.client.get().uri("/cookie-echo") + .cookies(cookies -> cookies.add("k1", "v1")) + .exchange() + .expectHeader().valueMatches("Set-Cookie", "k1=v1"); + } + + + @RestController + static class TestController { + + @GetMapping("header-echo") + ResponseEntity handleHeader(@RequestHeader("h1") String myHeader) { + String value = myHeader + "-out"; + return ResponseEntity.ok().header("h1", value).build(); + } + + @GetMapping("header-multi-value") + ResponseEntity multiValue() { + return ResponseEntity.ok().header("h1", "v1", "v2", "v3").build(); + } + + @GetMapping("cookie-echo") + ResponseEntity handleCookie(@CookieValue("k1") String cookieValue) { + HttpHeaders headers = new HttpHeaders(); + headers.set("Set-Cookie", "k1=" + cookieValue); + return new ResponseEntity<>(headers, HttpStatus.OK); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/JsonContentTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/JsonContentTests.java new file mode 100644 index 0000000000000000000000000000000000000000..74704a9f9093f1b19118dab535a1376f0d7dc427 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/JsonContentTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples; + +import java.net.URI; + +import org.junit.Test; +import reactor.core.publisher.Flux; + +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import static org.hamcrest.Matchers.*; + +/** + * Samples of tests using {@link WebTestClient} with serialized JSON content. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @since 5.0 + */ +public class JsonContentTests { + + private final WebTestClient client = WebTestClient.bindToController(new PersonController()).build(); + + + @Test + public void jsonContent() { + this.client.get().uri("/persons") + .accept(MediaType.APPLICATION_JSON_UTF8) + .exchange() + .expectStatus().isOk() + .expectBody().json("[{\"name\":\"Jane\"},{\"name\":\"Jason\"},{\"name\":\"John\"}]"); + } + + @Test + public void jsonPathIsEqualTo() { + this.client.get().uri("/persons") + .accept(MediaType.APPLICATION_JSON_UTF8) + .exchange() + .expectStatus().isOk() + .expectBody() + .jsonPath("$[0].name").isEqualTo("Jane") + .jsonPath("$[1].name").isEqualTo("Jason") + .jsonPath("$[2].name").isEqualTo("John"); + } + + @Test + public void jsonPathMatches() { + this.client.get().uri("/persons/John") + .accept(MediaType.APPLICATION_JSON_UTF8) + .exchange() + .expectStatus().isOk() + .expectBody() + .jsonPath("$.name").value(containsString("oh")); + } + + @Test + public void postJsonContent() { + this.client.post().uri("/persons") + .contentType(MediaType.APPLICATION_JSON_UTF8) + .syncBody("{\"name\":\"John\"}") + .exchange() + .expectStatus().isCreated() + .expectBody().isEmpty(); + } + + + @RestController + @RequestMapping("/persons") + static class PersonController { + + @GetMapping + Flux getPersons() { + return Flux.just(new Person("Jane"), new Person("Jason"), new Person("John")); + } + + @GetMapping("/{name}") + Person getPerson(@PathVariable String name) { + return new Person(name); + } + + @PostMapping + ResponseEntity savePerson(@RequestBody Person person) { + return ResponseEntity.created(URI.create("/persons/" + person.getName())).build(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/Person.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/Person.java new file mode 100644 index 0000000000000000000000000000000000000000..23c2115e1d7606d73b18408974a9dfb9a01c5fec --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/Person.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.reactive.server.samples; + +import javax.xml.bind.annotation.XmlRootElement; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +@XmlRootElement +class Person { + + private String name; + + + // No-arg constructor for XML + public Person() { + } + + @JsonCreator + public Person(@JsonProperty("name") String name) { + this.name = name; + } + + public void setName(String name) { + this.name = name; + } + + public String getName() { + return this.name; + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + Person person = (Person) other; + return getName().equals(person.getName()); + } + + @Override + public int hashCode() { + return getName().hashCode(); + } + + @Override + public String toString() { + return "Person[name='" + name + "']"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ResponseEntityTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ResponseEntityTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0ce3752d9a1023e2edf6da3fc76f1c0e8aebe79b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ResponseEntityTests.java @@ -0,0 +1,192 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples; + +import java.net.URI; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.hamcrest.MatcherAssert; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.test.web.reactive.server.FluxExchangeResult; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import static java.time.Duration.*; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.http.MediaType.*; + +/** + * Annotated controllers accepting and returning typed Objects. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ResponseEntityTests { + + private final WebTestClient client = WebTestClient.bindToController(new PersonController()) + .configureClient() + .baseUrl("/persons") + .build(); + + + @Test + public void entity() { + this.client.get().uri("/John") + .exchange() + .expectStatus().isOk() + .expectHeader().contentType(MediaType.APPLICATION_JSON_UTF8) + .expectBody(Person.class).isEqualTo(new Person("John")); + } + + @Test + public void entityMatcher() { + this.client.get().uri("/John") + .exchange() + .expectStatus().isOk() + .expectHeader().contentType(MediaType.APPLICATION_JSON_UTF8) + .expectBody(Person.class).value(Person::getName, startsWith("Joh")); + } + + @Test + public void entityWithConsumer() { + this.client.get().uri("/John") + .exchange() + .expectStatus().isOk() + .expectHeader().contentType(MediaType.APPLICATION_JSON_UTF8) + .expectBody(Person.class) + .consumeWith(result -> assertEquals(new Person("John"), result.getResponseBody())); + } + + @Test + public void entityList() { + + List expected = Arrays.asList( + new Person("Jane"), new Person("Jason"), new Person("John")); + + this.client.get() + .exchange() + .expectStatus().isOk() + .expectHeader().contentType(MediaType.APPLICATION_JSON_UTF8) + .expectBodyList(Person.class).isEqualTo(expected); + } + + @Test + public void entityListWithConsumer() { + + this.client.get() + .exchange() + .expectStatus().isOk() + .expectHeader().contentType(MediaType.APPLICATION_JSON_UTF8) + .expectBodyList(Person.class).value(people -> { + MatcherAssert.assertThat(people, hasItem(new Person("Jason"))); + }); + } + + @Test + public void entityMap() { + + Map map = new LinkedHashMap<>(); + map.put("Jane", new Person("Jane")); + map.put("Jason", new Person("Jason")); + map.put("John", new Person("John")); + + this.client.get().uri("?map=true") + .exchange() + .expectStatus().isOk() + .expectBody(new ParameterizedTypeReference>() {}).isEqualTo(map); + } + + @Test + public void entityStream() { + + FluxExchangeResult result = this.client.get() + .accept(TEXT_EVENT_STREAM) + .exchange() + .expectStatus().isOk() + .expectHeader().contentTypeCompatibleWith(TEXT_EVENT_STREAM) + .returnResult(Person.class); + + StepVerifier.create(result.getResponseBody()) + .expectNext(new Person("N0"), new Person("N1"), new Person("N2")) + .expectNextCount(4) + .consumeNextWith(person -> assertThat(person.getName(), endsWith("7"))) + .thenCancel() + .verify(); + } + + @Test + public void postEntity() { + this.client.post() + .syncBody(new Person("John")) + .exchange() + .expectStatus().isCreated() + .expectHeader().valueEquals("location", "/persons/John") + .expectBody().isEmpty(); + } + + + @RestController + @RequestMapping("/persons") + static class PersonController { + + @GetMapping("/{name}") + Person getPerson(@PathVariable String name) { + return new Person(name); + } + + @GetMapping + Flux getPersons() { + return Flux.just(new Person("Jane"), new Person("Jason"), new Person("John")); + } + + @GetMapping(params = "map") + Map getPersonsAsMap() { + Map map = new LinkedHashMap<>(); + map.put("Jane", new Person("Jane")); + map.put("Jason", new Person("Jason")); + map.put("John", new Person("John")); + return map; + } + + @GetMapping(produces = "text/event-stream") + Flux getPersonStream() { + return Flux.interval(ofMillis(100)).take(50).onBackpressureBuffer(50) + .map(index -> new Person("N" + index)); + } + + @PostMapping + ResponseEntity savePerson(@RequestBody Person person) { + return ResponseEntity.created(URI.create("/persons/" + person.getName())).build(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/XmlContentTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/XmlContentTests.java new file mode 100644 index 0000000000000000000000000000000000000000..846c269e31bc2b3b2716cc2e28cad4c622e66ea5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/XmlContentTests.java @@ -0,0 +1,171 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples; + +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.xml.bind.annotation.XmlAccessType; +import javax.xml.bind.annotation.XmlAccessorType; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import static org.hamcrest.Matchers.*; + +/** + * Samples of tests using {@link WebTestClient} with XML content. + * + * @author Eric Deandrea + * @since 5.1 + */ +public class XmlContentTests { + + private static final String persons_XML = + "" + + "" + + "Jane" + + "Jason" + + "John" + + ""; + + + private final WebTestClient client = WebTestClient.bindToController(new PersonController()).build(); + + + @Test + public void xmlContent() { + this.client.get().uri("/persons") + .accept(MediaType.APPLICATION_XML) + .exchange() + .expectStatus().isOk() + .expectBody().xml(persons_XML); + } + + @Test + public void xpathIsEqualTo() { + this.client.get().uri("/persons") + .accept(MediaType.APPLICATION_XML) + .exchange() + .expectStatus().isOk() + .expectBody() + .xpath("/").exists() + .xpath("/persons").exists() + .xpath("/persons/person").exists() + .xpath("/persons/person").nodeCount(3) + .xpath("/persons/person[1]/name").isEqualTo("Jane") + .xpath("/persons/person[2]/name").isEqualTo("Jason") + .xpath("/persons/person[3]/name").isEqualTo("John"); + } + + @Test + public void xpathMatches() { + this.client.get().uri("/persons") + .accept(MediaType.APPLICATION_XML) + .exchange() + .expectStatus().isOk() + .expectBody() + .xpath("//person/name").string(startsWith("J")); + } + + @Test + public void xpathContainsSubstringViaRegex() { + this.client.get().uri("/persons/John") + .accept(MediaType.APPLICATION_XML) + .exchange() + .expectStatus().isOk() + .expectBody() + .xpath("//name[contains(text(), 'oh')]").exists(); + } + + @Test + public void postXmlContent() { + + String content = + "" + + "John"; + + this.client.post().uri("/persons") + .contentType(MediaType.APPLICATION_XML) + .syncBody(content) + .exchange() + .expectStatus().isCreated() + .expectHeader().valueEquals(HttpHeaders.LOCATION, "/persons/John") + .expectBody().isEmpty(); + } + + + @SuppressWarnings("unused") + @XmlRootElement(name="persons") + @XmlAccessorType(XmlAccessType.FIELD) + private static class PersonsWrapper { + + @XmlElement(name="person") + private final List persons = new ArrayList<>(); + + public PersonsWrapper() { + } + + public PersonsWrapper(List persons) { + this.persons.addAll(persons); + } + + public PersonsWrapper(Person... persons) { + this.persons.addAll(Arrays.asList(persons)); + } + + public List getpersons() { + return this.persons; + } + } + + @RestController + @RequestMapping("/persons") + static class PersonController { + + @GetMapping(produces = MediaType.APPLICATION_XML_VALUE) + PersonsWrapper getPersons() { + return new PersonsWrapper(new Person("Jane"), new Person("Jason"), new Person("John")); + } + + @GetMapping(path = "/{name}", produces = MediaType.APPLICATION_XML_VALUE) + Person getPerson(@PathVariable String name) { + return new Person(name); + } + + @PostMapping(consumes = MediaType.APPLICATION_XML_VALUE) + ResponseEntity savepersons(@RequestBody Person person) { + URI location = URI.create(String.format("/persons/%s", person.getName())); + return ResponseEntity.created(location).build(); + } + } + +} \ No newline at end of file diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8e71fc88e6cec0c2816effdac666c866d98c5760 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples.bind; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.reactive.config.EnableWebFlux; + +/** + * Sample tests demonstrating "mock" server tests binding to server infrastructure + * declared in a Spring ApplicationContext. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ApplicationContextTests { + + private WebTestClient client; + + + @Before + public void setUp() throws Exception { + + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(WebConfig.class); + context.refresh(); + + this.client = WebTestClient.bindToApplicationContext(context).build(); + } + + @Test + public void test() throws Exception { + this.client.get().uri("/test") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("It works!"); + } + + + @Configuration + @EnableWebFlux + static class WebConfig { + + @Bean + public TestController controller() { + return new TestController(); + } + + } + + @RestController + static class TestController { + + @GetMapping("/test") + public String handle() { + return "It works!"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0dbddfcae412d55986939c36624f35572468e931 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples.bind; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +/** + * Sample tests demonstrating "mock" server tests binding to an annotated + * controller. + * + * @author Rossen Stoyanchev + */ +public class ControllerTests { + + private WebTestClient client; + + + @Before + public void setUp() throws Exception { + this.client = WebTestClient.bindToController(new TestController()).build(); + } + + + @Test + public void test() throws Exception { + this.client.get().uri("/test") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("It works!"); + } + + + @RestController + static class TestController { + + @GetMapping("/test") + public String handle() { + return "It works!"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/HttpServerTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/HttpServerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..002b83536fd24412fbb319f5357fdaee55baf6dc --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/HttpServerTests.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples.bind; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.bootstrap.ReactorHttpServer; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; + +import static org.springframework.web.reactive.function.server.RequestPredicates.*; +import static org.springframework.web.reactive.function.server.RouterFunctions.*; + +/** + * Sample tests demonstrating live server integration tests. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class HttpServerTests { + + private ReactorHttpServer server; + + private WebTestClient client; + + + @Before + public void start() throws Exception { + HttpHandler httpHandler = RouterFunctions.toHttpHandler( + route(GET("/test"), request -> ServerResponse.ok().syncBody("It works!"))); + + this.server = new ReactorHttpServer(); + this.server.setHandler(httpHandler); + this.server.afterPropertiesSet(); + this.server.start(); + + this.client = WebTestClient.bindToServer() + .baseUrl("http://localhost:" + this.server.getPort()) + .build(); + } + + @After + public void stop() { + this.server.stop(); + } + + + @Test + public void test() { + this.client.get().uri("/test") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("It works!"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/RouterFunctionTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/RouterFunctionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..32f4f1e770f8c934370fbca580af0389e85c14a3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/RouterFunctionTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples.bind; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.ServerResponse; + +import static org.springframework.web.reactive.function.server.RequestPredicates.GET; +import static org.springframework.web.reactive.function.server.RouterFunctions.route; + +/** + * Sample tests demonstrating "mock" server tests binding to a RouterFunction. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class RouterFunctionTests { + + private WebTestClient testClient; + + + @Before + public void setUp() throws Exception { + + RouterFunction route = route(GET("/test"), request -> + ServerResponse.ok().syncBody("It works!")); + + this.testClient = WebTestClient.bindToRouterFunction(route).build(); + } + + @Test + public void test() throws Exception { + this.testClient.get().uri("/test") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("It works!"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/WebFilterTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/WebFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5bca558d88032c7738e42ca60dd64927ad17e8a2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/WebFilterTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server.samples.bind; + +import java.nio.charset.StandardCharsets; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.server.WebFilter; + +/** + * Tests for a {@link WebFilter}. + * @author Rossen Stoyanchev + */ +public class WebFilterTests { + + @Test + public void testWebFilter() throws Exception { + + WebFilter filter = (exchange, chain) -> { + DataBuffer buffer = new DefaultDataBufferFactory().allocateBuffer(); + buffer.write("It works!".getBytes(StandardCharsets.UTF_8)); + return exchange.getResponse().writeWith(Mono.just(buffer)); + }; + + WebTestClient client = WebTestClient.bindToWebHandler(exchange -> Mono.empty()) + .webFilter(filter) + .build(); + + client.get().uri("/") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("It works!"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java new file mode 100644 index 0000000000000000000000000000000000000000..453f9b919931fb48be82c354542964a89f1b24ed --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet; + +import java.util.concurrent.CountDownLatch; + +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletRequest; + +/** + * Test fixture for {@link DefaultMvcResult}. + * + * @author Rossen Stoyanchev + */ +public class DefaultMvcResultTests { + + private final DefaultMvcResult mvcResult = new DefaultMvcResult(new MockHttpServletRequest(), null); + + @Test + public void getAsyncResultSuccess() { + this.mvcResult.setAsyncResult("Foo"); + this.mvcResult.setAsyncDispatchLatch(new CountDownLatch(0)); + this.mvcResult.getAsyncResult(); + } + + @Test(expected = IllegalStateException.class) + public void getAsyncResultFailure() { + this.mvcResult.getAsyncResult(0); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/MockMvcReuseTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/MockMvcReuseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..30d36f948dc52bc551a1581e6e41a8de6e6e6c2a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/MockMvcReuseTests.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.hamcrest.CoreMatchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Integration tests that verify that {@link MockMvc} can be reused multiple + * times within the same test method without side effects between independent + * requests. + *

See SPR-13260. + * + * @author Sam Brannen + * @author Rob Winch + * @since 4.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@WebAppConfiguration +public class MockMvcReuseTests { + + private static final String HELLO = "hello"; + private static final String ENIGMA = "enigma"; + private static final String FOO = "foo"; + private static final String BAR = "bar"; + + @Autowired + private WebApplicationContext wac; + + private MockMvc mvc; + + + @Before + public void setUp() { + this.mvc = webAppContextSetup(this.wac).build(); + } + + @Test + public void sessionAttributesAreClearedBetweenInvocations() throws Exception { + + this.mvc.perform(get("/")) + .andExpect(content().string(HELLO)) + .andExpect(request().sessionAttribute(FOO, nullValue())); + + this.mvc.perform(get("/").sessionAttr(FOO, BAR)) + .andExpect(content().string(HELLO)) + .andExpect(request().sessionAttribute(FOO, BAR)); + + this.mvc.perform(get("/")) + .andExpect(content().string(HELLO)) + .andExpect(request().sessionAttribute(FOO, nullValue())); + } + + @Test + public void requestParametersAreClearedBetweenInvocations() throws Exception { + this.mvc.perform(get("/")) + .andExpect(content().string(HELLO)); + + this.mvc.perform(get("/").param(ENIGMA, "")) + .andExpect(content().string(ENIGMA)); + + this.mvc.perform(get("/")) + .andExpect(content().string(HELLO)); + } + + + @Configuration + @EnableWebMvc + static class Config { + + @Bean + public MyController myController() { + return new MyController(); + } + } + + @RestController + static class MyController { + + @RequestMapping("/") + public String hello() { + return HELLO; + } + + @RequestMapping(path = "/", params = ENIGMA) + public String enigma() { + return ENIGMA; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java b/spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java new file mode 100644 index 0000000000000000000000000000000000000000..12bde4ea41aa25a7b7e2175755923f761430471e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.servlet.FlashMap; +import org.springframework.web.servlet.HandlerInterceptor; +import org.springframework.web.servlet.ModelAndView; + +/** + * A stub implementation of the {@link MvcResult} contract. + * + * @author Rossen Stoyanchev + */ +public class StubMvcResult implements MvcResult { + + private MockHttpServletRequest request; + + private Object handler; + + private HandlerInterceptor[] interceptors; + + private Exception resolvedException; + + private ModelAndView mav; + + private FlashMap flashMap; + + private MockHttpServletResponse response; + + public StubMvcResult(MockHttpServletRequest request, + Object handler, + HandlerInterceptor[] interceptors, + Exception resolvedException, + ModelAndView mav, + FlashMap flashMap, + MockHttpServletResponse response) { + this.request = request; + this.handler = handler; + this.interceptors = interceptors; + this.resolvedException = resolvedException; + this.mav = mav; + this.flashMap = flashMap; + this.response = response; + } + + @Override + public MockHttpServletRequest getRequest() { + return request; + } + + @Override + public Object getHandler() { + return handler; + } + + @Override + public HandlerInterceptor[] getInterceptors() { + return interceptors; + } + + @Override + public Exception getResolvedException() { + return resolvedException; + } + + @Override + public ModelAndView getModelAndView() { + return mav; + } + + @Override + public FlashMap getFlashMap() { + return flashMap; + } + + @Override + public MockHttpServletResponse getResponse() { + return response; + } + + public ModelAndView getMav() { + return mav; + } + + public void setMav(ModelAndView mav) { + this.mav = mav; + } + + public void setRequest(MockHttpServletRequest request) { + this.request = request; + } + + public void setHandler(Object handler) { + this.handler = handler; + } + + public void setInterceptors(HandlerInterceptor[] interceptors) { + this.interceptors = interceptors; + } + + public void setResolvedException(Exception resolvedException) { + this.resolvedException = resolvedException; + } + + public void setFlashMap(FlashMap flashMap) { + this.flashMap = flashMap; + } + + public void setResponse(MockHttpServletResponse response) { + this.response = response; + } + + @Override + public Object getAsyncResult() { + return null; + } + + @Override + public Object getAsyncResult(long timeToWait) { + return null; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/AbstractWebRequestMatcherTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/AbstractWebRequestMatcherTests.java new file mode 100644 index 0000000000000000000000000000000000000000..289df88f1b205733fc7ca29b8608948fcd56a2d0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/AbstractWebRequestMatcherTests.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import java.net.MalformedURLException; +import java.net.URL; + +import com.gargoylesoftware.htmlunit.WebRequest; + +import static org.junit.Assert.*; + +/** + * Abstract base class for testing {@link WebRequestMatcher} implementations. + * + * @author Sam Brannen + * @since 4.2 + */ +public class AbstractWebRequestMatcherTests { + + protected void assertMatches(WebRequestMatcher matcher, String url) throws MalformedURLException { + assertTrue(matcher.matches(new WebRequest(new URL(url)))); + } + + protected void assertDoesNotMatch(WebRequestMatcher matcher, String url) throws MalformedURLException { + assertFalse(matcher.matches(new WebRequest(new URL(url)))); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/DelegatingWebConnectionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/DelegatingWebConnectionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..675f6e57aa05772b4ee0fd5f5ce5b9969bd2f1d0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/DelegatingWebConnectionTests.java @@ -0,0 +1,152 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import java.net.URL; +import java.util.Collections; + +import com.gargoylesoftware.htmlunit.HttpWebConnection; +import com.gargoylesoftware.htmlunit.Page; +import com.gargoylesoftware.htmlunit.WebClient; +import com.gargoylesoftware.htmlunit.WebConnection; +import com.gargoylesoftware.htmlunit.WebRequest; +import com.gargoylesoftware.htmlunit.WebResponse; +import com.gargoylesoftware.htmlunit.WebResponseData; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.htmlunit.DelegatingWebConnection.DelegateWebConnection; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.tests.Assume; +import org.springframework.tests.TestGroup; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.Matchers.*; +import static org.hamcrest.core.IsNot.not; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit and integration tests for {@link DelegatingWebConnection}. + * + * @author Rob Winch + * @since 4.2 + */ +@RunWith(MockitoJUnitRunner.class) +public class DelegatingWebConnectionTests { + + private DelegatingWebConnection webConnection; + + private WebRequest request; + + private WebResponse expectedResponse; + + + @Mock + private WebRequestMatcher matcher1; + + @Mock + private WebRequestMatcher matcher2; + + @Mock + private WebConnection defaultConnection; + + @Mock + private WebConnection connection1; + + @Mock + private WebConnection connection2; + + + @Before + public void setup() throws Exception { + request = new WebRequest(new URL("http://localhost/")); + WebResponseData data = new WebResponseData("".getBytes("UTF-8"), 200, "", Collections.emptyList()); + expectedResponse = new WebResponse(data, request, 100L); + webConnection = new DelegatingWebConnection(defaultConnection, + new DelegateWebConnection(matcher1, connection1), new DelegateWebConnection(matcher2, connection2)); + } + + + @Test + public void getResponseDefault() throws Exception { + when(defaultConnection.getResponse(request)).thenReturn(expectedResponse); + WebResponse response = webConnection.getResponse(request); + + assertThat(response, sameInstance(expectedResponse)); + verify(matcher1).matches(request); + verify(matcher2).matches(request); + verifyNoMoreInteractions(connection1, connection2); + verify(defaultConnection).getResponse(request); + } + + @Test + public void getResponseAllMatches() throws Exception { + when(matcher1.matches(request)).thenReturn(true); + when(connection1.getResponse(request)).thenReturn(expectedResponse); + WebResponse response = webConnection.getResponse(request); + + assertThat(response, sameInstance(expectedResponse)); + verify(matcher1).matches(request); + verifyNoMoreInteractions(matcher2, connection2, defaultConnection); + verify(connection1).getResponse(request); + } + + @Test + public void getResponseSecondMatches() throws Exception { + when(matcher2.matches(request)).thenReturn(true); + when(connection2.getResponse(request)).thenReturn(expectedResponse); + WebResponse response = webConnection.getResponse(request); + + assertThat(response, sameInstance(expectedResponse)); + verify(matcher1).matches(request); + verify(matcher2).matches(request); + verifyNoMoreInteractions(connection1, defaultConnection); + verify(connection2).getResponse(request); + } + + @Test + public void verifyExampleInClassLevelJavadoc() throws Exception { + Assume.group(TestGroup.PERFORMANCE); + + WebClient webClient = new WebClient(); + + MockMvc mockMvc = MockMvcBuilders.standaloneSetup().build(); + MockMvcWebConnection mockConnection = new MockMvcWebConnection(mockMvc, webClient); + + WebRequestMatcher cdnMatcher = new UrlRegexRequestMatcher(".*?//code.jquery.com/.*"); + WebConnection httpConnection = new HttpWebConnection(webClient); + webClient.setWebConnection( + new DelegatingWebConnection(mockConnection, new DelegateWebConnection(cdnMatcher, httpConnection))); + + Page page = webClient.getPage("https://code.jquery.com/jquery-1.11.0.min.js"); + assertThat(page.getWebResponse().getStatusCode(), equalTo(200)); + assertThat(page.getWebResponse().getContentAsString(), not(isEmptyString())); + } + + + @Controller + static class TestController { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/ForwardController.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/ForwardController.java new file mode 100644 index 0000000000000000000000000000000000000000..0ab357e274d24ae19e707bdef6eff0efec7aad69 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/ForwardController.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import org.springframework.stereotype.Controller; +import org.springframework.web.bind.annotation.RequestMapping; + +/** + * @author Rob Winch + * @since 4.2 + */ +@Controller +public class ForwardController { + + @RequestMapping("/forward") + public String forward() { + return "forward:/"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HelloController.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HelloController.java new file mode 100644 index 0000000000000000000000000000000000000000..98e09a810ee297173db25e269fa3766e949f010f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HelloController.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +/** + * @author Rob Winch + * @since 4.2 + */ +@RestController +public class HelloController { + + @RequestMapping + public String header(HttpServletRequest request) { + return "hello"; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HostRequestMatcherTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HostRequestMatcherTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ecf4f24591896dc9befb956d92730ebf5e1014bc --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HostRequestMatcherTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import org.junit.Test; + +/** + * Unit tests for {@link HostRequestMatcher}. + * + * @author Rob Winch + * @author Sam Brannen + * @since 4.2 + */ +public class HostRequestMatcherTests extends AbstractWebRequestMatcherTests { + + @Test + public void localhost() throws Exception { + WebRequestMatcher matcher = new HostRequestMatcher("localhost"); + assertMatches(matcher, "http://localhost/jquery-1.11.0.min.js"); + assertDoesNotMatch(matcher, "http://example.com/jquery-1.11.0.min.js"); + } + + @Test + public void multipleHosts() throws Exception { + WebRequestMatcher matcher = new HostRequestMatcher("localhost", "example.com"); + assertMatches(matcher, "http://localhost/jquery-1.11.0.min.js"); + assertMatches(matcher, "http://example.com/jquery-1.11.0.min.js"); + } + + @Test + public void specificPort() throws Exception { + WebRequestMatcher matcher = new HostRequestMatcher("localhost:8080"); + assertMatches(matcher, "http://localhost:8080/jquery-1.11.0.min.js"); + assertDoesNotMatch(matcher, "http://localhost:9090/jquery-1.11.0.min.js"); + } + + @Test + public void defaultHttpPort() throws Exception { + WebRequestMatcher matcher = new HostRequestMatcher("localhost:80"); + assertMatches(matcher, "http://localhost:80/jquery-1.11.0.min.js"); + assertMatches(matcher, "http://localhost/jquery-1.11.0.min.js"); + assertDoesNotMatch(matcher, "https://localhost/jquery-1.11.0.min.js"); + assertDoesNotMatch(matcher, "http://localhost:9090/jquery-1.11.0.min.js"); + } + + @Test + public void defaultHttpsPort() throws Exception { + WebRequestMatcher matcher = new HostRequestMatcher("localhost:443"); + assertMatches(matcher, "https://localhost:443/jquery-1.11.0.min.js"); + assertMatches(matcher, "https://localhost/jquery-1.11.0.min.js"); + assertDoesNotMatch(matcher, "http://localhost/jquery-1.11.0.min.js"); + assertDoesNotMatch(matcher, "https://localhost:9090/jquery-1.11.0.min.js"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HtmlUnitRequestBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HtmlUnitRequestBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cc60186aa29e1a5a442960c06dbfa2736dff58e0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/HtmlUnitRequestBuilderTests.java @@ -0,0 +1,945 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import javax.servlet.ServletContext; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpSession; + +import com.gargoylesoftware.htmlunit.FormEncodingType; +import com.gargoylesoftware.htmlunit.HttpMethod; +import com.gargoylesoftware.htmlunit.WebClient; +import com.gargoylesoftware.htmlunit.WebRequest; +import com.gargoylesoftware.htmlunit.util.NameValuePair; +import org.apache.commons.io.IOUtils; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; + +import static java.util.Arrays.asList; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.isEmptyString; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; +import static org.junit.Assert.assertThat; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; + +/** + * Unit tests for {@link HtmlUnitRequestBuilder}. + * + * @author Rob Winch + * @author Sam Brannen + * @since 4.2 + */ +public class HtmlUnitRequestBuilderTests { + + private final WebClient webClient = new WebClient(); + + private final ServletContext servletContext = new MockServletContext(); + + private final Map sessions = new HashMap<>(); + + private WebRequest webRequest; + + private HtmlUnitRequestBuilder requestBuilder; + + + @Before + public void setup() throws Exception { + webRequest = new WebRequest(new URL("http://example.com:80/test/this/here")); + webRequest.setHttpMethod(HttpMethod.GET); + requestBuilder = new HtmlUnitRequestBuilder(sessions, webClient, webRequest); + } + + + // --- constructor + + @Test(expected = IllegalArgumentException.class) + public void constructorNullSessions() { + new HtmlUnitRequestBuilder(null, webClient, webRequest); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorNullWebClient() { + new HtmlUnitRequestBuilder(sessions, null, webRequest); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorNullWebRequest() { + new HtmlUnitRequestBuilder(sessions, webClient, null); + } + + + // --- buildRequest + + @Test + @SuppressWarnings("deprecation") + public void buildRequestBasicAuth() { + String base64Credentials = "dXNlcm5hbWU6cGFzc3dvcmQ="; + String authzHeaderValue = "Basic: " + base64Credentials; + UsernamePasswordCredentials credentials = new UsernamePasswordCredentials(base64Credentials); + webRequest.setCredentials(credentials); + webRequest.setAdditionalHeader("Authorization", authzHeaderValue); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getAuthType(), equalTo("Basic")); + assertThat(actualRequest.getHeader("Authorization"), equalTo(authzHeaderValue)); + } + + @Test + public void buildRequestCharacterEncoding() { + webRequest.setCharset(StandardCharsets.UTF_8); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getCharacterEncoding(), equalTo("UTF-8")); + } + + @Test + public void buildRequestDefaultCharacterEncoding() { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getCharacterEncoding(), equalTo("ISO-8859-1")); + } + + @Test + public void buildRequestContentLength() { + String content = "some content that has length"; + webRequest.setHttpMethod(HttpMethod.POST); + webRequest.setRequestBody(content); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getContentLength(), equalTo(content.length())); + } + + @Test + public void buildRequestContentType() { + String contentType = "text/html;charset=UTF-8"; + webRequest.setAdditionalHeader("Content-Type", contentType); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getContentType(), equalTo(contentType)); + assertThat(actualRequest.getHeader("Content-Type"), equalTo(contentType)); + } + + @Test // SPR-14916 + public void buildRequestContentTypeWithFormSubmission() { + webRequest.setEncodingType(FormEncodingType.URL_ENCODED); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getContentType(), equalTo("application/x-www-form-urlencoded")); + assertThat(actualRequest.getHeader("Content-Type"), + equalTo("application/x-www-form-urlencoded;charset=ISO-8859-1")); + } + + + @Test + public void buildRequestContextPathUsesFirstSegmentByDefault() { + String contextPath = requestBuilder.buildRequest(servletContext).getContextPath(); + + assertThat(contextPath, equalTo("/test")); + } + + @Test + public void buildRequestContextPathUsesNoFirstSegmentWithDefault() throws MalformedURLException { + webRequest.setUrl(new URL("https://example.com/")); + String contextPath = requestBuilder.buildRequest(servletContext).getContextPath(); + + assertThat(contextPath, equalTo("")); + } + + @Test(expected = IllegalArgumentException.class) + public void buildRequestContextPathInvalid() { + requestBuilder.setContextPath("/invalid"); + + requestBuilder.buildRequest(servletContext).getContextPath(); + } + + @Test + public void buildRequestContextPathEmpty() { + String expected = ""; + requestBuilder.setContextPath(expected); + + String contextPath = requestBuilder.buildRequest(servletContext).getContextPath(); + + assertThat(contextPath, equalTo(expected)); + } + + @Test + public void buildRequestContextPathExplicit() { + String expected = "/test"; + requestBuilder.setContextPath(expected); + + String contextPath = requestBuilder.buildRequest(servletContext).getContextPath(); + + assertThat(contextPath, equalTo(expected)); + } + + @Test + public void buildRequestContextPathMulti() { + String expected = "/test/this"; + requestBuilder.setContextPath(expected); + + String contextPath = requestBuilder.buildRequest(servletContext).getContextPath(); + + assertThat(contextPath, equalTo(expected)); + } + + @Test + public void buildRequestCookiesNull() { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getCookies(), nullValue()); + } + + @Test + public void buildRequestCookiesSingle() { + webRequest.setAdditionalHeader("Cookie", "name=value"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + Cookie[] cookies = actualRequest.getCookies(); + assertThat(cookies.length, equalTo(1)); + assertThat(cookies[0].getName(), equalTo("name")); + assertThat(cookies[0].getValue(), equalTo("value")); + } + + @Test + public void buildRequestCookiesMulti() { + webRequest.setAdditionalHeader("Cookie", "name=value; name2=value2"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + Cookie[] cookies = actualRequest.getCookies(); + assertThat(cookies.length, equalTo(2)); + Cookie cookie = cookies[0]; + assertThat(cookie.getName(), equalTo("name")); + assertThat(cookie.getValue(), equalTo("value")); + cookie = cookies[1]; + assertThat(cookie.getName(), equalTo("name2")); + assertThat(cookie.getValue(), equalTo("value2")); + } + + @Test + @SuppressWarnings("deprecation") + public void buildRequestInputStream() throws Exception { + String content = "some content that has length"; + webRequest.setHttpMethod(HttpMethod.POST); + webRequest.setRequestBody(content); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(IOUtils.toString(actualRequest.getInputStream()), equalTo(content)); + } + + @Test + public void buildRequestLocalAddr() { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocalAddr(), equalTo("127.0.0.1")); + } + + @Test + public void buildRequestLocaleDefault() { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocale(), equalTo(Locale.getDefault())); + } + + @Test + public void buildRequestLocaleDa() { + webRequest.setAdditionalHeader("Accept-Language", "da"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocale(), equalTo(new Locale("da"))); + } + + @Test + public void buildRequestLocaleEnGbQ08() { + webRequest.setAdditionalHeader("Accept-Language", "en-gb;q=0.8"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocale(), equalTo(new Locale("en", "gb"))); + } + + @Test + public void buildRequestLocaleEnQ07() { + webRequest.setAdditionalHeader("Accept-Language", "en"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocale(), equalTo(new Locale("en", ""))); + } + + @Test + public void buildRequestLocaleEnUs() { + webRequest.setAdditionalHeader("Accept-Language", "en-US"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocale(), equalTo(Locale.US)); + } + + @Test + public void buildRequestLocaleFr() { + webRequest.setAdditionalHeader("Accept-Language", "fr"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocale(), equalTo(Locale.FRENCH)); + } + + @Test + public void buildRequestLocaleMulti() { + webRequest.setAdditionalHeader("Accept-Language", "en-gb;q=0.8, da, en;q=0.7"); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + List expected = asList(new Locale("da"), new Locale("en", "gb"), new Locale("en", "")); + assertThat(Collections.list(actualRequest.getLocales()), equalTo(expected)); + } + + @Test + public void buildRequestLocalName() { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocalName(), equalTo("localhost")); + } + + @Test + public void buildRequestLocalPort() { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocalPort(), equalTo(80)); + } + + @Test + public void buildRequestLocalMissing() throws Exception { + webRequest.setUrl(new URL("http://localhost/test/this")); + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getLocalPort(), equalTo(-1)); + } + + @Test + public void buildRequestMethods() { + for (HttpMethod expectedMethod : HttpMethod.values()) { + webRequest.setHttpMethod(expectedMethod); + String actualMethod = requestBuilder.buildRequest(servletContext).getMethod(); + assertThat(actualMethod, equalTo(expectedMethod.name())); + } + } + + @Test + public void buildRequestParameterMapViaWebRequestDotSetRequestParametersWithSingleRequestParam() { + webRequest.setRequestParameters(asList(new NameValuePair("name", "value"))); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo("value")); + } + + @Test + public void buildRequestParameterMapViaWebRequestDotSetRequestParametersWithSingleRequestParamWithNullValue() { + webRequest.setRequestParameters(asList(new NameValuePair("name", null))); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), nullValue()); + } + + @Test + public void buildRequestParameterMapViaWebRequestDotSetRequestParametersWithSingleRequestParamWithEmptyValue() { + webRequest.setRequestParameters(asList(new NameValuePair("name", ""))); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo("")); + } + + @Test + public void buildRequestParameterMapViaWebRequestDotSetRequestParametersWithSingleRequestParamWithValueSetToSpace() { + webRequest.setRequestParameters(asList(new NameValuePair("name", " "))); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo(" ")); + } + + @Test + public void buildRequestParameterMapViaWebRequestDotSetRequestParametersWithMultipleRequestParams() { + webRequest.setRequestParameters(asList(new NameValuePair("name1", "value1"), new NameValuePair("name2", "value2"))); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(2)); + assertThat(actualRequest.getParameter("name1"), equalTo("value1")); + assertThat(actualRequest.getParameter("name2"), equalTo("value2")); + } + + @Test + public void buildRequestParameterMapFromSingleQueryParam() throws Exception { + webRequest.setUrl(new URL("https://example.com/example/?name=value")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo("value")); + } + + // SPR-14177 + @Test + public void buildRequestParameterMapDecodesParameterName() throws Exception { + webRequest.setUrl(new URL("https://example.com/example/?row%5B0%5D=value")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("row[0]"), equalTo("value")); + } + + @Test + public void buildRequestParameterMapDecodesParameterValue() throws Exception { + webRequest.setUrl(new URL("https://example.com/example/?name=row%5B0%5D")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo("row[0]")); + } + + @Test + public void buildRequestParameterMapFromSingleQueryParamWithoutValueAndWithoutEqualsSign() throws Exception { + webRequest.setUrl(new URL("https://example.com/example/?name")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo("")); + } + + @Test + public void buildRequestParameterMapFromSingleQueryParamWithoutValueButWithEqualsSign() throws Exception { + webRequest.setUrl(new URL("https://example.com/example/?name=")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo("")); + } + + @Test + public void buildRequestParameterMapFromSingleQueryParamWithValueSetToEncodedSpace() throws Exception { + webRequest.setUrl(new URL("https://example.com/example/?name=%20")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(1)); + assertThat(actualRequest.getParameter("name"), equalTo(" ")); + } + + @Test + public void buildRequestParameterMapFromMultipleQueryParams() throws Exception { + webRequest.setUrl(new URL("https://example.com/example/?name=value¶m2=value+2")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getParameterMap().size(), equalTo(2)); + assertThat(actualRequest.getParameter("name"), equalTo("value")); + assertThat(actualRequest.getParameter("param2"), equalTo("value 2")); + } + + @Test + public void buildRequestPathInfo() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getPathInfo(), nullValue()); + } + + @Test + public void buildRequestPathInfoNull() throws Exception { + webRequest.setUrl(new URL("https://example.com/example")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getPathInfo(), nullValue()); + } + + @Test + public void buildRequestAndAntPathRequestMatcher() throws Exception { + webRequest.setUrl(new URL("https://example.com/app/login/authenticate")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + // verify it is going to work with Spring Security's AntPathRequestMatcher + assertThat(actualRequest.getPathInfo(), nullValue()); + assertThat(actualRequest.getServletPath(), equalTo("/login/authenticate")); + } + + @Test + public void buildRequestProtocol() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getProtocol(), equalTo("HTTP/1.1")); + } + + @Test + public void buildRequestQueryWithSingleQueryParam() throws Exception { + String expectedQuery = "param=value"; + webRequest.setUrl(new URL("https://example.com/example?" + expectedQuery)); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getQueryString(), equalTo(expectedQuery)); + } + + @Test + public void buildRequestQueryWithSingleQueryParamWithoutValueAndWithoutEqualsSign() throws Exception { + String expectedQuery = "param"; + webRequest.setUrl(new URL("https://example.com/example?" + expectedQuery)); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getQueryString(), equalTo(expectedQuery)); + } + + @Test + public void buildRequestQueryWithSingleQueryParamWithoutValueButWithEqualsSign() throws Exception { + String expectedQuery = "param="; + webRequest.setUrl(new URL("https://example.com/example?" + expectedQuery)); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getQueryString(), equalTo(expectedQuery)); + } + + @Test + public void buildRequestQueryWithSingleQueryParamWithValueSetToEncodedSpace() throws Exception { + String expectedQuery = "param=%20"; + webRequest.setUrl(new URL("https://example.com/example?" + expectedQuery)); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getQueryString(), equalTo(expectedQuery)); + } + + @Test + public void buildRequestQueryWithMultipleQueryParams() throws Exception { + String expectedQuery = "param1=value1¶m2=value2"; + webRequest.setUrl(new URL("https://example.com/example?" + expectedQuery)); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getQueryString(), equalTo(expectedQuery)); + } + + @Test + public void buildRequestReader() throws Exception { + String expectedBody = "request body"; + webRequest.setHttpMethod(HttpMethod.POST); + webRequest.setRequestBody(expectedBody); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(IOUtils.toString(actualRequest.getReader()), equalTo(expectedBody)); + } + + @Test + public void buildRequestRemoteAddr() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getRemoteAddr(), equalTo("127.0.0.1")); + } + + @Test + public void buildRequestRemoteHost() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getRemoteAddr(), equalTo("127.0.0.1")); + } + + @Test + public void buildRequestRemotePort() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getRemotePort(), equalTo(80)); + } + + @Test + public void buildRequestRemotePort8080() throws Exception { + webRequest.setUrl(new URL("https://example.com:8080/")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getRemotePort(), equalTo(8080)); + } + + @Test + public void buildRequestRemotePort80WithDefault() throws Exception { + webRequest.setUrl(new URL("http://example.com/")); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getRemotePort(), equalTo(80)); + } + + @Test + public void buildRequestRequestedSessionId() throws Exception { + String sessionId = "session-id"; + webRequest.setAdditionalHeader("Cookie", "JSESSIONID=" + sessionId); + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getRequestedSessionId(), equalTo(sessionId)); + } + + @Test + public void buildRequestRequestedSessionIdNull() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getRequestedSessionId(), nullValue()); + } + + @Test + public void buildRequestUri() { + String uri = requestBuilder.buildRequest(servletContext).getRequestURI(); + assertThat(uri, equalTo("/test/this/here")); + } + + @Test + public void buildRequestUrl() { + String uri = requestBuilder.buildRequest(servletContext).getRequestURL().toString(); + assertThat(uri, equalTo("http://example.com/test/this/here")); + } + + @Test + public void buildRequestSchemeHttp() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getScheme(), equalTo("http")); + } + + @Test + public void buildRequestSchemeHttps() throws Exception { + webRequest.setUrl(new URL("https://example.com/")); + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getScheme(), equalTo("https")); + } + + @Test + public void buildRequestServerName() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getServerName(), equalTo("example.com")); + } + + @Test + public void buildRequestServerPort() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getServerPort(), equalTo(80)); + } + + @Test + public void buildRequestServerPortDefault() throws Exception { + webRequest.setUrl(new URL("https://example.com/")); + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getServerPort(), equalTo(-1)); + } + + @Test + public void buildRequestServletContext() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getServletContext(), equalTo(servletContext)); + } + + @Test + public void buildRequestServletPath() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getServletPath(), equalTo("/this/here")); + } + + @Test + public void buildRequestSession() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + HttpSession newSession = actualRequest.getSession(); + assertThat(newSession, notNullValue()); + assertSingleSessionCookie( + "JSESSIONID=" + newSession.getId() + "; Path=/test; Domain=example.com"); + + webRequest.setAdditionalHeader("Cookie", "JSESSIONID=" + newSession.getId()); + + requestBuilder = new HtmlUnitRequestBuilder(sessions, webClient, webRequest); + actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getSession(), sameInstance(newSession)); + } + + @Test + public void buildRequestSessionWithExistingSession() throws Exception { + String sessionId = "session-id"; + webRequest.setAdditionalHeader("Cookie", "JSESSIONID=" + sessionId); + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + HttpSession session = actualRequest.getSession(); + assertThat(session.getId(), equalTo(sessionId)); + assertSingleSessionCookie("JSESSIONID=" + session.getId() + "; Path=/test; Domain=example.com"); + + requestBuilder = new HtmlUnitRequestBuilder(sessions, webClient, webRequest); + actualRequest = requestBuilder.buildRequest(servletContext); + assertThat(actualRequest.getSession(), equalTo(session)); + + webRequest.setAdditionalHeader("Cookie", "JSESSIONID=" + sessionId + "NEW"); + actualRequest = requestBuilder.buildRequest(servletContext); + assertThat(actualRequest.getSession(), not(equalTo(session))); + assertSingleSessionCookie("JSESSIONID=" + actualRequest.getSession().getId() + + "; Path=/test; Domain=example.com"); + } + + @Test + public void buildRequestSessionTrue() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + HttpSession session = actualRequest.getSession(true); + assertThat(session, notNullValue()); + } + + @Test + public void buildRequestSessionFalseIsNull() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + HttpSession session = actualRequest.getSession(false); + assertThat(session, nullValue()); + } + + @Test + public void buildRequestSessionFalseWithExistingSession() throws Exception { + String sessionId = "session-id"; + webRequest.setAdditionalHeader("Cookie", "JSESSIONID=" + sessionId); + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + HttpSession session = actualRequest.getSession(false); + assertThat(session, notNullValue()); + } + + @Test + public void buildRequestSessionIsNew() throws Exception { + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getSession().isNew(), equalTo(true)); + } + + @Test + public void buildRequestSessionIsNewFalse() throws Exception { + String sessionId = "session-id"; + webRequest.setAdditionalHeader("Cookie", "JSESSIONID=" + sessionId); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getSession().isNew(), equalTo(false)); + } + + @Test + public void buildRequestSessionInvalidate() throws Exception { + String sessionId = "session-id"; + webRequest.setAdditionalHeader("Cookie", "JSESSIONID=" + sessionId); + + MockHttpServletRequest actualRequest = requestBuilder.buildRequest(servletContext); + HttpSession sessionToRemove = actualRequest.getSession(); + sessionToRemove.invalidate(); + + assertThat(sessions.containsKey(sessionToRemove.getId()), equalTo(false)); + assertSingleSessionCookie("JSESSIONID=" + sessionToRemove.getId() + + "; Expires=Thu, 01-Jan-1970 00:00:01 GMT; Path=/test; Domain=example.com"); + + webRequest.removeAdditionalHeader("Cookie"); + requestBuilder = new HtmlUnitRequestBuilder(sessions, webClient, webRequest); + + actualRequest = requestBuilder.buildRequest(servletContext); + + assertThat(actualRequest.getSession().isNew(), equalTo(true)); + assertThat(sessions.containsKey(sessionToRemove.getId()), equalTo(false)); + } + + // --- setContextPath + + @Test + public void setContextPathNull() { + requestBuilder.setContextPath(null); + + assertThat(getContextPath(), nullValue()); + } + + @Test + public void setContextPathEmptyString() { + requestBuilder.setContextPath(""); + + assertThat(getContextPath(), isEmptyString()); + } + + @Test(expected = IllegalArgumentException.class) + public void setContextPathDoesNotStartWithSlash() { + requestBuilder.setContextPath("abc/def"); + } + + @Test(expected = IllegalArgumentException.class) + public void setContextPathEndsWithSlash() { + requestBuilder.setContextPath("/abc/def/"); + } + + @Test + public void setContextPath() { + String expectedContextPath = "/abc/def"; + requestBuilder.setContextPath(expectedContextPath); + + assertThat(getContextPath(), equalTo(expectedContextPath)); + } + + @Test + public void mergeHeader() throws Exception { + String headerName = "PARENT"; + String headerValue = "VALUE"; + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new HelloController()) + .defaultRequest(get("/").header(headerName, headerValue)) + .build(); + + assertThat(mockMvc.perform(requestBuilder).andReturn().getRequest().getHeader(headerName), equalTo(headerValue)); + } + + @Test + public void mergeSession() throws Exception { + String attrName = "PARENT"; + String attrValue = "VALUE"; + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new HelloController()) + .defaultRequest(get("/").sessionAttr(attrName, attrValue)) + .build(); + + assertThat(mockMvc.perform(requestBuilder).andReturn().getRequest().getSession().getAttribute(attrName), equalTo(attrValue)); + } + + @Test + public void mergeSessionNotInitialized() throws Exception { + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new HelloController()) + .defaultRequest(get("/")) + .build(); + + assertThat(mockMvc.perform(requestBuilder).andReturn().getRequest().getSession(false), nullValue()); + } + + @Test + public void mergeParameter() throws Exception { + String paramName = "PARENT"; + String paramValue = "VALUE"; + String paramValue2 = "VALUE2"; + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new HelloController()) + .defaultRequest(get("/").param(paramName, paramValue, paramValue2)) + .build(); + + MockHttpServletRequest performedRequest = mockMvc.perform(requestBuilder).andReturn().getRequest(); + assertThat(asList(performedRequest.getParameterValues(paramName)), contains(paramValue, paramValue2)); + } + + @Test + public void mergeCookie() throws Exception { + String cookieName = "PARENT"; + String cookieValue = "VALUE"; + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new HelloController()) + .defaultRequest(get("/").cookie(new Cookie(cookieName, cookieValue))) + .build(); + + Cookie[] cookies = mockMvc.perform(requestBuilder).andReturn().getRequest().getCookies(); + assertThat(cookies, notNullValue()); + assertThat(cookies.length, equalTo(1)); + Cookie cookie = cookies[0]; + assertThat(cookie.getName(), equalTo(cookieName)); + assertThat(cookie.getValue(), equalTo(cookieValue)); + } + + @Test + public void mergeRequestAttribute() throws Exception { + String attrName = "PARENT"; + String attrValue = "VALUE"; + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new HelloController()) + .defaultRequest(get("/").requestAttr(attrName, attrValue)) + .build(); + + assertThat(mockMvc.perform(requestBuilder).andReturn().getRequest().getAttribute(attrName), equalTo(attrValue)); + } + + @Test // SPR-14584 + public void mergeDoesNotCorruptPathInfoOnParent() throws Exception { + String pathInfo = "/foo/bar"; + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new HelloController()) + .defaultRequest(get("/")) + .build(); + + assertThat(mockMvc.perform(get(pathInfo)).andReturn().getRequest().getPathInfo(), equalTo(pathInfo)); + + mockMvc.perform(requestBuilder); + + assertThat(mockMvc.perform(get(pathInfo)).andReturn().getRequest().getPathInfo(), equalTo(pathInfo)); + } + + + private void assertSingleSessionCookie(String expected) { + com.gargoylesoftware.htmlunit.util.Cookie jsessionidCookie = webClient.getCookieManager().getCookie("JSESSIONID"); + if (expected == null || expected.contains("Expires=Thu, 01-Jan-1970 00:00:01 GMT")) { + assertThat(jsessionidCookie, nullValue()); + return; + } + String actual = jsessionidCookie.getValue(); + assertThat("JSESSIONID=" + actual + "; Path=/test; Domain=example.com", equalTo(expected)); + } + + private String getContextPath() { + return (String) ReflectionTestUtils.getField(requestBuilder, "contextPath"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcConnectionBuilderSupportTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcConnectionBuilderSupportTests.java new file mode 100644 index 0000000000000000000000000000000000000000..dcca42664291b1e5b98a6dca236cf1420b265d62 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcConnectionBuilderSupportTests.java @@ -0,0 +1,162 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import java.io.IOException; +import java.net.URL; + +import javax.servlet.http.HttpServletRequest; + +import com.gargoylesoftware.htmlunit.WebClient; +import com.gargoylesoftware.htmlunit.WebConnection; +import com.gargoylesoftware.htmlunit.WebRequest; +import com.gargoylesoftware.htmlunit.WebResponse; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Integration tests for {@link MockMvcWebConnectionBuilderSupport}. + * + * @author Rob Winch + * @author Rossen Stoyanchev + * @since 4.2 + */ +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +@WebAppConfiguration +@SuppressWarnings("rawtypes") +public class MockMvcConnectionBuilderSupportTests { + + private final WebClient client = mock(WebClient.class); + + private MockMvcWebConnectionBuilderSupport builder; + + @Autowired + private WebApplicationContext wac; + + + @Before + public void setup() { + when(this.client.getWebConnection()).thenReturn(mock(WebConnection.class)); + this.builder = new MockMvcWebConnectionBuilderSupport(this.wac) {}; + } + + + @Test(expected = IllegalArgumentException.class) + public void constructorMockMvcNull() { + new MockMvcWebConnectionBuilderSupport((MockMvc) null){}; + } + + @Test(expected = IllegalArgumentException.class) + public void constructorContextNull() { + new MockMvcWebConnectionBuilderSupport((WebApplicationContext) null){}; + } + + @Test + public void context() throws Exception { + WebConnection conn = this.builder.createConnection(this.client); + + assertMockMvcUsed(conn, "http://localhost/"); + assertMockMvcNotUsed(conn, "https://example.com/"); + } + + @Test + public void mockMvc() throws Exception { + MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(wac).build(); + WebConnection conn = new MockMvcWebConnectionBuilderSupport(mockMvc) {}.createConnection(this.client); + + assertMockMvcUsed(conn, "http://localhost/"); + assertMockMvcNotUsed(conn, "https://example.com/"); + } + + @Test + public void mockMvcExampleDotCom() throws Exception { + WebConnection conn = this.builder.useMockMvcForHosts("example.com").createConnection(this.client); + + assertMockMvcUsed(conn, "http://localhost/"); + assertMockMvcUsed(conn, "https://example.com/"); + assertMockMvcNotUsed(conn, "http://other.com/"); + } + + @Test + public void mockMvcAlwaysUseMockMvc() throws Exception { + WebConnection conn = this.builder.alwaysUseMockMvc().createConnection(this.client); + assertMockMvcUsed(conn, "http://other.com/"); + } + + @Test + public void defaultContextPathEmpty() throws Exception { + WebConnection conn = this.builder.createConnection(this.client); + assertThat(getResponse(conn, "http://localhost/abc").getContentAsString(), equalTo("")); + } + + @Test + public void defaultContextPathCustom() throws Exception { + WebConnection conn = this.builder.contextPath("/abc").createConnection(this.client); + assertThat(getResponse(conn, "http://localhost/abc/def").getContentAsString(), equalTo("/abc")); + } + + + private void assertMockMvcUsed(WebConnection connection, String url) throws Exception { + assertThat(getResponse(connection, url), notNullValue()); + } + + private void assertMockMvcNotUsed(WebConnection connection, String url) throws Exception { + assertThat(getResponse(connection, url), nullValue()); + } + + private WebResponse getResponse(WebConnection connection, String url) throws IOException { + return connection.getResponse(new WebRequest(new URL(url))); + } + + + @Configuration + @EnableWebMvc + static class Config { + + @RestController + static class ContextPathController { + + @RequestMapping + public String contextPath(HttpServletRequest request) { + return request.getContextPath(); + } + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcWebClientBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcWebClientBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..da03c36ba23b51d6d167fc94cb286659f926984c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcWebClientBuilderTests.java @@ -0,0 +1,193 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import java.io.IOException; +import java.net.URL; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import com.gargoylesoftware.htmlunit.HttpMethod; +import com.gargoylesoftware.htmlunit.WebClient; +import com.gargoylesoftware.htmlunit.WebRequest; +import com.gargoylesoftware.htmlunit.WebResponse; +import com.gargoylesoftware.htmlunit.util.Cookie; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.tests.Assume; +import org.springframework.tests.TestGroup; +import org.springframework.web.bind.annotation.CookieValue; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Integration tests for {@link MockMvcWebClientBuilder}. + * + * @author Rob Winch + * @author Sam Brannen + * @author Rossen Stoyanchev + * @since 4.2 + */ +@RunWith(SpringRunner.class) +@ContextConfiguration +@WebAppConfiguration +public class MockMvcWebClientBuilderTests { + + @Autowired + private WebApplicationContext wac; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).build(); + } + + + @Test(expected = IllegalArgumentException.class) + public void mockMvcSetupNull() { + MockMvcWebClientBuilder.mockMvcSetup(null); + } + + @Test(expected = IllegalArgumentException.class) + public void webAppContextSetupNull() { + MockMvcWebClientBuilder.webAppContextSetup(null); + } + + @Test + public void mockMvcSetupWithDefaultWebClientDelegate() throws Exception { + WebClient client = MockMvcWebClientBuilder.mockMvcSetup(this.mockMvc).build(); + + assertMockMvcUsed(client, "http://localhost/test"); + Assume.group(TestGroup.PERFORMANCE, () -> assertMockMvcNotUsed(client, "https://example.com/")); + } + + @Test + public void mockMvcSetupWithCustomWebClientDelegate() throws Exception { + WebClient otherClient = new WebClient(); + WebClient client = MockMvcWebClientBuilder.mockMvcSetup(this.mockMvc).withDelegate(otherClient).build(); + + assertMockMvcUsed(client, "http://localhost/test"); + Assume.group(TestGroup.PERFORMANCE, () -> assertMockMvcNotUsed(client, "https://example.com/")); + } + + @Test // SPR-14066 + public void cookieManagerShared() throws Exception { + this.mockMvc = MockMvcBuilders.standaloneSetup(new CookieController()).build(); + WebClient client = MockMvcWebClientBuilder.mockMvcSetup(this.mockMvc).build(); + + assertThat(getResponse(client, "http://localhost/").getContentAsString(), equalTo("NA")); + client.getCookieManager().addCookie(new Cookie("localhost", "cookie", "cookieManagerShared")); + assertThat(getResponse(client, "http://localhost/").getContentAsString(), equalTo("cookieManagerShared")); + } + + @Test // SPR-14265 + public void cookiesAreManaged() throws Exception { + this.mockMvc = MockMvcBuilders.standaloneSetup(new CookieController()).build(); + WebClient client = MockMvcWebClientBuilder.mockMvcSetup(this.mockMvc).build(); + + assertThat(getResponse(client, "http://localhost/").getContentAsString(), equalTo("NA")); + assertThat(postResponse(client, "http://localhost/?cookie=foo").getContentAsString(), equalTo("Set")); + assertThat(getResponse(client, "http://localhost/").getContentAsString(), equalTo("foo")); + assertThat(deleteResponse(client, "http://localhost/").getContentAsString(), equalTo("Delete")); + assertThat(getResponse(client, "http://localhost/").getContentAsString(), equalTo("NA")); + } + + private void assertMockMvcUsed(WebClient client, String url) throws Exception { + assertThat(getResponse(client, url).getContentAsString(), equalTo("mvc")); + } + + private void assertMockMvcNotUsed(WebClient client, String url) throws Exception { + assertThat(getResponse(client, url).getContentAsString(), not(equalTo("mvc"))); + } + + private WebResponse getResponse(WebClient client, String url) throws IOException { + return createResponse(client, new WebRequest(new URL(url))); + } + + private WebResponse postResponse(WebClient client, String url) throws IOException { + return createResponse(client, new WebRequest(new URL(url), HttpMethod.POST)); + } + + private WebResponse deleteResponse(WebClient client, String url) throws IOException { + return createResponse(client, new WebRequest(new URL(url), HttpMethod.DELETE)); + } + + private WebResponse createResponse(WebClient client, WebRequest request) throws IOException { + return client.getWebConnection().getResponse(request); + } + + + @Configuration + @EnableWebMvc + static class Config { + + @RestController + static class ContextPathController { + + @RequestMapping + public String contextPath(HttpServletRequest request) { + return "mvc"; + } + } + } + + @RestController + static class CookieController { + + static final String COOKIE_NAME = "cookie"; + + @RequestMapping(path = "/", produces = "text/plain") + String cookie(@CookieValue(name = COOKIE_NAME, defaultValue = "NA") String cookie) { + return cookie; + } + + @PostMapping(path = "/", produces = "text/plain") + String setCookie(@RequestParam String cookie, HttpServletResponse response) { + response.addCookie(new javax.servlet.http.Cookie(COOKIE_NAME, cookie)); + return "Set"; + } + + @DeleteMapping(path = "/", produces = "text/plain") + String deleteCookie(HttpServletResponse response) { + javax.servlet.http.Cookie cookie = new javax.servlet.http.Cookie(COOKIE_NAME, ""); + cookie.setMaxAge(0); + response.addCookie(cookie); + return "Delete"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcWebConnectionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcWebConnectionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6a3d0c5157d4f3623f88fa1a4a36826cedd9f8df --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockMvcWebConnectionTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import java.io.IOException; + +import com.gargoylesoftware.htmlunit.Page; +import com.gargoylesoftware.htmlunit.WebClient; +import org.junit.Test; + +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Integration tests for {@link MockMvcWebConnection}. + * + * @author Rob Winch + * @since 4.2 + */ +public class MockMvcWebConnectionTests { + + private final WebClient webClient = new WebClient(); + + private final MockMvc mockMvc = + MockMvcBuilders.standaloneSetup(new HelloController(), new ForwardController()).build(); + + + @Test + public void contextPathNull() throws IOException { + this.webClient.setWebConnection(new MockMvcWebConnection(this.mockMvc, this.webClient)); + Page page = this.webClient.getPage("http://localhost/context/a"); + assertThat(page.getWebResponse().getStatusCode(), equalTo(200)); + } + + @Test + public void contextPathExplicit() throws IOException { + this.webClient.setWebConnection(new MockMvcWebConnection(this.mockMvc, this.webClient, "/context")); + Page page = this.webClient.getPage("http://localhost/context/a"); + assertThat(page.getWebResponse().getStatusCode(), equalTo(200)); + } + + @Test + public void contextPathEmpty() throws IOException { + this.webClient.setWebConnection(new MockMvcWebConnection(this.mockMvc, this.webClient, "")); + Page page = this.webClient.getPage("http://localhost/context/a"); + assertThat(page.getWebResponse().getStatusCode(), equalTo(200)); + } + + @Test + public void forward() throws IOException { + this.webClient.setWebConnection(new MockMvcWebConnection(this.mockMvc, this.webClient, "")); + Page page = this.webClient.getPage("http://localhost/forward"); + assertThat(page.getWebResponse().getContentAsString(), equalTo("hello")); + } + + @Test(expected = IllegalArgumentException.class) + @SuppressWarnings("resource") + public void contextPathDoesNotStartWithSlash() throws IOException { + new MockMvcWebConnection(this.mockMvc, this.webClient, "context"); + } + + @Test(expected = IllegalArgumentException.class) + @SuppressWarnings("resource") + public void contextPathEndsWithSlash() throws IOException { + new MockMvcWebConnection(this.mockMvc, this.webClient, "/context/"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockWebResponseBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockWebResponseBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..db4207c4c77534fc356d191d5adfb110e9fe85cf --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/MockWebResponseBuilderTests.java @@ -0,0 +1,171 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import javax.servlet.http.Cookie; + +import com.gargoylesoftware.htmlunit.WebRequest; +import com.gargoylesoftware.htmlunit.WebResponse; +import com.gargoylesoftware.htmlunit.util.NameValuePair; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Tests for {@link MockWebResponseBuilder}. + * + * @author Rob Winch + * @since 4.2 + */ +public class MockWebResponseBuilderTests { + + private final MockHttpServletResponse response = new MockHttpServletResponse(); + + private WebRequest webRequest; + + private MockWebResponseBuilder responseBuilder; + + + @Before + public void setup() throws Exception { + this.webRequest = new WebRequest(new URL("http://example.com:80/test/this/here")); + this.responseBuilder = new MockWebResponseBuilder(System.currentTimeMillis(), this.webRequest, this.response); + } + + + // --- constructor + + @Test(expected = IllegalArgumentException.class) + public void constructorWithNullWebRequest() { + new MockWebResponseBuilder(0L, null, this.response); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorWithNullResponse() throws Exception { + new MockWebResponseBuilder(0L, new WebRequest(new URL("http://example.com:80/test/this/here")), null); + } + + + // --- build + + @Test + public void buildContent() throws Exception { + this.response.getWriter().write("expected content"); + WebResponse webResponse = this.responseBuilder.build(); + + assertThat(webResponse.getContentAsString(), equalTo("expected content")); + } + + @Test + public void buildContentCharset() throws Exception { + this.response.addHeader("Content-Type", "text/html; charset=UTF-8"); + WebResponse webResponse = this.responseBuilder.build(); + + assertThat(webResponse.getContentCharset(), equalTo(StandardCharsets.UTF_8)); + } + + @Test + public void buildContentType() throws Exception { + this.response.addHeader("Content-Type", "text/html; charset-UTF-8"); + WebResponse webResponse = this.responseBuilder.build(); + + assertThat(webResponse.getContentType(), equalTo("text/html")); + } + + @Test + public void buildResponseHeaders() throws Exception { + this.response.addHeader("Content-Type", "text/html"); + this.response.addHeader("X-Test", "value"); + Cookie cookie = new Cookie("cookieA", "valueA"); + cookie.setDomain("domain"); + cookie.setPath("/path"); + cookie.setMaxAge(1800); + cookie.setSecure(true); + cookie.setHttpOnly(true); + this.response.addCookie(cookie); + WebResponse webResponse = this.responseBuilder.build(); + + List responseHeaders = webResponse.getResponseHeaders(); + assertThat(responseHeaders.size(), equalTo(3)); + NameValuePair header = responseHeaders.get(0); + assertThat(header.getName(), equalTo("Content-Type")); + assertThat(header.getValue(), equalTo("text/html")); + header = responseHeaders.get(1); + assertThat(header.getName(), equalTo("X-Test")); + assertThat(header.getValue(), equalTo("value")); + header = responseHeaders.get(2); + assertThat(header.getName(), equalTo("Set-Cookie")); + assertThat(header.getValue(), startsWith("cookieA=valueA; Path=/path; Domain=domain; Max-Age=1800; Expires=")); + assertThat(header.getValue(), endsWith("; Secure; HttpOnly")); + } + + // SPR-14169 + @Test + public void buildResponseHeadersNullDomainDefaulted() throws Exception { + Cookie cookie = new Cookie("cookieA", "valueA"); + this.response.addCookie(cookie); + WebResponse webResponse = this.responseBuilder.build(); + + List responseHeaders = webResponse.getResponseHeaders(); + assertThat(responseHeaders.size(), equalTo(1)); + NameValuePair header = responseHeaders.get(0); + assertThat(header.getName(), equalTo("Set-Cookie")); + assertThat(header.getValue(), equalTo("cookieA=valueA")); + } + + @Test + public void buildStatus() throws Exception { + WebResponse webResponse = this.responseBuilder.build(); + + assertThat(webResponse.getStatusCode(), equalTo(200)); + assertThat(webResponse.getStatusMessage(), equalTo("OK")); + } + + @Test + public void buildStatusNotOk() throws Exception { + this.response.setStatus(401); + WebResponse webResponse = this.responseBuilder.build(); + + assertThat(webResponse.getStatusCode(), equalTo(401)); + assertThat(webResponse.getStatusMessage(), equalTo("Unauthorized")); + } + + @Test + public void buildStatusWithCustomMessage() throws Exception { + this.response.sendError(401, "Custom"); + WebResponse webResponse = this.responseBuilder.build(); + + assertThat(webResponse.getStatusCode(), equalTo(401)); + assertThat(webResponse.getStatusMessage(), equalTo("Custom")); + } + + @Test + public void buildWebRequest() throws Exception { + WebResponse webResponse = this.responseBuilder.build(); + + assertThat(webResponse.getWebRequest(), equalTo(this.webRequest)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/UrlRegexRequestMatcherTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/UrlRegexRequestMatcherTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4978ecae14fbc8cfd193c2169031460455fa4f58 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/UrlRegexRequestMatcherTests.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit; + +import org.junit.Test; + +/** + * Unit tests for {@link UrlRegexRequestMatcher}. + * + * @author Rob Winch + * @author Sam Brannen + * @since 4.2 + */ +public class UrlRegexRequestMatcherTests extends AbstractWebRequestMatcherTests { + + @Test + public void verifyExampleInClassLevelJavadoc() throws Exception { + WebRequestMatcher cdnMatcher = new UrlRegexRequestMatcher(".*?//code.jquery.com/.*"); + assertMatches(cdnMatcher, "https://code.jquery.com/jquery-1.11.0.min.js"); + assertDoesNotMatch(cdnMatcher, "http://localhost/jquery-1.11.0.min.js"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/webdriver/MockMvcHtmlUnitDriverBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/webdriver/MockMvcHtmlUnitDriverBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5fb8a76dd1a60d47ef68758a736b4f67f8fd5de8 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/webdriver/MockMvcHtmlUnitDriverBuilderTests.java @@ -0,0 +1,165 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit.webdriver; + +import java.io.IOException; + +import javax.servlet.http.HttpServletRequest; + +import com.gargoylesoftware.htmlunit.util.Cookie; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.openqa.selenium.htmlunit.HtmlUnitDriver; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.tests.Assume; +import org.springframework.tests.TestGroup; +import org.springframework.web.bind.annotation.CookieValue; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Integration tests for {@link MockMvcHtmlUnitDriverBuilder}. + * + * @author Rob Winch + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(SpringRunner.class) +@ContextConfiguration +@WebAppConfiguration +public class MockMvcHtmlUnitDriverBuilderTests { + + private static final String EXPECTED_BODY = "MockMvcHtmlUnitDriverBuilderTests mvc"; + + @Autowired + private WebApplicationContext wac; + + private MockMvc mockMvc; + + private HtmlUnitDriver driver; + + + @Before + public void setup() { + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).build(); + } + + + @Test(expected = IllegalArgumentException.class) + public void webAppContextSetupNull() { + MockMvcHtmlUnitDriverBuilder.webAppContextSetup(null); + } + + @Test(expected = IllegalArgumentException.class) + public void mockMvcSetupNull() { + MockMvcHtmlUnitDriverBuilder.mockMvcSetup(null); + } + + @Test + public void mockMvcSetupWithCustomDriverDelegate() throws Exception { + WebConnectionHtmlUnitDriver otherDriver = new WebConnectionHtmlUnitDriver(); + this.driver = MockMvcHtmlUnitDriverBuilder.mockMvcSetup(this.mockMvc).withDelegate(otherDriver).build(); + + assertMockMvcUsed("http://localhost/test"); + Assume.group(TestGroup.PERFORMANCE, () -> assertMockMvcNotUsed("https://example.com/")); + } + + @Test + public void mockMvcSetupWithDefaultDriverDelegate() throws Exception { + this.driver = MockMvcHtmlUnitDriverBuilder.mockMvcSetup(this.mockMvc).build(); + + assertMockMvcUsed("http://localhost/test"); + Assume.group(TestGroup.PERFORMANCE, () -> assertMockMvcNotUsed("https://example.com/")); + } + + @Test + public void javaScriptEnabledByDefault() { + this.driver = MockMvcHtmlUnitDriverBuilder.mockMvcSetup(this.mockMvc).build(); + assertTrue(this.driver.isJavascriptEnabled()); + } + + @Test + public void javaScriptDisabled() { + this.driver = MockMvcHtmlUnitDriverBuilder.mockMvcSetup(this.mockMvc).javascriptEnabled(false).build(); + assertFalse(this.driver.isJavascriptEnabled()); + } + + @Test // SPR-14066 + public void cookieManagerShared() throws Exception { + WebConnectionHtmlUnitDriver otherDriver = new WebConnectionHtmlUnitDriver(); + this.mockMvc = MockMvcBuilders.standaloneSetup(new CookieController()).build(); + this.driver = MockMvcHtmlUnitDriverBuilder.mockMvcSetup(this.mockMvc) + .withDelegate(otherDriver).build(); + + assertThat(get("http://localhost/"), equalTo("")); + Cookie cookie = new Cookie("localhost", "cookie", "cookieManagerShared"); + otherDriver.getWebClient().getCookieManager().addCookie(cookie); + assertThat(get("http://localhost/"), equalTo("cookieManagerShared")); + } + + + private void assertMockMvcUsed(String url) throws Exception { + assertThat(get(url), containsString(EXPECTED_BODY)); + } + + private void assertMockMvcNotUsed(String url) throws Exception { + assertThat(get(url), not(containsString(EXPECTED_BODY))); + } + + private String get(String url) throws IOException { + this.driver.get(url); + return this.driver.getPageSource(); + } + + + @Configuration + @EnableWebMvc + static class Config { + + @RestController + static class ContextPathController { + + @RequestMapping + public String contextPath(HttpServletRequest request) { + return EXPECTED_BODY; + } + } + } + + @RestController + static class CookieController { + + @RequestMapping(path = "/", produces = "text/plain") + String cookie(@CookieValue("cookie") String cookie) { + return cookie; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/webdriver/WebConnectionHtmlUnitDriverTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/webdriver/WebConnectionHtmlUnitDriverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c86bc86f99f16b680e86ee1788e94c5d43abb447 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/htmlunit/webdriver/WebConnectionHtmlUnitDriverTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.htmlunit.webdriver; + +import java.io.IOException; + +import com.gargoylesoftware.htmlunit.WebConnection; +import com.gargoylesoftware.htmlunit.WebRequest; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.openqa.selenium.WebDriverException; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.*; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link WebConnectionHtmlUnitDriver}. + * + * @author Rob Winch + * @author Sam Brannen + * @since 4.2 + */ +@RunWith(MockitoJUnitRunner.class) +public class WebConnectionHtmlUnitDriverTests { + + private final WebConnectionHtmlUnitDriver driver = new WebConnectionHtmlUnitDriver(); + + @Mock + private WebConnection connection; + + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Before + public void setup() throws Exception { + when(this.connection.getResponse(any(WebRequest.class))).thenThrow(new IOException("")); + } + + + @Test + public void getWebConnectionDefaultNotNull() { + assertThat(this.driver.getWebConnection(), notNullValue()); + } + + @Test + public void setWebConnectionToNull() { + this.exception.expect(IllegalArgumentException.class); + this.driver.setWebConnection(null); + } + + @Test + public void setWebConnection() { + this.driver.setWebConnection(this.connection); + assertThat(this.driver.getWebConnection(), equalTo(this.connection)); + + this.exception.expect(WebDriverException.class); + this.driver.get("https://example.com"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..814bd16acfae3272df0a990514b19a89f755533a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilderTests.java @@ -0,0 +1,549 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.request; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import javax.servlet.ServletContext; +import javax.servlet.http.Cookie; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.mock.web.MockServletContext; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.servlet.FlashMap; +import org.springframework.web.servlet.support.SessionFlashMapManager; +import org.springframework.web.util.UriComponentsBuilder; + +import static org.junit.Assert.*; + +/** + * Unit tests for building a {@link MockHttpServletRequest} with + * {@link MockHttpServletRequestBuilder}. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class MockHttpServletRequestBuilderTests { + + private final ServletContext servletContext = new MockServletContext(); + + private MockHttpServletRequestBuilder builder; + + + @Before + public void setUp() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/foo/bar"); + } + + + @Test + public void method() { + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("GET", request.getMethod()); + } + + @Test + public void uri() { + String uri = "https://java.sun.com:8080/javase/6/docs/api/java/util/BitSet.html?foo=bar#and(java.util.BitSet)"; + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, uri); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("https", request.getScheme()); + assertEquals("foo=bar", request.getQueryString()); + assertEquals("java.sun.com", request.getServerName()); + assertEquals(8080, request.getServerPort()); + assertEquals("/javase/6/docs/api/java/util/BitSet.html", request.getRequestURI()); + assertEquals("https://java.sun.com:8080/javase/6/docs/api/java/util/BitSet.html", + request.getRequestURL().toString()); + } + + @Test + public void requestUriWithEncoding() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/foo bar"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("/foo%20bar", request.getRequestURI()); + } + + @Test // SPR-13435 + public void requestUriWithDoubleSlashes() throws URISyntaxException { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, new URI("/test//currentlyValid/0")); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("/test//currentlyValid/0", request.getRequestURI()); + } + + @Test + public void contextPathEmpty() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/foo"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("", request.getContextPath()); + assertEquals("", request.getServletPath()); + assertEquals("/foo", request.getPathInfo()); + } + + @Test + public void contextPathServletPathEmpty() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/travel/hotels/42"); + this.builder.contextPath("/travel"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("/travel", request.getContextPath()); + assertEquals("", request.getServletPath()); + assertEquals("/hotels/42", request.getPathInfo()); + } + + @Test + public void contextPathServletPath() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/travel/main/hotels/42"); + this.builder.contextPath("/travel"); + this.builder.servletPath("/main"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("/travel", request.getContextPath()); + assertEquals("/main", request.getServletPath()); + assertEquals("/hotels/42", request.getPathInfo()); + } + + @Test + public void contextPathServletPathInfoEmpty() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/travel/hotels/42"); + this.builder.contextPath("/travel"); + this.builder.servletPath("/hotels/42"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("/travel", request.getContextPath()); + assertEquals("/hotels/42", request.getServletPath()); + assertNull(request.getPathInfo()); + } + + @Test + public void contextPathServletPathInfo() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/"); + this.builder.servletPath("/index.html"); + this.builder.pathInfo(null); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("", request.getContextPath()); + assertEquals("/index.html", request.getServletPath()); + assertNull(request.getPathInfo()); + } + + @Test // SPR-16453 + public void pathInfoIsDecoded() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/travel/hotels 42"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("/travel/hotels 42", request.getPathInfo()); + } + + @Test + public void contextPathServletPathInvalid() { + testContextPathServletPathInvalid("/Foo", "", "Request URI [/foo/bar] does not start with context path [/Foo]"); + testContextPathServletPathInvalid("foo", "", "Context path must start with a '/'"); + testContextPathServletPathInvalid("/foo/", "", "Context path must not end with a '/'"); + + testContextPathServletPathInvalid("/foo", "/Bar", "Invalid servlet path [/Bar] for request URI [/foo/bar]"); + testContextPathServletPathInvalid("/foo", "bar", "Servlet path must start with a '/'"); + testContextPathServletPathInvalid("/foo", "/bar/", "Servlet path must not end with a '/'"); + } + + private void testContextPathServletPathInvalid(String contextPath, String servletPath, String message) { + try { + this.builder.contextPath(contextPath); + this.builder.servletPath(servletPath); + this.builder.buildRequest(this.servletContext); + } + catch (IllegalArgumentException ex) { + assertEquals(message, ex.getMessage()); + } + } + + @Test + public void requestUriAndFragment() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/foo#bar"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("/foo", request.getRequestURI()); + } + + @Test + public void requestParameter() { + this.builder.param("foo", "bar", "baz"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + Map parameterMap = request.getParameterMap(); + + assertArrayEquals(new String[] {"bar", "baz"}, parameterMap.get("foo")); + } + + @Test + public void requestParameterFromQuery() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/?foo=bar&foo=baz"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + Map parameterMap = request.getParameterMap(); + + assertArrayEquals(new String[] {"bar", "baz"}, parameterMap.get("foo")); + assertEquals("foo=bar&foo=baz", request.getQueryString()); + } + + @Test + public void requestParameterFromQueryList() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/?foo[0]=bar&foo[1]=baz"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("foo%5B0%5D=bar&foo%5B1%5D=baz", request.getQueryString()); + assertEquals("bar", request.getParameter("foo[0]")); + assertEquals("baz", request.getParameter("foo[1]")); + } + + @Test + public void requestParameterFromQueryWithEncoding() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/?foo={value}", "bar=baz"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("foo=bar%3Dbaz", request.getQueryString()); + assertEquals("bar=baz", request.getParameter("foo")); + } + + @Test // SPR-11043 + public void requestParameterFromQueryNull() { + this.builder = new MockHttpServletRequestBuilder(HttpMethod.GET, "/?foo"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + Map parameterMap = request.getParameterMap(); + + assertArrayEquals(new String[] {null}, parameterMap.get("foo")); + assertEquals("foo", request.getQueryString()); + } + + @Test // SPR-13801 + public void requestParameterFromMultiValueMap() throws Exception { + MultiValueMap params = new LinkedMultiValueMap<>(); + params.add("foo", "bar"); + params.add("foo", "baz"); + this.builder = new MockHttpServletRequestBuilder(HttpMethod.POST, "/foo"); + this.builder.params(params); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertArrayEquals(new String[] {"bar", "baz"}, request.getParameterMap().get("foo")); + } + + @Test + public void requestParameterFromRequestBodyFormData() throws Exception { + String contentType = "application/x-www-form-urlencoded;charset=UTF-8"; + String body = "name+1=value+1&name+2=value+A&name+2=value+B&name+3"; + + MockHttpServletRequest request = new MockHttpServletRequestBuilder(HttpMethod.POST, "/foo") + .contentType(contentType).content(body.getBytes(StandardCharsets.UTF_8)) + .buildRequest(this.servletContext); + + assertArrayEquals(new String[] {"value 1"}, request.getParameterMap().get("name 1")); + assertArrayEquals(new String[] {"value A", "value B"}, request.getParameterMap().get("name 2")); + assertArrayEquals(new String[] {null}, request.getParameterMap().get("name 3")); + } + + @Test + public void acceptHeader() { + this.builder.accept(MediaType.TEXT_HTML, MediaType.APPLICATION_XML); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + List accept = Collections.list(request.getHeaders("Accept")); + List result = MediaType.parseMediaTypes(accept.get(0)); + + assertEquals(1, accept.size()); + assertEquals("text/html", result.get(0).toString()); + assertEquals("application/xml", result.get(1).toString()); + } + + @Test + public void contentType() { + this.builder.contentType(MediaType.TEXT_HTML); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + String contentType = request.getContentType(); + List contentTypes = Collections.list(request.getHeaders("Content-Type")); + + assertEquals("text/html", contentType); + assertEquals(1, contentTypes.size()); + assertEquals("text/html", contentTypes.get(0)); + } + + @Test + public void contentTypeViaString() { + this.builder.contentType("text/html"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + String contentType = request.getContentType(); + List contentTypes = Collections.list(request.getHeaders("Content-Type")); + + assertEquals("text/html", contentType); + assertEquals(1, contentTypes.size()); + assertEquals("text/html", contentTypes.get(0)); + } + + @Test // SPR-11308 + public void contentTypeViaHeader() { + this.builder.header("Content-Type", MediaType.TEXT_HTML_VALUE); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + String contentType = request.getContentType(); + + assertEquals("text/html", contentType); + } + + @Test // SPR-11308 + public void contentTypeViaMultipleHeaderValues() { + this.builder.header("Content-Type", MediaType.TEXT_HTML_VALUE, MediaType.ALL_VALUE); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("text/html", request.getContentType()); + } + + @Test + public void body() throws IOException { + byte[] body = "Hello World".getBytes("UTF-8"); + this.builder.content(body); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + byte[] result = FileCopyUtils.copyToByteArray(request.getInputStream()); + + assertArrayEquals(body, result); + } + + @Test + public void header() { + this.builder.header("foo", "bar", "baz"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + List headers = Collections.list(request.getHeaders("foo")); + + assertEquals(2, headers.size()); + assertEquals("bar", headers.get(0)); + assertEquals("baz", headers.get(1)); + } + + @Test + public void headers() { + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.put("foo", Arrays.asList("bar", "baz")); + this.builder.headers(httpHeaders); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + List headers = Collections.list(request.getHeaders("foo")); + + assertEquals(2, headers.size()); + assertEquals("bar", headers.get(0)); + assertEquals("baz", headers.get(1)); + assertEquals(MediaType.APPLICATION_JSON.toString(), request.getHeader("Content-Type")); + } + + @Test + public void cookie() { + Cookie cookie1 = new Cookie("foo", "bar"); + Cookie cookie2 = new Cookie("baz", "qux"); + this.builder.cookie(cookie1, cookie2); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + Cookie[] cookies = request.getCookies(); + + assertEquals(2, cookies.length); + assertEquals("foo", cookies[0].getName()); + assertEquals("bar", cookies[0].getValue()); + assertEquals("baz", cookies[1].getName()); + assertEquals("qux", cookies[1].getValue()); + } + + @Test + public void noCookies() { + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + assertNull(request.getCookies()); + } + + @Test + public void locale() { + Locale locale = new Locale("nl", "nl"); + this.builder.locale(locale); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals(locale, request.getLocale()); + } + + @Test + public void characterEncoding() { + String encoding = "UTF-8"; + this.builder.characterEncoding(encoding); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals(encoding, request.getCharacterEncoding()); + } + + @Test + public void requestAttribute() { + this.builder.requestAttr("foo", "bar"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("bar", request.getAttribute("foo")); + } + + @Test + public void sessionAttribute() { + this.builder.sessionAttr("foo", "bar"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("bar", request.getSession().getAttribute("foo")); + } + + @Test + public void sessionAttributes() { + Map map = new HashMap<>(); + map.put("foo", "bar"); + this.builder.sessionAttrs(map); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals("bar", request.getSession().getAttribute("foo")); + } + + @Test + public void session() { + MockHttpSession session = new MockHttpSession(this.servletContext); + session.setAttribute("foo", "bar"); + this.builder.session(session); + this.builder.sessionAttr("baz", "qux"); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals(session, request.getSession()); + assertEquals("bar", request.getSession().getAttribute("foo")); + assertEquals("qux", request.getSession().getAttribute("baz")); + } + + @Test + public void flashAttribute() { + this.builder.flashAttr("foo", "bar"); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + FlashMap flashMap = new SessionFlashMapManager().retrieveAndUpdate(request, null); + assertNotNull(flashMap); + assertEquals("bar", flashMap.get("foo")); + } + + @Test + public void principal() { + User user = new User(); + this.builder.principal(user); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals(user, request.getUserPrincipal()); + } + + @Test // SPR-12945 + public void mergeInvokesDefaultRequestPostProcessorFirst() { + final String ATTR = "ATTR"; + final String EXPECTED = "override"; + + MockHttpServletRequestBuilder defaultBuilder = + new MockHttpServletRequestBuilder(HttpMethod.GET, "/foo/bar") + .with(requestAttr(ATTR).value("default")) + .with(requestAttr(ATTR).value(EXPECTED)); + + builder.merge(defaultBuilder); + + MockHttpServletRequest request = builder.buildRequest(servletContext); + request = builder.postProcessRequest(request); + + assertEquals(EXPECTED, request.getAttribute(ATTR)); + } + + @Test // SPR-13719 + public void arbitraryMethod() { + String httpMethod = "REPort"; + URI url = UriComponentsBuilder.fromPath("/foo/{bar}").buildAndExpand(42).toUri(); + this.builder = new MockHttpServletRequestBuilder(httpMethod, url); + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + + assertEquals(httpMethod, request.getMethod()); + assertEquals("/foo/42", request.getPathInfo()); + } + + + private static RequestAttributePostProcessor requestAttr(String attrName) { + return new RequestAttributePostProcessor().attr(attrName); + } + + + private final class User implements Principal { + + @Override + public String getName() { + return "Foo"; + } + } + + + private static class RequestAttributePostProcessor implements RequestPostProcessor { + + String attr; + + String value; + + public RequestAttributePostProcessor attr(String attr) { + this.attr = attr; + return this; + } + + public RequestAttributePostProcessor value(String value) { + this.value = value; + return this; + } + + public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { + request.setAttribute(attr, value); + return request; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8816cef2bd75145989ac6beb693f61d96f9898f1 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.request; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockServletContext; + +import static org.junit.Assert.*; + +/** + * @author Rossen Stoyanchev + */ +public class MockMultipartHttpServletRequestBuilderTests { + + @Test + public void test() { + MockHttpServletRequestBuilder parent = new MockHttpServletRequestBuilder(HttpMethod.GET, "/"); + parent.characterEncoding("UTF-8"); + Object result = new MockMultipartHttpServletRequestBuilder("/fileUpload").merge(parent); + + assertNotNull(result); + assertEquals(MockMultipartHttpServletRequestBuilder.class, result.getClass()); + + MockMultipartHttpServletRequestBuilder builder = (MockMultipartHttpServletRequestBuilder) result; + MockHttpServletRequest request = builder.buildRequest(new MockServletContext()); + assertEquals("UTF-8", request.getCharacterEncoding()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/ContentResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/ContentResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..17fb1fd9c72c9e8a9ecab22f471be04ce6d173b9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/ContentResultMatchersTests.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.StubMvcResult; + +/** + * @author Rossen Stoyanchev + */ +public class ContentResultMatchersTests { + + @Test + public void typeMatches() throws Exception { + new ContentResultMatchers().contentType("application/json;charset=UTF-8").match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void typeNoMatch() throws Exception { + new ContentResultMatchers().contentType("text/plain").match(getStubMvcResult()); + } + + @Test + public void encoding() throws Exception { + new ContentResultMatchers().encoding("UTF-8").match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void encodingNoMatch() throws Exception { + new ContentResultMatchers().encoding("ISO-8859-1").match(getStubMvcResult()); + } + + @Test + public void string() throws Exception { + new ContentResultMatchers().string(new String(CONTENT.getBytes("UTF-8"))).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void stringNoMatch() throws Exception { + new ContentResultMatchers().encoding("bogus").match(getStubMvcResult()); + } + + @Test + public void stringMatcher() throws Exception { + String content = new String(CONTENT.getBytes("UTF-8")); + new ContentResultMatchers().string(Matchers.equalTo(content)).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void stringMatcherNoMatch() throws Exception { + new ContentResultMatchers().string(Matchers.equalTo("bogus")).match(getStubMvcResult()); + } + + @Test + public void bytes() throws Exception { + new ContentResultMatchers().bytes(CONTENT.getBytes("UTF-8")).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void bytesNoMatch() throws Exception { + new ContentResultMatchers().bytes("bogus".getBytes()).match(getStubMvcResult()); + } + + @Test + public void jsonLenientMatch() throws Exception { + new ContentResultMatchers().json("{\n \"foo\" : \"bar\" \n}").match(getStubMvcResult()); + new ContentResultMatchers().json("{\n \"foo\" : \"bar\" \n}", false).match(getStubMvcResult()); + } + + @Test + public void jsonStrictMatch() throws Exception { + new ContentResultMatchers().json("{\n \"foo\":\"bar\", \"foo array\":[\"foo\",\"bar\"] \n}", true).match(getStubMvcResult()); + new ContentResultMatchers().json("{\n \"foo array\":[\"foo\",\"bar\"], \"foo\":\"bar\" \n}", true).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void jsonLenientNoMatch() throws Exception { + new ContentResultMatchers().json("{\n\"fooo\":\"bar\"\n}").match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void jsonStrictNoMatch() throws Exception { + new ContentResultMatchers().json("{\"foo\":\"bar\", \"foo array\":[\"bar\",\"foo\"]}", true).match(getStubMvcResult()); + } + + private static final String CONTENT = "{\"foo\":\"bar\",\"foo array\":[\"foo\",\"bar\"]}"; + + private StubMvcResult getStubMvcResult() throws Exception { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.addHeader("Content-Type", "application/json; charset=UTF-8"); + response.getWriter().print(new String(CONTENT.getBytes("UTF-8"))); + return new StubMvcResult(null, null, null, null, null, null, response); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/FlashAttributeResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/FlashAttributeResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..abcd167429accb2b722414e0d100284b366bd4ec --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/FlashAttributeResultMatchersTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import org.junit.Test; + +import org.springframework.test.web.servlet.StubMvcResult; +import org.springframework.web.servlet.FlashMap; + +/** + * @author Craig Walls + */ +public class FlashAttributeResultMatchersTests { + + @Test + public void attributeExists() throws Exception { + new FlashAttributeResultMatchers().attributeExists("good").match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void attributeExists_doesntExist() throws Exception { + new FlashAttributeResultMatchers().attributeExists("bad").match(getStubMvcResult()); + } + + @Test + public void attribute() throws Exception { + new FlashAttributeResultMatchers().attribute("good", "good").match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void attribute_incorrectValue() throws Exception { + new FlashAttributeResultMatchers().attribute("good", "not good").match(getStubMvcResult()); + } + + private StubMvcResult getStubMvcResult() { + FlashMap flashMap = new FlashMap(); + flashMap.put("good", "good"); + StubMvcResult mvcResult = new StubMvcResult(null, null, null, null, null, flashMap, null); + return mvcResult; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/HeaderResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/HeaderResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..401f9cab0c7e7f92d3d6e42ccbc44fb67f1ce48c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/HeaderResultMatchersTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.servlet.result; + +import java.time.ZoneId; +import java.time.ZonedDateTime; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.StubMvcResult; + +/** + * Unit tests for {@link HeaderResultMatchers}. + * @author Rossen Stoyanchev + */ +public class HeaderResultMatchersTests { + + private final HeaderResultMatchers matchers = new HeaderResultMatchers(); + + private final MockHttpServletResponse response = new MockHttpServletResponse(); + + private final MvcResult mvcResult = + new StubMvcResult(new MockHttpServletRequest(), null, null, null, null, null, this.response); + + + @Test // SPR-17330 + public void matchDateFormattedWithHttpHeaders() throws Exception { + + long epochMilli = ZonedDateTime.of(2018, 10, 5, 0, 0, 0, 0, ZoneId.of("GMT")).toInstant().toEpochMilli(); + HttpHeaders headers = new HttpHeaders(); + headers.setDate("myDate", epochMilli); + this.response.setHeader("d", headers.getFirst("myDate")); + + this.matchers.dateValue("d", epochMilli).match(this.mvcResult); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/JsonPathResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/JsonPathResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b6f8c37e497b214bc8253969a63c3708a5990f1c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/JsonPathResultMatchersTests.java @@ -0,0 +1,284 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import org.hamcrest.Matchers; + +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.StubMvcResult; + +/** + * Unit tests for {@link JsonPathResultMatchers}. + * + * @author Rossen Stoyanchev + * @author Craig Andrews + * @author Sam Brannen + * @author Brian Clozel + */ +public class JsonPathResultMatchersTests { + + private static final String RESPONSE_CONTENT = "{" + // + "'str': 'foo', " + // + "'num': 5, " + // + "'bool': true, " + // + "'arr': [42], " + // + "'colorMap': {'red': 'rojo'}, " + // + "'emptyString': '', " + // + "'emptyArray': [], " + // + "'emptyMap': {} " + // + "}"; + + private static final StubMvcResult stubMvcResult; + + static { + try { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.addHeader("Content-Type", "application/json"); + response.getWriter().print(new String(RESPONSE_CONTENT.getBytes("ISO-8859-1"))); + stubMvcResult = new StubMvcResult(null, null, null, null, null, null, response); + } + catch (Exception e) { + throw new IllegalStateException(e); + } + } + + @Test(expected = AssertionError.class) + public void valueWithMismatch() throws Exception { + new JsonPathResultMatchers("$.str").value("bogus").match(stubMvcResult); + } + + @Test + public void valueWithDirectMatch() throws Exception { + new JsonPathResultMatchers("$.str").value("foo").match(stubMvcResult); + } + + @Test // SPR-16587 + public void valueWithNumberConversion() throws Exception { + new JsonPathResultMatchers("$.num").value(5.0f).match(stubMvcResult); + } + + @Test + public void valueWithMatcher() throws Exception { + new JsonPathResultMatchers("$.str").value(Matchers.equalTo("foo")).match(stubMvcResult); + } + + @Test // SPR-16587 + public void valueWithMatcherAndNumberConversion() throws Exception { + new JsonPathResultMatchers("$.num").value(Matchers.equalTo(5.0f), Float.class).match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void valueWithMatcherAndMismatch() throws Exception { + new JsonPathResultMatchers("$.str").value(Matchers.equalTo("bogus")).match(stubMvcResult); + } + + @Test + public void exists() throws Exception { + new JsonPathResultMatchers("$.str").exists().match(stubMvcResult); + } + + @Test + public void existsForAnEmptyArray() throws Exception { + new JsonPathResultMatchers("$.emptyArray").exists().match(stubMvcResult); + } + + @Test + public void existsForAnEmptyMap() throws Exception { + new JsonPathResultMatchers("$.emptyMap").exists().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void existsNoMatch() throws Exception { + new JsonPathResultMatchers("$.bogus").exists().match(stubMvcResult); + } + + @Test + public void doesNotExist() throws Exception { + new JsonPathResultMatchers("$.bogus").doesNotExist().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void doesNotExistNoMatch() throws Exception { + new JsonPathResultMatchers("$.str").doesNotExist().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void doesNotExistForAnEmptyArray() throws Exception { + new JsonPathResultMatchers("$.emptyArray").doesNotExist().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void doesNotExistForAnEmptyMap() throws Exception { + new JsonPathResultMatchers("$.emptyMap").doesNotExist().match(stubMvcResult); + } + + @Test + public void isEmptyForAnEmptyString() throws Exception { + new JsonPathResultMatchers("$.emptyString").isEmpty().match(stubMvcResult); + } + + @Test + public void isEmptyForAnEmptyArray() throws Exception { + new JsonPathResultMatchers("$.emptyArray").isEmpty().match(stubMvcResult); + } + + @Test + public void isEmptyForAnEmptyMap() throws Exception { + new JsonPathResultMatchers("$.emptyMap").isEmpty().match(stubMvcResult); + } + + @Test + public void isNotEmptyForString() throws Exception { + new JsonPathResultMatchers("$.str").isNotEmpty().match(stubMvcResult); + } + + @Test + public void isNotEmptyForNumber() throws Exception { + new JsonPathResultMatchers("$.num").isNotEmpty().match(stubMvcResult); + } + + @Test + public void isNotEmptyForBoolean() throws Exception { + new JsonPathResultMatchers("$.bool").isNotEmpty().match(stubMvcResult); + } + + @Test + public void isNotEmptyForArray() throws Exception { + new JsonPathResultMatchers("$.arr").isNotEmpty().match(stubMvcResult); + } + + @Test + public void isNotEmptyForMap() throws Exception { + new JsonPathResultMatchers("$.colorMap").isNotEmpty().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isNotEmptyForAnEmptyString() throws Exception { + new JsonPathResultMatchers("$.emptyString").isNotEmpty().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isNotEmptyForAnEmptyArray() throws Exception { + new JsonPathResultMatchers("$.emptyArray").isNotEmpty().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isNotEmptyForAnEmptyMap() throws Exception { + new JsonPathResultMatchers("$.emptyMap").isNotEmpty().match(stubMvcResult); + } + + @Test + public void isArray() throws Exception { + new JsonPathResultMatchers("$.arr").isArray().match(stubMvcResult); + } + + @Test + public void isArrayForAnEmptyArray() throws Exception { + new JsonPathResultMatchers("$.emptyArray").isArray().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isArrayNoMatch() throws Exception { + new JsonPathResultMatchers("$.bar").isArray().match(stubMvcResult); + } + + @Test + public void isMap() throws Exception { + new JsonPathResultMatchers("$.colorMap").isMap().match(stubMvcResult); + } + + @Test + public void isMapForAnEmptyMap() throws Exception { + new JsonPathResultMatchers("$.emptyMap").isMap().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isMapNoMatch() throws Exception { + new JsonPathResultMatchers("$.str").isMap().match(stubMvcResult); + } + + @Test + public void isBoolean() throws Exception { + new JsonPathResultMatchers("$.bool").isBoolean().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isBooleanNoMatch() throws Exception { + new JsonPathResultMatchers("$.str").isBoolean().match(stubMvcResult); + } + + @Test + public void isNumber() throws Exception { + new JsonPathResultMatchers("$.num").isNumber().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isNumberNoMatch() throws Exception { + new JsonPathResultMatchers("$.str").isNumber().match(stubMvcResult); + } + + @Test + public void isString() throws Exception { + new JsonPathResultMatchers("$.str").isString().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void isStringNoMatch() throws Exception { + new JsonPathResultMatchers("$.arr").isString().match(stubMvcResult); + } + + @Test(expected = AssertionError.class) + public void valueWithJsonPrefixNotConfigured() throws Exception { + String jsonPrefix = "prefix"; + StubMvcResult result = createPrefixedStubMvcResult(jsonPrefix); + new JsonPathResultMatchers("$.str").value("foo").match(result); + } + + @Test(expected = AssertionError.class) + public void valueWithJsonWrongPrefix() throws Exception { + String jsonPrefix = "prefix"; + StubMvcResult result = createPrefixedStubMvcResult(jsonPrefix); + new JsonPathResultMatchers("$.str").prefix("wrong").value("foo").match(result); + } + + @Test + public void valueWithJsonPrefix() throws Exception { + String jsonPrefix = "prefix"; + StubMvcResult result = createPrefixedStubMvcResult(jsonPrefix); + new JsonPathResultMatchers("$.str").prefix(jsonPrefix).value("foo").match(result); + } + + @Test(expected = AssertionError.class) + public void prefixWithPayloadNotLongEnough() throws Exception { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.addHeader("Content-Type", "application/json"); + response.getWriter().print(new String("test".getBytes("ISO-8859-1"))); + StubMvcResult result = new StubMvcResult(null, null, null, null, null, null, response); + + new JsonPathResultMatchers("$.str").prefix("prefix").value("foo").match(result); + } + + private StubMvcResult createPrefixedStubMvcResult(String jsonPrefix) throws Exception { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.addHeader("Content-Type", "application/json"); + response.getWriter().print(jsonPrefix + new String(RESPONSE_CONTENT.getBytes("ISO-8859-1"))); + return new StubMvcResult(null, null, null, null, null, null, response); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/MockMvcResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/MockMvcResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b825dbdc76a0b105532057db00d80802a262ba45 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/MockMvcResultMatchersTests.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.StubMvcResult; + +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; + +/** + * Unit tests for {@link MockMvcResultMatchers}. + * + * @author Brian Clozel + * @author Sam Brannen + */ +public class MockMvcResultMatchersTests { + + @Test + public void redirect() throws Exception { + redirectedUrl("/resource/1").match(getRedirectedUrlStubMvcResult("/resource/1")); + } + + @Test + public void redirectWithUrlTemplate() throws Exception { + redirectedUrlTemplate("/orders/{orderId}/items/{itemId}", 1, 2).match(getRedirectedUrlStubMvcResult("/orders/1/items/2")); + } + + @Test + public void redirectWithMatchingPattern() throws Exception { + redirectedUrlPattern("/resource/*").match(getRedirectedUrlStubMvcResult("/resource/1")); + } + + @Test(expected = AssertionError.class) + public void redirectWithNonMatchingPattern() throws Exception { + redirectedUrlPattern("/resource/").match(getRedirectedUrlStubMvcResult("/resource/1")); + } + + @Test + public void forward() throws Exception { + forwardedUrl("/api/resource/1").match(getForwardedUrlStubMvcResult("/api/resource/1")); + } + + @Test + public void forwardWithQueryString() throws Exception { + forwardedUrl("/api/resource/1?arg=value").match(getForwardedUrlStubMvcResult("/api/resource/1?arg=value")); + } + + @Test + public void forwardWithUrlTemplate() throws Exception { + forwardedUrlTemplate("/orders/{orderId}/items/{itemId}", 1, 2).match(getForwardedUrlStubMvcResult("/orders/1/items/2")); + } + + @Test + public void forwardWithMatchingPattern() throws Exception { + forwardedUrlPattern("/api/**/?").match(getForwardedUrlStubMvcResult("/api/resource/1")); + } + + @Test(expected = AssertionError.class) + public void forwardWithNonMatchingPattern() throws Exception { + forwardedUrlPattern("/resource/").match(getForwardedUrlStubMvcResult("/resource/1")); + } + + private StubMvcResult getRedirectedUrlStubMvcResult(String redirectUrl) throws Exception { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.sendRedirect(redirectUrl); + StubMvcResult mvcResult = new StubMvcResult(null, null, null, null, null, null, response); + return mvcResult; + } + + private StubMvcResult getForwardedUrlStubMvcResult(String forwardedUrl) { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.setForwardedUrl(forwardedUrl); + StubMvcResult mvcResult = new StubMvcResult(null, null, null, null, null, null, response); + return mvcResult; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/ModelResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/ModelResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6afec2968c37b7e98fe85401364428f108357d73 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/ModelResultMatchersTests.java @@ -0,0 +1,174 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import java.util.Date; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.StubMvcResult; +import org.springframework.validation.BeanPropertyBindingResult; +import org.springframework.validation.BindingResult; +import org.springframework.web.servlet.ModelAndView; + +import static org.hamcrest.Matchers.*; + +/** + * Unit tests for + * {@link org.springframework.test.web.servlet.result.ModelResultMatchers}. + * + * @author Craig Walls + */ +public class ModelResultMatchersTests { + + private ModelResultMatchers matchers; + + private MvcResult mvcResult; + private MvcResult mvcResultWithError; + + @Before + public void setUp() throws Exception { + this.matchers = new ModelResultMatchers(); + + ModelAndView mav = new ModelAndView("view", "good", "good"); + BindingResult bindingResult = new BeanPropertyBindingResult("good", "good"); + mav.addObject(BindingResult.MODEL_KEY_PREFIX + "good", bindingResult); + + this.mvcResult = getMvcResult(mav); + + Date date = new Date(); + BindingResult bindingResultWithError = new BeanPropertyBindingResult(date, "date"); + bindingResultWithError.rejectValue("time", "error"); + + ModelAndView mavWithError = new ModelAndView("view", "good", "good"); + mavWithError.addObject("date", date); + mavWithError.addObject(BindingResult.MODEL_KEY_PREFIX + "date", bindingResultWithError); + + this.mvcResultWithError = getMvcResult(mavWithError); + } + + @Test + public void attributeExists() throws Exception { + this.matchers.attributeExists("good").match(this.mvcResult); + } + + @Test(expected = AssertionError.class) + public void attributeExists_doesNotExist() throws Exception { + this.matchers.attributeExists("bad").match(this.mvcResult); + } + + @Test + public void attributeDoesNotExist() throws Exception { + this.matchers.attributeDoesNotExist("bad").match(this.mvcResult); + } + + @Test(expected = AssertionError.class) + public void attributeDoesNotExist_doesExist() throws Exception { + this.matchers.attributeDoesNotExist("good").match(this.mvcResultWithError); + } + + @Test + public void attribute_equal() throws Exception { + this.matchers.attribute("good", is("good")).match(this.mvcResult); + } + + @Test(expected = AssertionError.class) + public void attribute_notEqual() throws Exception { + this.matchers.attribute("good", is("bad")).match(this.mvcResult); + } + + @Test + public void hasNoErrors() throws Exception { + this.matchers.hasNoErrors().match(this.mvcResult); + } + + @Test(expected = AssertionError.class) + public void hasNoErrors_withErrors() throws Exception { + this.matchers.hasNoErrors().match(this.mvcResultWithError); + } + + @Test + public void attributeHasErrors() throws Exception { + this.matchers.attributeHasErrors("date").match(this.mvcResultWithError); + } + + @Test(expected = AssertionError.class) + public void attributeHasErrors_withoutErrors() throws Exception { + this.matchers.attributeHasErrors("good").match(this.mvcResultWithError); + } + + @Test + public void attributeHasNoErrors() throws Exception { + this.matchers.attributeHasNoErrors("good").match(this.mvcResult); + } + + @Test(expected = AssertionError.class) + public void attributeHasNoErrors_withoutAttribute() throws Exception { + this.matchers.attributeHasNoErrors("missing").match(this.mvcResultWithError); + } + + @Test(expected = AssertionError.class) + public void attributeHasNoErrors_withErrors() throws Exception { + this.matchers.attributeHasNoErrors("date").match(this.mvcResultWithError); + } + + @Test + public void attributeHasFieldErrors() throws Exception { + this.matchers.attributeHasFieldErrors("date", "time").match(this.mvcResultWithError); + } + + @Test(expected = AssertionError.class) + public void attributeHasFieldErrors_withoutAttribute() throws Exception { + this.matchers.attributeHasFieldErrors("missing", "bad").match(this.mvcResult); + } + + @Test(expected = AssertionError.class) + public void attributeHasFieldErrors_withoutErrorsForAttribute() throws Exception { + this.matchers.attributeHasFieldErrors("date", "time").match(this.mvcResult); + } + + @Test(expected = AssertionError.class) + public void attributeHasFieldErrors_withoutErrorsForField() throws Exception { + this.matchers.attributeHasFieldErrors("date", "good", "time").match(this.mvcResultWithError); + } + + @Test + public void attributeHasFieldErrorCode() throws Exception { + this.matchers.attributeHasFieldErrorCode("date", "time", "error").match(this.mvcResultWithError); + } + + @Test(expected = AssertionError.class) + public void attributeHasFieldErrorCode_withoutErrorOnField() throws Exception { + this.matchers.attributeHasFieldErrorCode("date", "time", "incorrectError").match(this.mvcResultWithError); + } + + @Test + public void attributeHasFieldErrorCode_startsWith() throws Exception { + this.matchers.attributeHasFieldErrorCode("date", "time", startsWith("err")).match(this.mvcResultWithError); + } + + @Test(expected = AssertionError.class) + public void attributeHasFieldErrorCode_startsWith_withoutErrorOnField() throws Exception { + this.matchers.attributeHasFieldErrorCode("date", "time", startsWith("inc")).match(this.mvcResultWithError); + } + + private MvcResult getMvcResult(ModelAndView modelAndView) { + return new StubMvcResult(null, null, null, null, modelAndView, null, null); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/PrintingResultHandlerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/PrintingResultHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..95870aab2203eedf7a35b281defd721939007838 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/PrintingResultHandlerTests.java @@ -0,0 +1,383 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpSession; + +import org.junit.Test; +import org.mockito.Mockito; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.StubMvcResult; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.validation.BindException; +import org.springframework.validation.BindingResult; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.FlashMap; +import org.springframework.web.servlet.ModelAndView; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link PrintingResultHandler}. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @see org.springframework.test.web.servlet.samples.standalone.resulthandlers.PrintingResultHandlerSmokeTests + */ +public class PrintingResultHandlerTests { + + private final TestPrintingResultHandler handler = new TestPrintingResultHandler(); + + private final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/") { + @Override + public boolean isAsyncStarted() { + return false; + } + }; + + private final MockHttpServletResponse response = new MockHttpServletResponse(); + + private final StubMvcResult mvcResult = new StubMvcResult( + this.request, null, null, null, null, null, this.response); + + + @Test + public void printRequest() throws Exception { + this.request.addParameter("param", "paramValue"); + this.request.addHeader("header", "headerValue"); + this.request.setCharacterEncoding("UTF-16"); + String palindrome = "ablE was I ere I saw Elba"; + byte[] bytes = palindrome.getBytes("UTF-16"); + this.request.setContent(bytes); + this.request.getSession().setAttribute("foo", "bar"); + + this.handler.handle(this.mvcResult); + + HttpHeaders headers = new HttpHeaders(); + headers.set("header", "headerValue"); + + MultiValueMap params = new LinkedMultiValueMap<>(); + params.add("param", "paramValue"); + + assertValue("MockHttpServletRequest", "HTTP Method", this.request.getMethod()); + assertValue("MockHttpServletRequest", "Request URI", this.request.getRequestURI()); + assertValue("MockHttpServletRequest", "Parameters", params); + assertValue("MockHttpServletRequest", "Headers", headers); + assertValue("MockHttpServletRequest", "Body", palindrome); + assertValue("MockHttpServletRequest", "Session Attrs", Collections.singletonMap("foo", "bar")); + } + + @Test + public void printRequestWithoutSession() throws Exception { + this.request.addParameter("param", "paramValue"); + this.request.addHeader("header", "headerValue"); + this.request.setCharacterEncoding("UTF-16"); + String palindrome = "ablE was I ere I saw Elba"; + byte[] bytes = palindrome.getBytes("UTF-16"); + this.request.setContent(bytes); + + this.handler.handle(this.mvcResult); + + HttpHeaders headers = new HttpHeaders(); + headers.set("header", "headerValue"); + + MultiValueMap params = new LinkedMultiValueMap<>(); + params.add("param", "paramValue"); + + assertValue("MockHttpServletRequest", "HTTP Method", this.request.getMethod()); + assertValue("MockHttpServletRequest", "Request URI", this.request.getRequestURI()); + assertValue("MockHttpServletRequest", "Parameters", params); + assertValue("MockHttpServletRequest", "Headers", headers); + assertValue("MockHttpServletRequest", "Body", palindrome); + } + + @Test + public void printRequestWithEmptySessionMock() throws Exception { + this.request.addParameter("param", "paramValue"); + this.request.addHeader("header", "headerValue"); + this.request.setCharacterEncoding("UTF-16"); + String palindrome = "ablE was I ere I saw Elba"; + byte[] bytes = palindrome.getBytes("UTF-16"); + this.request.setContent(bytes); + this.request.setSession(Mockito.mock(HttpSession.class)); + + this.handler.handle(this.mvcResult); + + HttpHeaders headers = new HttpHeaders(); + headers.set("header", "headerValue"); + + MultiValueMap params = new LinkedMultiValueMap<>(); + params.add("param", "paramValue"); + + assertValue("MockHttpServletRequest", "HTTP Method", this.request.getMethod()); + assertValue("MockHttpServletRequest", "Request URI", this.request.getRequestURI()); + assertValue("MockHttpServletRequest", "Parameters", params); + assertValue("MockHttpServletRequest", "Headers", headers); + assertValue("MockHttpServletRequest", "Body", palindrome); + } + + @Test + @SuppressWarnings("deprecation") + public void printResponse() throws Exception { + Cookie enigmaCookie = new Cookie("enigma", "42"); + enigmaCookie.setComment("This is a comment"); + enigmaCookie.setHttpOnly(true); + enigmaCookie.setMaxAge(1234); + enigmaCookie.setDomain(".example.com"); + enigmaCookie.setPath("/crumbs"); + enigmaCookie.setSecure(true); + + this.response.setStatus(400, "error"); + this.response.addHeader("header", "headerValue"); + this.response.setContentType("text/plain"); + this.response.getWriter().print("content"); + this.response.setForwardedUrl("redirectFoo"); + this.response.sendRedirect("/redirectFoo"); + this.response.addCookie(new Cookie("cookie", "cookieValue")); + this.response.addCookie(enigmaCookie); + + this.handler.handle(this.mvcResult); + + // Manually validate cookie values since maxAge changes... + List cookieValues = this.response.getHeaders("Set-Cookie"); + assertEquals(2, cookieValues.size()); + assertEquals("cookie=cookieValue", cookieValues.get(0)); + assertTrue("Actual: " + cookieValues.get(1), cookieValues.get(1).startsWith( + "enigma=42; Path=/crumbs; Domain=.example.com; Max-Age=1234; Expires=")); + + HttpHeaders headers = new HttpHeaders(); + headers.set("header", "headerValue"); + headers.setContentType(MediaType.TEXT_PLAIN); + headers.setLocation(new URI("/redirectFoo")); + headers.put("Set-Cookie", cookieValues); + + String heading = "MockHttpServletResponse"; + assertValue(heading, "Status", this.response.getStatus()); + assertValue(heading, "Error message", response.getErrorMessage()); + assertValue(heading, "Headers", headers); + assertValue(heading, "Content type", this.response.getContentType()); + assertValue(heading, "Body", this.response.getContentAsString()); + assertValue(heading, "Forwarded URL", this.response.getForwardedUrl()); + assertValue(heading, "Redirected URL", this.response.getRedirectedUrl()); + + Map> printedValues = this.handler.getPrinter().printedValues; + String[] cookies = (String[]) printedValues.get(heading).get("Cookies"); + assertEquals(2, cookies.length); + String cookie1 = cookies[0]; + String cookie2 = cookies[1]; + assertTrue(cookie1.startsWith("[" + Cookie.class.getSimpleName())); + assertTrue(cookie1.contains("name = 'cookie', value = 'cookieValue'")); + assertTrue(cookie1.endsWith("]")); + assertTrue(cookie2.startsWith("[" + Cookie.class.getSimpleName())); + assertTrue(cookie2.contains("name = 'enigma', value = '42', " + + "comment = 'This is a comment', domain = '.example.com', maxAge = 1234, " + + "path = '/crumbs', secure = true, version = 0, httpOnly = true")); + assertTrue(cookie2.endsWith("]")); + } + + @Test + public void printRequestWithCharacterEncoding() throws Exception { + this.request.setCharacterEncoding("UTF-8"); + this.request.setContent("text".getBytes("UTF-8")); + + this.handler.handle(this.mvcResult); + + assertValue("MockHttpServletRequest", "Body", "text"); + } + + @Test + public void printRequestWithoutCharacterEncoding() throws Exception { + this.handler.handle(this.mvcResult); + + assertValue("MockHttpServletRequest", "Body", ""); + } + + @Test + public void printResponseWithCharacterEncoding() throws Exception { + this.response.setCharacterEncoding("UTF-8"); + this.response.getWriter().print("text"); + + this.handler.handle(this.mvcResult); + assertValue("MockHttpServletResponse", "Body", "text"); + } + + @Test + public void printResponseWithDefaultCharacterEncoding() throws Exception { + this.response.getWriter().print("text"); + + this.handler.handle(this.mvcResult); + + assertValue("MockHttpServletResponse", "Body", "text"); + } + + @Test + public void printResponseWithoutCharacterEncoding() throws Exception { + this.response.setCharacterEncoding(null); + this.response.getWriter().print("text"); + + this.handler.handle(this.mvcResult); + + assertValue("MockHttpServletResponse", "Body", ""); + } + + @Test + public void printHandlerNull() throws Exception { + StubMvcResult mvcResult = new StubMvcResult(this.request, null, null, null, null, null, this.response); + this.handler.handle(mvcResult); + + assertValue("Handler", "Type", null); + } + + @Test + public void printHandler() throws Exception { + this.mvcResult.setHandler(new Object()); + this.handler.handle(this.mvcResult); + + assertValue("Handler", "Type", Object.class.getName()); + } + + @Test + public void printHandlerMethod() throws Exception { + HandlerMethod handlerMethod = new HandlerMethod(this, "handle"); + this.mvcResult.setHandler(handlerMethod); + this.handler.handle(mvcResult); + + assertValue("Handler", "Type", this.getClass().getName()); + assertValue("Handler", "Method", handlerMethod); + } + + @Test + public void resolvedExceptionNull() throws Exception { + this.handler.handle(this.mvcResult); + + assertValue("Resolved Exception", "Type", null); + } + + @Test + public void resolvedException() throws Exception { + this.mvcResult.setResolvedException(new Exception()); + this.handler.handle(this.mvcResult); + + assertValue("Resolved Exception", "Type", Exception.class.getName()); + } + + @Test + public void modelAndViewNull() throws Exception { + this.handler.handle(this.mvcResult); + + assertValue("ModelAndView", "View name", null); + assertValue("ModelAndView", "View", null); + assertValue("ModelAndView", "Model", null); + } + + @Test + public void modelAndView() throws Exception { + BindException bindException = new BindException(new Object(), "target"); + bindException.reject("errorCode"); + + ModelAndView mav = new ModelAndView("viewName"); + mav.addObject("attrName", "attrValue"); + mav.addObject(BindingResult.MODEL_KEY_PREFIX + "attrName", bindException); + + this.mvcResult.setMav(mav); + this.handler.handle(this.mvcResult); + + assertValue("ModelAndView", "View name", "viewName"); + assertValue("ModelAndView", "View", null); + assertValue("ModelAndView", "Attribute", "attrName"); + assertValue("ModelAndView", "value", "attrValue"); + assertValue("ModelAndView", "errors", bindException.getAllErrors()); + } + + @Test + public void flashMapNull() throws Exception { + this.handler.handle(mvcResult); + + assertValue("FlashMap", "Type", null); + } + + @Test + public void flashMap() throws Exception { + FlashMap flashMap = new FlashMap(); + flashMap.put("attrName", "attrValue"); + this.request.setAttribute(DispatcherServlet.class.getName() + ".OUTPUT_FLASH_MAP", flashMap); + + this.handler.handle(this.mvcResult); + + assertValue("FlashMap", "Attribute", "attrName"); + assertValue("FlashMap", "value", "attrValue"); + } + + private void assertValue(String heading, String label, Object value) { + Map> printedValues = this.handler.getPrinter().printedValues; + assertTrue("Heading '" + heading + "' not printed", printedValues.containsKey(heading)); + assertEquals("For label '" + label + "' under heading '" + heading + "' =>", value, + printedValues.get(heading).get(label)); + } + + + private static class TestPrintingResultHandler extends PrintingResultHandler { + + TestPrintingResultHandler() { + super(new TestResultValuePrinter()); + } + + @Override + public TestResultValuePrinter getPrinter() { + return (TestResultValuePrinter) super.getPrinter(); + } + + private static class TestResultValuePrinter implements ResultValuePrinter { + + private String printedHeading; + + private Map> printedValues = new HashMap<>(); + + @Override + public void printHeading(String heading) { + this.printedHeading = heading; + this.printedValues.put(heading, new HashMap<>()); + } + + @Override + public void printValue(String label, Object value) { + Assert.notNull(this.printedHeading, + "Heading not printed before label " + label + " with value " + value); + this.printedValues.get(this.printedHeading).put(label, value); + } + } + } + + + public void handle() { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/StatusResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/StatusResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..43972e8c23e2176d2415806f07b5373f6a425b71 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/StatusResultMatchersTests.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.Conventions; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.StubMvcResult; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.StringUtils; + +import static org.junit.Assert.*; + +/** + * Tests for {@link StatusResultMatchers}. + * + * @author Rossen Stoyanchev + */ +public class StatusResultMatchersTests { + + private StatusResultMatchers matchers; + + private MockHttpServletRequest request; + + + @Before + public void setup() { + this.matchers = new StatusResultMatchers(); + this.request = new MockHttpServletRequest(); + } + + + @Test + public void testHttpStatusCodeResultMatchers() throws Exception { + + List failures = new ArrayList<>(); + + for (HttpStatus status : HttpStatus.values()) { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.setStatus(status.value()); + MvcResult mvcResult = new StubMvcResult(request, null, null, null, null, null, response); + try { + Method method = getMethodForHttpStatus(status); + ResultMatcher matcher = (ResultMatcher) ReflectionUtils.invokeMethod(method, this.matchers); + try { + matcher.match(mvcResult); + } + catch (AssertionError error) { + failures.add(error); + } + } + catch (Exception ex) { + throw new Exception("Failed to obtain ResultMatcher for status " + status, ex); + } + } + + if (!failures.isEmpty()) { + fail("Failed status codes: " + failures); + } + } + + private Method getMethodForHttpStatus(HttpStatus status) throws NoSuchMethodException { + String name = status.name().toLowerCase().replace("_", "-"); + name = "is" + StringUtils.capitalize(Conventions.attributeNameToPropertyName(name)); + return StatusResultMatchers.class.getMethod(name); + } + + @Test + public void statusRanges() throws Exception { + for (HttpStatus status : HttpStatus.values()) { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.setStatus(status.value()); + MvcResult mvcResult = new StubMvcResult(request, null, null, null, null, null, response); + switch (status.series().value()) { + case 1: + this.matchers.is1xxInformational().match(mvcResult); + break; + case 2: + this.matchers.is2xxSuccessful().match(mvcResult); + break; + case 3: + this.matchers.is3xxRedirection().match(mvcResult); + break; + case 4: + this.matchers.is4xxClientError().match(mvcResult); + break; + case 5: + this.matchers.is5xxServerError().match(mvcResult); + break; + default: + fail("Unexpected range for status code value " + status); + } + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/result/XpathResultMatchersTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/result/XpathResultMatchersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ed0624f59a939760f2e981b1c13d7d219f4c8fd7 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/result/XpathResultMatchersTests.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result; + +import java.nio.charset.StandardCharsets; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.StubMvcResult; +import org.springframework.util.StreamUtils; + +/** + * Tests for {@link XpathResultMatchers}. + * + * @author Rossen Stoyanchev + */ +public class XpathResultMatchersTests { + + private static final String RESPONSE_CONTENT = "111true"; + + + @Test + public void node() throws Exception { + new XpathResultMatchers("/foo/bar", null).node(Matchers.notNullValue()).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void nodeNoMatch() throws Exception { + new XpathResultMatchers("/foo/bar", null).node(Matchers.nullValue()).match(getStubMvcResult()); + } + + @Test + public void exists() throws Exception { + new XpathResultMatchers("/foo/bar", null).exists().match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void existsNoMatch() throws Exception { + new XpathResultMatchers("/foo/Bar", null).exists().match(getStubMvcResult()); + } + + @Test + public void doesNotExist() throws Exception { + new XpathResultMatchers("/foo/Bar", null).doesNotExist().match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void doesNotExistNoMatch() throws Exception { + new XpathResultMatchers("/foo/bar", null).doesNotExist().match(getStubMvcResult()); + } + + @Test + public void nodeCount() throws Exception { + new XpathResultMatchers("/foo/bar", null).nodeCount(2).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void nodeCountNoMatch() throws Exception { + new XpathResultMatchers("/foo/bar", null).nodeCount(1).match(getStubMvcResult()); + } + + @Test + public void string() throws Exception { + new XpathResultMatchers("/foo/bar[1]", null).string("111").match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void stringNoMatch() throws Exception { + new XpathResultMatchers("/foo/bar[1]", null).string("112").match(getStubMvcResult()); + } + + @Test + public void number() throws Exception { + new XpathResultMatchers("/foo/bar[1]", null).number(111.0).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void numberNoMatch() throws Exception { + new XpathResultMatchers("/foo/bar[1]", null).number(111.1).match(getStubMvcResult()); + } + + @Test + public void booleanValue() throws Exception { + new XpathResultMatchers("/foo/bar[2]", null).booleanValue(true).match(getStubMvcResult()); + } + + @Test(expected = AssertionError.class) + public void booleanValueNoMatch() throws Exception { + new XpathResultMatchers("/foo/bar[2]", null).booleanValue(false).match(getStubMvcResult()); + } + + @Test + public void stringEncodingDetection() throws Exception { + String content = "\n" + + "Jürgen"; + byte[] bytes = content.getBytes(StandardCharsets.UTF_8); + MockHttpServletResponse response = new MockHttpServletResponse(); + response.addHeader("Content-Type", "application/xml"); + StreamUtils.copy(bytes, response.getOutputStream()); + StubMvcResult result = new StubMvcResult(null, null, null, null, null, null, response); + + new XpathResultMatchers("/person/name", null).string("Jürgen").match(result); + } + + + private StubMvcResult getStubMvcResult() throws Exception { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.addHeader("Content-Type", "application/xml"); + response.getWriter().print(new String(RESPONSE_CONTENT.getBytes(StandardCharsets.ISO_8859_1))); + return new StubMvcResult(null, null, null, null, null, null, response); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/AsyncControllerJavaConfigTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/AsyncControllerJavaConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..98aae70d15cb97405a5a4c2f81f022d4e18e6a3d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/AsyncControllerJavaConfigTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.context; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.Callable; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.MediaType; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextHierarchy; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.request.async.CallableProcessingInterceptor; +import org.springframework.web.servlet.config.annotation.AsyncSupportConfigurer; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import static org.mockito.ArgumentMatchers.any; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Tests with Java configuration. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration +@ContextHierarchy(@ContextConfiguration(classes = AsyncControllerJavaConfigTests.WebConfig.class)) +public class AsyncControllerJavaConfigTests { + + @Autowired + private WebApplicationContext wac; + + @Autowired + private CallableProcessingInterceptor callableInterceptor; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).build(); + } + + // SPR-13615 + + @Test + public void callableInterceptor() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/callable").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(request().asyncStarted()) + .andExpect(request().asyncResult(Collections.singletonMap("key", "value"))) + .andReturn(); + + Mockito.verify(this.callableInterceptor).beforeConcurrentHandling(any(), any()); + Mockito.verify(this.callableInterceptor).preProcess(any(), any()); + Mockito.verify(this.callableInterceptor).postProcess(any(), any(), any()); + Mockito.verifyNoMoreInteractions(this.callableInterceptor); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().string("{\"key\":\"value\"}")); + + Mockito.verify(this.callableInterceptor).afterCompletion(any(), any()); + Mockito.verifyNoMoreInteractions(this.callableInterceptor); + } + + + @Configuration + @EnableWebMvc + static class WebConfig implements WebMvcConfigurer { + + @Override + public void configureAsyncSupport(AsyncSupportConfigurer configurer) { + configurer.registerCallableInterceptors(callableInterceptor()); + } + + @Bean + public CallableProcessingInterceptor callableInterceptor() { + return Mockito.mock(CallableProcessingInterceptor.class); + } + + @Bean + public AsyncController asyncController() { + return new AsyncController(); + } + + } + + @RestController + static class AsyncController { + + @GetMapping("/callable") + public Callable> getCallable() { + return () -> Collections.singletonMap("key", "value"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/JavaConfigTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/JavaConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..90c016cf453401765bfb7bf729957a3528a194aa --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/JavaConfigTests.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.context; + +import javax.servlet.ServletContext; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.MediaType; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextHierarchy; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.samples.context.JavaConfigTests.RootConfig; +import org.springframework.test.web.servlet.samples.context.JavaConfigTests.WebConfig; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.DefaultServletHandlerConfigurer; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; +import org.springframework.web.servlet.config.annotation.ViewControllerRegistry; +import org.springframework.web.servlet.config.annotation.ViewResolverRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; +import org.springframework.web.servlet.view.tiles3.TilesConfigurer; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.BDDMockito.given; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Tests with Java configuration. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @author Sebastien Deleuze + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration("classpath:META-INF/web-resources") +@ContextHierarchy({ + @ContextConfiguration(classes = RootConfig.class), + @ContextConfiguration(classes = WebConfig.class) +}) +public class JavaConfigTests { + + @Autowired + private WebApplicationContext wac; + + @Autowired + private PersonDao personDao; + + @Autowired + private PersonController personController; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).build(); + verifyRootWacSupport(); + given(this.personDao.getPerson(5L)).willReturn(new Person("Joe")); + } + + @Test + public void person() throws Exception { + this.mockMvc.perform(get("/person/5").accept(MediaType.APPLICATION_JSON)) + .andDo(print()) + .andExpect(status().isOk()) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + @Test + public void tilesDefinitions() throws Exception { + this.mockMvc.perform(get("/")) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("/WEB-INF/layouts/standardLayout.jsp")); + } + + /** + * Verify that the breaking change introduced in SPR-12553 has been reverted. + * + *

This code has been copied from + * {@link org.springframework.test.context.hierarchies.web.ControllerIntegrationTests}. + * + * @see org.springframework.test.context.hierarchies.web.ControllerIntegrationTests#verifyRootWacSupport() + */ + private void verifyRootWacSupport() { + assertNotNull(personDao); + assertNotNull(personController); + + ApplicationContext parent = wac.getParent(); + assertNotNull(parent); + assertTrue(parent instanceof WebApplicationContext); + WebApplicationContext root = (WebApplicationContext) parent; + + ServletContext childServletContext = wac.getServletContext(); + assertNotNull(childServletContext); + ServletContext rootServletContext = root.getServletContext(); + assertNotNull(rootServletContext); + assertSame(childServletContext, rootServletContext); + + assertSame(root, rootServletContext.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE)); + assertSame(root, childServletContext.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE)); + } + + + @Configuration + static class RootConfig { + + @Bean + public PersonDao personDao() { + return Mockito.mock(PersonDao.class); + } + } + + @Configuration + @EnableWebMvc + static class WebConfig implements WebMvcConfigurer { + + @Autowired + private RootConfig rootConfig; + + @Bean + public PersonController personController() { + return new PersonController(this.rootConfig.personDao()); + } + + @Override + public void addResourceHandlers(ResourceHandlerRegistry registry) { + registry.addResourceHandler("/resources/**").addResourceLocations("/resources/"); + } + + @Override + public void addViewControllers(ViewControllerRegistry registry) { + registry.addViewController("/").setViewName("home"); + } + + @Override + public void configureDefaultServletHandling(DefaultServletHandlerConfigurer configurer) { + configurer.enable(); + } + + @Override + public void configureViewResolvers(ViewResolverRegistry registry) { + registry.tiles(); + } + + @Bean + public TilesConfigurer tilesConfigurer() { + TilesConfigurer configurer = new TilesConfigurer(); + configurer.setDefinitions("/WEB-INF/**/tiles.xml"); + return configurer; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/PersonController.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/PersonController.java new file mode 100644 index 0000000000000000000000000000000000000000..fa3ad7fa263a29a663c947c1a17273aad6db4940 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/PersonController.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.context; + +import org.springframework.test.web.Person; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping("/person") +public class PersonController { + + private final PersonDao personDao; + + + PersonController(PersonDao personDao) { + this.personDao = personDao; + } + + @GetMapping("/{id}") + public Person getPerson(@PathVariable long id) { + return this.personDao.getPerson(id); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/PersonDao.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/PersonDao.java new file mode 100644 index 0000000000000000000000000000000000000000..4001526e94c7a019866868bad715507ab3b3ca72 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/PersonDao.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.context; + +import org.springframework.test.web.Person; + +public interface PersonDao { + + Person getPerson(Long id); + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/WebAppResourceTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/WebAppResourceTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fbb1b1bc4f6a7855355affe8f96b65bf7b9b66f7 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/WebAppResourceTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.context; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextHierarchy; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.resource.DefaultServletHttpRequestHandler; + +import static org.hamcrest.Matchers.containsString; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.handler; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Tests dependent on access to resources under the web application root directory. + * + * @author Rossen Stoyanchev + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration("src/test/resources/META-INF/web-resources") +@ContextHierarchy({ + @ContextConfiguration("root-context.xml"), + @ContextConfiguration("servlet-context.xml") +}) +public class WebAppResourceTests { + + @Autowired + private WebApplicationContext wac; + + private MockMvc mockMvc; + + @Before + public void setup() { + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).alwaysExpect(status().isOk()).build(); + } + + // TilesConfigurer: resources under "/WEB-INF/**/tiles.xml" + + @Test + public void tilesDefinitions() throws Exception { + this.mockMvc.perform(get("/")) + .andExpect(forwardedUrl("/WEB-INF/layouts/standardLayout.jsp")); + } + + // Resources served via + + @Test + public void resourceRequest() throws Exception { + this.mockMvc.perform(get("/resources/Spring.js")) + .andExpect(content().contentType("application/javascript")) + .andExpect(content().string(containsString("Spring={};"))); + } + + // Forwarded to the "default" servlet via + + @Test + public void resourcesViaDefaultServlet() throws Exception { + this.mockMvc.perform(get("/unknown/resource")) + .andExpect(handler().handlerType(DefaultServletHttpRequestHandler.class)) + .andExpect(forwardedUrl("default")); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/XmlConfigTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/XmlConfigTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1f9179053d7909383b56249d775e5a2bde6f66e8 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/context/XmlConfigTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.context; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.MediaType; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.ContextHierarchy; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.context.WebApplicationContext; + +import static org.mockito.BDDMockito.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; + +/** + * Tests with XML configuration. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration("src/test/resources/META-INF/web-resources") +@ContextHierarchy({ + @ContextConfiguration("root-context.xml"), + @ContextConfiguration("servlet-context.xml") +}) +public class XmlConfigTests { + + @Autowired + private WebApplicationContext wac; + + @Autowired + private PersonDao personDao; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).build(); + given(this.personDao.getPerson(5L)).willReturn(new Person("Joe")); + } + + @Test + public void person() throws Exception { + this.mockMvc.perform(get("/person/5").accept(MediaType.APPLICATION_JSON)) + .andDo(print()) + .andExpect(status().isOk()) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + @Test + public void tilesDefinitions() throws Exception { + this.mockMvc.perform(get("/"))// + .andExpect(status().isOk())// + .andExpect(forwardedUrl("/WEB-INF/layouts/standardLayout.jsp")); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/ControllerAdviceIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/ControllerAdviceIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..144639e29d7bbd4671983fc320fc27df87d03c08 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/ControllerAdviceIntegrationTests.java @@ -0,0 +1,206 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.spr; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.ui.Model; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.annotation.RequestScope; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.junit.Assert.assertEquals; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.webAppContextSetup; + +/** + * Integration tests for {@link ControllerAdvice @ControllerAdvice}. + * + *

Introduced in conjunction with + * gh-24017. + * + * @author Sam Brannen + * @since 5.1.12 + */ +@RunWith(SpringRunner.class) +@WebAppConfiguration +public class ControllerAdviceIntegrationTests { + + @Autowired + WebApplicationContext wac; + + MockMvc mockMvc; + + @Before + public void setUpMockMvc() { + this.mockMvc = webAppContextSetup(wac).build(); + resetCounters(); + } + + @Test + public void controllerAdviceIsAppliedOnlyOnce() throws Exception { + this.mockMvc.perform(get("/test").param("requestParam", "foo"))// + .andExpect(status().isOk())// + .andExpect(forwardedUrl("singleton:1;prototype:1;request-scoped:1;requestParam:foo")); + + assertEquals(1, SingletonControllerAdvice.invocationCount.get()); + assertEquals(1, PrototypeControllerAdvice.invocationCount.get()); + assertEquals(1, RequestScopedControllerAdvice.invocationCount.get()); + } + + @Test + public void prototypeAndRequestScopedControllerAdviceBeansAreNotCached() throws Exception { + this.mockMvc.perform(get("/test").param("requestParam", "foo"))// + .andExpect(status().isOk())// + .andExpect(forwardedUrl("singleton:1;prototype:1;request-scoped:1;requestParam:foo")); + + // singleton @ControllerAdvice beans should not be instantiated again. + assertEquals(0, SingletonControllerAdvice.instanceCount.get()); + // prototype and request-scoped @ControllerAdvice beans should be instantiated once per request. + assertEquals(1, PrototypeControllerAdvice.instanceCount.get()); + assertEquals(1, RequestScopedControllerAdvice.instanceCount.get()); + + this.mockMvc.perform(get("/test").param("requestParam", "bar"))// + .andExpect(status().isOk())// + .andExpect(forwardedUrl("singleton:2;prototype:2;request-scoped:2;requestParam:bar")); + + // singleton @ControllerAdvice beans should not be instantiated again. + assertEquals(0, SingletonControllerAdvice.instanceCount.get()); + // prototype and request-scoped @ControllerAdvice beans should be instantiated once per request. + assertEquals(2, PrototypeControllerAdvice.instanceCount.get()); + assertEquals(2, RequestScopedControllerAdvice.instanceCount.get()); + } + + private static void resetCounters() { + SingletonControllerAdvice.invocationCount.set(0); + SingletonControllerAdvice.instanceCount.set(0); + PrototypeControllerAdvice.invocationCount.set(0); + PrototypeControllerAdvice.instanceCount.set(0); + RequestScopedControllerAdvice.invocationCount.set(0); + RequestScopedControllerAdvice.instanceCount.set(0); + } + + + @Configuration + @EnableWebMvc + static class Config { + + @Bean + TestController testController() { + return new TestController(); + } + + @Bean + SingletonControllerAdvice singletonControllerAdvice() { + return new SingletonControllerAdvice(); + } + + @Bean + @Scope("prototype") + PrototypeControllerAdvice prototypeControllerAdvice() { + return new PrototypeControllerAdvice(); + } + + @Bean + @RequestScope + RequestScopedControllerAdvice requestScopedControllerAdvice() { + return new RequestScopedControllerAdvice(); + } + } + + @ControllerAdvice + static class SingletonControllerAdvice { + + static final AtomicInteger instanceCount = new AtomicInteger(); + static final AtomicInteger invocationCount = new AtomicInteger(); + + { + instanceCount.incrementAndGet(); + } + + @ModelAttribute + void initModel(Model model) { + model.addAttribute("singleton", invocationCount.incrementAndGet()); + } + } + + @ControllerAdvice + static class PrototypeControllerAdvice { + + static final AtomicInteger instanceCount = new AtomicInteger(); + static final AtomicInteger invocationCount = new AtomicInteger(); + + { + instanceCount.incrementAndGet(); + } + + @ModelAttribute + void initModel(Model model) { + model.addAttribute("prototype", invocationCount.incrementAndGet()); + } + } + + @ControllerAdvice + static class RequestScopedControllerAdvice { + + static final AtomicInteger instanceCount = new AtomicInteger(); + static final AtomicInteger invocationCount = new AtomicInteger(); + + { + instanceCount.incrementAndGet(); + } + + @ModelAttribute + void initModel(@RequestParam String requestParam, Model model) { + model.addAttribute("requestParam", requestParam); + model.addAttribute("request-scoped", invocationCount.incrementAndGet()); + } + } + + @Controller + static class TestController { + + @GetMapping("/test") + String get(Model model) { + Map map = model.asMap(); + return "singleton:" + map.get("singleton") + + ";prototype:" + map.get("prototype") + + ";request-scoped:" + map.get("request-scoped") + + ";requestParam:" + map.get("requestParam"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/CustomRequestAttributesRequestContextHolderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/CustomRequestAttributesRequestContextHolderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0f73302b88176c2f42d4ef72370352b304a36344 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/CustomRequestAttributesRequestContextHolderTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.spr; + +import javax.servlet.ServletContext; +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.context.annotation.AnnotatedBeanDefinitionReader; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.context.support.GenericWebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.webAppContextSetup; + +/** + * Integration tests for SPR-13211 which verify that a custom mock request + * is not reused by MockMvc. + * + * @author Sam Brannen + * @since 4.2 + * @see RequestContextHolderTests + */ +public class CustomRequestAttributesRequestContextHolderTests { + + private static final String FROM_CUSTOM_MOCK = "fromCustomMock"; + private static final String FROM_MVC_TEST_DEFAULT = "fromSpringMvcTestDefault"; + private static final String FROM_MVC_TEST_MOCK = "fromSpringMvcTestMock"; + + private final GenericWebApplicationContext wac = new GenericWebApplicationContext(); + + private MockMvc mockMvc; + + + @Before + public void setUp() { + ServletContext servletContext = new MockServletContext(); + MockHttpServletRequest mockRequest = new MockHttpServletRequest(servletContext); + mockRequest.setAttribute(FROM_CUSTOM_MOCK, FROM_CUSTOM_MOCK); + RequestContextHolder.setRequestAttributes(new ServletWebRequest(mockRequest, new MockHttpServletResponse())); + + this.wac.setServletContext(servletContext); + new AnnotatedBeanDefinitionReader(this.wac).register(WebConfig.class); + this.wac.refresh(); + + this.mockMvc = webAppContextSetup(this.wac) + .defaultRequest(get("/").requestAttr(FROM_MVC_TEST_DEFAULT, FROM_MVC_TEST_DEFAULT)) + .alwaysExpect(status().isOk()) + .build(); + } + + @Test + public void singletonController() throws Exception { + this.mockMvc.perform(get("/singletonController").requestAttr(FROM_MVC_TEST_MOCK, FROM_MVC_TEST_MOCK)); + } + + @After + public void verifyCustomRequestAttributesAreRestored() { + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + assertThat(requestAttributes, instanceOf(ServletRequestAttributes.class)); + HttpServletRequest request = ((ServletRequestAttributes) requestAttributes).getRequest(); + + assertThat(request.getAttribute(FROM_CUSTOM_MOCK), is(FROM_CUSTOM_MOCK)); + assertThat(request.getAttribute(FROM_MVC_TEST_DEFAULT), is(nullValue())); + assertThat(request.getAttribute(FROM_MVC_TEST_MOCK), is(nullValue())); + + RequestContextHolder.resetRequestAttributes(); + this.wac.close(); + } + + + // ------------------------------------------------------------------- + + @Configuration + @EnableWebMvc + static class WebConfig implements WebMvcConfigurer { + + @Bean + public SingletonController singletonController() { + return new SingletonController(); + } + } + + @RestController + private static class SingletonController { + + @RequestMapping("/singletonController") + public void handle() { + assertRequestAttributes(); + } + } + + private static void assertRequestAttributes() { + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + assertThat(requestAttributes, instanceOf(ServletRequestAttributes.class)); + assertRequestAttributes(((ServletRequestAttributes) requestAttributes).getRequest()); + } + + private static void assertRequestAttributes(ServletRequest request) { + assertThat(request.getAttribute(FROM_CUSTOM_MOCK), is(nullValue())); + assertThat(request.getAttribute(FROM_MVC_TEST_DEFAULT), is(FROM_MVC_TEST_DEFAULT)); + assertThat(request.getAttribute(FROM_MVC_TEST_MOCK), is(FROM_MVC_TEST_MOCK)); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/EncodedUriTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/EncodedUriTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2059d3285fdd8c927733d2010f736ea8966786f5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/EncodedUriTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.spr; + + +import java.net.URI; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.PriorityOrdered; +import org.springframework.stereotype.Component; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.ResultActions; +import org.springframework.ui.Model; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.ViewResolverRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; +import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping; +import org.springframework.web.util.UriComponentsBuilder; + +import static org.hamcrest.core.Is.is; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.webAppContextSetup; + +/** + * Tests for SPR-11441 (MockMvc accepts an already encoded URI). + * + * @author Sebastien Deleuze + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration +@ContextConfiguration +public class EncodedUriTests { + + @Autowired + private WebApplicationContext wac; + + private MockMvc mockMvc; + + @Before + public void setup() { + this.mockMvc = webAppContextSetup(this.wac).build(); + } + + @Test + public void test() throws Exception { + String id = "a/b"; + URI url = UriComponentsBuilder.fromUriString("/circuit").pathSegment(id).build().encode().toUri(); + ResultActions result = mockMvc.perform(get(url)); + result.andExpect(status().isOk()).andExpect(model().attribute("receivedId", is(id))); + } + + + @Configuration + @EnableWebMvc + static class WebConfig implements WebMvcConfigurer { + + @Bean + public MyController myController() { + return new MyController(); + } + + @Bean + public HandlerMappingConfigurer myHandlerMappingConfigurer() { + return new HandlerMappingConfigurer(); + } + + @Override + public void configureViewResolvers(ViewResolverRegistry registry) { + registry.jsp("", ""); + } + } + + @Controller + private static class MyController { + + @RequestMapping(value = "/circuit/{id}", method = RequestMethod.GET) + public String getCircuit(@PathVariable String id, Model model) { + model.addAttribute("receivedId", id); + return "result"; + } + } + + @Component + private static class HandlerMappingConfigurer implements BeanPostProcessor, PriorityOrdered { + + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + if (bean instanceof RequestMappingHandlerMapping) { + RequestMappingHandlerMapping requestMappingHandlerMapping = (RequestMappingHandlerMapping) bean; + + // URL decode after request mapping, not before. + requestMappingHandlerMapping.setUrlDecode(false); + } + return bean; + } + + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + return bean; + } + + public int getOrder() { + return PriorityOrdered.HIGHEST_PRECEDENCE; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/FormContentTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/FormContentTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c7bbe8e1180e76b4922832fed765c80371889c7e --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/FormContentTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.spr; + +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.filter.FormContentFilter; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; + +/** + * Test for issues related to form content. + * + * @author Rossen Stoyanchev + */ +public class FormContentTests { + + @Test // SPR-15753 + public void formContentIsNotDuplicated() throws Exception { + + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Spr15753Controller()) + .addFilter(new FormContentFilter()) + .build(); + + mockMvc.perform(put("/").content("d1=a&d2=s").contentType(MediaType.APPLICATION_FORM_URLENCODED)) + .andExpect(content().string("d1:a, d2:s.")); + } + + + @RestController + private static class Spr15753Controller { + + @PutMapping + public String test(Data d) { + return String.format("d1:%s, d2:%s.", d.getD1(), d.getD2()); + } + } + + @SuppressWarnings("unused") + private static class Data { + + private String d1; + + private String d2; + + public Data() { + } + + public String getD1() { + return d1; + } + + public void setD1(String d1) { + this.d1 = d1; + } + + public String getD2() { + return d2; + } + + public void setD2(String d2) { + this.d2 = d2; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/HttpOptionsTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/HttpOptionsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c4fef7c1de841cfd9f0bbe545f052d7fe38bff50 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/HttpOptionsTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.spr; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import static org.junit.Assert.assertEquals; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.options; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.webAppContextSetup; + +/** + * Tests for SPR-10093 (support for OPTIONS requests). + * + * @author Arnaud Cogoluègnes + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration +@ContextConfiguration +public class HttpOptionsTests { + + @Autowired + private WebApplicationContext wac; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = webAppContextSetup(this.wac).dispatchOptions(true).build(); + } + + @Test + public void test() throws Exception { + MyController controller = this.wac.getBean(MyController.class); + int initialCount = controller.counter.get(); + this.mockMvc.perform(options("/myUrl")).andExpect(status().isOk()); + + assertEquals(initialCount + 1, controller.counter.get()); + } + + + @Configuration + @EnableWebMvc + static class WebConfig implements WebMvcConfigurer { + + @Bean + public MyController myController() { + return new MyController(); + } + } + + @Controller + private static class MyController { + + private AtomicInteger counter = new AtomicInteger(0); + + + @RequestMapping(value = "/myUrl", method = RequestMethod.OPTIONS) + @ResponseBody + public void handle() { + counter.incrementAndGet(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/MockMvcBuilderMethodChainTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/MockMvcBuilderMethodChainTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d40ebaf7858e01e6f221816303702fc132c7269c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/MockMvcBuilderMethodChainTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.spr; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.filter.CharacterEncodingFilter; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; + +/** + * Test for SPR-10277 (multiple method chaining when building MockMvc). + * + * @author Wesley Hall + */ +@RunWith(SpringJUnit4ClassRunner.class) +@WebAppConfiguration +@ContextConfiguration +public class MockMvcBuilderMethodChainTests { + + @Autowired + private WebApplicationContext wac; + + @Test + public void chainMultiple() { + MockMvcBuilders + .webAppContextSetup(wac) + .addFilter(new CharacterEncodingFilter() ) + .defaultRequest(get("/").contextPath("/mywebapp")) + .build(); + } + + @Configuration + @EnableWebMvc + static class WebConfig implements WebMvcConfigurer { + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/RequestContextHolderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/RequestContextHolderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2e989c1d44c05489a88e07bc1c41406c4197cf8a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/spr/RequestContextHolderTests.java @@ -0,0 +1,329 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.spr; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; +import org.springframework.context.annotation.ScopedProxyMode; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.annotation.RequestScope; +import org.springframework.web.context.annotation.SessionScope; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.filter.GenericFilterBean; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Integration tests for the following use cases. + *

    + *
  • SPR-10025: Access to request attributes via RequestContextHolder
  • + *
  • SPR-13217: Populate RequestAttributes before invoking Filters in MockMvc
  • + *
  • SPR-13260: No reuse of mock requests
  • + *
+ * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @see CustomRequestAttributesRequestContextHolderTests + */ +@RunWith(SpringRunner.class) +@WebAppConfiguration +@ContextConfiguration +@DirtiesContext +public class RequestContextHolderTests { + + private static final String FROM_TCF_MOCK = "fromTestContextFrameworkMock"; + private static final String FROM_MVC_TEST_DEFAULT = "fromSpringMvcTestDefault"; + private static final String FROM_MVC_TEST_MOCK = "fromSpringMvcTestMock"; + private static final String FROM_REQUEST_FILTER = "fromRequestFilter"; + private static final String FROM_REQUEST_ATTRIBUTES_FILTER = "fromRequestAttributesFilter"; + + @Autowired + private WebApplicationContext wac; + + @Autowired + private MockHttpServletRequest mockRequest; + + @Autowired + private RequestScopedController requestScopedController; + + @Autowired + private RequestScopedService requestScopedService; + + @Autowired + private SessionScopedService sessionScopedService; + + @Autowired + private FilterWithSessionScopedService filterWithSessionScopedService; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockRequest.setAttribute(FROM_TCF_MOCK, FROM_TCF_MOCK); + + this.mockMvc = webAppContextSetup(this.wac) + .addFilters(new RequestFilter(), new RequestAttributesFilter(), this.filterWithSessionScopedService) + .defaultRequest(get("/").requestAttr(FROM_MVC_TEST_DEFAULT, FROM_MVC_TEST_DEFAULT)) + .alwaysExpect(status().isOk()) + .build(); + } + + @Test + public void singletonController() throws Exception { + this.mockMvc.perform(get("/singletonController").requestAttr(FROM_MVC_TEST_MOCK, FROM_MVC_TEST_MOCK)); + } + + @Test + public void requestScopedController() throws Exception { + assertTrue("request-scoped controller must be a CGLIB proxy", AopUtils.isCglibProxy(this.requestScopedController)); + this.mockMvc.perform(get("/requestScopedController").requestAttr(FROM_MVC_TEST_MOCK, FROM_MVC_TEST_MOCK)); + } + + @Test + public void requestScopedService() throws Exception { + assertTrue("request-scoped service must be a CGLIB proxy", AopUtils.isCglibProxy(this.requestScopedService)); + this.mockMvc.perform(get("/requestScopedService").requestAttr(FROM_MVC_TEST_MOCK, FROM_MVC_TEST_MOCK)); + } + + @Test + public void sessionScopedService() throws Exception { + assertTrue("session-scoped service must be a CGLIB proxy", AopUtils.isCglibProxy(this.sessionScopedService)); + this.mockMvc.perform(get("/sessionScopedService").requestAttr(FROM_MVC_TEST_MOCK, FROM_MVC_TEST_MOCK)); + } + + @After + public void verifyRestoredRequestAttributes() { + assertRequestAttributes(false); + } + + + // ------------------------------------------------------------------- + + @Configuration + @EnableWebMvc + static class WebConfig implements WebMvcConfigurer { + + @Bean + public SingletonController singletonController() { + return new SingletonController(); + } + + @Bean + @Scope(scopeName = "request", proxyMode = ScopedProxyMode.TARGET_CLASS) + public RequestScopedController requestScopedController() { + return new RequestScopedController(); + } + + @Bean + @RequestScope + public RequestScopedService requestScopedService() { + return new RequestScopedService(); + } + + @Bean + public ControllerWithRequestScopedService controllerWithRequestScopedService() { + return new ControllerWithRequestScopedService(); + } + + @Bean + @SessionScope + public SessionScopedService sessionScopedService() { + return new SessionScopedService(); + } + + @Bean + public ControllerWithSessionScopedService controllerWithSessionScopedService() { + return new ControllerWithSessionScopedService(); + } + + @Bean + public FilterWithSessionScopedService filterWithSessionScopedService() { + return new FilterWithSessionScopedService(); + } + } + + @RestController + static class SingletonController { + + @RequestMapping("/singletonController") + public void handle() { + assertRequestAttributes(); + } + } + + @RestController + static class RequestScopedController { + + @Autowired + private ServletRequest request; + + + @RequestMapping("/requestScopedController") + public void handle() { + assertRequestAttributes(request); + assertRequestAttributes(); + } + } + + static class RequestScopedService { + + @Autowired + private ServletRequest request; + + + void process() { + assertRequestAttributes(request); + } + } + + static class SessionScopedService { + + @Autowired + private ServletRequest request; + + + void process() { + assertRequestAttributes(this.request); + } + } + + @RestController + static class ControllerWithRequestScopedService { + + @Autowired + private RequestScopedService service; + + + @RequestMapping("/requestScopedService") + public void handle() { + this.service.process(); + assertRequestAttributes(); + } + } + + @RestController + static class ControllerWithSessionScopedService { + + @Autowired + private SessionScopedService service; + + + @RequestMapping("/sessionScopedService") + public void handle() { + this.service.process(); + assertRequestAttributes(); + } + } + + static class FilterWithSessionScopedService extends GenericFilterBean { + + @Autowired + private SessionScopedService service; + + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + this.service.process(); + assertRequestAttributes(request); + assertRequestAttributes(); + chain.doFilter(request, response); + } + } + + static class RequestFilter extends GenericFilterBean { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + request.setAttribute(FROM_REQUEST_FILTER, FROM_REQUEST_FILTER); + chain.doFilter(request, response); + } + } + + static class RequestAttributesFilter extends GenericFilterBean { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + RequestContextHolder.getRequestAttributes().setAttribute(FROM_REQUEST_ATTRIBUTES_FILTER, FROM_REQUEST_ATTRIBUTES_FILTER, RequestAttributes.SCOPE_REQUEST); + chain.doFilter(request, response); + } + } + + + private static void assertRequestAttributes() { + assertRequestAttributes(true); + } + + private static void assertRequestAttributes(boolean withinMockMvc) { + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + assertThat(requestAttributes, instanceOf(ServletRequestAttributes.class)); + assertRequestAttributes(((ServletRequestAttributes) requestAttributes).getRequest(), withinMockMvc); + } + + private static void assertRequestAttributes(ServletRequest request) { + assertRequestAttributes(request, true); + } + + private static void assertRequestAttributes(ServletRequest request, boolean withinMockMvc) { + if (withinMockMvc) { + assertThat(request.getAttribute(FROM_TCF_MOCK), is(nullValue())); + assertThat(request.getAttribute(FROM_MVC_TEST_DEFAULT), is(FROM_MVC_TEST_DEFAULT)); + assertThat(request.getAttribute(FROM_MVC_TEST_MOCK), is(FROM_MVC_TEST_MOCK)); + assertThat(request.getAttribute(FROM_REQUEST_FILTER), is(FROM_REQUEST_FILTER)); + assertThat(request.getAttribute(FROM_REQUEST_ATTRIBUTES_FILTER), is(FROM_REQUEST_ATTRIBUTES_FILTER)); + } + else { + assertThat(request.getAttribute(FROM_TCF_MOCK), is(FROM_TCF_MOCK)); + assertThat(request.getAttribute(FROM_MVC_TEST_DEFAULT), is(nullValue())); + assertThat(request.getAttribute(FROM_MVC_TEST_MOCK), is(nullValue())); + assertThat(request.getAttribute(FROM_REQUEST_FILTER), is(nullValue())); + assertThat(request.getAttribute(FROM_REQUEST_ATTRIBUTES_FILTER), is(nullValue())); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0763d15da08825847996d400730668d3b3c491f2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java @@ -0,0 +1,294 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.junit.Test; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureTask; +import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseStatus; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.request.async.DeferredResult; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; + +import static org.junit.Assert.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Tests with asynchronous request handling. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @author Sam Brannen + * @author Jacek Suchenia + */ +public class AsyncTests { + + private final AsyncController asyncController = new AsyncController(); + + private final MockMvc mockMvc = standaloneSetup(this.asyncController).build(); + + + @Test + public void callable() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("callable", "true")) + .andExpect(request().asyncStarted()) + .andExpect(request().asyncResult(new Person("Joe"))) + .andReturn(); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + @Test + public void streaming() throws Exception { + this.mockMvc.perform(get("/1").param("streaming", "true")) + .andExpect(request().asyncStarted()) + .andDo(MvcResult::getAsyncResult) // fetch async result similar to "asyncDispatch" builder + .andExpect(status().isOk()) + .andExpect(content().string("name=Joe")); + } + + @Test + public void streamingSlow() throws Exception { + this.mockMvc.perform(get("/1").param("streamingSlow", "true")) + .andExpect(request().asyncStarted()) + .andDo(MvcResult::getAsyncResult) + .andExpect(status().isOk()) + .andExpect(content().string("name=Joe&someBoolean=true")); + } + + @Test + public void streamingJson() throws Exception { + this.mockMvc.perform(get("/1").param("streamingJson", "true")) + .andExpect(request().asyncStarted()) + .andDo(MvcResult::getAsyncResult) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.5}")); + } + + @Test + public void deferredResult() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResult", "true")) + .andExpect(request().asyncStarted()) + .andReturn(); + + this.asyncController.onMessage("Joe"); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + @Test + public void deferredResultWithImmediateValue() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResultWithImmediateValue", "true")) + .andExpect(request().asyncStarted()) + .andExpect(request().asyncResult(new Person("Joe"))) + .andReturn(); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + @Test // SPR-13079 + public void deferredResultWithDelayedError() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResultWithDelayedError", "true")) + .andExpect(request().asyncStarted()) + .andReturn(); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().is5xxServerError()) + .andExpect(content().string("Delayed Error")); + } + + @Test + public void listenableFuture() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("listenableFuture", "true")) + .andExpect(request().asyncStarted()) + .andReturn(); + + this.asyncController.onMessage("Joe"); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + @Test // SPR-12597 + public void completableFutureWithImmediateValue() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("completableFutureWithImmediateValue", "true")) + .andExpect(request().asyncStarted()) + .andReturn(); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + @Test // SPR-12735 + public void printAsyncResult() throws Exception { + StringWriter writer = new StringWriter(); + + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResult", "true")) + .andDo(print(writer)) + .andExpect(request().asyncStarted()) + .andReturn(); + + assertTrue(writer.toString().contains("Async started = true")); + writer = new StringWriter(); + + this.asyncController.onMessage("Joe"); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andDo(print(writer)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); + + assertTrue(writer.toString().contains("Async started = false")); + } + + + @RestController + @RequestMapping(path = "/{id}", produces = "application/json") + private static class AsyncController { + + private final Collection> deferredResults = new CopyOnWriteArrayList<>(); + + private final Collection> futureTasks = new CopyOnWriteArrayList<>(); + + @RequestMapping(params = "callable") + public Callable getCallable() { + return () -> new Person("Joe"); + } + + @RequestMapping(params = "streaming") + public StreamingResponseBody getStreaming() { + return os -> os.write("name=Joe".getBytes(StandardCharsets.UTF_8)); + } + + @RequestMapping(params = "streamingSlow") + public StreamingResponseBody getStreamingSlow() { + return os -> { + os.write("name=Joe".getBytes()); + try { + Thread.sleep(200); + os.write("&someBoolean=true".getBytes(StandardCharsets.UTF_8)); + } + catch (InterruptedException e) { + /* no-op */ + } + }; + } + + @RequestMapping(params = "streamingJson") + public ResponseEntity getStreamingJson() { + return ResponseEntity.ok().contentType(MediaType.APPLICATION_JSON_UTF8) + .body(os -> os.write("{\"name\":\"Joe\",\"someDouble\":0.5}".getBytes(StandardCharsets.UTF_8))); + } + + @RequestMapping(params = "deferredResult") + public DeferredResult getDeferredResult() { + DeferredResult deferredResult = new DeferredResult<>(); + this.deferredResults.add(deferredResult); + return deferredResult; + } + + @RequestMapping(params = "deferredResultWithImmediateValue") + public DeferredResult getDeferredResultWithImmediateValue() { + DeferredResult deferredResult = new DeferredResult<>(); + deferredResult.setResult(new Person("Joe")); + return deferredResult; + } + + @RequestMapping(params = "deferredResultWithDelayedError") + public DeferredResult getDeferredResultWithDelayedError() { + final DeferredResult deferredResult = new DeferredResult<>(); + new Thread() { + public void run() { + try { + Thread.sleep(100); + deferredResult.setErrorResult(new RuntimeException("Delayed Error")); + } + catch (InterruptedException e) { + /* no-op */ + } + } + }.start(); + return deferredResult; + } + + @RequestMapping(params = "listenableFuture") + public ListenableFuture getListenableFuture() { + ListenableFutureTask futureTask = new ListenableFutureTask<>(() -> new Person("Joe")); + this.futureTasks.add(futureTask); + return futureTask; + } + + @RequestMapping(params = "completableFutureWithImmediateValue") + public CompletableFuture getCompletableFutureWithImmediateValue() { + CompletableFuture future = new CompletableFuture<>(); + future.complete(new Person("Joe")); + return future; + } + + @ExceptionHandler(Exception.class) + @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR) + public String errorHandler(Exception e) { + return e.getMessage(); + } + + void onMessage(String name) { + for (DeferredResult deferredResult : this.deferredResults) { + deferredResult.setResult(new Person(name)); + this.deferredResults.remove(deferredResult); + } + for (ListenableFutureTask futureTask : this.futureTasks) { + futureTask.run(); + this.futureTasks.remove(futureTask); + } + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..504927e47036a84966fcb0b9ca0683e1be0b2341 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java @@ -0,0 +1,212 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import org.junit.Test; + +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.RestControllerAdvice; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Exception handling via {@code @ExceptionHandler} method. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class ExceptionHandlerTests { + + @Test + public void mvcLocalExceptionHandlerMethod() throws Exception { + standaloneSetup(new PersonController()).build() + .perform(get("/person/Clyde")) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("errorView")); + } + + @Test + public void mvcGlobalExceptionHandlerMethod() throws Exception { + standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build() + .perform(get("/person/Bonnie")) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("globalErrorView")); + } + + @Test + public void mvcGlobalExceptionHandlerMethodUsingClassArgument() throws Exception { + standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build() + .perform(get("/person/Bonnie")) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("globalErrorView")); + } + + @Test + public void restNoException() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Yoda").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.name").value("Yoda")); + } + + @Test + public void restLocalExceptionHandlerMethod() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Luke").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("local - IllegalArgumentException")); + } + + @Test + public void restGlobalExceptionHandlerMethod() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class).build() + .perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("global - IllegalStateException")); + } + + @Test + public void restGlobalRestPersonControllerExceptionHandlerTakesPrecedenceOverGlobalExceptionHandler() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("globalPersonController - IllegalStateException")); + } + + @Test // gh-25520 + public void restNoHandlerFound() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class) + .addDispatcherServletCustomizer(dispatcherServlet -> dispatcherServlet.setThrowExceptionIfNoHandlerFound(true)) + .build() + .perform(get("/bogus").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("global - NoHandlerFoundException")); + } + + + @Controller + private static class PersonController { + + @GetMapping("/person/{name}") + public String show(@PathVariable String name) { + if (name.equals("Clyde")) { + throw new IllegalArgumentException("simulated exception"); + } + else if (name.equals("Bonnie")) { + throw new IllegalStateException("simulated exception"); + } + return "person/show"; + } + + @ExceptionHandler + public String handleException(IllegalArgumentException exception) { + return "errorView"; + } + } + + @ControllerAdvice + private static class GlobalExceptionHandler { + + @ExceptionHandler + public String handleException(IllegalStateException exception) { + return "globalErrorView"; + } + } + + @RestController + private static class RestPersonController { + + @GetMapping("/person/{name}") + Person get(@PathVariable String name) { + switch (name) { + case "Luke": + throw new IllegalArgumentException(); + case "Leia": + throw new IllegalStateException(); + default: + return new Person("Yoda"); + } + } + + @ExceptionHandler + Error handleException(IllegalArgumentException exception) { + return new Error("local - " + exception.getClass().getSimpleName()); + } + } + + @RestControllerAdvice(assignableTypes = RestPersonController.class) + @Order(Ordered.HIGHEST_PRECEDENCE) + private static class RestPersonControllerExceptionHandler { + + @ExceptionHandler + Error handleException(Throwable exception) { + return new Error("globalPersonController - " + exception.getClass().getSimpleName()); + } + } + + @RestControllerAdvice + @Order(Ordered.LOWEST_PRECEDENCE) + private static class RestGlobalExceptionHandler { + + @ExceptionHandler + Error handleException(Throwable exception) { + return new Error( "global - " + exception.getClass().getSimpleName()); + } + } + + static class Person { + + private final String name; + + Person(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + static class Error { + + private final String error; + + Error(String error) { + this.error = error; + } + + public String getError() { + return error; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..33719ff9fbdad0335132374810546b2aa211e7f6 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java @@ -0,0 +1,300 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import java.io.IOException; +import java.security.Principal; +import java.util.concurrent.CompletableFuture; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncListener; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; +import javax.validation.Valid; + +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.validation.Errors; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.filter.ShallowEtagHeaderFilter; +import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.servlet.mvc.support.RedirectAttributes; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.flash; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup; + +/** + * Tests with {@link Filter}'s. + * @author Rob Winch + */ +public class FilterTests { + + @Test + public void whenFiltersCompleteMvcProcessesRequest() throws Exception { + standaloneSetup(new PersonController()) + .addFilters(new ContinueFilter()).build() + .perform(post("/persons").param("name", "Andy")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/person/1")) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("id")) + .andExpect(flash().attributeCount(1)) + .andExpect(flash().attribute("message", "success!")); + } + + @Test + public void filtersProcessRequest() throws Exception { + standaloneSetup(new PersonController()) + .addFilters(new ContinueFilter(), new RedirectFilter()).build() + .perform(post("/persons").param("name", "Andy")) + .andExpect(redirectedUrl("/login")); + } + + @Test + public void filterMappedBySuffix() throws Exception { + standaloneSetup(new PersonController()) + .addFilter(new RedirectFilter(), "*.html").build() + .perform(post("/persons.html").param("name", "Andy")) + .andExpect(redirectedUrl("/login")); + } + + @Test + public void filterWithExactMapping() throws Exception { + standaloneSetup(new PersonController()) + .addFilter(new RedirectFilter(), "/p", "/persons").build() + .perform(post("/persons").param("name", "Andy")) + .andExpect(redirectedUrl("/login")); + } + + @Test + public void filterSkipped() throws Exception { + standaloneSetup(new PersonController()) + .addFilter(new RedirectFilter(), "/p", "/person").build() + .perform(post("/persons").param("name", "Andy")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/person/1")) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("id")) + .andExpect(flash().attributeCount(1)) + .andExpect(flash().attribute("message", "success!")); + } + + @Test + public void filterWrapsRequestResponse() throws Exception { + standaloneSetup(new PersonController()) + .addFilters(new WrappingRequestResponseFilter()).build() + .perform(post("/user")) + .andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME)); + } + + @Test // SPR-16067, SPR-16695 + public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception { + MockMvc mockMvc = standaloneSetup(new PersonController()) + .addFilters(new WrappingRequestResponseFilter(), new ShallowEtagHeaderFilter()) + .build(); + + MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andExpect(request().asyncResult(new Person("Lukas"))) + .andReturn(); + + mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(header().longValue("Content-Length", 53)) + .andExpect(header().string("ETag", "\"0e37becb4f0c90709cb2e1efcc61eaa00\"")) + .andExpect(content().string("{\"name\":\"Lukas\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + + + @Controller + private static class PersonController { + + @PostMapping(path="/persons") + public String save(@Valid Person person, Errors errors, RedirectAttributes redirectAttrs) { + if (errors.hasErrors()) { + return "person/add"; + } + redirectAttrs.addAttribute("id", "1"); + redirectAttrs.addFlashAttribute("message", "success!"); + return "redirect:/person/{id}"; + } + + @PostMapping("/user") + public ModelAndView user(Principal principal) { + return new ModelAndView("user/view", "principal", principal.getName()); + } + + @GetMapping("/forward") + public String forward() { + return "forward:/persons"; + } + + @GetMapping("persons/{id}") + @ResponseBody + public CompletableFuture getPerson() { + return CompletableFuture.completedFuture(new Person("Lukas")); + } + } + + private class ContinueFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + + filterChain.doFilter(request, response); + } + } + + private static class WrappingRequestResponseFilter extends OncePerRequestFilter { + + public static final String PRINCIPAL_NAME = "WrapRequestResponseFilterPrincipal"; + + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + filterChain.doFilter(new HttpServletRequestWrapper(request) { + + @Override + public Principal getUserPrincipal() { + return () -> PRINCIPAL_NAME; + } + + // Like Spring Security does in HttpServlet3RequestFactory.. + + @Override + public AsyncContext getAsyncContext() { + return super.getAsyncContext() != null ? + new AsyncContextWrapper(super.getAsyncContext()) : null; + } + + }, new HttpServletResponseWrapper(response)); + } + } + + private class RedirectFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + response.sendRedirect("/login"); + } + } + + + private static class AsyncContextWrapper implements AsyncContext { + + private final AsyncContext delegate; + + public AsyncContextWrapper(AsyncContext delegate) { + this.delegate = delegate; + } + + @Override + public ServletRequest getRequest() { + return this.delegate.getRequest(); + } + + @Override + public ServletResponse getResponse() { + return this.delegate.getResponse(); + } + + @Override + public boolean hasOriginalRequestAndResponse() { + return this.delegate.hasOriginalRequestAndResponse(); + } + + @Override + public void dispatch() { + this.delegate.dispatch(); + } + + @Override + public void dispatch(String path) { + this.delegate.dispatch(path); + } + + @Override + public void dispatch(ServletContext context, String path) { + this.delegate.dispatch(context, path); + } + + @Override + public void complete() { + this.delegate.complete(); + } + + @Override + public void start(Runnable run) { + this.delegate.start(run); + } + + @Override + public void addListener(AsyncListener listener) { + this.delegate.addListener(listener); + } + + @Override + public void addListener(AsyncListener listener, ServletRequest req, ServletResponse res) { + this.delegate.addListener(listener, req, res); + } + + @Override + public T createListener(Class clazz) throws ServletException { + return this.delegate.createListener(clazz); + } + + @Override + public void setTimeout(long timeout) { + this.delegate.setTimeout(timeout); + } + + @Override + public long getTimeout() { + return this.delegate.getTimeout(); + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FrameworkExtensionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FrameworkExtensionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..eb676ba94ff2e8b52d02a3212708ed9c9ba3085a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FrameworkExtensionTests.java @@ -0,0 +1,150 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import java.security.Principal; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.RequestPostProcessor; +import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder; +import org.springframework.test.web.servlet.setup.MockMvcConfigurerAdapter; +import org.springframework.util.Assert; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.context.WebApplicationContext; + +import static org.mockito.Mockito.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Demonstrates use of SPI extension points: + *
    + *
  • {@link org.springframework.test.web.servlet.request.RequestPostProcessor} + * for extending request building with custom methods. + *
  • {@link org.springframework.test.web.servlet.setup.MockMvcConfigurer + * MockMvcConfigurer} for extending MockMvc building with some automatic setup. + *
+ * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class FrameworkExtensionTests { + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new SampleController()).apply(defaultSetup()).build(); + } + + @Test + public void fooHeader() throws Exception { + this.mockMvc.perform(get("/").with(headers().foo("a=b"))).andExpect(content().string("Foo")); + } + + @Test + public void barHeader() throws Exception { + this.mockMvc.perform(get("/").with(headers().bar("a=b"))).andExpect(content().string("Bar")); + } + + private static TestMockMvcConfigurer defaultSetup() { + return new TestMockMvcConfigurer(); + } + + private static TestRequestPostProcessor headers() { + return new TestRequestPostProcessor(); + } + + + /** + * Test {@code RequestPostProcessor}. + */ + private static class TestRequestPostProcessor implements RequestPostProcessor { + + private HttpHeaders headers = new HttpHeaders(); + + + public TestRequestPostProcessor foo(String value) { + this.headers.add("Foo", value); + return this; + } + + public TestRequestPostProcessor bar(String value) { + this.headers.add("Bar", value); + return this; + } + + @Override + public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { + for (String headerName : this.headers.keySet()) { + request.addHeader(headerName, this.headers.get(headerName)); + } + return request; + } + } + + + /** + * Test {@code MockMvcConfigurer}. + */ + private static class TestMockMvcConfigurer extends MockMvcConfigurerAdapter { + + @Override + public void afterConfigurerAdded(ConfigurableMockMvcBuilder builder) { + builder.alwaysExpect(status().isOk()); + } + + @Override + public RequestPostProcessor beforeMockMvcCreated(ConfigurableMockMvcBuilder builder, + WebApplicationContext context) { + return request -> { + request.setUserPrincipal(mock(Principal.class)); + return request; + }; + } + } + + + @Controller + @RequestMapping("/") + private static class SampleController { + + @RequestMapping(headers = "Foo") + @ResponseBody + public String handleFoo(Principal principal) { + Assert.notNull(principal, "Principal must not be null"); + return "Foo"; + } + + @RequestMapping(headers = "Bar") + @ResponseBody + public String handleBar(Principal principal) { + Assert.notNull(principal, "Principal must not be null"); + return "Bar"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d4d209f7c1e43b5476dc9a41a9c3a9f649cb61ba --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java @@ -0,0 +1,372 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.Part; + +import org.junit.Assert; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockMultipartFile; +import org.springframework.mock.web.MockPart; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.ui.Model; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.multipart.MultipartFile; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class MultipartControllerTests { + + @Test + public void multipartRequestWithSingleFile() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockMultipartFile filePart = new MockMultipartFile("file", "orig", null, fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfile").file(filePart).file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithSingleFileNotPresent() throws Exception { + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfile")) + .andExpect(status().isFound()); + } + + @Test + public void multipartRequestWithFileArray() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockMultipartFile filePart1 = new MockMultipartFile("file", "orig", null, fileContent); + MockMultipartFile filePart2 = new MockMultipartFile("file", "orig", null, fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfilearray").file(filePart1).file(filePart2).file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithFileArrayNotPresent() throws Exception { + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfilearray")) + .andExpect(status().isFound()); + } + + @Test + public void multipartRequestWithFileArrayNoMultipart() throws Exception { + standaloneSetup(new MultipartController()).build() + .perform(post("/multipartfilearray")) + .andExpect(status().isFound()); + } + + @Test + public void multipartRequestWithFileList() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockMultipartFile filePart1 = new MockMultipartFile("file", "orig", null, fileContent); + MockMultipartFile filePart2 = new MockMultipartFile("file", "orig", null, fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfilelist").file(filePart1).file(filePart2).file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithFileListNotPresent() throws Exception { + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfilelist")) + .andExpect(status().isFound()); + } + + @Test + public void multipartRequestWithFileListNoMultipart() throws Exception { + standaloneSetup(new MultipartController()).build() + .perform(post("/multipartfilelist")) + .andExpect(status().isFound()); + } + + @Test + public void multipartRequestWithOptionalFile() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockMultipartFile filePart = new MockMultipartFile("file", "orig", null, fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/optionalfile").file(filePart).file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithOptionalFileNotPresent() throws Exception { + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/optionalfile").file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attributeDoesNotExist("fileContent")) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithOptionalFileArray() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockMultipartFile filePart1 = new MockMultipartFile("file", "orig", null, fileContent); + MockMultipartFile filePart2 = new MockMultipartFile("file", "orig", null, fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/optionalfilearray").file(filePart1).file(filePart2).file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithOptionalFileArrayNotPresent() throws Exception { + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/optionalfilearray").file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attributeDoesNotExist("fileContent")) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithOptionalFileList() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockMultipartFile filePart1 = new MockMultipartFile("file", "orig", null, fileContent); + MockMultipartFile filePart2 = new MockMultipartFile("file", "orig", null, fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/optionalfilelist").file(filePart1).file(filePart2).file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithOptionalFileListNotPresent() throws Exception { + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/optionalfilelist").file(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attributeDoesNotExist("fileContent")) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithServletParts() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockPart filePart = new MockPart("file", "orig", fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockPart jsonPart = new MockPart("json", "json", json); + jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfile").part(filePart).part(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test // SPR-13317 + public void multipartRequestWrapped() throws Exception { + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + Filter filter = new RequestWrappingFilter(); + MockMvc mockMvc = standaloneSetup(new MultipartController()).addFilter(filter).build(); + + Map jsonMap = Collections.singletonMap("name", "yeeeah"); + mockMvc.perform(multipart("/json").file(jsonPart)).andExpect(model().attribute("json", jsonMap)); + } + + + @Controller + private static class MultipartController { + + @RequestMapping(value = "/multipartfile", method = RequestMethod.POST) + public String processMultipartFile(@RequestParam(required = false) MultipartFile file, + @RequestPart(required = false) Map json, Model model) throws IOException { + + if (file != null) { + model.addAttribute("fileContent", file.getBytes()); + } + if (json != null) { + model.addAttribute("jsonContent", json); + } + + return "redirect:/index"; + } + + @RequestMapping(value = "/multipartfilearray", method = RequestMethod.POST) + public String processMultipartFileArray(@RequestParam(required = false) MultipartFile[] file, + @RequestPart(required = false) Map json, Model model) throws IOException { + + if (file != null && file.length > 0) { + byte[] content = file[0].getBytes(); + Assert.assertArrayEquals(content, file[1].getBytes()); + model.addAttribute("fileContent", content); + } + if (json != null) { + model.addAttribute("jsonContent", json); + } + + return "redirect:/index"; + } + + @RequestMapping(value = "/multipartfilelist", method = RequestMethod.POST) + public String processMultipartFileList(@RequestParam(required = false) List file, + @RequestPart(required = false) Map json, Model model) throws IOException { + + if (file != null && !file.isEmpty()) { + byte[] content = file.get(0).getBytes(); + Assert.assertArrayEquals(content, file.get(1).getBytes()); + model.addAttribute("fileContent", content); + } + if (json != null) { + model.addAttribute("jsonContent", json); + } + + return "redirect:/index"; + } + + @RequestMapping(value = "/optionalfile", method = RequestMethod.POST) + public String processOptionalFile(@RequestParam Optional file, + @RequestPart Map json, Model model) throws IOException { + + if (file.isPresent()) { + model.addAttribute("fileContent", file.get().getBytes()); + } + model.addAttribute("jsonContent", json); + + return "redirect:/index"; + } + + @RequestMapping(value = "/optionalfilearray", method = RequestMethod.POST) + public String processOptionalFileArray(@RequestParam Optional file, + @RequestPart Map json, Model model) throws IOException { + + if (file.isPresent()) { + byte[] content = file.get()[0].getBytes(); + Assert.assertArrayEquals(content, file.get()[1].getBytes()); + model.addAttribute("fileContent", content); + } + model.addAttribute("jsonContent", json); + + return "redirect:/index"; + } + + @RequestMapping(value = "/optionalfilelist", method = RequestMethod.POST) + public String processOptionalFileList(@RequestParam Optional> file, + @RequestPart Map json, Model model) throws IOException { + + if (file.isPresent()) { + byte[] content = file.get().get(0).getBytes(); + Assert.assertArrayEquals(content, file.get().get(1).getBytes()); + model.addAttribute("fileContent", content); + } + model.addAttribute("jsonContent", json); + + return "redirect:/index"; + } + + @RequestMapping(value = "/part", method = RequestMethod.POST) + public String processPart(@RequestParam Part part, + @RequestPart Map json, Model model) throws IOException { + + model.addAttribute("fileContent", part.getInputStream()); + model.addAttribute("jsonContent", json); + + return "redirect:/index"; + } + + @RequestMapping(value = "/json", method = RequestMethod.POST) + public String processMultipart(@RequestPart Map json, Model model) { + model.addAttribute("json", json); + return "redirect:/index"; + } + } + + + private static class RequestWrappingFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws IOException, ServletException { + + request = new HttpServletRequestWrapper(request); + filterChain.doFilter(request, response); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ReactiveReturnTypeTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ReactiveReturnTypeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..61bdafacb4299e29dfa5db894aa6c68d5dd8c7b2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ReactiveReturnTypeTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.servlet.samples.standalone; + +import java.time.Duration; + +import org.junit.Test; +import reactor.core.publisher.Flux; + +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; + +/** + * Tests with reactive return value types. + * + * @author Rossen Stoyanchev + */ +public class ReactiveReturnTypeTests { + + + @Test // SPR-16869 + public void sseWithFlux() throws Exception { + + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(ReactiveController.class).build(); + + MvcResult mvcResult = mockMvc.perform(get("/spr16869")) + .andExpect(request().asyncStarted()) + .andExpect(status().isOk()) + .andReturn(); + + mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(content().string("data:event0\n\ndata:event1\n\ndata:event2\n\n")); + } + + + + @RestController + static class ReactiveController { + + @GetMapping(path = "/spr16869", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + Flux sseFlux() { + return Flux.interval(Duration.ofSeconds(1)).take(3) + .map(aLong -> String.format("event%d", aLong)); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/RedirectTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/RedirectTests.java new file mode 100644 index 0000000000000000000000000000000000000000..03d8b858f86baf62194521ef605962564e35742c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/RedirectTests.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import javax.validation.Valid; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.ui.Model; +import org.springframework.validation.Errors; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.servlet.mvc.support.RedirectAttributes; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Redirect scenarios including saving and retrieving flash attributes. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class RedirectTests { + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new PersonController()).build(); + } + + + @Test + public void save() throws Exception { + this.mockMvc.perform(post("/persons").param("name", "Andy")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/persons/Joe")) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("name")) + .andExpect(flash().attributeCount(1)) + .andExpect(flash().attribute("message", "success!")); + } + + @Test + public void saveSpecial() throws Exception { + this.mockMvc.perform(post("/people").param("name", "Andy")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/persons/Joe")) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("name")) + .andExpect(flash().attributeCount(1)) + .andExpect(flash().attribute("message", "success!")); + } + + @Test + public void saveWithErrors() throws Exception { + this.mockMvc.perform(post("/persons")) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("persons/add")) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("person")) + .andExpect(flash().attributeCount(0)); + } + + @Test + public void saveSpecialWithErrors() throws Exception { + this.mockMvc.perform(post("/people")) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("persons/add")) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("person")) + .andExpect(flash().attributeCount(0)); + } + + @Test + public void getPerson() throws Exception { + this.mockMvc.perform(get("/persons/Joe").flashAttr("message", "success!")) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("persons/index")) + .andExpect(model().size(2)) + .andExpect(model().attribute("person", new Person("Joe"))) + .andExpect(model().attribute("message", "success!")) + .andExpect(flash().attributeCount(0)); + } + + + @Controller + private static class PersonController { + + @GetMapping("/persons/{name}") + public String getPerson(@PathVariable String name, Model model) { + model.addAttribute(new Person(name)); + return "persons/index"; + } + + @PostMapping + public String save(@Valid Person person, Errors errors, RedirectAttributes redirectAttrs) { + if (errors.hasErrors()) { + return "persons/add"; + } + redirectAttrs.addAttribute("name", "Joe"); + redirectAttrs.addFlashAttribute("message", "success!"); + return "redirect:/persons/{name}"; + } + + @PostMapping("/people") + public Object saveSpecial(@Valid Person person, Errors errors, RedirectAttributes redirectAttrs) { + if (errors.hasErrors()) { + return "persons/add"; + } + redirectAttrs.addAttribute("name", "Joe"); + redirectAttrs.addFlashAttribute("message", "success!"); + return new StringBuilder("redirect:").append("/persons").append("/{name}"); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/RequestParameterTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/RequestParameterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9db137d8be3948af9af79dd2f88113cdf806a3eb --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/RequestParameterTests.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.test.web.servlet.samples.standalone; + +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.ResponseBody; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Tests demonstrating the use of request parameters. + * + * @author Rossen Stoyanchev + */ +public class RequestParameterTests { + + @Test + public void queryParameter() throws Exception { + + standaloneSetup(new PersonController()).build() + .perform(get("/search?name=George").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.name").value("George")); + } + + + @Controller + private class PersonController { + + @RequestMapping(value="/search") + @ResponseBody + public Person get(@RequestParam String name) { + Person person = new Person(name); + return person; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ResponseBodyTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ResponseBodyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..69cc6c53fb55e83a202602f02265190432f0fc5a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ResponseBodyTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Response written from {@code @ResponseBody} method. + * + * @author Rossen Stoyanchev + */ +public class ResponseBodyTests { + + @Test + public void json() throws Exception { + standaloneSetup(new PersonController()).build() + .perform(get("/person/Lee").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(content().contentType("application/json;charset=UTF-8")) + .andExpect(jsonPath("$.name").value("Lee")); + } + + + @Controller + private class PersonController { + + @RequestMapping(value="/person/{name}") + @ResponseBody + public Person get(@PathVariable String name) { + return new Person(name); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ViewResolutionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ViewResolutionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a85f75cccbcf38f7a41a8188c4d0957bcffa2c77 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ViewResolutionTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.oxm.jaxb.Jaxb2Marshaller; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.ui.Model; +import org.springframework.web.accept.ContentNegotiationManager; +import org.springframework.web.accept.FixedContentNegotiationStrategy; +import org.springframework.web.accept.HeaderContentNegotiationStrategy; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.servlet.View; +import org.springframework.web.servlet.view.ContentNegotiatingViewResolver; +import org.springframework.web.servlet.view.InternalResourceViewResolver; +import org.springframework.web.servlet.view.json.MappingJackson2JsonView; +import org.springframework.web.servlet.view.xml.MarshallingView; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Tests with view resolution. + * + * @author Rossen Stoyanchev + */ +public class ViewResolutionTests { + + @Test + public void testJspOnly() throws Exception { + InternalResourceViewResolver viewResolver = new InternalResourceViewResolver("/WEB-INF/", ".jsp"); + + standaloneSetup(new PersonController()).setViewResolvers(viewResolver).build() + .perform(get("/person/Corea")) + .andExpect(status().isOk()) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("person")) + .andExpect(forwardedUrl("/WEB-INF/person/show.jsp")); + } + + @Test + public void testJsonOnly() throws Exception { + standaloneSetup(new PersonController()).setSingleView(new MappingJackson2JsonView()).build() + .perform(get("/person/Corea")) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON)) + .andExpect(jsonPath("$.person.name").value("Corea")); + } + + @Test + public void testXmlOnly() throws Exception { + Jaxb2Marshaller marshaller = new Jaxb2Marshaller(); + marshaller.setClassesToBeBound(Person.class); + + standaloneSetup(new PersonController()).setSingleView(new MarshallingView(marshaller)).build() + .perform(get("/person/Corea")) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_XML)) + .andExpect(xpath("/person/name/text()").string(equalTo("Corea"))); + } + + @Test + public void testContentNegotiation() throws Exception { + Jaxb2Marshaller marshaller = new Jaxb2Marshaller(); + marshaller.setClassesToBeBound(Person.class); + + List viewList = new ArrayList<>(); + viewList.add(new MappingJackson2JsonView()); + viewList.add(new MarshallingView(marshaller)); + + ContentNegotiationManager manager = new ContentNegotiationManager( + new HeaderContentNegotiationStrategy(), new FixedContentNegotiationStrategy(MediaType.TEXT_HTML)); + + ContentNegotiatingViewResolver cnViewResolver = new ContentNegotiatingViewResolver(); + cnViewResolver.setDefaultViews(viewList); + cnViewResolver.setContentNegotiationManager(manager); + cnViewResolver.afterPropertiesSet(); + + MockMvc mockMvc = + standaloneSetup(new PersonController()) + .setViewResolvers(cnViewResolver, new InternalResourceViewResolver()) + .build(); + + mockMvc.perform(get("/person/Corea")) + .andExpect(status().isOk()) + .andExpect(model().size(1)) + .andExpect(model().attributeExists("person")) + .andExpect(forwardedUrl("person/show")); + + mockMvc.perform(get("/person/Corea").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON)) + .andExpect(jsonPath("$.person.name").value("Corea")); + + mockMvc.perform(get("/person/Corea").accept(MediaType.APPLICATION_XML)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_XML)) + .andExpect(xpath("/person/name/text()").string(equalTo("Corea"))); + } + + @Test + public void defaultViewResolver() throws Exception { + standaloneSetup(new PersonController()).build() + .perform(get("/person/Corea")) + .andExpect(model().attribute("person", hasProperty("name", equalTo("Corea")))) + .andExpect(status().isOk()) + .andExpect(forwardedUrl("person/show")); // InternalResourceViewResolver + } + + + @Controller + private static class PersonController { + + @GetMapping("/person/{name}") + public String show(@PathVariable String name, Model model) { + Person person = new Person(name); + model.addAttribute(person); + return "person/show"; + } + } + +} + diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resulthandlers/PrintingResultHandlerSmokeTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resulthandlers/PrintingResultHandlerSmokeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d70874cbd3d26ef8663d5c04d4d2bc3e08dc057c --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resulthandlers/PrintingResultHandlerSmokeTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resulthandlers; + +import java.io.StringWriter; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Ignore; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.result.PrintingResultHandler; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Smoke test for {@link PrintingResultHandler}. + * + *

Prints debugging information about the executed request and response to + * various output streams. + * + *

NOTE: this smoke test is not intended to be + * executed with the build. To run this test, comment out the {@code @Ignore} + * declaration and inspect the output manually. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @see org.springframework.test.web.servlet.result.PrintingResultHandlerTests + */ +@Ignore("Not intended to be executed with the build. Comment out this line to inspect the output manually.") +public class PrintingResultHandlerSmokeTests { + + @Test + public void testPrint() throws Exception { + StringWriter writer = new StringWriter(); + + standaloneSetup(new SimpleController()) + .build() + .perform(get("/").content("Hello Request".getBytes())) + .andDo(log()) + .andDo(print()) + .andDo(print(System.err)) + .andDo(print(writer)) + ; + + System.out.println(); + System.out.println("==============================================================="); + System.out.println(writer.toString()); + } + + + @Controller + private static class SimpleController { + + @RequestMapping("/") + @ResponseBody + public String hello(HttpServletResponse response) { + response.addCookie(new Cookie("enigma", "42")); + return "Hello Response"; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ContentAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ContentAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2a841456ebfa37b287e8ca504f724cd5a0614ca1 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ContentAssertionTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of defining expectations on the response content, content type, and + * the character encoding. + * + * @author Rossen Stoyanchev + * + * @see JsonPathAssertionTests + * @see XmlContentAssertionTests + * @see XpathAssertionTests + */ +public class ContentAssertionTests { + + private MockMvc mockMvc; + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new SimpleController()).alwaysExpect(status().isOk()).build(); + } + + @Test + public void testContentType() throws Exception { + this.mockMvc.perform(get("/handle").accept(MediaType.TEXT_PLAIN)) + .andExpect(content().contentType(MediaType.valueOf("text/plain;charset=ISO-8859-1"))) + .andExpect(content().contentType("text/plain;charset=ISO-8859-1")) + .andExpect(content().contentTypeCompatibleWith("text/plain")) + .andExpect(content().contentTypeCompatibleWith(MediaType.TEXT_PLAIN)); + + this.mockMvc.perform(get("/handleUtf8")) + .andExpect(content().contentType(MediaType.valueOf("text/plain;charset=UTF-8"))) + .andExpect(content().contentType("text/plain;charset=UTF-8")) + .andExpect(content().contentTypeCompatibleWith("text/plain")) + .andExpect(content().contentTypeCompatibleWith(MediaType.TEXT_PLAIN)); + } + + @Test + public void testContentAsString() throws Exception { + + this.mockMvc.perform(get("/handle").accept(MediaType.TEXT_PLAIN)) + .andExpect(content().string("Hello world!")); + + this.mockMvc.perform(get("/handleUtf8")) + .andExpect(content().string("\u3053\u3093\u306b\u3061\u306f\u4e16\u754c\uff01")); + + // Hamcrest matchers... + this.mockMvc.perform(get("/handle").accept(MediaType.TEXT_PLAIN)).andExpect(content().string(equalTo("Hello world!"))); + this.mockMvc.perform(get("/handleUtf8")).andExpect(content().string(equalTo("\u3053\u3093\u306b\u3061\u306f\u4e16\u754c\uff01"))); + } + + @Test + public void testContentAsBytes() throws Exception { + + this.mockMvc.perform(get("/handle").accept(MediaType.TEXT_PLAIN)) + .andExpect(content().bytes("Hello world!".getBytes("ISO-8859-1"))); + + this.mockMvc.perform(get("/handleUtf8")) + .andExpect(content().bytes("\u3053\u3093\u306b\u3061\u306f\u4e16\u754c\uff01".getBytes("UTF-8"))); + } + + @Test + public void testContentStringMatcher() throws Exception { + this.mockMvc.perform(get("/handle").accept(MediaType.TEXT_PLAIN)) + .andExpect(content().string(containsString("world"))); + } + + @Test + public void testCharacterEncoding() throws Exception { + + this.mockMvc.perform(get("/handle").accept(MediaType.TEXT_PLAIN)) + .andExpect(content().encoding("ISO-8859-1")) + .andExpect(content().string(containsString("world"))); + + this.mockMvc.perform(get("/handleUtf8")) + .andExpect(content().encoding("UTF-8")) + .andExpect(content().bytes("\u3053\u3093\u306b\u3061\u306f\u4e16\u754c\uff01".getBytes("UTF-8"))); + } + + + @Controller + private static class SimpleController { + + @RequestMapping(value="/handle", produces="text/plain") + @ResponseBody + public String handle() { + return "Hello world!"; + } + + @RequestMapping(value="/handleUtf8", produces="text/plain;charset=UTF-8") + @ResponseBody + public String handleWithCharset() { + return "\u3053\u3093\u306b\u3061\u306f\u4e16\u754c\uff01"; // "Hello world! (Japanese) + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/CookieAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/CookieAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..05b3b7db962840473563f74f340e706828c05931 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/CookieAssertionTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.servlet.i18n.CookieLocaleResolver; +import org.springframework.web.servlet.i18n.LocaleChangeInterceptor; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on response cookies values. + * + * @author Rossen Stoyanchev + * @author Nikola Yovchev + */ +public class CookieAssertionTests { + + private static final String COOKIE_NAME = CookieLocaleResolver.DEFAULT_COOKIE_NAME; + + private MockMvc mockMvc; + + + @Before + public void setup() { + CookieLocaleResolver localeResolver = new CookieLocaleResolver(); + localeResolver.setCookieDomain("domain"); + localeResolver.setCookieHttpOnly(true); + + this.mockMvc = standaloneSetup(new SimpleController()) + .addInterceptors(new LocaleChangeInterceptor()) + .setLocaleResolver(localeResolver) + .defaultRequest(get("/").param("locale", "en_US")) + .alwaysExpect(status().isOk()) + .build(); + } + + + @Test + public void testExists() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().exists(COOKIE_NAME)); + } + + @Test + public void testNotExists() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().doesNotExist("unknownCookie")); + } + + @Test + public void testEqualTo() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().value(COOKIE_NAME, "en-US")); + this.mockMvc.perform(get("/")).andExpect(cookie().value(COOKIE_NAME, equalTo("en-US"))); + } + + @Test + public void testMatcher() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().value(COOKIE_NAME, startsWith("en"))); + } + + @Test + public void testMaxAge() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().maxAge(COOKIE_NAME, -1)); + } + + @Test + public void testDomain() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().domain(COOKIE_NAME, "domain")); + } + + @Test + public void testVersion() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().version(COOKIE_NAME, 0)); + } + + @Test + public void testPath() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().path(COOKIE_NAME, "/")); + } + + @Test + public void testSecured() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().secure(COOKIE_NAME, false)); + } + + @Test + public void testHttpOnly() throws Exception { + this.mockMvc.perform(get("/")).andExpect(cookie().httpOnly(COOKIE_NAME, true)); + } + + + @Controller + private static class SimpleController { + + @RequestMapping("/") + public String home() { + return "home"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/FlashAttributeAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/FlashAttributeAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f533d18d92b78fd5355b3af26cf66397725ede3b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/FlashAttributeAssertionTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.net.URL; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.servlet.mvc.support.RedirectAttributes; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on flash attributes. + * + * @author Rossen Stoyanchev + */ +public class FlashAttributeAssertionTests { + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new PersonController()) + .alwaysExpect(status().isFound()) + .alwaysExpect(flash().attributeCount(3)) + .build(); + } + + @Test + public void testExists() throws Exception { + this.mockMvc.perform(post("/persons")) + .andExpect(flash().attributeExists("one", "two", "three")); + } + + @Test + public void testEqualTo() throws Exception { + this.mockMvc.perform(post("/persons")) + .andExpect(flash().attribute("one", "1")) + .andExpect(flash().attribute("two", 2.222)) + .andExpect(flash().attribute("three", new URL("https://example.com"))) + .andExpect(flash().attribute("one", equalTo("1"))) // Hamcrest... + .andExpect(flash().attribute("two", equalTo(2.222))) + .andExpect(flash().attribute("three", equalTo(new URL("https://example.com")))); + } + + @Test + public void testMatchers() throws Exception { + this.mockMvc.perform(post("/persons")) + .andExpect(flash().attribute("one", containsString("1"))) + .andExpect(flash().attribute("two", closeTo(2, 0.5))) + .andExpect(flash().attribute("three", notNullValue())); + } + + + @Controller + private static class PersonController { + + @RequestMapping(value="/persons", method=RequestMethod.POST) + public String save(RedirectAttributes redirectAttrs) throws Exception { + redirectAttrs.addFlashAttribute("one", "1"); + redirectAttrs.addFlashAttribute("two", 2.222); + redirectAttrs.addFlashAttribute("three", new URL("https://example.com")); + return "redirect:/person/1"; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/HandlerAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/HandlerAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fb3123760846a91337e2446082a2877fa05c424d --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/HandlerAssertionTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.lang.reflect.Method; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.http.ResponseEntity; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.MvcUriComponentsBuilder; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.handler; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup; +import static org.springframework.web.servlet.mvc.method.annotation.MvcUriComponentsBuilder.on; + +/** + * Examples of expectations on the controller type and controller method. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class HandlerAssertionTests { + + private final MockMvc mockMvc = standaloneSetup(new SimpleController()).alwaysExpect(status().isOk()).build(); + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void handlerType() throws Exception { + this.mockMvc.perform(get("/")).andExpect(handler().handlerType(SimpleController.class)); + } + + @Test + public void methodCallOnNonMock() throws Exception { + exception.expect(AssertionError.class); + exception.expectMessage("The supplied object [bogus] is not an instance of"); + exception.expectMessage(MvcUriComponentsBuilder.MethodInvocationInfo.class.getName()); + exception.expectMessage("Ensure that you invoke the handler method via MvcUriComponentsBuilder.on()"); + + this.mockMvc.perform(get("/")).andExpect(handler().methodCall("bogus")); + } + + @Test + public void methodCall() throws Exception { + this.mockMvc.perform(get("/")).andExpect(handler().methodCall(on(SimpleController.class).handle())); + } + + @Test + public void methodName() throws Exception { + this.mockMvc.perform(get("/")).andExpect(handler().methodName("handle")); + } + + @Test + public void methodNameMatchers() throws Exception { + this.mockMvc.perform(get("/")).andExpect(handler().methodName(equalTo("handle"))); + this.mockMvc.perform(get("/")).andExpect(handler().methodName(is(not("save")))); + } + + @Test + public void method() throws Exception { + Method method = SimpleController.class.getMethod("handle"); + this.mockMvc.perform(get("/")).andExpect(handler().method(method)); + } + + + @RestController + static class SimpleController { + + @RequestMapping("/") + public ResponseEntity handle() { + return ResponseEntity.ok().build(); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/HeaderAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/HeaderAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3717aac56ba4874edee2139ace68463eb0321d45 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/HeaderAssertionTests.java @@ -0,0 +1,230 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; +import java.util.TimeZone; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.context.request.WebRequest; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.http.HttpHeaders.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on response header values. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @author Brian Clozel + */ +public class HeaderAssertionTests { + + private static final String ERROR_MESSAGE = "Should have thrown an AssertionError"; + + + private String now; + + private String minuteAgo; + + private MockMvc mockMvc; + + private final long currentTime = System.currentTimeMillis(); + + private SimpleDateFormat dateFormat; + + + @Before + public void setup() { + this.dateFormat = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US); + this.dateFormat.setTimeZone(TimeZone.getTimeZone("GMT")); + this.now = dateFormat.format(new Date(this.currentTime)); + this.minuteAgo = dateFormat.format(new Date(this.currentTime - (1000 * 60))); + + PersonController controller = new PersonController(); + controller.setStubTimestamp(this.currentTime); + this.mockMvc = standaloneSetup(controller).build(); + } + + + @Test + public void stringWithCorrectResponseHeaderValue() throws Exception { + this.mockMvc.perform(get("/persons/1").header(IF_MODIFIED_SINCE, minuteAgo)) + .andExpect(header().string(LAST_MODIFIED, now)); + } + + @Test + public void stringWithMatcherAndCorrectResponseHeaderValue() throws Exception { + this.mockMvc.perform(get("/persons/1").header(IF_MODIFIED_SINCE, minuteAgo)) + .andExpect(header().string(LAST_MODIFIED, equalTo(now))); + } + + @Test + public void multiStringHeaderValue() throws Exception { + this.mockMvc.perform(get("/persons/1")).andExpect(header().stringValues(VARY, "foo", "bar")); + } + + @SuppressWarnings("unchecked") + @Test + public void multiStringHeaderValueWithMatchers() throws Exception { + this.mockMvc.perform(get("/persons/1")) + .andExpect(header().stringValues(VARY, hasItems(containsString("foo"), startsWith("bar")))); + } + + @Test + public void dateValueWithCorrectResponseHeaderValue() throws Exception { + this.mockMvc.perform(get("/persons/1").header(IF_MODIFIED_SINCE, minuteAgo)) + .andExpect(header().dateValue(LAST_MODIFIED, this.currentTime)); + } + + @Test + public void longValueWithCorrectResponseHeaderValue() throws Exception { + this.mockMvc.perform(get("/persons/1")) + .andExpect(header().longValue("X-Rate-Limiting", 42)); + } + + @Test + public void stringWithMissingResponseHeader() throws Exception { + this.mockMvc.perform(get("/persons/1").header(IF_MODIFIED_SINCE, now)) + .andExpect(status().isNotModified()) + .andExpect(header().stringValues("X-Custom-Header")); + } + + @Test + public void stringWithMatcherAndMissingResponseHeader() throws Exception { + this.mockMvc.perform(get("/persons/1").header(IF_MODIFIED_SINCE, now)) + .andExpect(status().isNotModified()) + .andExpect(header().string("X-Custom-Header", nullValue())); + } + + @Test + public void longValueWithMissingResponseHeader() throws Exception { + try { + this.mockMvc.perform(get("/persons/1").header(IF_MODIFIED_SINCE, now)) + .andExpect(status().isNotModified()) + .andExpect(header().longValue("X-Custom-Header", 99L)); + + fail(ERROR_MESSAGE); + } + catch (AssertionError err) { + if (ERROR_MESSAGE.equals(err.getMessage())) { + throw err; + } + assertEquals("Response does not contain header 'X-Custom-Header'", err.getMessage()); + } + } + + @Test + public void exists() throws Exception { + this.mockMvc.perform(get("/persons/1")).andExpect(header().exists(LAST_MODIFIED)); + } + + @Test(expected = AssertionError.class) + public void existsFail() throws Exception { + this.mockMvc.perform(get("/persons/1")).andExpect(header().exists("X-Custom-Header")); + } + + @Test // SPR-10771 + public void doesNotExist() throws Exception { + this.mockMvc.perform(get("/persons/1")).andExpect(header().doesNotExist("X-Custom-Header")); + } + + @Test(expected = AssertionError.class) // SPR-10771 + public void doesNotExistFail() throws Exception { + this.mockMvc.perform(get("/persons/1")).andExpect(header().doesNotExist(LAST_MODIFIED)); + } + + @Test(expected = AssertionError.class) + public void longValueWithIncorrectResponseHeaderValue() throws Exception { + this.mockMvc.perform(get("/persons/1")).andExpect(header().longValue("X-Rate-Limiting", 1)); + } + + @Test + public void stringWithMatcherAndIncorrectResponseHeaderValue() throws Exception { + long secondLater = this.currentTime + 1000; + String expected = this.dateFormat.format(new Date(secondLater)); + assertIncorrectResponseHeader(header().string(LAST_MODIFIED, expected), expected); + assertIncorrectResponseHeader(header().string(LAST_MODIFIED, equalTo(expected)), expected); + // Comparison by date uses HttpHeaders to format the date in the error message. + HttpHeaders headers = new HttpHeaders(); + headers.setDate("expected", secondLater); + assertIncorrectResponseHeader(header().dateValue(LAST_MODIFIED, secondLater), headers.getFirst("expected")); + } + + private void assertIncorrectResponseHeader(ResultMatcher matcher, String expected) throws Exception { + try { + this.mockMvc.perform(get("/persons/1") + .header(IF_MODIFIED_SINCE, minuteAgo)) + .andExpect(matcher); + + fail(ERROR_MESSAGE); + } + catch (AssertionError err) { + if (ERROR_MESSAGE.equals(err.getMessage())) { + throw err; + } + // SPR-10659: ensure header name is in the message + // Unfortunately, we can't control formatting from JUnit or Hamcrest. + assertMessageContains(err, "Response header '" + LAST_MODIFIED + "'"); + assertMessageContains(err, expected); + assertMessageContains(err, this.now); + } + } + + private void assertMessageContains(AssertionError error, String expected) { + assertTrue("Failure message should contain [" + expected + "], actual is [" + error.getMessage() + "]", + error.getMessage().contains(expected)); + } + + + @Controller + private static class PersonController { + + private long timestamp; + + public void setStubTimestamp(long timestamp) { + this.timestamp = timestamp; + } + + @RequestMapping("/persons/{id}") + public ResponseEntity showEntity(@PathVariable long id, WebRequest request) { + return ResponseEntity + .ok() + .lastModified(this.timestamp) + .header("X-Rate-Limiting", "42") + .header("Vary", "foo", "bar") + .body(new Person("Jason")); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/JsonPathAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/JsonPathAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..927d2eeaa33e88688797819f68f0f1c6eecb615f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/JsonPathAssertionTests.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.util.Arrays; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of defining expectations on JSON response content with + * JsonPath expressions. + * + * @author Rossen Stoyanchev + * @see ContentAssertionTests + */ +public class JsonPathAssertionTests { + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new MusicController()) + .defaultRequest(get("/").accept(MediaType.APPLICATION_JSON)) + .alwaysExpect(status().isOk()) + .alwaysExpect(content().contentType("application/json;charset=UTF-8")) + .build(); + } + + + @Test + public void exists() throws Exception { + String composerByName = "$.composers[?(@.name == '%s')]"; + String performerByName = "$.performers[?(@.name == '%s')]"; + + this.mockMvc.perform(get("/music/people")) + .andExpect(jsonPath(composerByName, "Johann Sebastian Bach").exists()) + .andExpect(jsonPath(composerByName, "Johannes Brahms").exists()) + .andExpect(jsonPath(composerByName, "Edvard Grieg").exists()) + .andExpect(jsonPath(composerByName, "Robert Schumann").exists()) + .andExpect(jsonPath(performerByName, "Vladimir Ashkenazy").exists()) + .andExpect(jsonPath(performerByName, "Yehudi Menuhin").exists()) + .andExpect(jsonPath("$.composers[0]").exists()) + .andExpect(jsonPath("$.composers[1]").exists()) + .andExpect(jsonPath("$.composers[2]").exists()) + .andExpect(jsonPath("$.composers[3]").exists()); + } + + @Test + public void doesNotExist() throws Exception { + this.mockMvc.perform(get("/music/people")) + .andExpect(jsonPath("$.composers[?(@.name == 'Edvard Grieeeeeeg')]").doesNotExist()) + .andExpect(jsonPath("$.composers[?(@.name == 'Robert Schuuuuuuman')]").doesNotExist()) + .andExpect(jsonPath("$.composers[4]").doesNotExist()); + } + + @Test + public void equality() throws Exception { + this.mockMvc.perform(get("/music/people")) + .andExpect(jsonPath("$.composers[0].name").value("Johann Sebastian Bach")) + .andExpect(jsonPath("$.performers[1].name").value("Yehudi Menuhin")); + + // Hamcrest matchers... + this.mockMvc.perform(get("/music/people")) + .andExpect(jsonPath("$.composers[0].name").value(equalTo("Johann Sebastian Bach"))) + .andExpect(jsonPath("$.performers[1].name").value(equalTo("Yehudi Menuhin"))); + } + + @Test + public void hamcrestMatcher() throws Exception { + this.mockMvc.perform(get("/music/people")) + .andExpect(jsonPath("$.composers[0].name", startsWith("Johann"))) + .andExpect(jsonPath("$.performers[0].name", endsWith("Ashkenazy"))) + .andExpect(jsonPath("$.performers[1].name", containsString("di Me"))) + .andExpect(jsonPath("$.composers[1].name", isIn(Arrays.asList("Johann Sebastian Bach", "Johannes Brahms")))); + } + + @Test + public void hamcrestMatcherWithParameterizedJsonPath() throws Exception { + String composerName = "$.composers[%s].name"; + String performerName = "$.performers[%s].name"; + + this.mockMvc.perform(get("/music/people")) + .andExpect(jsonPath(composerName, 0).value(startsWith("Johann"))) + .andExpect(jsonPath(performerName, 0).value(endsWith("Ashkenazy"))) + .andExpect(jsonPath(performerName, 1).value(containsString("di Me"))) + .andExpect(jsonPath(composerName, 1).value(isIn(Arrays.asList("Johann Sebastian Bach", "Johannes Brahms")))); + } + + + @RestController + private class MusicController { + + @RequestMapping("/music/people") + public MultiValueMap get() { + MultiValueMap map = new LinkedMultiValueMap<>(); + + map.add("composers", new Person("Johann Sebastian Bach")); + map.add("composers", new Person("Johannes Brahms")); + map.add("composers", new Person("Edvard Grieg")); + map.add("composers", new Person("Robert Schumann")); + + map.add("performers", new Person("Vladimir Ashkenazy")); + map.add("performers", new Person("Yehudi Menuhin")); + + return map; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ModelAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ModelAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..514748c6782dc9f4fb83cb9fec879ea100e4fa22 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ModelAssertionTests.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import javax.validation.Valid; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.ui.Model; +import org.springframework.validation.BindingResult; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on the content of the model prepared by the controller. + * + * @author Rossen Stoyanchev + */ +public class ModelAssertionTests { + + private MockMvc mockMvc; + + + @Before + public void setup() { + + SampleController controller = new SampleController("a string value", 3, new Person("a name")); + + this.mockMvc = standaloneSetup(controller) + .defaultRequest(get("/")) + .alwaysExpect(status().isOk()) + .setControllerAdvice(new ModelAttributeAdvice()) + .build(); + } + + @Test + public void testAttributeEqualTo() throws Exception { + mockMvc.perform(get("/")) + .andExpect(model().attribute("integer", 3)) + .andExpect(model().attribute("string", "a string value")) + .andExpect(model().attribute("integer", equalTo(3))) // Hamcrest... + .andExpect(model().attribute("string", equalTo("a string value"))) + .andExpect(model().attribute("globalAttrName", equalTo("Global Attribute Value"))); + } + + @Test + public void testAttributeExists() throws Exception { + mockMvc.perform(get("/")) + .andExpect(model().attributeExists("integer", "string", "person")) + .andExpect(model().attribute("integer", notNullValue())) // Hamcrest... + .andExpect(model().attribute("INTEGER", nullValue())); + } + + @Test + public void testAttributeHamcrestMatchers() throws Exception { + mockMvc.perform(get("/")) + .andExpect(model().attribute("integer", equalTo(3))) + .andExpect(model().attribute("string", allOf(startsWith("a string"), endsWith("value")))) + .andExpect(model().attribute("person", hasProperty("name", equalTo("a name")))); + } + + @Test + public void testHasErrors() throws Exception { + mockMvc.perform(post("/persons")).andExpect(model().attributeHasErrors("person")); + } + + @Test + public void testHasNoErrors() throws Exception { + mockMvc.perform(get("/")).andExpect(model().hasNoErrors()); + } + + + @Controller + private static class SampleController { + + private final Object[] values; + + public SampleController(Object... values) { + this.values = values; + } + + @RequestMapping("/") + public String handle(Model model) { + for (Object value : this.values) { + model.addAttribute(value); + } + return "view"; + } + + @PostMapping("/persons") + public String create(@Valid Person person, BindingResult result, Model model) { + return "view"; + } + } + + @ControllerAdvice + private static class ModelAttributeAdvice { + + @ModelAttribute("globalAttrName") + public String getAttribute() { + return "Global Attribute Value"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/RequestAttributeAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/RequestAttributeAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9afc0ebcb17dd91115696867e8cae0e2118e7cc7 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/RequestAttributeAssertionTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.servlet.HandlerMapping; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on created request attributes. + * + * @author Rossen Stoyanchev + */ +public class RequestAttributeAssertionTests { + + private MockMvc mockMvc; + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new SimpleController()).build(); + } + + @Test + public void testRequestAttributeEqualTo() throws Exception { + this.mockMvc.perform(get("/main/1").servletPath("/main")) + .andExpect(request().attribute(HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE, "/{id}")) + .andExpect(request().attribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/1")) + .andExpect(request().attribute(HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE, equalTo("/{id}"))) + .andExpect(request().attribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, equalTo("/1"))); + } + + @Test + public void testRequestAttributeMatcher() throws Exception { + + String producibleMediaTypes = HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE; + + this.mockMvc.perform(get("/1")) + .andExpect(request().attribute(producibleMediaTypes, hasItem(MediaType.APPLICATION_JSON))) + .andExpect(request().attribute(producibleMediaTypes, not(hasItem(MediaType.APPLICATION_XML)))); + } + + + @Controller + private static class SimpleController { + + @RequestMapping(value="/{id}", produces="application/json") + public String show() { + return "view"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/SessionAttributeAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/SessionAttributeAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..291b2faa14e87d305a2ebed25444b428bda56c81 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/SessionAttributeAssertionTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.util.Locale; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.ui.Model; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.SessionAttributes; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on created session attributes. + * + * @author Rossen Stoyanchev + */ +public class SessionAttributeAssertionTests { + + private MockMvc mockMvc; + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new SimpleController()) + .defaultRequest(get("/")) + .alwaysExpect(status().isOk()) + .build(); + } + + @Test + public void testSessionAttributeEqualTo() throws Exception { + this.mockMvc.perform(get("/")) + .andExpect(request().sessionAttribute("locale", Locale.UK)) + .andExpect(request().sessionAttribute("locale", equalTo(Locale.UK))); + } + + @Test + public void testSessionAttributeMatcher() throws Exception { + this.mockMvc.perform(get("/")) + .andExpect(request().sessionAttribute("locale", notNullValue())); + } + + + @Controller + @SessionAttributes("locale") + private static class SimpleController { + + @ModelAttribute + public void populate(Model model) { + model.addAttribute("locale", Locale.UK); + } + + @RequestMapping("/") + public String handle() { + return "view"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/StatusAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/StatusAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..50f53f9559abd211e5722e7b0c8ba5d25c6746e8 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/StatusAssertionTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.Test; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.bind.annotation.ResponseStatus; + +import static org.hamcrest.Matchers.*; +import static org.springframework.http.HttpStatus.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on the status and the status reason found in the response. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class StatusAssertionTests { + + private final MockMvc mockMvc = standaloneSetup(new StatusController()).build(); + + @Test + public void testStatusInt() throws Exception { + this.mockMvc.perform(get("/created")).andExpect(status().is(201)); + this.mockMvc.perform(get("/createdWithComposedAnnotation")).andExpect(status().is(201)); + this.mockMvc.perform(get("/badRequest")).andExpect(status().is(400)); + } + + @Test + public void testHttpStatus() throws Exception { + this.mockMvc.perform(get("/created")).andExpect(status().isCreated()); + this.mockMvc.perform(get("/createdWithComposedAnnotation")).andExpect(status().isCreated()); + this.mockMvc.perform(get("/badRequest")).andExpect(status().isBadRequest()); + } + + @Test + public void testMatcher() throws Exception { + this.mockMvc.perform(get("/badRequest")).andExpect(status().is(equalTo(400))); + } + + @Test + public void testReasonEqualTo() throws Exception { + this.mockMvc.perform(get("/badRequest")).andExpect(status().reason("Expired token")); + + // Hamcrest matchers... + this.mockMvc.perform(get("/badRequest")).andExpect(status().reason(equalTo("Expired token"))); + } + + @Test + public void testReasonMatcher() throws Exception { + this.mockMvc.perform(get("/badRequest")).andExpect(status().reason(endsWith("token"))); + } + + + @RequestMapping + @ResponseStatus + @Retention(RetentionPolicy.RUNTIME) + @interface Get { + + @AliasFor(annotation = RequestMapping.class, attribute = "path") + String[] path() default {}; + + @AliasFor(annotation = ResponseStatus.class, attribute = "code") + HttpStatus status() default INTERNAL_SERVER_ERROR; + } + + @Controller + private static class StatusController { + + @RequestMapping("/created") + @ResponseStatus(CREATED) + public @ResponseBody void created(){ + } + + @Get(path = "/createdWithComposedAnnotation", status = CREATED) + public @ResponseBody void createdWithComposedAnnotation() { + } + + @RequestMapping("/badRequest") + @ResponseStatus(code = BAD_REQUEST, reason = "Expired token") + public @ResponseBody void badRequest(){ + } + + @RequestMapping("/notImplemented") + @ResponseStatus(NOT_IMPLEMENTED) + public @ResponseBody void notImplemented(){ + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/UrlAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/UrlAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fcf9c56cf0a06918cb46616dcc290d4dc109dd02 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/UrlAssertionTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on forwarded or redirected URLs. + * + * @author Rossen Stoyanchev + */ +public class UrlAssertionTests { + + private MockMvc mockMvc; + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new SimpleController()).build(); + } + + @Test + public void testRedirect() throws Exception { + this.mockMvc.perform(get("/persons")).andExpect(redirectedUrl("/persons/1")); + } + + @Test + public void testRedirectPattern() throws Exception { + this.mockMvc.perform(get("/persons")).andExpect(redirectedUrlPattern("/persons/*")); + } + + @Test + public void testForward() throws Exception { + this.mockMvc.perform(get("/")).andExpect(forwardedUrl("/home")); + } + + @Test + public void testForwardPattern() throws Exception { + this.mockMvc.perform(get("/")).andExpect(forwardedUrlPattern("/ho?e")); + } + + @Controller + private static class SimpleController { + + @RequestMapping("/persons") + public String save() { + return "redirect:/persons/1"; + } + + @RequestMapping("/") + public String forward() { + return "forward:/home"; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ViewNameAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ViewNameAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a3f6785904d1165f96d9591b1b243281199c1ea6 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/ViewNameAssertionTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of expectations on the view name selected by the controller. + * + * @author Rossen Stoyanchev + */ +public class ViewNameAssertionTests { + + private MockMvc mockMvc; + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new SimpleController()) + .alwaysExpect(status().isOk()) + .build(); + } + + @Test + public void testEqualTo() throws Exception { + this.mockMvc.perform(get("/")) + .andExpect(view().name("mySpecialView")) + .andExpect(view().name(equalTo("mySpecialView"))); + } + + @Test + public void testHamcrestMatcher() throws Exception { + this.mockMvc.perform(get("/")).andExpect(view().name(containsString("Special"))); + } + + + @Controller + private static class SimpleController { + + @RequestMapping("/") + public String handle() { + return "mySpecialView"; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XmlContentAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XmlContentAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..901f25f3a0c8c70b21344c6973b0b7676c2657a3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XmlContentAssertionTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.util.Arrays; +import java.util.List; + +import javax.xml.bind.annotation.XmlAccessType; +import javax.xml.bind.annotation.XmlAccessorType; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; +import javax.xml.bind.annotation.XmlRootElement; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Examples of defining expectations on XML response content with XMLUnit. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @see ContentAssertionTests + * @see XpathAssertionTests + */ +public class XmlContentAssertionTests { + + private static final String PEOPLE_XML = + "" + + "" + + "Johann Sebastian Bachfalse21.0" + + "Johannes Brahmsfalse0.0025" + + "Edvard Griegfalse1.6035" + + "Robert SchumannfalseNaN" + + ""; + + private MockMvc mockMvc; + + + @Before + public void setup() { + this.mockMvc = standaloneSetup(new MusicController()) + .defaultRequest(get("/").accept(MediaType.APPLICATION_XML, MediaType.parseMediaType("application/xml;charset=UTF-8"))) + .alwaysExpect(status().isOk()) + .alwaysExpect(content().contentType(MediaType.parseMediaType("application/xml;charset=UTF-8"))) + .build(); + } + + @Test + public void testXmlEqualTo() throws Exception { + this.mockMvc.perform(get("/music/people")).andExpect(content().xml(PEOPLE_XML)); + } + + @Test + public void testNodeHamcrestMatcher() throws Exception { + this.mockMvc.perform(get("/music/people")) + .andExpect(content().node(hasXPath("/people/composers/composer[1]"))); + } + + + @Controller + private static class MusicController { + + @RequestMapping(value="/music/people") + public @ResponseBody PeopleWrapper getPeople() { + + List composers = Arrays.asList( + new Person("Johann Sebastian Bach").setSomeDouble(21), + new Person("Johannes Brahms").setSomeDouble(.0025), + new Person("Edvard Grieg").setSomeDouble(1.6035), + new Person("Robert Schumann").setSomeDouble(Double.NaN)); + + return new PeopleWrapper(composers); + } + } + + @SuppressWarnings("unused") + @XmlRootElement(name="people") + @XmlAccessorType(XmlAccessType.FIELD) + private static class PeopleWrapper { + + @XmlElementWrapper(name="composers") + @XmlElement(name="composer") + private List composers; + + public PeopleWrapper() { + } + + public PeopleWrapper(List composers) { + this.composers = composers; + } + + public List getComposers() { + return this.composers; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XpathAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XpathAssertionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7155f30874fa321f5caab471399b98912c6d3876 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/resultmatchers/XpathAssertionTests.java @@ -0,0 +1,233 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.samples.standalone.resultmatchers; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import javax.xml.bind.annotation.XmlAccessType; +import javax.xml.bind.annotation.XmlAccessorType; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; +import javax.xml.bind.annotation.XmlRootElement; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; + +import static org.hamcrest.Matchers.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; +import static org.springframework.web.bind.annotation.RequestMethod.*; + +/** + * Examples of expectations on XML response content with XPath expressions. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @see ContentAssertionTests + * @see XmlContentAssertionTests + */ +public class XpathAssertionTests { + + private static final Map musicNamespace = + Collections.singletonMap("ns", "https://example.org/music/people"); + + private MockMvc mockMvc; + + @Before + public void setup() throws Exception { + this.mockMvc = standaloneSetup(new MusicController()) + .defaultRequest(get("/").accept(MediaType.APPLICATION_XML, MediaType.parseMediaType("application/xml;charset=UTF-8"))) + .alwaysExpect(status().isOk()) + .alwaysExpect(content().contentType(MediaType.parseMediaType("application/xml;charset=UTF-8"))) + .build(); + } + + @Test + public void testExists() throws Exception { + + String composer = "/ns:people/composers/composer[%s]"; + String performer = "/ns:people/performers/performer[%s]"; + + this.mockMvc.perform(get("/music/people")) + .andExpect(xpath(composer, musicNamespace, 1).exists()) + .andExpect(xpath(composer, musicNamespace, 2).exists()) + .andExpect(xpath(composer, musicNamespace, 3).exists()) + .andExpect(xpath(composer, musicNamespace, 4).exists()) + .andExpect(xpath(performer, musicNamespace, 1).exists()) + .andExpect(xpath(performer, musicNamespace, 2).exists()) + .andExpect(xpath(composer, musicNamespace, 1).node(notNullValue())); + } + + @Test + public void testDoesNotExist() throws Exception { + + String composer = "/ns:people/composers/composer[%s]"; + String performer = "/ns:people/performers/performer[%s]"; + + this.mockMvc.perform(get("/music/people")) + .andExpect(xpath(composer, musicNamespace, 0).doesNotExist()) + .andExpect(xpath(composer, musicNamespace, 5).doesNotExist()) + .andExpect(xpath(performer, musicNamespace, 0).doesNotExist()) + .andExpect(xpath(performer, musicNamespace, 3).doesNotExist()) + .andExpect(xpath(composer, musicNamespace, 0).node(nullValue())); + } + + @Test + public void testString() throws Exception { + + String composerName = "/ns:people/composers/composer[%s]/name"; + String performerName = "/ns:people/performers/performer[%s]/name"; + + this.mockMvc.perform(get("/music/people")) + .andExpect(xpath(composerName, musicNamespace, 1).string("Johann Sebastian Bach")) + .andExpect(xpath(composerName, musicNamespace, 2).string("Johannes Brahms")) + .andExpect(xpath(composerName, musicNamespace, 3).string("Edvard Grieg")) + .andExpect(xpath(composerName, musicNamespace, 4).string("Robert Schumann")) + .andExpect(xpath(performerName, musicNamespace, 1).string("Vladimir Ashkenazy")) + .andExpect(xpath(performerName, musicNamespace, 2).string("Yehudi Menuhin")) + .andExpect(xpath(composerName, musicNamespace, 1).string(equalTo("Johann Sebastian Bach"))) // Hamcrest.. + .andExpect(xpath(composerName, musicNamespace, 1).string(startsWith("Johann"))) + .andExpect(xpath(composerName, musicNamespace, 1).string(notNullValue())); + } + + @Test + public void testNumber() throws Exception { + + String composerDouble = "/ns:people/composers/composer[%s]/someDouble"; + + this.mockMvc.perform(get("/music/people")) + .andExpect(xpath(composerDouble, musicNamespace, 1).number(21d)) + .andExpect(xpath(composerDouble, musicNamespace, 2).number(.0025)) + .andExpect(xpath(composerDouble, musicNamespace, 3).number(1.6035)) + .andExpect(xpath(composerDouble, musicNamespace, 4).number(Double.NaN)) + .andExpect(xpath(composerDouble, musicNamespace, 1).number(equalTo(21d))) // Hamcrest.. + .andExpect(xpath(composerDouble, musicNamespace, 3).number(closeTo(1.6, .01))); + } + + @Test + public void testBoolean() throws Exception { + + String performerBooleanValue = "/ns:people/performers/performer[%s]/someBoolean"; + + this.mockMvc.perform(get("/music/people")) + .andExpect(xpath(performerBooleanValue, musicNamespace, 1).booleanValue(false)) + .andExpect(xpath(performerBooleanValue, musicNamespace, 2).booleanValue(true)); + } + + @Test + public void testNodeCount() throws Exception { + + this.mockMvc.perform(get("/music/people")) + .andExpect(xpath("/ns:people/composers/composer", musicNamespace).nodeCount(4)) + .andExpect(xpath("/ns:people/performers/performer", musicNamespace).nodeCount(2)) + .andExpect(xpath("/ns:people/composers/composer", musicNamespace).nodeCount(equalTo(4))) // Hamcrest.. + .andExpect(xpath("/ns:people/performers/performer", musicNamespace).nodeCount(equalTo(2))); + } + + // SPR-10704 + + @Test + public void testFeedWithLinefeedChars() throws Exception { + +// Map namespace = Collections.singletonMap("ns", ""); + + standaloneSetup(new BlogFeedController()).build() + .perform(get("/blog.atom").accept(MediaType.APPLICATION_ATOM_XML)) + .andExpect(status().isOk()) + .andExpect(content().contentTypeCompatibleWith(MediaType.APPLICATION_ATOM_XML)) + .andExpect(xpath("//feed/title").string("Test Feed")) + .andExpect(xpath("//feed/icon").string("https://www.example.com/favicon.ico")); + } + + + @Controller + private static class MusicController { + + @RequestMapping(value="/music/people") + public @ResponseBody PeopleWrapper getPeople() { + + List composers = Arrays.asList( + new Person("Johann Sebastian Bach").setSomeDouble(21), + new Person("Johannes Brahms").setSomeDouble(.0025), + new Person("Edvard Grieg").setSomeDouble(1.6035), + new Person("Robert Schumann").setSomeDouble(Double.NaN)); + + List performers = Arrays.asList( + new Person("Vladimir Ashkenazy").setSomeBoolean(false), + new Person("Yehudi Menuhin").setSomeBoolean(true)); + + return new PeopleWrapper(composers, performers); + } + } + + @SuppressWarnings("unused") + @XmlRootElement(name="people", namespace="https://example.org/music/people") + @XmlAccessorType(XmlAccessType.FIELD) + private static class PeopleWrapper { + + @XmlElementWrapper(name="composers") + @XmlElement(name="composer") + private List composers; + + @XmlElementWrapper(name="performers") + @XmlElement(name="performer") + private List performers; + + public PeopleWrapper() { + } + + public PeopleWrapper(List composers, List performers) { + this.composers = composers; + this.performers = performers; + } + + public List getComposers() { + return this.composers; + } + + public List getPerformers() { + return this.performers; + } + } + + + @Controller + public class BlogFeedController { + + @RequestMapping(value="/blog.atom", method = { GET, HEAD }) + @ResponseBody + public String listPublishedPosts() { + return "\r\n" + + "\r\n" + + " Test Feed\r\n" + + " https://www.example.com/favicon.ico\r\n" + + "\r\n\r\n"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/setup/ConditionalDelegatingFilterProxyTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/ConditionalDelegatingFilterProxyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cf2a4b11b596858c4b5ad433b3a0581ff68953f9 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/ConditionalDelegatingFilterProxyTests.java @@ -0,0 +1,288 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.setup; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockFilterConfig; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.hamcrest.MatcherAssert.*; +import static org.hamcrest.Matchers.*; + +/** + * @author Rob Winch + */ +public class ConditionalDelegatingFilterProxyTests { + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + private MockFilterChain filterChain; + + private MockFilter delegate; + + private PatternMappingFilterProxy filter; + + + @Before + public void setup() { + request = new MockHttpServletRequest(); + request.setContextPath("/context"); + response = new MockHttpServletResponse(); + filterChain = new MockFilterChain(); + delegate = new MockFilter(); + } + + + @Test + public void init() throws Exception { + FilterConfig config = new MockFilterConfig(); + filter = new PatternMappingFilterProxy(delegate, "/"); + filter.init(config); + assertThat(delegate.filterConfig, is(config)); + } + + @Test + public void destroy() throws Exception { + filter = new PatternMappingFilterProxy(delegate, "/"); + filter.destroy(); + assertThat(delegate.destroy, is(true)); + } + + @Test + public void matchExact() throws Exception { + assertFilterInvoked("/test", "/test"); + } + + @Test + public void matchExactEmpty() throws Exception { + assertFilterInvoked("", ""); + } + + @Test + public void matchPathMappingAllFolder() throws Exception { + assertFilterInvoked("/test/this", "/*"); + } + + @Test + public void matchPathMappingAll() throws Exception { + assertFilterInvoked("/test", "/*"); + } + + @Test + public void matchPathMappingAllContextRoot() throws Exception { + assertFilterInvoked("", "/*"); + } + + @Test + public void matchPathMappingContextRootAndSlash() throws Exception { + assertFilterInvoked("/", "/*"); + } + + @Test + public void matchPathMappingFolderPatternWithMultiFolderPath() throws Exception { + assertFilterInvoked("/test/this/here", "/test/*"); + } + + @Test + public void matchPathMappingFolderPattern() throws Exception { + assertFilterInvoked("/test/this", "/test/*"); + } + + @Test + public void matchPathMappingNoSuffix() throws Exception { + assertFilterInvoked("/test/", "/test/*"); + } + + @Test + public void matchPathMappingMissingSlash() throws Exception { + assertFilterInvoked("/test", "/test/*"); + } + + @Test + public void noMatchPathMappingMulti() throws Exception { + assertFilterNotInvoked("/this/test/here", "/test/*"); + } + + @Test + public void noMatchPathMappingEnd() throws Exception { + assertFilterNotInvoked("/this/test", "/test/*"); + } + + @Test + public void noMatchPathMappingEndSuffix() throws Exception { + assertFilterNotInvoked("/test2/", "/test/*"); + } + + @Test + public void noMatchPathMappingMissingSlash() throws Exception { + assertFilterNotInvoked("/test2", "/test/*"); + } + + @Test + public void matchExtensionMulti() throws Exception { + assertFilterInvoked("/test/this/here.html", "*.html"); + } + + @Test + public void matchExtension() throws Exception { + assertFilterInvoked("/test/this.html", "*.html"); + } + + @Test + public void matchExtensionNoPrefix() throws Exception { + assertFilterInvoked("/.html", "*.html"); + } + + @Test + public void matchExtensionNoFolder() throws Exception { + assertFilterInvoked("/test.html", "*.html"); + } + + @Test + public void noMatchExtensionNoSlash() throws Exception { + assertFilterNotInvoked(".html", "*.html"); + } + + @Test + public void noMatchExtensionSlashEnd() throws Exception { + assertFilterNotInvoked("/index.html/", "*.html"); + } + + @Test + public void noMatchExtensionPeriodEnd() throws Exception { + assertFilterNotInvoked("/index.html.", "*.html"); + } + + @Test + public void noMatchExtensionLarger() throws Exception { + assertFilterNotInvoked("/index.htm", "*.html"); + } + + @Test + public void noMatchInvalidPattern() throws Exception { + // pattern uses extension mapping but starts with / (treated as exact match) + assertFilterNotInvoked("/index.html", "/*.html"); + } + + /* + * Below are tests from Table 12-1 of the Servlet Specification + */ + @Test + public void specPathMappingMultiFolderPattern() throws Exception { + assertFilterInvoked("/foo/bar/index.html", "/foo/bar/*"); + } + + @Test + public void specPathMappingMultiFolderPatternAlternate() throws Exception { + assertFilterInvoked("/foo/bar/index.bop", "/foo/bar/*"); + } + + @Test + public void specPathMappingNoSlash() throws Exception { + assertFilterInvoked("/baz", "/baz/*"); + } + + @Test + public void specPathMapping() throws Exception { + assertFilterInvoked("/baz/index.html", "/baz/*"); + } + + @Test + public void specExactMatch() throws Exception { + assertFilterInvoked("/catalog", "/catalog"); + } + + @Test + public void specExtensionMappingSingleFolder() throws Exception { + assertFilterInvoked("/catalog/racecar.bop", "*.bop"); + } + + @Test + public void specExtensionMapping() throws Exception { + assertFilterInvoked("/index.bop", "*.bop"); + } + + private void assertFilterNotInvoked(String requestUri, String pattern) throws Exception { + request.setRequestURI(request.getContextPath() + requestUri); + filter = new PatternMappingFilterProxy(delegate, pattern); + filter.doFilter(request, response, filterChain); + + assertThat(delegate.request, equalTo((ServletRequest) null)); + assertThat(delegate.response, equalTo((ServletResponse) null)); + assertThat(delegate.chain, equalTo((FilterChain) null)); + + assertThat(filterChain.getRequest(), equalTo((ServletRequest) request)); + assertThat(filterChain.getResponse(), equalTo((ServletResponse) response)); + filterChain = new MockFilterChain(); + } + + private void assertFilterInvoked(String requestUri, String pattern) throws Exception { + request.setRequestURI(request.getContextPath() + requestUri); + filter = new PatternMappingFilterProxy(delegate, pattern); + filter.doFilter(request, response, filterChain); + + assertThat(delegate.request, equalTo((ServletRequest) request)); + assertThat(delegate.response, equalTo((ServletResponse) response)); + assertThat(delegate.chain, equalTo((FilterChain) filterChain)); + delegate = new MockFilter(); + } + + + private static class MockFilter implements Filter { + + private FilterConfig filterConfig; + + private ServletRequest request; + + private ServletResponse response; + + private FilterChain chain; + + private boolean destroy; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + this.filterConfig = filterConfig; + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) { + this.request = request; + this.response = response; + this.chain = chain; + } + + @Override + public void destroy() { + this.destroy = true; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/setup/DefaultMockMvcBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/DefaultMockMvcBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..88f70e676d2c1017db3f362c86d802aec10f8d95 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/DefaultMockMvcBuilderTests.java @@ -0,0 +1,159 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.setup; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.beans.DirectFieldAccessor; +import org.springframework.context.support.StaticApplicationContext; +import org.springframework.mock.web.MockServletContext; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.StaticWebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; +import org.springframework.web.servlet.DispatcherServlet; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; + +/** + * Tests for {@link DefaultMockMvcBuilder}. + * + * @author Rob Winch + * @author Sebastien Deleuze + * @author Sam Brannen + * @author Stephane Nicoll + */ +public class DefaultMockMvcBuilderTests { + + private final MockServletContext servletContext = new MockServletContext(); + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void webAppContextSetupWithNullWac() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(equalTo("WebApplicationContext is required")); + webAppContextSetup(null); + } + + @Test + public void webAppContextSetupWithNullServletContext() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(equalTo("WebApplicationContext must have a ServletContext")); + webAppContextSetup(new StubWebApplicationContext(null)); + } + + /** + * See SPR-12553 and SPR-13075. + */ + @Test + public void rootWacServletContainerAttributePreviouslySet() { + StubWebApplicationContext child = new StubWebApplicationContext(this.servletContext); + this.servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, child); + + DefaultMockMvcBuilder builder = webAppContextSetup(child); + assertSame(builder.initWebAppContext(), + WebApplicationContextUtils.getRequiredWebApplicationContext(this.servletContext)); + } + + /** + * See SPR-12553 and SPR-13075. + */ + @Test + public void rootWacServletContainerAttributePreviouslySetWithContextHierarchy() { + StubWebApplicationContext root = new StubWebApplicationContext(this.servletContext); + + this.servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, root); + + StaticWebApplicationContext child = new StaticWebApplicationContext(); + child.setParent(root); + child.setServletContext(this.servletContext); + + DefaultMockMvcBuilder builder = webAppContextSetup(child); + assertSame(builder.initWebAppContext().getParent(), + WebApplicationContextUtils.getRequiredWebApplicationContext(this.servletContext)); + } + + /** + * See SPR-12553 and SPR-13075. + */ + @Test + public void rootWacServletContainerAttributeNotPreviouslySet() { + StubWebApplicationContext root = new StubWebApplicationContext(this.servletContext); + DefaultMockMvcBuilder builder = webAppContextSetup(root); + WebApplicationContext wac = builder.initWebAppContext(); + assertSame(root, wac); + assertSame(root, WebApplicationContextUtils.getRequiredWebApplicationContext(this.servletContext)); + } + + /** + * See SPR-12553 and SPR-13075. + */ + @Test + public void rootWacServletContainerAttributeNotPreviouslySetWithContextHierarchy() { + StaticApplicationContext ear = new StaticApplicationContext(); + StaticWebApplicationContext root = new StaticWebApplicationContext(); + root.setParent(ear); + root.setServletContext(this.servletContext); + StaticWebApplicationContext dispatcher = new StaticWebApplicationContext(); + dispatcher.setParent(root); + dispatcher.setServletContext(this.servletContext); + + DefaultMockMvcBuilder builder = webAppContextSetup(dispatcher); + WebApplicationContext wac = builder.initWebAppContext(); + + assertSame(dispatcher, wac); + assertSame(root, wac.getParent()); + assertSame(ear, wac.getParent().getParent()); + assertSame(root, WebApplicationContextUtils.getRequiredWebApplicationContext(this.servletContext)); + } + + /** + * See /SPR-14277 + */ + @Test + public void dispatcherServletCustomizer() { + StubWebApplicationContext root = new StubWebApplicationContext(this.servletContext); + DefaultMockMvcBuilder builder = webAppContextSetup(root); + builder.addDispatcherServletCustomizer(ds -> ds.setContextId("test-id")); + builder.dispatchOptions(true); + MockMvc mvc = builder.build(); + DispatcherServlet ds = (DispatcherServlet) new DirectFieldAccessor(mvc) + .getPropertyValue("servlet"); + assertEquals("test-id", ds.getContextId()); + } + + @Test + public void dispatcherServletCustomizerProcessedInOrder() { + StubWebApplicationContext root = new StubWebApplicationContext(this.servletContext); + DefaultMockMvcBuilder builder = webAppContextSetup(root); + builder.addDispatcherServletCustomizer(ds -> ds.setContextId("test-id")); + builder.addDispatcherServletCustomizer(ds -> ds.setContextId("override-id")); + builder.dispatchOptions(true); + MockMvc mvc = builder.build(); + DispatcherServlet ds = (DispatcherServlet) new DirectFieldAccessor(mvc) + .getPropertyValue("servlet"); + assertEquals("override-id", ds.getContextId()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/setup/SharedHttpSessionTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/SharedHttpSessionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..36d6f037dee50d992ef01ec5e0711a01fd97750a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/SharedHttpSessionTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.setup; + +import javax.servlet.http.HttpSession; + +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.web.bind.annotation.GetMapping; + +import static org.junit.Assert.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.setup.SharedHttpSessionConfigurer.*; + +/** + * Tests for {@link SharedHttpSessionConfigurer}. + * + * @author Rossen Stoyanchev + */ +public class SharedHttpSessionTests { + + @Test + public void httpSession() throws Exception { + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new TestController()) + .apply(sharedHttpSession()) + .build(); + + String url = "/session"; + + MvcResult result = mockMvc.perform(get(url)).andExpect(status().isOk()).andReturn(); + HttpSession session = result.getRequest().getSession(false); + assertNotNull(session); + assertEquals(1, session.getAttribute("counter")); + + result = mockMvc.perform(get(url)).andExpect(status().isOk()).andReturn(); + session = result.getRequest().getSession(false); + assertNotNull(session); + assertEquals(2, session.getAttribute("counter")); + + result = mockMvc.perform(get(url)).andExpect(status().isOk()).andReturn(); + session = result.getRequest().getSession(false); + assertNotNull(session); + assertEquals(3, session.getAttribute("counter")); + } + + @Test + public void noHttpSession() throws Exception { + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new TestController()) + .apply(sharedHttpSession()) + .build(); + + String url = "/no-session"; + + MvcResult result = mockMvc.perform(get(url)).andExpect(status().isOk()).andReturn(); + HttpSession session = result.getRequest().getSession(false); + assertNull(session); + + result = mockMvc.perform(get(url)).andExpect(status().isOk()).andReturn(); + session = result.getRequest().getSession(false); + assertNull(session); + + url = "/session"; + + result = mockMvc.perform(get(url)).andExpect(status().isOk()).andReturn(); + session = result.getRequest().getSession(false); + assertNotNull(session); + assertEquals(1, session.getAttribute("counter")); + } + + + @Controller + private static class TestController { + + @GetMapping("/session") + public String handle(HttpSession session) { + Integer counter = (Integer) session.getAttribute("counter"); + session.setAttribute("counter", (counter != null ? counter + 1 : 1)); + return "view"; + } + + @GetMapping("/no-session") + public String handle() { + return "view"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/setup/StandaloneMockMvcBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/StandaloneMockMvcBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2272e51242dfdf950f7114c24a8e15aa8b6461c5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/setup/StandaloneMockMvcBuilderTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.setup; + +import java.io.IOException; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.ser.impl.UnknownSerializer; +import org.junit.Test; + +import org.springframework.http.converter.json.SpringHandlerInstantiator; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.stereotype.Controller; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.servlet.HandlerExecutionChain; +import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping; + +import static org.junit.Assert.*; + +/** + * Tests for {@link StandaloneMockMvcBuilder} + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @author Sebastien Deleuze + */ +public class StandaloneMockMvcBuilderTests { + + @Test // SPR-10825 + public void placeHoldersInRequestMapping() throws Exception { + TestStandaloneMockMvcBuilder builder = new TestStandaloneMockMvcBuilder(new PlaceholderController()); + builder.addPlaceholderValue("sys.login.ajax", "/foo"); + builder.build(); + + RequestMappingHandlerMapping hm = builder.wac.getBean(RequestMappingHandlerMapping.class); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/foo"); + HandlerExecutionChain chain = hm.getHandler(request); + + assertNotNull(chain); + assertEquals("handleWithPlaceholders", ((HandlerMethod) chain.getHandler()).getMethod().getName()); + } + + @Test // SPR-13637 + public void suffixPatternMatch() throws Exception { + TestStandaloneMockMvcBuilder builder = new TestStandaloneMockMvcBuilder(new PersonController()); + builder.setUseSuffixPatternMatch(false); + builder.build(); + + RequestMappingHandlerMapping hm = builder.wac.getBean(RequestMappingHandlerMapping.class); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/persons"); + HandlerExecutionChain chain = hm.getHandler(request); + assertNotNull(chain); + assertEquals("persons", ((HandlerMethod) chain.getHandler()).getMethod().getName()); + + request = new MockHttpServletRequest("GET", "/persons.xml"); + chain = hm.getHandler(request); + assertNull(chain); + } + + @Test // SPR-12553 + public void applicationContextAttribute() { + TestStandaloneMockMvcBuilder builder = new TestStandaloneMockMvcBuilder(new PlaceholderController()); + builder.addPlaceholderValue("sys.login.ajax", "/foo"); + WebApplicationContext wac = builder.initWebAppContext(); + assertEquals(wac, WebApplicationContextUtils.getRequiredWebApplicationContext(wac.getServletContext())); + } + + @Test(expected = IllegalArgumentException.class) + public void addFiltersFiltersNull() { + StandaloneMockMvcBuilder builder = MockMvcBuilders.standaloneSetup(new PersonController()); + builder.addFilters((Filter[]) null); + } + + @Test(expected = IllegalArgumentException.class) + public void addFiltersFiltersContainsNull() { + StandaloneMockMvcBuilder builder = MockMvcBuilders.standaloneSetup(new PersonController()); + builder.addFilters(new ContinueFilter(), (Filter) null); + } + + @Test(expected = IllegalArgumentException.class) + public void addFilterPatternsNull() { + StandaloneMockMvcBuilder builder = MockMvcBuilders.standaloneSetup(new PersonController()); + builder.addFilter(new ContinueFilter(), (String[]) null); + } + + @Test(expected = IllegalArgumentException.class) + public void addFilterPatternContainsNull() { + StandaloneMockMvcBuilder builder = MockMvcBuilders.standaloneSetup(new PersonController()); + builder.addFilter(new ContinueFilter(), (String) null); + } + + @Test // SPR-13375 + @SuppressWarnings("rawtypes") + public void springHandlerInstantiator() { + TestStandaloneMockMvcBuilder builder = new TestStandaloneMockMvcBuilder(new PersonController()); + builder.build(); + SpringHandlerInstantiator instantiator = new SpringHandlerInstantiator(builder.wac.getAutowireCapableBeanFactory()); + JsonSerializer serializer = instantiator.serializerInstance(null, null, UnknownSerializer.class); + assertNotNull(serializer); + } + + + @Controller + private static class PlaceholderController { + + @RequestMapping(value = "${sys.login.ajax}") + private void handleWithPlaceholders() { } + } + + + private static class TestStandaloneMockMvcBuilder extends StandaloneMockMvcBuilder { + + private WebApplicationContext wac; + + private TestStandaloneMockMvcBuilder(Object... controllers) { + super(controllers); + } + + @Override + protected WebApplicationContext initWebAppContext() { + this.wac = super.initWebAppContext(); + return this.wac; + } + } + + + @Controller + private static class PersonController { + + @RequestMapping(value="/persons") + public String persons() { + return null; + } + + @RequestMapping(value="/forward") + public String forward() { + return "forward:/persons"; + } + } + + + private class ContinueFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + filterChain.doFilter(request, response); + } + } + +} diff --git a/spring-test/src/test/kotlin/org/springframework/test/web/reactive/server/WebTestClientExtensionsTests.kt b/spring-test/src/test/kotlin/org/springframework/test/web/reactive/server/WebTestClientExtensionsTests.kt new file mode 100644 index 0000000000000000000000000000000000000000..eb8ca760392e0b4fe55db986ab420651034675b3 --- /dev/null +++ b/spring-test/src/test/kotlin/org/springframework/test/web/reactive/server/WebTestClientExtensionsTests.kt @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.reactive.server + +import com.nhaarman.mockito_kotlin.mock +import com.nhaarman.mockito_kotlin.times +import com.nhaarman.mockito_kotlin.verify +import org.junit.Assert.assertEquals +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Answers +import org.mockito.Mock +import org.mockito.junit.MockitoJUnitRunner +import org.reactivestreams.Publisher +import org.springframework.web.reactive.function.server.ServerResponse.* +import org.springframework.web.reactive.function.server.router + +/** + * Mock object based tests for [WebTestClient] Kotlin extensions + * + * @author Sebastien Deleuze + */ +@RunWith(MockitoJUnitRunner::class) +class WebTestClientExtensionsTests { + + @Mock(answer = Answers.RETURNS_MOCKS) + lateinit var requestBodySpec: WebTestClient.RequestBodySpec + + @Mock(answer = Answers.RETURNS_MOCKS) + lateinit var responseSpec: WebTestClient.ResponseSpec + + + @Test + fun `RequestBodySpec#body with Publisher and reified type parameters`() { + val body = mock>() + requestBodySpec.body(body) + verify(requestBodySpec, times(1)).body(body, Foo::class.java) + } + + @Test + fun `ResponseSpec#expectBody with reified type parameters`() { + responseSpec.expectBody() + verify(responseSpec, times(1)).expectBody(Foo::class.java) + } + + @Test + fun `KotlinBodySpec#isEqualTo`() { + WebTestClient + .bindToRouterFunction( router { GET("/") { ok().syncBody("foo") } } ) + .build() + .get().uri("/").exchange().expectBody().isEqualTo("foo") + } + + @Test + fun `KotlinBodySpec#consumeWith`() { + WebTestClient + .bindToRouterFunction( router { GET("/") { ok().syncBody("foo") } } ) + .build() + .get().uri("/").exchange().expectBody().consumeWith { assertEquals("foo", it.responseBody) } + } + + @Test + fun `KotlinBodySpec#returnResult`() { + WebTestClient + .bindToRouterFunction( router { GET("/") { ok().syncBody("foo") } } ) + .build() + .get().uri("/").exchange().expectBody().returnResult().apply { assertEquals("foo", responseBody) } + } + + @Test + fun `ResponseSpec#expectBodyList with reified type parameters`() { + responseSpec.expectBodyList() + verify(responseSpec, times(1)).expectBodyList(Foo::class.java) + } + + @Test + fun `ResponseSpec#returnResult with reified type parameters`() { + responseSpec.returnResult() + verify(responseSpec, times(1)).returnResult(Foo::class.java) + } + + class Foo + +} diff --git a/spring-test/src/test/kotlin/org/springframework/test/web/servlet/result/StatusResultMatchersExtensionsTests.kt b/spring-test/src/test/kotlin/org/springframework/test/web/servlet/result/StatusResultMatchersExtensionsTests.kt new file mode 100644 index 0000000000000000000000000000000000000000..74d5a80449b46ffabf8cac76e148a2773e8207cf --- /dev/null +++ b/spring-test/src/test/kotlin/org/springframework/test/web/servlet/result/StatusResultMatchersExtensionsTests.kt @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.result + +import com.nhaarman.mockito_kotlin.mock +import com.nhaarman.mockito_kotlin.times +import com.nhaarman.mockito_kotlin.verify +import org.hamcrest.Matcher +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Answers +import org.mockito.Mock +import org.mockito.junit.MockitoJUnitRunner + +@RunWith(MockitoJUnitRunner::class) +class StatusResultMatchersExtensionsTests { + + @Mock(answer = Answers.RETURNS_MOCKS) + lateinit var matchers: StatusResultMatchers + + @Test + fun `StatusResultMatchers#is with Matcher parameter is called as expected when using isEqualTo extension`() { + val matcher = mock>() + matchers.isEqualTo(matcher) + verify(matchers, times(1)).`is`(matcher) + } + + @Test + fun `StatusResultMatchers#is with int parameter is called as expected when using isEqualTo extension`() { + matchers.isEqualTo(200) + verify(matchers, times(1)).`is`(200) + } + +} diff --git a/spring-test/src/test/resources/META-INF/spring.factories b/spring-test/src/test/resources/META-INF/spring.factories new file mode 100644 index 0000000000000000000000000000000000000000..8b6ce9b41874b3be921a7f2afdb1abd980b5e056 --- /dev/null +++ b/spring-test/src/test/resources/META-INF/spring.factories @@ -0,0 +1,2 @@ +# Test configuration file containing a non-existent default TestExecutionListener. +org.springframework.test.context.TestExecutionListener = org.example.FooListener diff --git a/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/layouts/standardLayout.jsp b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/layouts/standardLayout.jsp new file mode 100644 index 0000000000000000000000000000000000000000..51499dabc9813485424ac7655539fe95a582eb45 --- /dev/null +++ b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/layouts/standardLayout.jsp @@ -0,0 +1,12 @@ +<%@ page language="java" contentType="text/html; charset=UTF-8" pageEncoding="UTF-8"%> +<%@ taglib uri="http://tiles.apache.org/tags-tiles" prefix="tiles" %> + + + + +Title + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/layouts/tiles.xml b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/layouts/tiles.xml new file mode 100644 index 0000000000000000000000000000000000000000..978b7c187d2886be7b67d1dfc568adf0a33120f4 --- /dev/null +++ b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/layouts/tiles.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/views/home.jsp b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/views/home.jsp new file mode 100644 index 0000000000000000000000000000000000000000..59990b85d296bd796d44a6a10a7632085369cdd1 --- /dev/null +++ b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/views/home.jsp @@ -0,0 +1,2 @@ + +

Main page

diff --git a/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/views/tiles.xml b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/views/tiles.xml new file mode 100644 index 0000000000000000000000000000000000000000..19b92e6ef2e376f888bdd68f669a40c029934cdd --- /dev/null +++ b/spring-test/src/test/resources/META-INF/web-resources/WEB-INF/views/tiles.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/spring-test/src/test/resources/META-INF/web-resources/resources/Spring.js b/spring-test/src/test/resources/META-INF/web-resources/resources/Spring.js new file mode 100644 index 0000000000000000000000000000000000000000..93bd975baaa8be4bf187e7833867f429c1ab7783 --- /dev/null +++ b/spring-test/src/test/resources/META-INF/web-resources/resources/Spring.js @@ -0,0 +1,16 @@ +/* + * Copyright 2004-2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +Spring={};Spring.debug=true;Spring.decorations={};Spring.decorations.applied=false;Spring.initialize=function(){Spring.applyDecorations();Spring.remoting=new Spring.RemotingHandler();};Spring.addDecoration=function(_1){if(!Spring.decorations[_1.elementId]){Spring.decorations[_1.elementId]=[];Spring.decorations[_1.elementId].push(_1);}else{var _2=false;for(var i=0;i + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/configuration/ContextConfigurationWithPropertiesExtendingPropertiesAndInheritedLoaderTests-context.properties b/spring-test/src/test/resources/org/springframework/test/context/configuration/ContextConfigurationWithPropertiesExtendingPropertiesAndInheritedLoaderTests-context.properties new file mode 100644 index 0000000000000000000000000000000000000000..45d36076bba4a3435dfa04cc9c03d317af5a55e5 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/configuration/ContextConfigurationWithPropertiesExtendingPropertiesAndInheritedLoaderTests-context.properties @@ -0,0 +1,5 @@ +dog.(class)=org.springframework.tests.sample.beans.Pet +dog.$0=Fido + +testString2.(class)=java.lang.String +testString2.$0=Test String #2 diff --git a/spring-test/src/test/resources/org/springframework/test/context/configuration/ContextConfigurationWithPropertiesExtendingPropertiesTests-context.properties b/spring-test/src/test/resources/org/springframework/test/context/configuration/ContextConfigurationWithPropertiesExtendingPropertiesTests-context.properties new file mode 100644 index 0000000000000000000000000000000000000000..45d36076bba4a3435dfa04cc9c03d317af5a55e5 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/configuration/ContextConfigurationWithPropertiesExtendingPropertiesTests-context.properties @@ -0,0 +1,5 @@ +dog.(class)=org.springframework.tests.sample.beans.Pet +dog.$0=Fido + +testString2.(class)=java.lang.String +testString2.$0=Test String #2 diff --git a/spring-test/src/test/resources/org/springframework/test/context/env/ApplicationPropertyOverridePropertiesFileTestPropertySourceTests.properties b/spring-test/src/test/resources/org/springframework/test/context/env/ApplicationPropertyOverridePropertiesFileTestPropertySourceTests.properties new file mode 100644 index 0000000000000000000000000000000000000000..1fafe0c45342ccb5315c3ec95c8175a470dde0b3 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/env/ApplicationPropertyOverridePropertiesFileTestPropertySourceTests.properties @@ -0,0 +1 @@ +explicit = test override \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/env/DefaultPropertiesFileDetectionTestPropertySourceTests.properties b/spring-test/src/test/resources/org/springframework/test/context/env/DefaultPropertiesFileDetectionTestPropertySourceTests.properties new file mode 100644 index 0000000000000000000000000000000000000000..ae10beee1cb108bd5c28d478c576e61e87216a9d --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/env/DefaultPropertiesFileDetectionTestPropertySourceTests.properties @@ -0,0 +1 @@ +riddle = auto detected \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/env/ExtendedDefaultPropertiesFileDetectionTestPropertySourceTests.properties b/spring-test/src/test/resources/org/springframework/test/context/env/ExtendedDefaultPropertiesFileDetectionTestPropertySourceTests.properties new file mode 100644 index 0000000000000000000000000000000000000000..b3abbd1ac45530e0f830b3725d82e1fe2ff5838a --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/env/ExtendedDefaultPropertiesFileDetectionTestPropertySourceTests.properties @@ -0,0 +1 @@ +enigma = auto detected \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/env/SystemPropertyOverridePropertiesFileTestPropertySourceTests.properties b/spring-test/src/test/resources/org/springframework/test/context/env/SystemPropertyOverridePropertiesFileTestPropertySourceTests.properties new file mode 100644 index 0000000000000000000000000000000000000000..09c2b72e240a678e7b0a5ab0703c147c7500b7ec --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/env/SystemPropertyOverridePropertiesFileTestPropertySourceTests.properties @@ -0,0 +1 @@ +SystemPropertyOverridePropertiesFileTestPropertySourceTests.riddle = enigma \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/env/explicit.properties b/spring-test/src/test/resources/org/springframework/test/context/env/explicit.properties new file mode 100644 index 0000000000000000000000000000000000000000..972c736192b95853556c8936b1f0b39acfc3a228 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/env/explicit.properties @@ -0,0 +1 @@ +explicit = enigma \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/env/extended.properties b/spring-test/src/test/resources/org/springframework/test/context/env/extended.properties new file mode 100644 index 0000000000000000000000000000000000000000..6a378359854f208f7d18956ef06570681207e34d --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/env/extended.properties @@ -0,0 +1 @@ +extended = 42 \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/expression/ExpressionUsageTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/expression/ExpressionUsageTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..e14ec42fcd331e24ad337918a65f507a8936f281 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/expression/ExpressionUsageTests-context.xml @@ -0,0 +1,44 @@ + + + + + + + Dave + Andy + + + + + + + + + + #{properties['user.name']} + #{properties['username']} + #{properties[username]} + #{properties.username} + exists + exists also + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionGroovySpringContextTestsContext.groovy b/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionGroovySpringContextTestsContext.groovy new file mode 100644 index 0000000000000000000000000000000000000000..f8318a7a6fa95dabc631ebd59c29a636eac122cd --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionGroovySpringContextTestsContext.groovy @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.groovy + +import org.springframework.tests.sample.beans.Employee +import org.springframework.tests.sample.beans.Pet + +/** + * Groovy script for defining Spring beans for integration tests. + * + * @author Sam Brannen + * @since 4.1 + */ +beans { + + foo String, 'Foo' + bar String, 'Bar' + + employee(Employee) { + name = "Dilbert" + age = 42 + company = "???" + } + + pet(Pet, 'Dogbert') +} diff --git a/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionXmlSupersedesGroovySpringContextTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionXmlSupersedesGroovySpringContextTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..8e61f9cc803e4af5ca6a9bff73a4f5fa5d6eb014 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionXmlSupersedesGroovySpringContextTests-context.xml @@ -0,0 +1,9 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionXmlSupersedesGroovySpringContextTestsContext.groovy b/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionXmlSupersedesGroovySpringContextTestsContext.groovy new file mode 100644 index 0000000000000000000000000000000000000000..c76dffa16fee212c3f264172d5482a0adc21b584 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/groovy/DefaultScriptDetectionXmlSupersedesGroovySpringContextTestsContext.groovy @@ -0,0 +1,30 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.groovy + +/** + * This is intentionally an empty config file, since the XML file should + * be picked up as the default before a Groovy script. + * + *

See: {@code DefaultScriptDetectionXmlSupersedesGroovySpringContextTests-context.xml} + * + * @author Sam Brannen + * @since 4.1 + */ +beans { + +} diff --git a/spring-test/src/test/resources/org/springframework/test/context/groovy/context.groovy b/spring-test/src/test/resources/org/springframework/test/context/groovy/context.groovy new file mode 100644 index 0000000000000000000000000000000000000000..f8318a7a6fa95dabc631ebd59c29a636eac122cd --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/groovy/context.groovy @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.groovy + +import org.springframework.tests.sample.beans.Employee +import org.springframework.tests.sample.beans.Pet + +/** + * Groovy script for defining Spring beans for integration tests. + * + * @author Sam Brannen + * @since 4.1 + */ +beans { + + foo String, 'Foo' + bar String, 'Bar' + + employee(Employee) { + name = "Dilbert" + age = 42 + company = "???" + } + + pet(Pet, 'Dogbert') +} diff --git a/spring-test/src/test/resources/org/springframework/test/context/groovy/contextA.groovy b/spring-test/src/test/resources/org/springframework/test/context/groovy/contextA.groovy new file mode 100644 index 0000000000000000000000000000000000000000..1362fba0f675375274b3a987ef384cd49566de10 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/groovy/contextA.groovy @@ -0,0 +1,30 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.groovy + +/** + * Groovy script for defining Spring beans for integration tests. + * + * @author Sam Brannen + * @since 4.1 + */ +beans { + + foo String, 'Groovy Foo' + bar String, 'Groovy Bar' + +} diff --git a/spring-test/src/test/resources/org/springframework/test/context/groovy/contextB.xml b/spring-test/src/test/resources/org/springframework/test/context/groovy/contextB.xml new file mode 100644 index 0000000000000000000000000000000000000000..103412b82f50811607ca46e96b2fabbf24850ee1 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/groovy/contextB.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/hierarchies/standard/SingleTestClassWithTwoLevelContextHierarchyAndMixedConfigTypesTests-ChildConfig.xml b/spring-test/src/test/resources/org/springframework/test/context/hierarchies/standard/SingleTestClassWithTwoLevelContextHierarchyAndMixedConfigTypesTests-ChildConfig.xml new file mode 100644 index 0000000000000000000000000000000000000000..8ce6952941a8942ff4fc377817378c5f9383e4a7 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/hierarchies/standard/SingleTestClassWithTwoLevelContextHierarchyAndMixedConfigTypesTests-ChildConfig.xml @@ -0,0 +1,9 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/hierarchies/standard/TestHierarchyLevelTwoWithSingleLevelContextHierarchyAndMixedConfigTypesTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/hierarchies/standard/TestHierarchyLevelTwoWithSingleLevelContextHierarchyAndMixedConfigTypesTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..91c9f59c2499ec7c8810163e54ea21bdbae9f589 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/hierarchies/standard/TestHierarchyLevelTwoWithSingleLevelContextHierarchyAndMixedConfigTypesTests-context.xml @@ -0,0 +1,9 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/hierarchies/web/DispatcherWacRootWacEarTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/hierarchies/web/DispatcherWacRootWacEarTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..9f755c03e5e57618430351ac2f9b674eff027548 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/hierarchies/web/DispatcherWacRootWacEarTests-context.xml @@ -0,0 +1,7 @@ + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/DefaultScriptDetectionSqlScriptsTests.methodLevel.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/DefaultScriptDetectionSqlScriptsTests.methodLevel.sql new file mode 100644 index 0000000000000000000000000000000000000000..06ad7cdec9a110655dfe4a81dba8c9bb6920e8e0 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/DefaultScriptDetectionSqlScriptsTests.methodLevel.sql @@ -0,0 +1,10 @@ +DROP TABLE user IF EXISTS; + +CREATE TABLE user ( + name VARCHAR(20) NOT NULL, + PRIMARY KEY(name) +); + +INSERT INTO user VALUES('Dilbert'); +INSERT INTO user VALUES('Dogbert'); +INSERT INTO user VALUES('Catbert'); diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/DefaultScriptDetectionSqlScriptsTests.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/DefaultScriptDetectionSqlScriptsTests.sql new file mode 100644 index 0000000000000000000000000000000000000000..e60b7a4b8a5dcc2f7fd62e77e663429ec0b8a269 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/DefaultScriptDetectionSqlScriptsTests.sql @@ -0,0 +1,9 @@ +DROP TABLE user IF EXISTS; + +CREATE TABLE user ( + name VARCHAR(20) NOT NULL, + PRIMARY KEY(name) +); + +INSERT INTO user VALUES('Dilbert'); +INSERT INTO user VALUES('Dogbert'); diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-catbert.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-catbert.sql new file mode 100644 index 0000000000000000000000000000000000000000..6c034c10fa3576ce3b79885daf5ac88eae591230 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-catbert.sql @@ -0,0 +1 @@ +INSERT INTO user VALUES('Catbert'); \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-dogbert.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-dogbert.sql new file mode 100644 index 0000000000000000000000000000000000000000..f20b0a368fcb755a3f0dc97273fdf9c873d64e53 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-dogbert.sql @@ -0,0 +1 @@ +INSERT INTO user VALUES('Dogbert'); \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-users-with-custom-script-syntax.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-users-with-custom-script-syntax.sql new file mode 100644 index 0000000000000000000000000000000000000000..664a6998e5e0231bd614cee6fe756e2929ce2e1f --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data-add-users-with-custom-script-syntax.sql @@ -0,0 +1,22 @@ +` custom single-line comment + +#$ custom +block +comment +$# + +INSERT + +INTO + +user + +VALUES('Dilbert') + +@@ + +` custom single-line comment + + +INSERT INTO user VALUES('Dogbert')@@ +INSERT INTO user VALUES('Catbert')@@ diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/data.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data.sql new file mode 100644 index 0000000000000000000000000000000000000000..182a36582b8c0c8a971ff8bf8b77f1aa7f7002ae --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/data.sql @@ -0,0 +1 @@ +INSERT INTO user VALUES('Dilbert'); \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/drop-schema.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/drop-schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..3a28aade0c420d05fb7110b1b7a329a127002cf8 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/drop-schema.sql @@ -0,0 +1 @@ +DROP TABLE user IF EXISTS; diff --git a/spring-test/src/test/resources/org/springframework/test/context/jdbc/schema.sql b/spring-test/src/test/resources/org/springframework/test/context/jdbc/schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..8b0abd27206d5647ce8892133fc4091390a161f5 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/jdbc/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE user ( + name VARCHAR(20) NOT NULL, + PRIMARY KEY(name) +); diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/ConcreteTransactionalJUnit4SpringContextTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/ConcreteTransactionalJUnit4SpringContextTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..7a130ccbe1c479769eaef5ea4f5fb9c8b38b5273 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/ConcreteTransactionalJUnit4SpringContextTests-context.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/FailingBeforeAndAfterMethodsTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/FailingBeforeAndAfterMethodsTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..e410f71a7bd1c82b6a1e7c378732b7474fc5c428 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/FailingBeforeAndAfterMethodsTests-context.xml @@ -0,0 +1,12 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context1.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context1.xml new file mode 100644 index 0000000000000000000000000000000000000000..f7ade734d14920b9dbc9fee79a77e725cbdac790 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context1.xml @@ -0,0 +1,11 @@ + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context2.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context2.xml new file mode 100644 index 0000000000000000000000000000000000000000..949d90c1867ea4ef07a373159a59493036fedf7b --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context2.xml @@ -0,0 +1,9 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context3.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context3.xml new file mode 100644 index 0000000000000000000000000000000000000000..4a488e2b9413cb231310663a0cb5347658e18c4c --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/MultipleResourcesSpringJUnit4ClassRunnerAppCtxTests-context3.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/ParameterizedDependencyInjectionTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/ParameterizedDependencyInjectionTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..6df0e2936883d5ba0f43e82c916b2a3600349ebb --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/ParameterizedDependencyInjectionTests-context.xml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests-context.properties b/spring-test/src/test/resources/org/springframework/test/context/junit4/PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests-context.properties new file mode 100644 index 0000000000000000000000000000000000000000..6df81585fb437d560345c149fba26b75af497311 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/PropertiesBasedSpringJUnit4ClassRunnerAppCtxTests-context.properties @@ -0,0 +1,5 @@ +cat.(class)=org.springframework.tests.sample.beans.Pet +cat.$0=Garfield + +testString.(class)=java.lang.String +testString.$0=Test String diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerAppCtxTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerAppCtxTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..27d136cda9a8b0bc8ee1848aa3dceda9fcbdc798 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/SpringJUnit4ClassRunnerAppCtxTests-context.xml @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/aci/xml/MultipleInitializersXmlConfigTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/aci/xml/MultipleInitializersXmlConfigTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..28a19d2639956f19478d39638ba3aba1153f4363 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/aci/xml/MultipleInitializersXmlConfigTests-context.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/hybrid/HybridContextLoaderTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/hybrid/HybridContextLoaderTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..d644ddc9036fcf61fa1d18ab9bbfae03f5fd5266 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/hybrid/HybridContextLoaderTests-context.xml @@ -0,0 +1,9 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/HibernateSessionFlushingTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/HibernateSessionFlushingTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..7ef4e6b9fad8a80fb44bdbd5f167486a07ae29d8 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/HibernateSessionFlushingTests-context.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + org.hibernate.dialect.HSQLDialect + false + + + + + org/springframework/test/context/junit4/orm/domain/Person.hbm.xml + org/springframework/test/context/junit4/orm/domain/DriversLicense.hbm.xml + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/db-schema.sql b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/db-schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..3dee075af5f07818b6532f4049ac08d2774e0e17 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/db-schema.sql @@ -0,0 +1,16 @@ +DROP TABLE drivers_license IF EXISTS; +DROP TABLE person IF EXISTS; + +CREATE TABLE person ( + id INTEGER NOT NULL IDENTITY, + name VARCHAR(50) NOT NULL, + drivers_license_id INTEGER NOT NULL +); +CREATE UNIQUE INDEX person_name ON person(name); +CREATE UNIQUE INDEX person_drivers_license_id ON person(drivers_license_id); + +CREATE TABLE drivers_license ( + id INTEGER NOT NULL IDENTITY, + license_number INTEGER NOT NULL +); +CREATE UNIQUE INDEX drivers_license_license_number ON drivers_license(license_number); diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/db-test-data.sql b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/db-test-data.sql new file mode 100644 index 0000000000000000000000000000000000000000..a174e3fbed790378d8e901cf5f3c7f4f917daa39 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/db-test-data.sql @@ -0,0 +1,3 @@ +INSERT INTO drivers_license(id, license_number) values(1, 1234); + +INSERT INTO person(id, name, drivers_license_id) values(1, 'Sam', 1); diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/domain/DriversLicense.hbm.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/domain/DriversLicense.hbm.xml new file mode 100644 index 0000000000000000000000000000000000000000..8f98d7d051e89af9fdcca626811365ba59800e83 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/domain/DriversLicense.hbm.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/domain/Person.hbm.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/domain/Person.hbm.xml new file mode 100644 index 0000000000000000000000000000000000000000..b0598cc2fe22175017587958bff231cf55361b42 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/orm/domain/Person.hbm.xml @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/profile/importresource/import.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/profile/importresource/import.xml new file mode 100644 index 0000000000000000000000000000000000000000..c57db6e1fa445cf997b4b9788bc105ee28c9a079 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/profile/importresource/import.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/profile/xml/DefaultProfileXmlConfigTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/profile/xml/DefaultProfileXmlConfigTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..872e9434e4b4378a28d59ab0cf1095b6ad3baa04 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/profile/xml/DefaultProfileXmlConfigTests-context.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/BeanOverridingDefaultLocationsInheritedTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/BeanOverridingDefaultLocationsInheritedTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..960a6cd0db2976e9566ded63f641fabe72bf6f38 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/BeanOverridingDefaultLocationsInheritedTests-context.xml @@ -0,0 +1,11 @@ + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/DefaultLocationsBaseTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/DefaultLocationsBaseTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..f7ade734d14920b9dbc9fee79a77e725cbdac790 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/DefaultLocationsBaseTests-context.xml @@ -0,0 +1,11 @@ + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/DefaultLocationsInheritedTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/DefaultLocationsInheritedTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..949d90c1867ea4ef07a373159a59493036fedf7b --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr3896/DefaultLocationsInheritedTests-context.xml @@ -0,0 +1,9 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr6128/AutowiredQualifierTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr6128/AutowiredQualifierTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..e069af3a6253f221b5a7a9edd7d703dacf68d400 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr6128/AutowiredQualifierTests-context.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/datasource-config-with-auto-generated-db-name.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/datasource-config-with-auto-generated-db-name.xml new file mode 100644 index 0000000000000000000000000000000000000000..0ee91d260f786d430f6ae1739d6d58599e3f39fd --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/datasource-config-with-auto-generated-db-name.xml @@ -0,0 +1,12 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/datasource-config.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/datasource-config.xml new file mode 100644 index 0000000000000000000000000000000000000000..b0c399df53348df8bd9d62a3355b721f25a9872b --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/datasource-config.xml @@ -0,0 +1,12 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/spr8849-schema.sql b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/spr8849-schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..da1ce4b8c9820129b2c1bd68d72616b40b1ced55 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr8849/spr8849-schema.sql @@ -0,0 +1,3 @@ +CREATE TABLE enigma ( + id INTEGER NOT NULL IDENTITY +); diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/spr9799/Spr9799XmlConfigTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr9799/Spr9799XmlConfigTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..9eb8fced790afd55213af8c811c4ca045e656a32 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/spr9799/Spr9799XmlConfigTests-context.xml @@ -0,0 +1,10 @@ + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/junit4/transactionalTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/junit4/transactionalTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..0aa19a9b800daed96344357f87c95634359e8eaa --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/junit4/transactionalTests-context.xml @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/support/AbstractContextConfigurationUtilsTests$BareAnnotations-context.xml b/spring-test/src/test/resources/org/springframework/test/context/support/AbstractContextConfigurationUtilsTests$BareAnnotations-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..f34dae103e5ef76197856c12d1e4a9a522ee1706 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/support/AbstractContextConfigurationUtilsTests$BareAnnotations-context.xml @@ -0,0 +1,7 @@ + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/support/CustomizedGenericXmlContextLoaderTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/support/CustomizedGenericXmlContextLoaderTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..1718f6085a0b5a0d7e606259750adadd24800cfb --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/support/CustomizedGenericXmlContextLoaderTests-context.xml @@ -0,0 +1,7 @@ + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/support/DelegatingSmartContextLoaderTests$ImproperDuplicateDefaultXmlAndConfigClassTestCase-context.xml b/spring-test/src/test/resources/org/springframework/test/context/support/DelegatingSmartContextLoaderTests$ImproperDuplicateDefaultXmlAndConfigClassTestCase-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..d27d5a6cb45b9fa5345c23f2627274239b6f8a2c --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/support/DelegatingSmartContextLoaderTests$ImproperDuplicateDefaultXmlAndConfigClassTestCase-context.xml @@ -0,0 +1,7 @@ + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/support/DelegatingSmartContextLoaderTests$XmlTestCase-context.xml b/spring-test/src/test/resources/org/springframework/test/context/support/DelegatingSmartContextLoaderTests$XmlTestCase-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..4165b41ea7ae1e5d87a89171622d749dcc2c3af0 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/support/DelegatingSmartContextLoaderTests$XmlTestCase-context.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/support/GenericXmlContextLoaderResourceLocationsTests$1ClasspathExistentDefaultLocationsTestCase-context.xml b/spring-test/src/test/resources/org/springframework/test/context/support/GenericXmlContextLoaderResourceLocationsTests$1ClasspathExistentDefaultLocationsTestCase-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..f34dae103e5ef76197856c12d1e4a9a522ee1706 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/support/GenericXmlContextLoaderResourceLocationsTests$1ClasspathExistentDefaultLocationsTestCase-context.xml @@ -0,0 +1,7 @@ + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/testng/ConcreteTransactionalTestNGSpringContextTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/testng/ConcreteTransactionalTestNGSpringContextTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..7bb8f2cd02a15193d2d125a9222f4b3e659e3d90 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/testng/ConcreteTransactionalTestNGSpringContextTests-context.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/testng/DirtiesContextTransactionalTestNGSpringContextTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/testng/DirtiesContextTransactionalTestNGSpringContextTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..65990dfa9ea1ceb1c0be72691f8857c2d1b9a4c2 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/testng/DirtiesContextTransactionalTestNGSpringContextTests-context.xml @@ -0,0 +1,12 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/testng/FailingBeforeAndAfterMethodsTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/testng/FailingBeforeAndAfterMethodsTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..65990dfa9ea1ceb1c0be72691f8857c2d1b9a4c2 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/testng/FailingBeforeAndAfterMethodsTests-context.xml @@ -0,0 +1,12 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/testng/TimedTransactionalTestNGSpringContextTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/testng/TimedTransactionalTestNGSpringContextTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..2d78bf6520f48cfb4cf1f4d9680cf9c33fe28fcf --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/testng/TimedTransactionalTestNGSpringContextTests-context.xml @@ -0,0 +1,12 @@ + + + + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-package.xml b/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-package.xml new file mode 100644 index 0000000000000000000000000000000000000000..f609b3c0c7397928c0c493ab3593b01f633938e0 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-package.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-test-separate.xml b/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-test-separate.xml new file mode 100644 index 0000000000000000000000000000000000000000..4b92f5f2af6f691094798ff78de2a2bc4552f2cd --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-test-separate.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-test-together.xml b/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-test-together.xml new file mode 100644 index 0000000000000000000000000000000000000000..5647d5f064dce889e16e201719a0aa321ee66bde --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/testng/transaction/ejb/testng-test-together.xml @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/common-config.xml b/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/common-config.xml new file mode 100644 index 0000000000000000000000000000000000000000..226eaa5bfb9d013fefb6caf78acce551ded5a065 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/common-config.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/required-tx-config.xml b/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/required-tx-config.xml new file mode 100644 index 0000000000000000000000000000000000000000..2ae4e1c81256a006d602321dd2f888120c454720 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/required-tx-config.xml @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/requires-new-tx-config.xml b/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/requires-new-tx-config.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c332a4a1eecfc4cf586e2c611703fb989a514ca --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/transaction/ejb/requires-new-tx-config.xml @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/web/BasicGroovyWacTestsContext.groovy b/spring-test/src/test/resources/org/springframework/test/context/web/BasicGroovyWacTestsContext.groovy new file mode 100644 index 0000000000000000000000000000000000000000..dcb7f12fe9f9ed78ef786722b46ade8e7aa8302b --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/web/BasicGroovyWacTestsContext.groovy @@ -0,0 +1,3 @@ +package org.springframework.test.context.web + +beans { foo String, 'Groovy Foo' } \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/context/web/BasicXmlWacTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/web/BasicXmlWacTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..52482f485ddaca8224c7a55ff01cd656dd4cc2a3 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/web/BasicXmlWacTests-context.xml @@ -0,0 +1,9 @@ + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/context/web/RequestAndSessionScopedBeansWacTests-context.xml b/spring-test/src/test/resources/org/springframework/test/context/web/RequestAndSessionScopedBeansWacTests-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..a8412fa9308b21199bdda7917a7491a72570f592 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/context/web/RequestAndSessionScopedBeansWacTests-context.xml @@ -0,0 +1,11 @@ + + + + + + + + + + diff --git a/spring-test/src/test/resources/org/springframework/test/jdbc/data.sql b/spring-test/src/test/resources/org/springframework/test/jdbc/data.sql new file mode 100644 index 0000000000000000000000000000000000000000..10d02a9c9194fccc2bb792ccd69719bdc9d1f680 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/jdbc/data.sql @@ -0,0 +1 @@ +INSERT INTO person VALUES('bob'); \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/jdbc/schema.sql b/spring-test/src/test/resources/org/springframework/test/jdbc/schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..81d7e08db921a5fcacc25e24fca2f103da174299 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/jdbc/schema.sql @@ -0,0 +1,6 @@ +DROP TABLE person IF EXISTS; + +CREATE TABLE person ( + name VARCHAR(20) NOT NULL, + PRIMARY KEY(name) +); \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/web/client/samples/ludwig.json b/spring-test/src/test/resources/org/springframework/test/web/client/samples/ludwig.json new file mode 100644 index 0000000000000000000000000000000000000000..2b1a2f67627c1fd7cebc74a555a5a5bbabf03cc4 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/web/client/samples/ludwig.json @@ -0,0 +1,5 @@ +{ + "name" : "Ludwig van Beethoven", + "someDouble" : "1.6035", + "someBoolean" : "true" +} \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/web/servlet/samples/context/root-context.xml b/spring-test/src/test/resources/org/springframework/test/web/servlet/samples/context/root-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..68812aeff525fa29e17ed77e8b637bf243cd35ec --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/web/servlet/samples/context/root-context.xml @@ -0,0 +1,12 @@ + + + + + + + + \ No newline at end of file diff --git a/spring-test/src/test/resources/org/springframework/test/web/servlet/samples/context/servlet-context.xml b/spring-test/src/test/resources/org/springframework/test/web/servlet/samples/context/servlet-context.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c442c5d7b8a55a6513b4682fe03790734f048b6 --- /dev/null +++ b/spring-test/src/test/resources/org/springframework/test/web/servlet/samples/context/servlet-context.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + /WEB-INF/**/tiles.xml + + + + + \ No newline at end of file diff --git a/spring-test/src/test/webapp/WEB-INF/layouts/main.jsp b/spring-test/src/test/webapp/WEB-INF/layouts/main.jsp new file mode 100644 index 0000000000000000000000000000000000000000..a3f3f6584ae94d46a811635a62261ba6c1ce99f6 --- /dev/null +++ b/spring-test/src/test/webapp/WEB-INF/layouts/main.jsp @@ -0,0 +1,12 @@ +<%@ page language="java" contentType="text/html; charset=UTF-8" + pageEncoding="UTF-8"%> + + + + +Fake Layout + + + + + \ No newline at end of file diff --git a/spring-test/src/test/webapp/WEB-INF/layouts/tiles.xml b/spring-test/src/test/webapp/WEB-INF/layouts/tiles.xml new file mode 100644 index 0000000000000000000000000000000000000000..d2444d74b1ac508d104762236b36f7bb772afd83 --- /dev/null +++ b/spring-test/src/test/webapp/WEB-INF/layouts/tiles.xml @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/spring-test/src/test/webapp/WEB-INF/readme.txt b/spring-test/src/test/webapp/WEB-INF/readme.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e3b6e0fb1bad0c5161beac448ae88e717dff634 --- /dev/null +++ b/spring-test/src/test/webapp/WEB-INF/readme.txt @@ -0,0 +1,2 @@ + +Dummy web application for testing purposes. \ No newline at end of file diff --git a/spring-test/src/test/webapp/WEB-INF/views/tiles.xml b/spring-test/src/test/webapp/WEB-INF/views/tiles.xml new file mode 100644 index 0000000000000000000000000000000000000000..c332299729133b257e66485358ef0e50dff943a3 --- /dev/null +++ b/spring-test/src/test/webapp/WEB-INF/views/tiles.xml @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/spring-test/src/test/webapp/resources/Spring.js b/spring-test/src/test/webapp/resources/Spring.js new file mode 100644 index 0000000000000000000000000000000000000000..93bd975baaa8be4bf187e7833867f429c1ab7783 --- /dev/null +++ b/spring-test/src/test/webapp/resources/Spring.js @@ -0,0 +1,16 @@ +/* + * Copyright 2004-2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +Spring={};Spring.debug=true;Spring.decorations={};Spring.decorations.applied=false;Spring.initialize=function(){Spring.applyDecorations();Spring.remoting=new Spring.RemotingHandler();};Spring.addDecoration=function(_1){if(!Spring.decorations[_1.elementId]){Spring.decorations[_1.elementId]=[];Spring.decorations[_1.elementId].push(_1);}else{var _2=false;for(var i=0;iFor example, this exception or a subclass might be thrown if a JDBC + * Connection couldn't be closed after it had been used successfully. + * + *

Note that data access code might perform resources cleanup in a + * finally block and therefore log cleanup failure rather than rethrow it, + * to keep the original data access exception, if any. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public class CleanupFailureDataAccessException extends NonTransientDataAccessException { + + /** + * Constructor for CleanupFailureDataAccessException. + * @param msg the detail message + * @param cause the root cause from the underlying data access API, + * such as JDBC + */ + public CleanupFailureDataAccessException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/ConcurrencyFailureException.java b/spring-tx/src/main/java/org/springframework/dao/ConcurrencyFailureException.java new file mode 100644 index 0000000000000000000000000000000000000000..cc990803f4837259ce7d658a0a9cfc40a15c9a65 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/ConcurrencyFailureException.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Exception thrown on concurrency failure. + * + *

This exception should be subclassed to indicate the type of failure: + * optimistic locking, failure to acquire lock, etc. + * + * @author Thomas Risberg + * @since 1.1 + * @see OptimisticLockingFailureException + * @see PessimisticLockingFailureException + * @see CannotAcquireLockException + * @see DeadlockLoserDataAccessException + */ +@SuppressWarnings("serial") +public class ConcurrencyFailureException extends TransientDataAccessException { + + /** + * Constructor for ConcurrencyFailureException. + * @param msg the detail message + */ + public ConcurrencyFailureException(String msg) { + super(msg); + } + + /** + * Constructor for ConcurrencyFailureException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public ConcurrencyFailureException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/DataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/DataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..85113ce5a59b6efd60e98debe125ef0b166803db --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/DataAccessException.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.core.NestedRuntimeException; +import org.springframework.lang.Nullable; + +/** + * Root of the hierarchy of data access exceptions discussed in + * Expert One-On-One J2EE Design and Development. + * Please see Chapter 9 of this book for detailed discussion of the + * motivation for this package. + * + *

This exception hierarchy aims to let user code find and handle the + * kind of error encountered without knowing the details of the particular + * data access API in use (e.g. JDBC). Thus it is possible to react to an + * optimistic locking failure without knowing that JDBC is being used. + * + *

As this class is a runtime exception, there is no need for user code + * to catch it or subclasses if any error is to be considered fatal + * (the usual case). + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public abstract class DataAccessException extends NestedRuntimeException { + + /** + * Constructor for DataAccessException. + * @param msg the detail message + */ + public DataAccessException(String msg) { + super(msg); + } + + /** + * Constructor for DataAccessException. + * @param msg the detail message + * @param cause the root cause (usually from using a underlying + * data access API such as JDBC) + */ + public DataAccessException(@Nullable String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/DataAccessResourceFailureException.java b/spring-tx/src/main/java/org/springframework/dao/DataAccessResourceFailureException.java new file mode 100644 index 0000000000000000000000000000000000000000..f3cd851aaa1d922b9d4ab5e17f58de4e51d43250 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/DataAccessResourceFailureException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Data access exception thrown when a resource fails completely: + * for example, if we can't connect to a database using JDBC. + * + * @author Rod Johnson + * @author Thomas Risberg + */ +@SuppressWarnings("serial") +public class DataAccessResourceFailureException extends NonTransientDataAccessResourceException { + + /** + * Constructor for DataAccessResourceFailureException. + * @param msg the detail message + */ + public DataAccessResourceFailureException(String msg) { + super(msg); + } + + /** + * Constructor for DataAccessResourceFailureException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public DataAccessResourceFailureException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/DataIntegrityViolationException.java b/spring-tx/src/main/java/org/springframework/dao/DataIntegrityViolationException.java new file mode 100644 index 0000000000000000000000000000000000000000..b1347e1ee94657200ed3a569a82619ec32877d0d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/DataIntegrityViolationException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Exception thrown when an attempt to insert or update data + * results in violation of an integrity constraint. Note that this + * is not purely a relational concept; unique primary keys are + * required by most database types. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public class DataIntegrityViolationException extends NonTransientDataAccessException { + + /** + * Constructor for DataIntegrityViolationException. + * @param msg the detail message + */ + public DataIntegrityViolationException(String msg) { + super(msg); + } + + /** + * Constructor for DataIntegrityViolationException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public DataIntegrityViolationException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/DataRetrievalFailureException.java b/spring-tx/src/main/java/org/springframework/dao/DataRetrievalFailureException.java new file mode 100644 index 0000000000000000000000000000000000000000..0ab0c91aafcfb70ad357acdb7a256d47c6a5837c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/DataRetrievalFailureException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Exception thrown if certain expected data could not be retrieved, e.g. + * when looking up specific data via a known identifier. This exception + * will be thrown either by O/R mapping tools or by DAO implementations. + * + * @author Juergen Hoeller + * @since 13.10.2003 + */ +@SuppressWarnings("serial") +public class DataRetrievalFailureException extends NonTransientDataAccessException { + + /** + * Constructor for DataRetrievalFailureException. + * @param msg the detail message + */ + public DataRetrievalFailureException(String msg) { + super(msg); + } + + /** + * Constructor for DataRetrievalFailureException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public DataRetrievalFailureException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/DeadlockLoserDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/DeadlockLoserDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..3c946cd56408d93e7fcc4be43589a0c79cfba8cd --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/DeadlockLoserDataAccessException.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Generic exception thrown when the current process was + * a deadlock loser, and its transaction rolled back. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public class DeadlockLoserDataAccessException extends PessimisticLockingFailureException { + + /** + * Constructor for DeadlockLoserDataAccessException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public DeadlockLoserDataAccessException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/DuplicateKeyException.java b/spring-tx/src/main/java/org/springframework/dao/DuplicateKeyException.java new file mode 100644 index 0000000000000000000000000000000000000000..f31a3a172d91a5d305f72eaef614d884085e95a5 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/DuplicateKeyException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Exception thrown when an attempt to insert or update data + * results in violation of an primary key or unique constraint. + * Note that this is not necessarily a purely relational concept; + * unique primary keys are required by most database types. + * + * @author Thomas Risberg + */ +@SuppressWarnings("serial") +public class DuplicateKeyException extends DataIntegrityViolationException { + + /** + * Constructor for DuplicateKeyException. + * @param msg the detail message + */ + public DuplicateKeyException(String msg) { + super(msg); + } + + /** + * Constructor for DuplicateKeyException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public DuplicateKeyException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/EmptyResultDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/EmptyResultDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..535870a423ef536e0127614cfb6865cea2ceea1a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/EmptyResultDataAccessException.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Data access exception thrown when a result was expected to have at least + * one row (or element) but zero rows (or elements) were actually returned. + * + * @author Juergen Hoeller + * @since 2.0 + * @see IncorrectResultSizeDataAccessException + */ +@SuppressWarnings("serial") +public class EmptyResultDataAccessException extends IncorrectResultSizeDataAccessException { + + /** + * Constructor for EmptyResultDataAccessException. + * @param expectedSize the expected result size + */ + public EmptyResultDataAccessException(int expectedSize) { + super(expectedSize, 0); + } + + /** + * Constructor for EmptyResultDataAccessException. + * @param msg the detail message + * @param expectedSize the expected result size + */ + public EmptyResultDataAccessException(String msg, int expectedSize) { + super(msg, expectedSize, 0); + } + + /** + * Constructor for EmptyResultDataAccessException. + * @param msg the detail message + * @param expectedSize the expected result size + * @param ex the wrapped exception + */ + public EmptyResultDataAccessException(String msg, int expectedSize, Throwable ex) { + super(msg, expectedSize, 0, ex); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/IncorrectResultSizeDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/IncorrectResultSizeDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..d0fdb2a542a7932cb9d38122cec75589d98ed167 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/IncorrectResultSizeDataAccessException.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Data access exception thrown when a result was not of the expected size, + * for example when expecting a single row but getting 0 or more than 1 rows. + * + * @author Juergen Hoeller + * @author Chris Beams + * @since 1.0.2 + * @see EmptyResultDataAccessException + */ +@SuppressWarnings("serial") +public class IncorrectResultSizeDataAccessException extends DataRetrievalFailureException { + + private final int expectedSize; + + private final int actualSize; + + + /** + * Constructor for IncorrectResultSizeDataAccessException. + * @param expectedSize the expected result size + */ + public IncorrectResultSizeDataAccessException(int expectedSize) { + super("Incorrect result size: expected " + expectedSize); + this.expectedSize = expectedSize; + this.actualSize = -1; + } + + /** + * Constructor for IncorrectResultSizeDataAccessException. + * @param expectedSize the expected result size + * @param actualSize the actual result size (or -1 if unknown) + */ + public IncorrectResultSizeDataAccessException(int expectedSize, int actualSize) { + super("Incorrect result size: expected " + expectedSize + ", actual " + actualSize); + this.expectedSize = expectedSize; + this.actualSize = actualSize; + } + + /** + * Constructor for IncorrectResultSizeDataAccessException. + * @param msg the detail message + * @param expectedSize the expected result size + */ + public IncorrectResultSizeDataAccessException(String msg, int expectedSize) { + super(msg); + this.expectedSize = expectedSize; + this.actualSize = -1; + } + + /** + * Constructor for IncorrectResultSizeDataAccessException. + * @param msg the detail message + * @param expectedSize the expected result size + * @param ex the wrapped exception + */ + public IncorrectResultSizeDataAccessException(String msg, int expectedSize, Throwable ex) { + super(msg, ex); + this.expectedSize = expectedSize; + this.actualSize = -1; + } + + /** + * Constructor for IncorrectResultSizeDataAccessException. + * @param msg the detail message + * @param expectedSize the expected result size + * @param actualSize the actual result size (or -1 if unknown) + */ + public IncorrectResultSizeDataAccessException(String msg, int expectedSize, int actualSize) { + super(msg); + this.expectedSize = expectedSize; + this.actualSize = actualSize; + } + + /** + * Constructor for IncorrectResultSizeDataAccessException. + * @param msg the detail message + * @param expectedSize the expected result size + * @param actualSize the actual result size (or -1 if unknown) + * @param ex the wrapped exception + */ + public IncorrectResultSizeDataAccessException(String msg, int expectedSize, int actualSize, Throwable ex) { + super(msg, ex); + this.expectedSize = expectedSize; + this.actualSize = actualSize; + } + + + /** + * Return the expected result size. + */ + public int getExpectedSize() { + return this.expectedSize; + } + + /** + * Return the actual result size (or -1 if unknown). + */ + public int getActualSize() { + return this.actualSize; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/IncorrectUpdateSemanticsDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/IncorrectUpdateSemanticsDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..70689977cd2ec8a913f799a692af7ec79ae793e5 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/IncorrectUpdateSemanticsDataAccessException.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Data access exception thrown when something unintended appears to have + * happened with an update, but the transaction hasn't already been rolled back. + * Thrown, for example, when we wanted to update 1 row in an RDBMS but actually + * updated 3. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public class IncorrectUpdateSemanticsDataAccessException extends InvalidDataAccessResourceUsageException { + + /** + * Constructor for IncorrectUpdateSemanticsDataAccessException. + * @param msg the detail message + */ + public IncorrectUpdateSemanticsDataAccessException(String msg) { + super(msg); + } + + /** + * Constructor for IncorrectUpdateSemanticsDataAccessException. + * @param msg the detail message + * @param cause the root cause from the underlying API, such as JDBC + */ + public IncorrectUpdateSemanticsDataAccessException(String msg, Throwable cause) { + super(msg, cause); + } + + /** + * Return whether data was updated. + * If this method returns false, there's nothing to roll back. + *

The default implementation always returns true. + * This can be overridden in subclasses. + */ + public boolean wasDataUpdated() { + return true; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/InvalidDataAccessApiUsageException.java b/spring-tx/src/main/java/org/springframework/dao/InvalidDataAccessApiUsageException.java new file mode 100644 index 0000000000000000000000000000000000000000..54daec40dcf773856a540231a55e23ee30646c05 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/InvalidDataAccessApiUsageException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Exception thrown on incorrect usage of the API, such as failing to + * "compile" a query object that needed compilation before execution. + * + *

This represents a problem in our Java data access framework, + * not the underlying data access infrastructure. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public class InvalidDataAccessApiUsageException extends NonTransientDataAccessException { + + /** + * Constructor for InvalidDataAccessApiUsageException. + * @param msg the detail message + */ + public InvalidDataAccessApiUsageException(String msg) { + super(msg); + } + + /** + * Constructor for InvalidDataAccessApiUsageException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public InvalidDataAccessApiUsageException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/InvalidDataAccessResourceUsageException.java b/spring-tx/src/main/java/org/springframework/dao/InvalidDataAccessResourceUsageException.java new file mode 100644 index 0000000000000000000000000000000000000000..218da9e056b61837e44b60dd85994ddc49e6f909 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/InvalidDataAccessResourceUsageException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Root for exceptions thrown when we use a data access resource incorrectly. + * Thrown for example on specifying bad SQL when using a RDBMS. + * Resource-specific subclasses are supplied by concrete data access packages. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public class InvalidDataAccessResourceUsageException extends NonTransientDataAccessException { + + /** + * Constructor for InvalidDataAccessResourceUsageException. + * @param msg the detail message + */ + public InvalidDataAccessResourceUsageException(String msg) { + super(msg); + } + + /** + * Constructor for InvalidDataAccessResourceUsageException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public InvalidDataAccessResourceUsageException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/NonTransientDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/NonTransientDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..acf5770509bc0d5172af0ada10378444deabc9a4 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/NonTransientDataAccessException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Root of the hierarchy of data access exceptions that are considered non-transient - + * where a retry of the same operation would fail unless the cause of the Exception + * is corrected. + * + * @author Thomas Risberg + * @since 2.5 + * @see java.sql.SQLNonTransientException + */ +@SuppressWarnings("serial") +public abstract class NonTransientDataAccessException extends DataAccessException { + + /** + * Constructor for NonTransientDataAccessException. + * @param msg the detail message + */ + public NonTransientDataAccessException(String msg) { + super(msg); + } + + /** + * Constructor for NonTransientDataAccessException. + * @param msg the detail message + * @param cause the root cause (usually from using a underlying + * data access API such as JDBC) + */ + public NonTransientDataAccessException(@Nullable String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/NonTransientDataAccessResourceException.java b/spring-tx/src/main/java/org/springframework/dao/NonTransientDataAccessResourceException.java new file mode 100644 index 0000000000000000000000000000000000000000..78094e6ce9b225526ea4fb51867aec999b7a3d5c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/NonTransientDataAccessResourceException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Data access exception thrown when a resource fails completely and the failure is permanent. + * + * @author Thomas Risberg + * @since 2.5 + * @see java.sql.SQLNonTransientConnectionException + */ +@SuppressWarnings("serial") +public class NonTransientDataAccessResourceException extends NonTransientDataAccessException { + + /** + * Constructor for NonTransientDataAccessResourceException. + * @param msg the detail message + */ + public NonTransientDataAccessResourceException(String msg) { + super(msg); + } + + /** + * Constructor for NonTransientDataAccessResourceException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public NonTransientDataAccessResourceException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/OptimisticLockingFailureException.java b/spring-tx/src/main/java/org/springframework/dao/OptimisticLockingFailureException.java new file mode 100644 index 0000000000000000000000000000000000000000..c1c8189d7ab81dce0f2c98a3681d295f721ecd78 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/OptimisticLockingFailureException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Exception thrown on an optimistic locking violation. + * + *

This exception will be thrown either by O/R mapping tools + * or by custom DAO implementations. Optimistic locking failure + * is typically not detected by the database itself. + * + * @author Rod Johnson + * @see PessimisticLockingFailureException + */ +@SuppressWarnings("serial") +public class OptimisticLockingFailureException extends ConcurrencyFailureException { + + /** + * Constructor for OptimisticLockingFailureException. + * @param msg the detail message + */ + public OptimisticLockingFailureException(String msg) { + super(msg); + } + + /** + * Constructor for OptimisticLockingFailureException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public OptimisticLockingFailureException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/PermissionDeniedDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/PermissionDeniedDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..6cfe0be2b49260a0c055ecb45e3e3f5e6017c98a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/PermissionDeniedDataAccessException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Exception thrown when the underlying resource denied a permission + * to access a specific element, such as a specific database table. + * + * @author Juergen Hoeller + * @since 2.0 + */ +@SuppressWarnings("serial") +public class PermissionDeniedDataAccessException extends NonTransientDataAccessException { + + /** + * Constructor for PermissionDeniedDataAccessException. + * @param msg the detail message + * @param cause the root cause from the underlying data access API, + * such as JDBC + */ + public PermissionDeniedDataAccessException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/PessimisticLockingFailureException.java b/spring-tx/src/main/java/org/springframework/dao/PessimisticLockingFailureException.java new file mode 100644 index 0000000000000000000000000000000000000000..3e60eb141bcac99ec603c86dfd48ebc6ba82f16a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/PessimisticLockingFailureException.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Exception thrown on a pessimistic locking violation. + * Thrown by Spring's SQLException translation mechanism + * if a corresponding database error is encountered. + * + *

Serves as superclass for more specific exceptions, like + * CannotAcquireLockException and DeadlockLoserDataAccessException. + * + * @author Thomas Risberg + * @since 1.2 + * @see CannotAcquireLockException + * @see DeadlockLoserDataAccessException + * @see OptimisticLockingFailureException + */ +@SuppressWarnings("serial") +public class PessimisticLockingFailureException extends ConcurrencyFailureException { + + /** + * Constructor for PessimisticLockingFailureException. + * @param msg the detail message + */ + public PessimisticLockingFailureException(String msg) { + super(msg); + } + + /** + * Constructor for PessimisticLockingFailureException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public PessimisticLockingFailureException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/QueryTimeoutException.java b/spring-tx/src/main/java/org/springframework/dao/QueryTimeoutException.java new file mode 100644 index 0000000000000000000000000000000000000000..31dccdab2065b170170ab6ec4b5b37bf6f18401c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/QueryTimeoutException.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Exception to be thrown on a query timeout. This could have different causes depending on + * the database API in use but most likely thrown after the database interrupts or stops + * the processing of a query before it has completed. + * + *

This exception can be thrown by user code trapping the native database exception or + * by exception translation. + * + * @author Thomas Risberg + * @since 3.1 + */ +@SuppressWarnings("serial") +public class QueryTimeoutException extends TransientDataAccessException { + + /** + * Constructor for QueryTimeoutException. + * @param msg the detail message + */ + public QueryTimeoutException(String msg) { + super(msg); + } + + /** + * Constructor for QueryTimeoutException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public QueryTimeoutException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/RecoverableDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/RecoverableDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..21f292e24153568b1068a698db3d3595d3444036 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/RecoverableDataAccessException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Data access exception thrown when a previously failed operation might be able + * to succeed if the application performs some recovery steps and retries the entire + * transaction or in the case of a distributed transaction, the transaction branch. + * At a minimum, the recovery operation must include closing the current connection + * and getting a new connection. + * + * @author Thomas Risberg + * @since 2.5 + * @see java.sql.SQLRecoverableException + */ +@SuppressWarnings("serial") +public class RecoverableDataAccessException extends DataAccessException { + + /** + * Constructor for RecoverableDataAccessException. + * @param msg the detail message + */ + public RecoverableDataAccessException(String msg) { + super(msg); + } + + /** + * Constructor for RecoverableDataAccessException. + * @param msg the detail message + * @param cause the root cause (usually from using a underlying + * data access API such as JDBC) + */ + public RecoverableDataAccessException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/TransientDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/TransientDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..420715092b5f9a1186581411e94fa2f8e3860780 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/TransientDataAccessException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Root of the hierarchy of data access exceptions that are considered transient - + * where a previously failed operation might be able to succeed when the operation + * is retried without any intervention by application-level functionality. + * + * @author Thomas Risberg + * @since 2.5 + * @see java.sql.SQLTransientException + */ +@SuppressWarnings("serial") +public abstract class TransientDataAccessException extends DataAccessException { + + /** + * Constructor for TransientDataAccessException. + * @param msg the detail message + */ + public TransientDataAccessException(String msg) { + super(msg); + } + + /** + * Constructor for TransientDataAccessException. + * @param msg the detail message + * @param cause the root cause (usually from using a underlying + * data access API such as JDBC) + */ + public TransientDataAccessException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/TransientDataAccessResourceException.java b/spring-tx/src/main/java/org/springframework/dao/TransientDataAccessResourceException.java new file mode 100644 index 0000000000000000000000000000000000000000..298c53596df3db90354ea5df53dc7a5e07f3bb43 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/TransientDataAccessResourceException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Data access exception thrown when a resource fails temporarily + * and the operation can be retried. + * + * @author Thomas Risberg + * @since 2.5 + * @see java.sql.SQLTransientConnectionException + */ +@SuppressWarnings("serial") +public class TransientDataAccessResourceException extends TransientDataAccessException { + + /** + * Constructor for TransientDataAccessResourceException. + * @param msg the detail message + */ + public TransientDataAccessResourceException(String msg) { + super(msg); + } + + /** + * Constructor for TransientDataAccessResourceException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public TransientDataAccessResourceException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/TypeMismatchDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/TypeMismatchDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..2808cfe04efaf515243bc051a06736899b57a8a4 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/TypeMismatchDataAccessException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +/** + * Exception thrown on mismatch between Java type and database type: + * for example on an attempt to set an object of the wrong type + * in an RDBMS column. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public class TypeMismatchDataAccessException extends InvalidDataAccessResourceUsageException { + + /** + * Constructor for TypeMismatchDataAccessException. + * @param msg the detail message + */ + public TypeMismatchDataAccessException(String msg) { + super(msg); + } + + /** + * Constructor for TypeMismatchDataAccessException. + * @param msg the detail message + * @param cause the root cause from the data access API in use + */ + public TypeMismatchDataAccessException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/UncategorizedDataAccessException.java b/spring-tx/src/main/java/org/springframework/dao/UncategorizedDataAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..f74ac67a7617c923dfb2dbb7fc1df4cbca7106bf --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/UncategorizedDataAccessException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao; + +import org.springframework.lang.Nullable; + +/** + * Normal superclass when we can't distinguish anything more specific + * than "something went wrong with the underlying resource": for example, + * a SQLException from JDBC we can't pinpoint more precisely. + * + * @author Rod Johnson + */ +@SuppressWarnings("serial") +public abstract class UncategorizedDataAccessException extends NonTransientDataAccessException { + + /** + * Constructor for UncategorizedDataAccessException. + * @param msg the detail message + * @param cause the exception thrown by underlying data access API + */ + public UncategorizedDataAccessException(@Nullable String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/annotation/PersistenceExceptionTranslationAdvisor.java b/spring-tx/src/main/java/org/springframework/dao/annotation/PersistenceExceptionTranslationAdvisor.java new file mode 100644 index 0000000000000000000000000000000000000000..ce2390707c2ce292627bce17fbf29d09af595e36 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/annotation/PersistenceExceptionTranslationAdvisor.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.annotation; + +import java.lang.annotation.Annotation; + +import org.aopalliance.aop.Advice; + +import org.springframework.aop.Pointcut; +import org.springframework.aop.support.AbstractPointcutAdvisor; +import org.springframework.aop.support.annotation.AnnotationMatchingPointcut; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.dao.support.PersistenceExceptionTranslationInterceptor; +import org.springframework.dao.support.PersistenceExceptionTranslator; + +/** + * Spring AOP exception translation aspect for use at Repository or DAO layer level. + * Translates native persistence exceptions into Spring's DataAccessException hierarchy, + * based on a given PersistenceExceptionTranslator. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 2.0 + * @see org.springframework.dao.DataAccessException + * @see org.springframework.dao.support.PersistenceExceptionTranslator + */ +@SuppressWarnings("serial") +public class PersistenceExceptionTranslationAdvisor extends AbstractPointcutAdvisor { + + private final PersistenceExceptionTranslationInterceptor advice; + + private final AnnotationMatchingPointcut pointcut; + + + /** + * Create a new PersistenceExceptionTranslationAdvisor. + * @param persistenceExceptionTranslator the PersistenceExceptionTranslator to use + * @param repositoryAnnotationType the annotation type to check for + */ + public PersistenceExceptionTranslationAdvisor( + PersistenceExceptionTranslator persistenceExceptionTranslator, + Class repositoryAnnotationType) { + + this.advice = new PersistenceExceptionTranslationInterceptor(persistenceExceptionTranslator); + this.pointcut = new AnnotationMatchingPointcut(repositoryAnnotationType, true); + } + + /** + * Create a new PersistenceExceptionTranslationAdvisor. + * @param beanFactory the ListableBeanFactory to obtaining all + * PersistenceExceptionTranslators from + * @param repositoryAnnotationType the annotation type to check for + */ + PersistenceExceptionTranslationAdvisor( + ListableBeanFactory beanFactory, Class repositoryAnnotationType) { + + this.advice = new PersistenceExceptionTranslationInterceptor(beanFactory); + this.pointcut = new AnnotationMatchingPointcut(repositoryAnnotationType, true); + } + + + @Override + public Advice getAdvice() { + return this.advice; + } + + @Override + public Pointcut getPointcut() { + return this.pointcut; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/annotation/PersistenceExceptionTranslationPostProcessor.java b/spring-tx/src/main/java/org/springframework/dao/annotation/PersistenceExceptionTranslationPostProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..c975e9ec6b056c3f8930743744586501314463f4 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/annotation/PersistenceExceptionTranslationPostProcessor.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.annotation; + +import java.lang.annotation.Annotation; + +import org.springframework.aop.framework.autoproxy.AbstractBeanFactoryAwareAdvisingPostProcessor; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.stereotype.Repository; +import org.springframework.util.Assert; + +/** + * Bean post-processor that automatically applies persistence exception translation to any + * bean marked with Spring's @{@link org.springframework.stereotype.Repository Repository} + * annotation, adding a corresponding {@link PersistenceExceptionTranslationAdvisor} to + * the exposed proxy (either an existing AOP proxy or a newly generated proxy that + * implements all of the target's interfaces). + * + *

Translates native resource exceptions to Spring's + * {@link org.springframework.dao.DataAccessException DataAccessException} hierarchy. + * Autodetects beans that implement the + * {@link org.springframework.dao.support.PersistenceExceptionTranslator + * PersistenceExceptionTranslator} interface, which are subsequently asked to translate + * candidate exceptions. + * + + *

All of Spring's applicable resource factories (e.g. + * {@link org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean}) + * implement the {@code PersistenceExceptionTranslator} interface out of the box. + * As a consequence, all that is usually needed to enable automatic exception + * translation is marking all affected beans (such as Repositories or DAOs) + * with the {@code @Repository} annotation, along with defining this post-processor + * as a bean in the application context. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 2.0 + * @see PersistenceExceptionTranslationAdvisor + * @see org.springframework.stereotype.Repository + * @see org.springframework.dao.DataAccessException + * @see org.springframework.dao.support.PersistenceExceptionTranslator + */ +@SuppressWarnings("serial") +public class PersistenceExceptionTranslationPostProcessor extends AbstractBeanFactoryAwareAdvisingPostProcessor { + + private Class repositoryAnnotationType = Repository.class; + + + /** + * Set the 'repository' annotation type. + * The default repository annotation type is the {@link Repository} annotation. + *

This setter property exists so that developers can provide their own + * (non-Spring-specific) annotation type to indicate that a class has a + * repository role. + * @param repositoryAnnotationType the desired annotation type + */ + public void setRepositoryAnnotationType(Class repositoryAnnotationType) { + Assert.notNull(repositoryAnnotationType, "'repositoryAnnotationType' must not be null"); + this.repositoryAnnotationType = repositoryAnnotationType; + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) { + super.setBeanFactory(beanFactory); + + if (!(beanFactory instanceof ListableBeanFactory)) { + throw new IllegalArgumentException( + "Cannot use PersistenceExceptionTranslator autodetection without ListableBeanFactory"); + } + this.advisor = new PersistenceExceptionTranslationAdvisor( + (ListableBeanFactory) beanFactory, this.repositoryAnnotationType); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/annotation/package-info.java b/spring-tx/src/main/java/org/springframework/dao/annotation/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..17d23ed157a126b113aba13cb407f6861513263c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/annotation/package-info.java @@ -0,0 +1,10 @@ +/** + * Annotation support for DAOs. Contains a bean post-processor for translating + * persistence exceptions based on a repository stereotype annotation. + */ +@NonNullApi +@NonNullFields +package org.springframework.dao.annotation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/dao/package-info.java b/spring-tx/src/main/java/org/springframework/dao/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..da3fd7911e4eea79e39304bbe8686aa2cc0048d9 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/package-info.java @@ -0,0 +1,21 @@ +/** + * Exception hierarchy enabling sophisticated error handling independent + * of the data access approach in use. For example, when DAOs and data + * access frameworks use the exceptions in this package (and custom + * subclasses), calling code can detect and handle common problems such + * as deadlocks without being tied to a particular data access strategy, + * such as JDBC. + * + *

All these exceptions are unchecked, meaning that calling code can + * leave them uncaught and treat all data access exceptions as fatal. + * + *

The classes in this package are discussed in Chapter 9 of + * Expert One-On-One J2EE Design and Development + * by Rod Johnson (Wrox, 2002). + */ +@NonNullApi +@NonNullFields +package org.springframework.dao; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/dao/support/ChainedPersistenceExceptionTranslator.java b/spring-tx/src/main/java/org/springframework/dao/support/ChainedPersistenceExceptionTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..0496afe5eac932549ca4f6d9a0e047a4812ac162 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/support/ChainedPersistenceExceptionTranslator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.support; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link PersistenceExceptionTranslator} that supports chaining, + * allowing the addition of PersistenceExceptionTranslator instances in order. + * Returns {@code non-null} on the first (if any) match. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 2.0 + */ +public class ChainedPersistenceExceptionTranslator implements PersistenceExceptionTranslator { + + /** List of PersistenceExceptionTranslators. */ + private final List delegates = new ArrayList<>(4); + + + /** + * Add a PersistenceExceptionTranslator to the chained delegate list. + */ + public final void addDelegate(PersistenceExceptionTranslator pet) { + Assert.notNull(pet, "PersistenceExceptionTranslator must not be null"); + this.delegates.add(pet); + } + + /** + * Return all registered PersistenceExceptionTranslator delegates (as array). + */ + public final PersistenceExceptionTranslator[] getDelegates() { + return this.delegates.toArray(new PersistenceExceptionTranslator[0]); + } + + + @Override + @Nullable + public DataAccessException translateExceptionIfPossible(RuntimeException ex) { + for (PersistenceExceptionTranslator pet : this.delegates) { + DataAccessException translatedDex = pet.translateExceptionIfPossible(ex); + if (translatedDex != null) { + return translatedDex; + } + } + return null; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/support/DaoSupport.java b/spring-tx/src/main/java/org/springframework/dao/support/DaoSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..16b264ed8a6f2649b7fae81d60373a1ae3187a8e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/support/DaoSupport.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.support; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.BeanInitializationException; +import org.springframework.beans.factory.InitializingBean; + +/** + * Generic base class for DAOs, defining template methods for DAO initialization. + * + *

Extended by Spring's specific DAO support classes, such as: + * JdbcDaoSupport, JdoDaoSupport, etc. + * + * @author Juergen Hoeller + * @since 1.2.2 + * @see org.springframework.jdbc.core.support.JdbcDaoSupport + */ +public abstract class DaoSupport implements InitializingBean { + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + + @Override + public final void afterPropertiesSet() throws IllegalArgumentException, BeanInitializationException { + // Let abstract subclasses check their configuration. + checkDaoConfig(); + + // Let concrete implementations initialize themselves. + try { + initDao(); + } + catch (Exception ex) { + throw new BeanInitializationException("Initialization of DAO failed", ex); + } + } + + /** + * Abstract subclasses must override this to check their configuration. + *

Implementors should be marked as {@code final} if concrete subclasses + * are not supposed to override this template method themselves. + * @throws IllegalArgumentException in case of illegal configuration + */ + protected abstract void checkDaoConfig() throws IllegalArgumentException; + + /** + * Concrete subclasses can override this for custom initialization behavior. + * Gets called after population of this instance's bean properties. + * @throws Exception if DAO initialization fails + * (will be rethrown as a BeanInitializationException) + * @see org.springframework.beans.factory.BeanInitializationException + */ + protected void initDao() throws Exception { + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/support/DataAccessUtils.java b/spring-tx/src/main/java/org/springframework/dao/support/DataAccessUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..c615a28fa3d5bdf57945003368c6e21075f79a47 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/support/DataAccessUtils.java @@ -0,0 +1,246 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.support; + +import java.util.Collection; + +import org.springframework.dao.DataAccessException; +import org.springframework.dao.EmptyResultDataAccessException; +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.dao.TypeMismatchDataAccessException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.NumberUtils; + +/** + * Miscellaneous utility methods for DAO implementations. + * Useful with any data access technology. + * + * @author Juergen Hoeller + * @since 1.0.2 + */ +public abstract class DataAccessUtils { + + /** + * Return a single result object from the given Collection. + *

Returns {@code null} if 0 result objects found; + * throws an exception if more than 1 element found. + * @param results the result Collection (can be {@code null}) + * @return the single result object, or {@code null} if none + * @throws IncorrectResultSizeDataAccessException if more than one + * element has been found in the given Collection + */ + @Nullable + public static T singleResult(@Nullable Collection results) throws IncorrectResultSizeDataAccessException { + if (CollectionUtils.isEmpty(results)) { + return null; + } + if (results.size() > 1) { + throw new IncorrectResultSizeDataAccessException(1, results.size()); + } + return results.iterator().next(); + } + + /** + * Return a single result object from the given Collection. + *

Throws an exception if 0 or more than 1 element found. + * @param results the result Collection (can be {@code null} + * but is not expected to contain {@code null} elements) + * @return the single result object + * @throws IncorrectResultSizeDataAccessException if more than one + * element has been found in the given Collection + * @throws EmptyResultDataAccessException if no element at all + * has been found in the given Collection + */ + public static T requiredSingleResult(@Nullable Collection results) throws IncorrectResultSizeDataAccessException { + if (CollectionUtils.isEmpty(results)) { + throw new EmptyResultDataAccessException(1); + } + if (results.size() > 1) { + throw new IncorrectResultSizeDataAccessException(1, results.size()); + } + return results.iterator().next(); + } + + /** + * Return a single result object from the given Collection. + *

Throws an exception if 0 or more than 1 element found. + * @param results the result Collection (can be {@code null} + * and is also expected to contain {@code null} elements) + * @return the single result object + * @throws IncorrectResultSizeDataAccessException if more than one + * element has been found in the given Collection + * @throws EmptyResultDataAccessException if no element at all + * has been found in the given Collection + * @since 5.0.2 + */ + @Nullable + public static T nullableSingleResult(@Nullable Collection results) throws IncorrectResultSizeDataAccessException { + // This is identical to the requiredSingleResult implementation but differs in the + // semantics of the incoming Collection (which we currently can't formally express) + if (CollectionUtils.isEmpty(results)) { + throw new EmptyResultDataAccessException(1); + } + if (results.size() > 1) { + throw new IncorrectResultSizeDataAccessException(1, results.size()); + } + return results.iterator().next(); + } + + /** + * Return a unique result object from the given Collection. + *

Returns {@code null} if 0 result objects found; + * throws an exception if more than 1 instance found. + * @param results the result Collection (can be {@code null}) + * @return the unique result object, or {@code null} if none + * @throws IncorrectResultSizeDataAccessException if more than one + * result object has been found in the given Collection + * @see org.springframework.util.CollectionUtils#hasUniqueObject + */ + @Nullable + public static T uniqueResult(@Nullable Collection results) throws IncorrectResultSizeDataAccessException { + if (CollectionUtils.isEmpty(results)) { + return null; + } + if (!CollectionUtils.hasUniqueObject(results)) { + throw new IncorrectResultSizeDataAccessException(1, results.size()); + } + return results.iterator().next(); + } + + /** + * Return a unique result object from the given Collection. + *

Throws an exception if 0 or more than 1 instance found. + * @param results the result Collection (can be {@code null} + * but is not expected to contain {@code null} elements) + * @return the unique result object + * @throws IncorrectResultSizeDataAccessException if more than one + * result object has been found in the given Collection + * @throws EmptyResultDataAccessException if no result object at all + * has been found in the given Collection + * @see org.springframework.util.CollectionUtils#hasUniqueObject + */ + public static T requiredUniqueResult(@Nullable Collection results) throws IncorrectResultSizeDataAccessException { + if (CollectionUtils.isEmpty(results)) { + throw new EmptyResultDataAccessException(1); + } + if (!CollectionUtils.hasUniqueObject(results)) { + throw new IncorrectResultSizeDataAccessException(1, results.size()); + } + return results.iterator().next(); + } + + /** + * Return a unique result object from the given Collection. + * Throws an exception if 0 or more than 1 result objects found, + * of if the unique result object is not convertible to the + * specified required type. + * @param results the result Collection (can be {@code null} + * but is not expected to contain {@code null} elements) + * @return the unique result object + * @throws IncorrectResultSizeDataAccessException if more than one + * result object has been found in the given Collection + * @throws EmptyResultDataAccessException if no result object + * at all has been found in the given Collection + * @throws TypeMismatchDataAccessException if the unique object does + * not match the specified required type + */ + @SuppressWarnings("unchecked") + public static T objectResult(@Nullable Collection results, @Nullable Class requiredType) + throws IncorrectResultSizeDataAccessException, TypeMismatchDataAccessException { + + Object result = requiredUniqueResult(results); + if (requiredType != null && !requiredType.isInstance(result)) { + if (String.class == requiredType) { + result = result.toString(); + } + else if (Number.class.isAssignableFrom(requiredType) && Number.class.isInstance(result)) { + try { + result = NumberUtils.convertNumberToTargetClass(((Number) result), (Class) requiredType); + } + catch (IllegalArgumentException ex) { + throw new TypeMismatchDataAccessException(ex.getMessage()); + } + } + else { + throw new TypeMismatchDataAccessException( + "Result object is of type [" + result.getClass().getName() + + "] and could not be converted to required type [" + requiredType.getName() + "]"); + } + } + return (T) result; + } + + /** + * Return a unique int result from the given Collection. + * Throws an exception if 0 or more than 1 result objects found, + * of if the unique result object is not convertible to an int. + * @param results the result Collection (can be {@code null} + * but is not expected to contain {@code null} elements) + * @return the unique int result + * @throws IncorrectResultSizeDataAccessException if more than one + * result object has been found in the given Collection + * @throws EmptyResultDataAccessException if no result object + * at all has been found in the given Collection + * @throws TypeMismatchDataAccessException if the unique object + * in the collection is not convertible to an int + */ + public static int intResult(@Nullable Collection results) + throws IncorrectResultSizeDataAccessException, TypeMismatchDataAccessException { + + return objectResult(results, Number.class).intValue(); + } + + /** + * Return a unique long result from the given Collection. + * Throws an exception if 0 or more than 1 result objects found, + * of if the unique result object is not convertible to a long. + * @param results the result Collection (can be {@code null} + * but is not expected to contain {@code null} elements) + * @return the unique long result + * @throws IncorrectResultSizeDataAccessException if more than one + * result object has been found in the given Collection + * @throws EmptyResultDataAccessException if no result object + * at all has been found in the given Collection + * @throws TypeMismatchDataAccessException if the unique object + * in the collection is not convertible to a long + */ + public static long longResult(@Nullable Collection results) + throws IncorrectResultSizeDataAccessException, TypeMismatchDataAccessException { + + return objectResult(results, Number.class).longValue(); + } + + + /** + * Return a translated exception if this is appropriate, + * otherwise return the given exception as-is. + * @param rawException an exception that we may wish to translate + * @param pet the PersistenceExceptionTranslator to use to perform the translation + * @return a translated persistence exception if translation is possible, + * or the raw exception if it is not + */ + public static RuntimeException translateIfNecessary( + RuntimeException rawException, PersistenceExceptionTranslator pet) { + + Assert.notNull(pet, "PersistenceExceptionTranslator must not be null"); + DataAccessException dae = pet.translateExceptionIfPossible(rawException); + return (dae != null ? dae : rawException); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/support/PersistenceExceptionTranslationInterceptor.java b/spring-tx/src/main/java/org/springframework/dao/support/PersistenceExceptionTranslationInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..757d1f9b59175dc7731727ba441d4e6c1d7e8c15 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/support/PersistenceExceptionTranslationInterceptor.java @@ -0,0 +1,177 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.support; + +import java.util.Map; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * AOP Alliance MethodInterceptor that provides persistence exception translation + * based on a given PersistenceExceptionTranslator. + * + *

Delegates to the given {@link PersistenceExceptionTranslator} to translate + * a RuntimeException thrown into Spring's DataAccessException hierarchy + * (if appropriate). If the RuntimeException in question is declared on the + * target method, it is always propagated as-is (with no translation applied). + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 2.0 + * @see PersistenceExceptionTranslator + */ +public class PersistenceExceptionTranslationInterceptor + implements MethodInterceptor, BeanFactoryAware, InitializingBean { + + @Nullable + private volatile PersistenceExceptionTranslator persistenceExceptionTranslator; + + private boolean alwaysTranslate = false; + + @Nullable + private ListableBeanFactory beanFactory; + + + /** + * Create a new PersistenceExceptionTranslationInterceptor. + * Needs to be configured with a PersistenceExceptionTranslator afterwards. + * @see #setPersistenceExceptionTranslator + */ + public PersistenceExceptionTranslationInterceptor() { + } + + /** + * Create a new PersistenceExceptionTranslationInterceptor + * for the given PersistenceExceptionTranslator. + * @param pet the PersistenceExceptionTranslator to use + */ + public PersistenceExceptionTranslationInterceptor(PersistenceExceptionTranslator pet) { + Assert.notNull(pet, "PersistenceExceptionTranslator must not be null"); + this.persistenceExceptionTranslator = pet; + } + + /** + * Create a new PersistenceExceptionTranslationInterceptor, autodetecting + * PersistenceExceptionTranslators in the given BeanFactory. + * @param beanFactory the ListableBeanFactory to obtaining all + * PersistenceExceptionTranslators from + */ + public PersistenceExceptionTranslationInterceptor(ListableBeanFactory beanFactory) { + Assert.notNull(beanFactory, "ListableBeanFactory must not be null"); + this.beanFactory = beanFactory; + } + + + /** + * Specify the PersistenceExceptionTranslator to use. + *

Default is to autodetect all PersistenceExceptionTranslators + * in the containing BeanFactory, using them in a chain. + * @see #detectPersistenceExceptionTranslators + */ + public void setPersistenceExceptionTranslator(PersistenceExceptionTranslator pet) { + this.persistenceExceptionTranslator = pet; + } + + /** + * Specify whether to always translate the exception ("true"), or whether throw the + * raw exception when declared, i.e. when the originating method signature's exception + * declarations allow for the raw exception to be thrown ("false"). + *

Default is "false". Switch this flag to "true" in order to always translate + * applicable exceptions, independent from the originating method signature. + *

Note that the originating method does not have to declare the specific exception. + * Any base class will do as well, even {@code throws Exception}: As long as the + * originating method does explicitly declare compatible exceptions, the raw exception + * will be rethrown. If you would like to avoid throwing raw exceptions in any case, + * switch this flag to "true". + */ + public void setAlwaysTranslate(boolean alwaysTranslate) { + this.alwaysTranslate = alwaysTranslate; + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + if (this.persistenceExceptionTranslator == null) { + // No explicit exception translator specified - perform autodetection. + if (!(beanFactory instanceof ListableBeanFactory)) { + throw new IllegalArgumentException( + "Cannot use PersistenceExceptionTranslator autodetection without ListableBeanFactory"); + } + this.beanFactory = (ListableBeanFactory) beanFactory; + } + } + + @Override + public void afterPropertiesSet() { + if (this.persistenceExceptionTranslator == null && this.beanFactory == null) { + throw new IllegalArgumentException("Property 'persistenceExceptionTranslator' is required"); + } + } + + + @Override + public Object invoke(MethodInvocation mi) throws Throwable { + try { + return mi.proceed(); + } + catch (RuntimeException ex) { + // Let it throw raw if the type of the exception is on the throws clause of the method. + if (!this.alwaysTranslate && ReflectionUtils.declaresException(mi.getMethod(), ex.getClass())) { + throw ex; + } + else { + PersistenceExceptionTranslator translator = this.persistenceExceptionTranslator; + if (translator == null) { + Assert.state(this.beanFactory != null, + "Cannot use PersistenceExceptionTranslator autodetection without ListableBeanFactory"); + translator = detectPersistenceExceptionTranslators(this.beanFactory); + this.persistenceExceptionTranslator = translator; + } + throw DataAccessUtils.translateIfNecessary(ex, translator); + } + } + } + + /** + * Detect all PersistenceExceptionTranslators in the given BeanFactory. + * @param bf the ListableBeanFactory to obtain PersistenceExceptionTranslators from + * @return a chained PersistenceExceptionTranslator, combining all + * PersistenceExceptionTranslators found in the given bean factory + * @see ChainedPersistenceExceptionTranslator + */ + protected PersistenceExceptionTranslator detectPersistenceExceptionTranslators(ListableBeanFactory bf) { + // Find all translators, being careful not to activate FactoryBeans. + Map pets = BeanFactoryUtils.beansOfTypeIncludingAncestors( + bf, PersistenceExceptionTranslator.class, false, false); + ChainedPersistenceExceptionTranslator cpet = new ChainedPersistenceExceptionTranslator(); + for (PersistenceExceptionTranslator pet : pets.values()) { + cpet.addDelegate(pet); + } + return cpet; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/support/PersistenceExceptionTranslator.java b/spring-tx/src/main/java/org/springframework/dao/support/PersistenceExceptionTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..282d6c5c1a0437ed9ad37afc83485cd234297be4 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/support/PersistenceExceptionTranslator.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.support; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; + +/** + * Interface implemented by Spring integrations with data access technologies + * that throw runtime exceptions, such as JPA and Hibernate. + * + *

This allows consistent usage of combined exception translation functionality, + * without forcing a single translator to understand every single possible type + * of exception. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 2.0 + */ +@FunctionalInterface +public interface PersistenceExceptionTranslator { + + /** + * Translate the given runtime exception thrown by a persistence framework to a + * corresponding exception from Spring's generic + * {@link org.springframework.dao.DataAccessException} hierarchy, if possible. + *

Do not translate exceptions that are not understood by this translator: + * for example, if coming from another persistence framework, or resulting + * from user code or otherwise unrelated to persistence. + *

Of particular importance is the correct translation to + * DataIntegrityViolationException, for example on constraint violation. + * Implementations may use Spring JDBC's sophisticated exception translation + * to provide further information in the event of SQLException as a root cause. + * @param ex a RuntimeException to translate + * @return the corresponding DataAccessException (or {@code null} if the + * exception could not be translated, as in this case it may result from + * user code rather than from an actual persistence problem) + * @see org.springframework.dao.DataIntegrityViolationException + * @see org.springframework.jdbc.support.SQLExceptionTranslator + */ + @Nullable + DataAccessException translateExceptionIfPossible(RuntimeException ex); + +} diff --git a/spring-tx/src/main/java/org/springframework/dao/support/package-info.java b/spring-tx/src/main/java/org/springframework/dao/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..1efd755dd17cea8f48c67fe72a856ed22f80fbfb --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/dao/support/package-info.java @@ -0,0 +1,10 @@ +/** + * Support classes for DAO implementations, + * providing miscellaneous utility methods. + */ +@NonNullApi +@NonNullFields +package org.springframework.dao.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/CannotCreateRecordException.java b/spring-tx/src/main/java/org/springframework/jca/cci/CannotCreateRecordException.java new file mode 100644 index 0000000000000000000000000000000000000000..066bb9ae9d0fd9b0c4a168ecb7bf6d0e0974d200 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/CannotCreateRecordException.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import javax.resource.ResourceException; + +import org.springframework.dao.DataAccessResourceFailureException; + +/** + * Exception thrown when the creating of a CCI Record failed + * for connector-internal reasons. + * + * @author Juergen Hoeller + * @since 1.2 + */ +@SuppressWarnings("serial") +public class CannotCreateRecordException extends DataAccessResourceFailureException { + + /** + * Constructor for CannotCreateRecordException. + * @param msg message + * @param ex the root ResourceException cause + */ + public CannotCreateRecordException(String msg, ResourceException ex) { + super(msg, ex); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/CannotGetCciConnectionException.java b/spring-tx/src/main/java/org/springframework/jca/cci/CannotGetCciConnectionException.java new file mode 100644 index 0000000000000000000000000000000000000000..1d9c23de48d5a6f26d0d8af816d63ff08f69c018 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/CannotGetCciConnectionException.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import javax.resource.ResourceException; + +import org.springframework.dao.DataAccessResourceFailureException; + +/** + * Fatal exception thrown when we can't connect to an EIS using CCI. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + */ +@SuppressWarnings("serial") +public class CannotGetCciConnectionException extends DataAccessResourceFailureException { + + /** + * Constructor for CannotGetCciConnectionException. + * @param msg message + * @param ex the root ResourceException cause + */ + public CannotGetCciConnectionException(String msg, ResourceException ex) { + super(msg, ex); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/CciOperationNotSupportedException.java b/spring-tx/src/main/java/org/springframework/jca/cci/CciOperationNotSupportedException.java new file mode 100644 index 0000000000000000000000000000000000000000..41d1f4252fd43d085f422b737502713ac8190e81 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/CciOperationNotSupportedException.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import javax.resource.ResourceException; + +import org.springframework.dao.InvalidDataAccessResourceUsageException; + +/** + * Exception thrown when the connector doesn't support a specific CCI operation. + * + * @author Juergen Hoeller + * @since 1.2 + */ +@SuppressWarnings("serial") +public class CciOperationNotSupportedException extends InvalidDataAccessResourceUsageException { + + /** + * Constructor for CciOperationNotSupportedException. + * @param msg message + * @param ex the root ResourceException cause + */ + public CciOperationNotSupportedException(String msg, ResourceException ex) { + super(msg, ex); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/InvalidResultSetAccessException.java b/spring-tx/src/main/java/org/springframework/jca/cci/InvalidResultSetAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..090fc528925f83d5e979361dc7451f4affb05ce7 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/InvalidResultSetAccessException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import java.sql.SQLException; + +import org.springframework.dao.InvalidDataAccessResourceUsageException; + +/** + * Exception thrown when a ResultSet has been accessed in an invalid fashion. + * Such exceptions always have a {@code java.sql.SQLException} root cause. + * + *

This typically happens when an invalid ResultSet column index or name + * has been specified. + * + * @author Juergen Hoeller + * @since 1.2 + * @see javax.resource.cci.ResultSet + */ +@SuppressWarnings("serial") +public class InvalidResultSetAccessException extends InvalidDataAccessResourceUsageException { + + /** + * Constructor for InvalidResultSetAccessException. + * @param msg message + * @param ex the root cause + */ + public InvalidResultSetAccessException(String msg, SQLException ex) { + super(ex.getMessage(), ex); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/RecordTypeNotSupportedException.java b/spring-tx/src/main/java/org/springframework/jca/cci/RecordTypeNotSupportedException.java new file mode 100644 index 0000000000000000000000000000000000000000..bda7bf24a4a691c854107ef659dbbd378aa26afd --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/RecordTypeNotSupportedException.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import javax.resource.ResourceException; + +import org.springframework.dao.InvalidDataAccessResourceUsageException; + +/** + * Exception thrown when the creating of a CCI Record failed because + * the connector doesn't support the desired CCI Record type. + * + * @author Juergen Hoeller + * @since 1.2 + */ +@SuppressWarnings("serial") +public class RecordTypeNotSupportedException extends InvalidDataAccessResourceUsageException { + + /** + * Constructor for RecordTypeNotSupportedException. + * @param msg message + * @param ex the root ResourceException cause + */ + public RecordTypeNotSupportedException(String msg, ResourceException ex) { + super(msg, ex); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/CciLocalTransactionManager.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/CciLocalTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..5d5ea978154434e3dbf201f46c0e51261f8e6338 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/CciLocalTransactionManager.java @@ -0,0 +1,294 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import javax.resource.NotSupportedException; +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.spi.LocalTransactionException; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.transaction.CannotCreateTransactionException; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.transaction.support.AbstractPlatformTransactionManager; +import org.springframework.transaction.support.DefaultTransactionStatus; +import org.springframework.transaction.support.ResourceTransactionManager; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.util.Assert; + +/** + * {@link org.springframework.transaction.PlatformTransactionManager} implementation + * that manages local transactions for a single CCI ConnectionFactory. + * Binds a CCI Connection from the specified ConnectionFactory to the thread, + * potentially allowing for one thread-bound Connection per ConnectionFactory. + * + *

Application code is required to retrieve the CCI Connection via + * {@link ConnectionFactoryUtils#getConnection(ConnectionFactory)} instead of a standard + * Java EE-style {@link ConnectionFactory#getConnection()} call. Spring classes such as + * {@link org.springframework.jca.cci.core.CciTemplate} use this strategy implicitly. + * If not used in combination with this transaction manager, the + * {@link ConnectionFactoryUtils} lookup strategy behaves exactly like the native + * DataSource lookup; it can thus be used in a portable fashion. + * + *

Alternatively, you can allow application code to work with the standard + * Java EE lookup pattern {@link ConnectionFactory#getConnection()}, for example + * for legacy code that is not aware of Spring at all. In that case, define a + * {@link TransactionAwareConnectionFactoryProxy} for your target ConnectionFactory, + * which will automatically participate in Spring-managed transactions. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see ConnectionFactoryUtils#getConnection(javax.resource.cci.ConnectionFactory) + * @see ConnectionFactoryUtils#releaseConnection + * @see TransactionAwareConnectionFactoryProxy + * @see org.springframework.jca.cci.core.CciTemplate + */ +@SuppressWarnings("serial") +public class CciLocalTransactionManager extends AbstractPlatformTransactionManager + implements ResourceTransactionManager, InitializingBean { + + @Nullable + private ConnectionFactory connectionFactory; + + + /** + * Create a new CciLocalTransactionManager instance. + * A ConnectionFactory has to be set to be able to use it. + * @see #setConnectionFactory + */ + public CciLocalTransactionManager() { + } + + /** + * Create a new CciLocalTransactionManager instance. + * @param connectionFactory the CCI ConnectionFactory to manage local transactions for + */ + public CciLocalTransactionManager(ConnectionFactory connectionFactory) { + setConnectionFactory(connectionFactory); + afterPropertiesSet(); + } + + + /** + * Set the CCI ConnectionFactory that this instance should manage local + * transactions for. + */ + public void setConnectionFactory(@Nullable ConnectionFactory cf) { + if (cf instanceof TransactionAwareConnectionFactoryProxy) { + // If we got a TransactionAwareConnectionFactoryProxy, we need to perform transactions + // for its underlying target ConnectionFactory, else JMS access code won't see + // properly exposed transactions (i.e. transactions for the target ConnectionFactory). + this.connectionFactory = ((TransactionAwareConnectionFactoryProxy) cf).getTargetConnectionFactory(); + } + else { + this.connectionFactory = cf; + } + } + + /** + * Return the CCI ConnectionFactory that this instance manages local + * transactions for. + */ + @Nullable + public ConnectionFactory getConnectionFactory() { + return this.connectionFactory; + } + + private ConnectionFactory obtainConnectionFactory() { + ConnectionFactory connectionFactory = getConnectionFactory(); + Assert.state(connectionFactory != null, "No ConnectionFactory set"); + return connectionFactory; + } + + @Override + public void afterPropertiesSet() { + if (getConnectionFactory() == null) { + throw new IllegalArgumentException("Property 'connectionFactory' is required"); + } + } + + + @Override + public Object getResourceFactory() { + return obtainConnectionFactory(); + } + + @Override + protected Object doGetTransaction() { + CciLocalTransactionObject txObject = new CciLocalTransactionObject(); + ConnectionHolder conHolder = + (ConnectionHolder) TransactionSynchronizationManager.getResource(obtainConnectionFactory()); + txObject.setConnectionHolder(conHolder); + return txObject; + } + + @Override + protected boolean isExistingTransaction(Object transaction) { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) transaction; + // Consider a pre-bound connection as transaction. + return txObject.hasConnectionHolder(); + } + + @Override + protected void doBegin(Object transaction, TransactionDefinition definition) { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) transaction; + ConnectionFactory connectionFactory = obtainConnectionFactory(); + Connection con = null; + + try { + con = connectionFactory.getConnection(); + if (logger.isDebugEnabled()) { + logger.debug("Acquired Connection [" + con + "] for local CCI transaction"); + } + + ConnectionHolder connectionHolder = new ConnectionHolder(con); + connectionHolder.setSynchronizedWithTransaction(true); + + con.getLocalTransaction().begin(); + int timeout = determineTimeout(definition); + if (timeout != TransactionDefinition.TIMEOUT_DEFAULT) { + connectionHolder.setTimeoutInSeconds(timeout); + } + + txObject.setConnectionHolder(connectionHolder); + TransactionSynchronizationManager.bindResource(connectionFactory, connectionHolder); + } + catch (NotSupportedException ex) { + ConnectionFactoryUtils.releaseConnection(con, connectionFactory); + throw new CannotCreateTransactionException("CCI Connection does not support local transactions", ex); + } + catch (LocalTransactionException ex) { + ConnectionFactoryUtils.releaseConnection(con, connectionFactory); + throw new CannotCreateTransactionException("Could not begin local CCI transaction", ex); + } + catch (Throwable ex) { + ConnectionFactoryUtils.releaseConnection(con, connectionFactory); + throw new TransactionSystemException("Unexpected failure on begin of CCI local transaction", ex); + } + } + + @Override + protected Object doSuspend(Object transaction) { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) transaction; + txObject.setConnectionHolder(null); + return TransactionSynchronizationManager.unbindResource(obtainConnectionFactory()); + } + + @Override + protected void doResume(@Nullable Object transaction, Object suspendedResources) { + ConnectionHolder conHolder = (ConnectionHolder) suspendedResources; + TransactionSynchronizationManager.bindResource(obtainConnectionFactory(), conHolder); + } + + protected boolean isRollbackOnly(Object transaction) throws TransactionException { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) transaction; + return txObject.getConnectionHolder().isRollbackOnly(); + } + + @Override + protected void doCommit(DefaultTransactionStatus status) { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) status.getTransaction(); + Connection con = txObject.getConnectionHolder().getConnection(); + if (status.isDebug()) { + logger.debug("Committing CCI local transaction on Connection [" + con + "]"); + } + try { + con.getLocalTransaction().commit(); + } + catch (LocalTransactionException ex) { + throw new TransactionSystemException("Could not commit CCI local transaction", ex); + } + catch (ResourceException ex) { + throw new TransactionSystemException("Unexpected failure on commit of CCI local transaction", ex); + } + } + + @Override + protected void doRollback(DefaultTransactionStatus status) { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) status.getTransaction(); + Connection con = txObject.getConnectionHolder().getConnection(); + if (status.isDebug()) { + logger.debug("Rolling back CCI local transaction on Connection [" + con + "]"); + } + try { + con.getLocalTransaction().rollback(); + } + catch (LocalTransactionException ex) { + throw new TransactionSystemException("Could not roll back CCI local transaction", ex); + } + catch (ResourceException ex) { + throw new TransactionSystemException("Unexpected failure on rollback of CCI local transaction", ex); + } + } + + @Override + protected void doSetRollbackOnly(DefaultTransactionStatus status) { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) status.getTransaction(); + if (status.isDebug()) { + logger.debug("Setting CCI local transaction [" + txObject.getConnectionHolder().getConnection() + + "] rollback-only"); + } + txObject.getConnectionHolder().setRollbackOnly(); + } + + @Override + protected void doCleanupAfterCompletion(Object transaction) { + CciLocalTransactionObject txObject = (CciLocalTransactionObject) transaction; + ConnectionFactory connectionFactory = obtainConnectionFactory(); + + // Remove the connection holder from the thread. + TransactionSynchronizationManager.unbindResource(connectionFactory); + txObject.getConnectionHolder().clear(); + + Connection con = txObject.getConnectionHolder().getConnection(); + if (logger.isDebugEnabled()) { + logger.debug("Releasing CCI Connection [" + con + "] after transaction"); + } + ConnectionFactoryUtils.releaseConnection(con, connectionFactory); + } + + + /** + * CCI local transaction object, representing a ConnectionHolder. + * Used as transaction object by CciLocalTransactionManager. + * @see ConnectionHolder + */ + private static class CciLocalTransactionObject { + + @Nullable + private ConnectionHolder connectionHolder; + + public void setConnectionHolder(@Nullable ConnectionHolder connectionHolder) { + this.connectionHolder = connectionHolder; + } + + public ConnectionHolder getConnectionHolder() { + Assert.state(this.connectionHolder != null, "No ConnectionHolder available"); + return this.connectionHolder; + } + + public boolean hasConnectionHolder() { + return (this.connectionHolder != null); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionFactoryUtils.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionFactoryUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..5a30ccfcb3e3bac1a41dfa1d1697ddab45a985b6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionFactoryUtils.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.ConnectionSpec; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.jca.cci.CannotGetCciConnectionException; +import org.springframework.lang.Nullable; +import org.springframework.transaction.support.ResourceHolderSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.util.Assert; + +/** + * Helper class that provides static methods for obtaining CCI Connections + * from a {@link javax.resource.cci.ConnectionFactory}. Includes special + * support for Spring-managed transactional Connections, e.g. managed + * by {@link CciLocalTransactionManager} or + * {@link org.springframework.transaction.jta.JtaTransactionManager}. + * + *

Used internally by {@link org.springframework.jca.cci.core.CciTemplate}, + * Spring's CCI operation objects and the {@link CciLocalTransactionManager}. + * Can also be used directly in application code. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see #getConnection + * @see #releaseConnection + * @see CciLocalTransactionManager + * @see org.springframework.transaction.jta.JtaTransactionManager + * @see org.springframework.transaction.support.TransactionSynchronizationManager + */ +public abstract class ConnectionFactoryUtils { + + private static final Log logger = LogFactory.getLog(ConnectionFactoryUtils.class); + + + /** + * Obtain a Connection from the given ConnectionFactory. Translates ResourceExceptions + * into the Spring hierarchy of unchecked generic data access exceptions, simplifying + * calling code and making any exception that is thrown more meaningful. + *

Is aware of a corresponding Connection bound to the current thread, for example + * when using {@link CciLocalTransactionManager}. Will bind a Connection to the thread + * if transaction synchronization is active (e.g. if in a JTA transaction). + * @param cf the ConnectionFactory to obtain Connection from + * @return a CCI Connection from the given ConnectionFactory + * @throws org.springframework.jca.cci.CannotGetCciConnectionException + * if the attempt to get a Connection failed + * @see #releaseConnection + */ + public static Connection getConnection(ConnectionFactory cf) throws CannotGetCciConnectionException { + return getConnection(cf, null); + } + + /** + * Obtain a Connection from the given ConnectionFactory. Translates ResourceExceptions + * into the Spring hierarchy of unchecked generic data access exceptions, simplifying + * calling code and making any exception that is thrown more meaningful. + *

Is aware of a corresponding Connection bound to the current thread, for example + * when using {@link CciLocalTransactionManager}. Will bind a Connection to the thread + * if transaction synchronization is active (e.g. if in a JTA transaction). + * @param cf the ConnectionFactory to obtain Connection from + * @param spec the ConnectionSpec for the desired Connection (may be {@code null}). + * Note: If this is specified, a new Connection will be obtained for every call, + * without participating in a shared transactional Connection. + * @return a CCI Connection from the given ConnectionFactory + * @throws org.springframework.jca.cci.CannotGetCciConnectionException + * if the attempt to get a Connection failed + * @see #releaseConnection + */ + public static Connection getConnection(ConnectionFactory cf, @Nullable ConnectionSpec spec) + throws CannotGetCciConnectionException { + try { + if (spec != null) { + Assert.notNull(cf, "No ConnectionFactory specified"); + return cf.getConnection(spec); + } + else { + return doGetConnection(cf); + } + } + catch (ResourceException ex) { + throw new CannotGetCciConnectionException("Could not get CCI Connection", ex); + } + } + + /** + * Actually obtain a CCI Connection from the given ConnectionFactory. + * Same as {@link #getConnection}, but throwing the original ResourceException. + *

Is aware of a corresponding Connection bound to the current thread, for example + * when using {@link CciLocalTransactionManager}. Will bind a Connection to the thread + * if transaction synchronization is active (e.g. if in a JTA transaction). + *

Directly accessed by {@link TransactionAwareConnectionFactoryProxy}. + * @param cf the ConnectionFactory to obtain Connection from + * @return a CCI Connection from the given ConnectionFactory + * @throws ResourceException if thrown by CCI API methods + * @see #doReleaseConnection + */ + public static Connection doGetConnection(ConnectionFactory cf) throws ResourceException { + Assert.notNull(cf, "No ConnectionFactory specified"); + + ConnectionHolder conHolder = (ConnectionHolder) TransactionSynchronizationManager.getResource(cf); + if (conHolder != null) { + return conHolder.getConnection(); + } + + logger.debug("Opening CCI Connection"); + Connection con = cf.getConnection(); + + if (TransactionSynchronizationManager.isSynchronizationActive()) { + conHolder = new ConnectionHolder(con); + conHolder.setSynchronizedWithTransaction(true); + TransactionSynchronizationManager.registerSynchronization(new ConnectionSynchronization(conHolder, cf)); + TransactionSynchronizationManager.bindResource(cf, conHolder); + } + + return con; + } + + /** + * Determine whether the given JCA CCI Connection is transactional, that is, + * bound to the current thread by Spring's transaction facilities. + * @param con the Connection to check + * @param cf the ConnectionFactory that the Connection was obtained from + * (may be {@code null}) + * @return whether the Connection is transactional + */ + public static boolean isConnectionTransactional(Connection con, @Nullable ConnectionFactory cf) { + if (cf == null) { + return false; + } + ConnectionHolder conHolder = (ConnectionHolder) TransactionSynchronizationManager.getResource(cf); + return (conHolder != null && conHolder.getConnection() == con); + } + + /** + * Close the given Connection, obtained from the given ConnectionFactory, + * if it is not managed externally (that is, not bound to the thread). + * @param con the Connection to close if necessary + * (if this is {@code null}, the call will be ignored) + * @param cf the ConnectionFactory that the Connection was obtained from + * (can be {@code null}) + * @see #getConnection + */ + public static void releaseConnection(@Nullable Connection con, @Nullable ConnectionFactory cf) { + try { + doReleaseConnection(con, cf); + } + catch (ResourceException ex) { + logger.debug("Could not close CCI Connection", ex); + } + catch (Throwable ex) { + // We don't trust the CCI driver: It might throw RuntimeException or Error. + logger.debug("Unexpected exception on closing CCI Connection", ex); + } + } + + /** + * Actually close the given Connection, obtained from the given ConnectionFactory. + * Same as {@link #releaseConnection}, but throwing the original ResourceException. + *

Directly accessed by {@link TransactionAwareConnectionFactoryProxy}. + * @param con the Connection to close if necessary + * (if this is {@code null}, the call will be ignored) + * @param cf the ConnectionFactory that the Connection was obtained from + * (can be {@code null}) + * @throws ResourceException if thrown by JCA CCI methods + * @see #doGetConnection + */ + public static void doReleaseConnection(@Nullable Connection con, @Nullable ConnectionFactory cf) + throws ResourceException { + + if (con == null || isConnectionTransactional(con, cf)) { + return; + } + con.close(); + } + + + /** + * Callback for resource cleanup at the end of a non-native CCI transaction + * (e.g. when participating in a JTA transaction). + */ + private static class ConnectionSynchronization + extends ResourceHolderSynchronization { + + public ConnectionSynchronization(ConnectionHolder connectionHolder, ConnectionFactory connectionFactory) { + super(connectionHolder, connectionFactory); + } + + @Override + protected void releaseResource(ConnectionHolder resourceHolder, ConnectionFactory resourceKey) { + releaseConnection(resourceHolder.getConnection(), resourceKey); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionHolder.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionHolder.java new file mode 100644 index 0000000000000000000000000000000000000000..d705d4831d91cb2962650114c7a10ec261ca0884 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionHolder.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import javax.resource.cci.Connection; + +import org.springframework.transaction.support.ResourceHolderSupport; + +/** + * Resource holder wrapping a CCI {@link Connection}. + * {@link CciLocalTransactionManager} binds instances of this class to the thread, + * for a given {@link javax.resource.cci.ConnectionFactory}. + * + *

Note: This is an SPI class, not intended to be used by applications. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see CciLocalTransactionManager + * @see ConnectionFactoryUtils + */ +public class ConnectionHolder extends ResourceHolderSupport { + + private final Connection connection; + + + public ConnectionHolder(Connection connection) { + this.connection = connection; + } + + + public Connection getConnection() { + return this.connection; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionSpecConnectionFactoryAdapter.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionSpecConnectionFactoryAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..a9cca7cb31518a7ac0860882900e4ba8f4ca4594 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/ConnectionSpecConnectionFactoryAdapter.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.ConnectionSpec; + +import org.springframework.core.NamedThreadLocal; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * An adapter for a target CCI {@link javax.resource.cci.ConnectionFactory}, + * applying the given ConnectionSpec to every standard {@code getConnection()} + * call, that is, implicitly invoking {@code getConnection(ConnectionSpec)} + * on the target. All other methods simply delegate to the corresponding methods + * of the target ConnectionFactory. + * + *

Can be used to proxy a target JNDI ConnectionFactory that does not have a + * ConnectionSpec configured. Client code can work with the ConnectionFactory + * without passing in a ConnectionSpec on every {@code getConnection()} call. + * + *

In the following example, client code can simply transparently work with + * the preconfigured "myConnectionFactory", implicitly accessing + * "myTargetConnectionFactory" with the specified user credentials. + * + *

+ * <bean id="myTargetConnectionFactory" class="org.springframework.jndi.JndiObjectFactoryBean">
+ *   <property name="jndiName" value="java:comp/env/cci/mycf"/>
+ * </bean>
+ *
+ * <bean id="myConnectionFactory" class="org.springframework.jca.cci.connection.ConnectionSpecConnectionFactoryAdapter">
+ *   <property name="targetConnectionFactory" ref="myTargetConnectionFactory"/>
+ *   <property name="connectionSpec">
+ *     <bean class="your.resource.adapter.ConnectionSpecImpl">
+ *       <property name="username" value="myusername"/>
+ *       <property name="password" value="mypassword"/>
+ *     </bean>
+ *   </property>
+ * </bean>
+ * + *

If the "connectionSpec" is empty, this proxy will simply delegate to the + * standard {@code getConnection()} method of the target ConnectionFactory. + * This can be used to keep a UserCredentialsConnectionFactoryAdapter bean definition + * just for the option of implicitly passing in a ConnectionSpec if the + * particular target ConnectionFactory requires it. + * + * @author Juergen Hoeller + * @since 1.2 + * @see #getConnection + */ +@SuppressWarnings("serial") +public class ConnectionSpecConnectionFactoryAdapter extends DelegatingConnectionFactory { + + @Nullable + private ConnectionSpec connectionSpec; + + private final ThreadLocal threadBoundSpec = + new NamedThreadLocal<>("Current CCI ConnectionSpec"); + + + /** + * Set the ConnectionSpec that this adapter should use for retrieving Connections. + * Default is none. + */ + public void setConnectionSpec(ConnectionSpec connectionSpec) { + this.connectionSpec = connectionSpec; + } + + /** + * Set a ConnectionSpec for this proxy and the current thread. + * The given ConnectionSpec will be applied to all subsequent + * {@code getConnection()} calls on this ConnectionFactory proxy. + *

This will override any statically specified "connectionSpec" property. + * @param spec the ConnectionSpec to apply + * @see #removeConnectionSpecFromCurrentThread + */ + public void setConnectionSpecForCurrentThread(ConnectionSpec spec) { + this.threadBoundSpec.set(spec); + } + + /** + * Remove any ConnectionSpec for this proxy from the current thread. + * A statically specified ConnectionSpec applies again afterwards. + * @see #setConnectionSpecForCurrentThread + */ + public void removeConnectionSpecFromCurrentThread() { + this.threadBoundSpec.remove(); + } + + + /** + * Determine whether there is currently a thread-bound ConnectionSpec, + * using it if available, falling back to the statically specified + * "connectionSpec" property else. + * @see #doGetConnection + */ + @Override + public final Connection getConnection() throws ResourceException { + ConnectionSpec threadSpec = this.threadBoundSpec.get(); + if (threadSpec != null) { + return doGetConnection(threadSpec); + } + else { + return doGetConnection(this.connectionSpec); + } + } + + /** + * This implementation delegates to the {@code getConnection(ConnectionSpec)} + * method of the target ConnectionFactory, passing in the specified user credentials. + * If the specified username is empty, it will simply delegate to the standard + * {@code getConnection()} method of the target ConnectionFactory. + * @param spec the ConnectionSpec to apply + * @return the Connection + * @see javax.resource.cci.ConnectionFactory#getConnection(javax.resource.cci.ConnectionSpec) + * @see javax.resource.cci.ConnectionFactory#getConnection() + */ + protected Connection doGetConnection(@Nullable ConnectionSpec spec) throws ResourceException { + ConnectionFactory connectionFactory = getTargetConnectionFactory(); + Assert.state(connectionFactory != null, "No 'targetConnectionFactory' set"); + return (spec != null ? connectionFactory.getConnection(spec) : connectionFactory.getConnection()); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/DelegatingConnectionFactory.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/DelegatingConnectionFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..01edac23855ebab2856fd37e45e1cc0215d3d81b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/DelegatingConnectionFactory.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import javax.naming.NamingException; +import javax.naming.Reference; +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.ConnectionSpec; +import javax.resource.cci.RecordFactory; +import javax.resource.cci.ResourceAdapterMetaData; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * CCI {@link ConnectionFactory} implementation that delegates all calls + * to a given target {@link ConnectionFactory}. + * + *

This class is meant to be subclassed, with subclasses overriding only + * those methods (such as {@link #getConnection()}) that should not simply + * delegate to the target {@link ConnectionFactory}. + * + * @author Juergen Hoeller + * @since 1.2 + * @see #getConnection + */ +@SuppressWarnings("serial") +public class DelegatingConnectionFactory implements ConnectionFactory, InitializingBean { + + @Nullable + private ConnectionFactory targetConnectionFactory; + + + /** + * Set the target ConnectionFactory that this ConnectionFactory should delegate to. + */ + public void setTargetConnectionFactory(@Nullable ConnectionFactory targetConnectionFactory) { + this.targetConnectionFactory = targetConnectionFactory; + } + + /** + * Return the target ConnectionFactory that this ConnectionFactory should delegate to. + */ + @Nullable + public ConnectionFactory getTargetConnectionFactory() { + return this.targetConnectionFactory; + } + + /** + * Obtain the target {@code ConnectionFactory} for actual use (never {@code null}). + * @since 5.0 + */ + protected ConnectionFactory obtainTargetConnectionFactory() { + ConnectionFactory connectionFactory = getTargetConnectionFactory(); + Assert.state(connectionFactory != null, "No 'targetConnectionFactory' set"); + return connectionFactory; + } + + + @Override + public void afterPropertiesSet() { + if (getTargetConnectionFactory() == null) { + throw new IllegalArgumentException("Property 'targetConnectionFactory' is required"); + } + } + + + @Override + public Connection getConnection() throws ResourceException { + return obtainTargetConnectionFactory().getConnection(); + } + + @Override + public Connection getConnection(ConnectionSpec connectionSpec) throws ResourceException { + return obtainTargetConnectionFactory().getConnection(connectionSpec); + } + + @Override + public RecordFactory getRecordFactory() throws ResourceException { + return obtainTargetConnectionFactory().getRecordFactory(); + } + + @Override + public ResourceAdapterMetaData getMetaData() throws ResourceException { + return obtainTargetConnectionFactory().getMetaData(); + } + + @Override + public Reference getReference() throws NamingException { + return obtainTargetConnectionFactory().getReference(); + } + + @Override + public void setReference(Reference reference) { + obtainTargetConnectionFactory().setReference(reference); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/NotSupportedRecordFactory.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/NotSupportedRecordFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..9a20266fdec3ebaeb19dd4fb71f621ad058b8042 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/NotSupportedRecordFactory.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import javax.resource.NotSupportedException; +import javax.resource.ResourceException; +import javax.resource.cci.IndexedRecord; +import javax.resource.cci.MappedRecord; +import javax.resource.cci.RecordFactory; + +/** + * Implementation of the CCI RecordFactory interface that always throws + * NotSupportedException. + * + *

Useful as a placeholder for a RecordFactory argument (for example as + * defined by the RecordCreator callback), in particular when the connector's + * {@code ConnectionFactory.getRecordFactory()} implementation happens to + * throw NotSupportedException early rather than throwing the exception from + * RecordFactory's methods. + * + * @author Juergen Hoeller + * @since 1.2.4 + * @see org.springframework.jca.cci.core.RecordCreator#createRecord(javax.resource.cci.RecordFactory) + * @see org.springframework.jca.cci.core.CciTemplate#getRecordFactory(javax.resource.cci.ConnectionFactory) + * @see javax.resource.cci.ConnectionFactory#getRecordFactory() + * @see javax.resource.NotSupportedException + */ +public class NotSupportedRecordFactory implements RecordFactory { + + @Override + public MappedRecord createMappedRecord(String name) throws ResourceException { + throw new NotSupportedException("The RecordFactory facility is not supported by the connector"); + } + + @Override + public IndexedRecord createIndexedRecord(String name) throws ResourceException { + throw new NotSupportedException("The RecordFactory facility is not supported by the connector"); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/SingleConnectionFactory.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/SingleConnectionFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..937f0cb509d1354842e6295383aedf290517d7da --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/SingleConnectionFactory.java @@ -0,0 +1,261 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; + +import javax.resource.NotSupportedException; +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.ConnectionSpec; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * A CCI ConnectionFactory adapter that returns the same Connection on all + * {@code getConnection} calls, and ignores calls to + * {@code Connection.close()}. + * + *

Useful for testing and standalone environments, to keep using the same + * Connection for multiple CciTemplate calls, without having a pooling + * ConnectionFactory, also spanning any number of transactions. + * + *

You can either pass in a CCI Connection directly, or let this + * factory lazily create a Connection via a given target ConnectionFactory. + * + * @author Juergen Hoeller + * @since 1.2 + * @see #getConnection() + * @see javax.resource.cci.Connection#close() + * @see org.springframework.jca.cci.core.CciTemplate + */ +@SuppressWarnings("serial") +public class SingleConnectionFactory extends DelegatingConnectionFactory implements DisposableBean { + + protected final Log logger = LogFactory.getLog(getClass()); + + /** Wrapped Connection. */ + @Nullable + private Connection target; + + /** Proxy Connection. */ + @Nullable + private Connection connection; + + /** Synchronization monitor for the shared Connection. */ + private final Object connectionMonitor = new Object(); + + + /** + * Create a new SingleConnectionFactory for bean-style usage. + * @see #setTargetConnectionFactory + */ + public SingleConnectionFactory() { + } + + /** + * Create a new SingleConnectionFactory that always returns the + * given Connection. + * @param target the single Connection + */ + public SingleConnectionFactory(Connection target) { + Assert.notNull(target, "Target Connection must not be null"); + this.target = target; + this.connection = getCloseSuppressingConnectionProxy(target); + } + + /** + * Create a new SingleConnectionFactory that always returns a single + * Connection which it will lazily create via the given target + * ConnectionFactory. + * @param targetConnectionFactory the target ConnectionFactory + */ + public SingleConnectionFactory(ConnectionFactory targetConnectionFactory) { + Assert.notNull(targetConnectionFactory, "Target ConnectionFactory must not be null"); + setTargetConnectionFactory(targetConnectionFactory); + } + + + /** + * Make sure a Connection or ConnectionFactory has been set. + */ + @Override + public void afterPropertiesSet() { + if (this.connection == null && getTargetConnectionFactory() == null) { + throw new IllegalArgumentException("Connection or 'targetConnectionFactory' is required"); + } + } + + + @Override + public Connection getConnection() throws ResourceException { + synchronized (this.connectionMonitor) { + if (this.connection == null) { + initConnection(); + } + return this.connection; + } + } + + @Override + public Connection getConnection(ConnectionSpec connectionSpec) throws ResourceException { + throw new NotSupportedException( + "SingleConnectionFactory does not support custom ConnectionSpec"); + } + + /** + * Close the underlying Connection. + * The provider of this ConnectionFactory needs to care for proper shutdown. + *

As this bean implements DisposableBean, a bean factory will + * automatically invoke this on destruction of its cached singletons. + */ + @Override + public void destroy() { + resetConnection(); + } + + + /** + * Initialize the single underlying Connection. + *

Closes and reinitializes the Connection if an underlying + * Connection is present already. + * @throws javax.resource.ResourceException if thrown by CCI API methods + */ + public void initConnection() throws ResourceException { + if (getTargetConnectionFactory() == null) { + throw new IllegalStateException( + "'targetConnectionFactory' is required for lazily initializing a Connection"); + } + synchronized (this.connectionMonitor) { + if (this.target != null) { + closeConnection(this.target); + } + this.target = doCreateConnection(); + prepareConnection(this.target); + if (logger.isDebugEnabled()) { + logger.debug("Established shared CCI Connection: " + this.target); + } + this.connection = getCloseSuppressingConnectionProxy(this.target); + } + } + + /** + * Reset the underlying shared Connection, to be reinitialized on next access. + */ + public void resetConnection() { + synchronized (this.connectionMonitor) { + if (this.target != null) { + closeConnection(this.target); + } + this.target = null; + this.connection = null; + } + } + + /** + * Create a CCI Connection via this template's ConnectionFactory. + * @return the new CCI Connection + * @throws javax.resource.ResourceException if thrown by CCI API methods + */ + protected Connection doCreateConnection() throws ResourceException { + ConnectionFactory connectionFactory = getTargetConnectionFactory(); + Assert.state(connectionFactory != null, "No 'targetConnectionFactory' set"); + return connectionFactory.getConnection(); + } + + /** + * Prepare the given Connection before it is exposed. + *

The default implementation is empty. Can be overridden in subclasses. + * @param con the Connection to prepare + */ + protected void prepareConnection(Connection con) throws ResourceException { + } + + /** + * Close the given Connection. + * @param con the Connection to close + */ + protected void closeConnection(Connection con) { + try { + con.close(); + } + catch (Throwable ex) { + logger.warn("Could not close shared CCI Connection", ex); + } + } + + /** + * Wrap the given Connection with a proxy that delegates every method call to it + * but suppresses close calls. This is useful for allowing application code to + * handle a special framework Connection just like an ordinary Connection from a + * CCI ConnectionFactory. + * @param target the original Connection to wrap + * @return the wrapped Connection + */ + protected Connection getCloseSuppressingConnectionProxy(Connection target) { + return (Connection) Proxy.newProxyInstance( + Connection.class.getClassLoader(), + new Class[] {Connection.class}, + new CloseSuppressingInvocationHandler(target)); + } + + + /** + * Invocation handler that suppresses close calls on CCI Connections. + */ + private static final class CloseSuppressingInvocationHandler implements InvocationHandler { + + private final Connection target; + + private CloseSuppressingInvocationHandler(Connection target) { + this.target = target; + } + + @Override + @Nullable + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if (method.getName().equals("equals")) { + // Only consider equal when proxies are identical. + return (proxy == args[0]); + } + else if (method.getName().equals("hashCode")) { + // Use hashCode of Connection proxy. + return System.identityHashCode(proxy); + } + else if (method.getName().equals("close")) { + // Handle close method: don't pass the call on. + return null; + } + try { + return method.invoke(this.target, args); + } + catch (InvocationTargetException ex) { + throw ex.getTargetException(); + } + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/TransactionAwareConnectionFactoryProxy.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/TransactionAwareConnectionFactoryProxy.java new file mode 100644 index 0000000000000000000000000000000000000000..9e9c09f8306813d7dec8c3226beb3f2e0a13b62c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/TransactionAwareConnectionFactoryProxy.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.connection; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; + +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; + +import org.springframework.lang.Nullable; + +/** + * Proxy for a target CCI {@link javax.resource.cci.ConnectionFactory}, adding + * awareness of Spring-managed transactions. Similar to a transactional JNDI + * ConnectionFactory as provided by a Java EE server. + * + *

Data access code that should remain unaware of Spring's data access support + * can work with this proxy to seamlessly participate in Spring-managed transactions. + * Note that the transaction manager, for example the {@link CciLocalTransactionManager}, + * still needs to work with underlying ConnectionFactory, not with this proxy. + * + *

Make sure that TransactionAwareConnectionFactoryProxy is the outermost + * ConnectionFactory of a chain of ConnectionFactory proxies/adapters. + * TransactionAwareConnectionFactoryProxy can delegate either directly to the + * target connection pool or to some intermediate proxy/adapter like + * {@link ConnectionSpecConnectionFactoryAdapter}. + * + *

Delegates to {@link ConnectionFactoryUtils} for automatically participating in + * thread-bound transactions, for example managed by {@link CciLocalTransactionManager}. + * {@code getConnection} calls and {@code close} calls on returned Connections + * will behave properly within a transaction, i.e. always operate on the transactional + * Connection. If not within a transaction, normal ConnectionFactory behavior applies. + * + *

This proxy allows data access code to work with the plain JCA CCI API and still + * participate in Spring-managed transactions, similar to CCI code in a Java EE/JTA + * environment. However, if possible, use Spring's ConnectionFactoryUtils, CciTemplate or + * CCI operation objects to get transaction participation even without a proxy for + * the target ConnectionFactory, avoiding the need to define such a proxy in the first place. + * + *

NOTE: This ConnectionFactory proxy needs to return wrapped Connections + * in order to handle close calls properly. Therefore, the returned Connections cannot + * be cast to a native CCI Connection type or to a connection pool implementation type. + * + * @author Juergen Hoeller + * @since 1.2 + * @see javax.resource.cci.ConnectionFactory#getConnection + * @see javax.resource.cci.Connection#close + * @see ConnectionFactoryUtils#doGetConnection + * @see ConnectionFactoryUtils#doReleaseConnection + */ +@SuppressWarnings("serial") +public class TransactionAwareConnectionFactoryProxy extends DelegatingConnectionFactory { + + /** + * Create a new TransactionAwareConnectionFactoryProxy. + * @see #setTargetConnectionFactory + */ + public TransactionAwareConnectionFactoryProxy() { + } + + /** + * Create a new TransactionAwareConnectionFactoryProxy. + * @param targetConnectionFactory the target ConnectionFactory + */ + public TransactionAwareConnectionFactoryProxy(ConnectionFactory targetConnectionFactory) { + setTargetConnectionFactory(targetConnectionFactory); + afterPropertiesSet(); + } + + + /** + * Delegate to ConnectionFactoryUtils for automatically participating in Spring-managed + * transactions. Throws the original ResourceException, if any. + * @return a transactional Connection if any, a new one else + * @see org.springframework.jca.cci.connection.ConnectionFactoryUtils#doGetConnection + */ + @Override + public Connection getConnection() throws ResourceException { + ConnectionFactory targetConnectionFactory = obtainTargetConnectionFactory(); + Connection con = ConnectionFactoryUtils.doGetConnection(targetConnectionFactory); + return getTransactionAwareConnectionProxy(con, targetConnectionFactory); + } + + /** + * Wrap the given Connection with a proxy that delegates every method call to it + * but delegates {@code close} calls to ConnectionFactoryUtils. + * @param target the original Connection to wrap + * @param cf the ConnectionFactory that the Connection came from + * @return the wrapped Connection + * @see javax.resource.cci.Connection#close() + * @see ConnectionFactoryUtils#doReleaseConnection + */ + protected Connection getTransactionAwareConnectionProxy(Connection target, ConnectionFactory cf) { + return (Connection) Proxy.newProxyInstance( + Connection.class.getClassLoader(), + new Class[] {Connection.class}, + new TransactionAwareInvocationHandler(target, cf)); + } + + + /** + * Invocation handler that delegates close calls on CCI Connections + * to ConnectionFactoryUtils for being aware of thread-bound transactions. + */ + private static class TransactionAwareInvocationHandler implements InvocationHandler { + + private final Connection target; + + private final ConnectionFactory connectionFactory; + + public TransactionAwareInvocationHandler(Connection target, ConnectionFactory cf) { + this.target = target; + this.connectionFactory = cf; + } + + @Override + @Nullable + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + // Invocation on Connection interface coming in... + + if (method.getName().equals("equals")) { + // Only consider equal when proxies are identical. + return (proxy == args[0]); + } + else if (method.getName().equals("hashCode")) { + // Use hashCode of Connection proxy. + return System.identityHashCode(proxy); + } + else if (method.getName().equals("getLocalTransaction")) { + if (ConnectionFactoryUtils.isConnectionTransactional(this.target, this.connectionFactory)) { + throw new javax.resource.spi.IllegalStateException( + "Local transaction handling not allowed within a managed transaction"); + } + } + else if (method.getName().equals("close")) { + // Handle close method: only close if not within a transaction. + ConnectionFactoryUtils.doReleaseConnection(this.target, this.connectionFactory); + return null; + } + + // Invoke method on target Connection. + try { + return method.invoke(this.target, args); + } + catch (InvocationTargetException ex) { + throw ex.getTargetException(); + } + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/connection/package-info.java b/spring-tx/src/main/java/org/springframework/jca/cci/connection/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..b90d61e6211cbe2caa9a4a7397f5233e6f0f845a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/connection/package-info.java @@ -0,0 +1,11 @@ +/** + * Provides a utility class for easy ConnectionFactory access, + * a PlatformTransactionManager for local CCI transactions, + * and various simple ConnectionFactory proxies/adapters. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.cci.connection; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/CciOperations.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/CciOperations.java new file mode 100644 index 0000000000000000000000000000000000000000..e12342695db2922e98e632af69615b135e5d7144 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/CciOperations.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core; + +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.Record; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; + +/** + * Interface that specifies a basic set of CCI operations on an EIS. + * Implemented by CciTemplate. Not often used, but a useful option + * to enhance testability, as it can easily be mocked or stubbed. + * + *

Alternatively, the standard CCI infrastructure can be mocked. + * However, mocking this interface constitutes significantly less work. + * + * @author Juergen Hoeller + * @since 1.2 + * @see CciTemplate + */ +public interface CciOperations { + + /** + * Execute a request on an EIS with CCI, implemented as callback action + * working on a CCI Connection. This allows for implementing arbitrary + * data access operations, within Spring's managed CCI environment: + * that is, participating in Spring-managed transactions and converting + * JCA ResourceExceptions into Spring's DataAccessException hierarchy. + *

The callback action can return a result object, for example a + * domain object or a collection of domain objects. + * @param action the callback object that specifies the action + * @return the result object returned by the action, if any + * @throws DataAccessException if there is any problem + */ + @Nullable + T execute(ConnectionCallback action) throws DataAccessException; + + /** + * Execute a request on an EIS with CCI, implemented as callback action + * working on a CCI Interaction. This allows for implementing arbitrary + * data access operations on a single Interaction, within Spring's managed + * CCI environment: that is, participating in Spring-managed transactions + * and converting JCA ResourceExceptions into Spring's DataAccessException + * hierarchy. + *

The callback action can return a result object, for example a + * domain object or a collection of domain objects. + * @param action the callback object that specifies the action + * @return the result object returned by the action, if any + * @throws DataAccessException if there is any problem + */ + @Nullable + T execute(InteractionCallback action) throws DataAccessException; + + /** + * Execute the specified interaction on an EIS with CCI. + * @param spec the CCI InteractionSpec instance that defines + * the interaction (connector-specific) + * @param inputRecord the input record + * @return the output record + * @throws DataAccessException if there is any problem + */ + @Nullable + Record execute(InteractionSpec spec, Record inputRecord) throws DataAccessException; + + /** + * Execute the specified interaction on an EIS with CCI. + * @param spec the CCI InteractionSpec instance that defines + * the interaction (connector-specific) + * @param inputRecord the input record + * @param outputRecord the output record + * @throws DataAccessException if there is any problem + */ + void execute(InteractionSpec spec, Record inputRecord, Record outputRecord) throws DataAccessException; + + /** + * Execute the specified interaction on an EIS with CCI. + * @param spec the CCI InteractionSpec instance that defines + * the interaction (connector-specific) + * @param inputCreator object that creates the input record to use + * @return the output record + * @throws DataAccessException if there is any problem + */ + Record execute(InteractionSpec spec, RecordCreator inputCreator) throws DataAccessException; + + /** + * Execute the specified interaction on an EIS with CCI. + * @param spec the CCI InteractionSpec instance that defines + * the interaction (connector-specific) + * @param inputRecord the input record + * @param outputExtractor object to convert the output record to a result object + * @return the output data extracted with the RecordExtractor object + * @throws DataAccessException if there is any problem + */ + @Nullable + T execute(InteractionSpec spec, Record inputRecord, RecordExtractor outputExtractor) + throws DataAccessException; + + /** + * Execute the specified interaction on an EIS with CCI. + * @param spec the CCI InteractionSpec instance that defines + * the interaction (connector-specific) + * @param inputCreator object that creates the input record to use + * @param outputExtractor object to convert the output record to a result object + * @return the output data extracted with the RecordExtractor object + * @throws DataAccessException if there is any problem + */ + @Nullable + T execute(InteractionSpec spec, RecordCreator inputCreator, RecordExtractor outputExtractor) + throws DataAccessException; + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/CciTemplate.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/CciTemplate.java new file mode 100644 index 0000000000000000000000000000000000000000..d9d2e01b805a847da4dca28348ca7367b3f168c4 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/CciTemplate.java @@ -0,0 +1,455 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core; + +import java.sql.SQLException; + +import javax.resource.NotSupportedException; +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.ConnectionSpec; +import javax.resource.cci.IndexedRecord; +import javax.resource.cci.Interaction; +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.MappedRecord; +import javax.resource.cci.Record; +import javax.resource.cci.RecordFactory; +import javax.resource.cci.ResultSet; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.dao.DataAccessException; +import org.springframework.dao.DataAccessResourceFailureException; +import org.springframework.jca.cci.CannotCreateRecordException; +import org.springframework.jca.cci.CciOperationNotSupportedException; +import org.springframework.jca.cci.InvalidResultSetAccessException; +import org.springframework.jca.cci.RecordTypeNotSupportedException; +import org.springframework.jca.cci.connection.ConnectionFactoryUtils; +import org.springframework.jca.cci.connection.NotSupportedRecordFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * This is the central class in the CCI core package. + * It simplifies the use of CCI and helps to avoid common errors. + * It executes core CCI workflow, leaving application code to provide parameters + * to CCI and extract results. This class executes EIS queries or updates, + * catching ResourceExceptions and translating them to the generic exception + * hierarchy defined in the {@code org.springframework.dao} package. + * + *

Code using this class can pass in and receive {@link javax.resource.cci.Record} + * instances, or alternatively implement callback interfaces for creating input + * Records and extracting result objects from output Records (or CCI ResultSets). + * + *

Can be used within a service implementation via direct instantiation + * with a ConnectionFactory reference, or get prepared in an application context + * and given to services as bean reference. Note: The ConnectionFactory should + * always be configured as a bean in the application context, in the first case + * given to the service directly, in the second case to the prepared template. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see RecordCreator + * @see RecordExtractor + */ +public class CciTemplate implements CciOperations { + + private final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private ConnectionFactory connectionFactory; + + @Nullable + private ConnectionSpec connectionSpec; + + @Nullable + private RecordCreator outputRecordCreator; + + + /** + * Construct a new CciTemplate for bean usage. + *

Note: The ConnectionFactory has to be set before using the instance. + * @see #setConnectionFactory + */ + public CciTemplate() { + } + + /** + * Construct a new CciTemplate, given a ConnectionFactory to obtain Connections from. + * Note: This will trigger eager initialization of the exception translator. + * @param connectionFactory the JCA ConnectionFactory to obtain Connections from + */ + public CciTemplate(ConnectionFactory connectionFactory) { + setConnectionFactory(connectionFactory); + afterPropertiesSet(); + } + + /** + * Construct a new CciTemplate, given a ConnectionFactory to obtain Connections from. + * Note: This will trigger eager initialization of the exception translator. + * @param connectionFactory the JCA ConnectionFactory to obtain Connections from + * @param connectionSpec the CCI ConnectionSpec to obtain Connections for + * (may be {@code null}) + */ + public CciTemplate(ConnectionFactory connectionFactory, @Nullable ConnectionSpec connectionSpec) { + setConnectionFactory(connectionFactory); + if (connectionSpec != null) { + setConnectionSpec(connectionSpec); + } + afterPropertiesSet(); + } + + + /** + * Set the CCI ConnectionFactory to obtain Connections from. + */ + public void setConnectionFactory(@Nullable ConnectionFactory connectionFactory) { + this.connectionFactory = connectionFactory; + } + + /** + * Return the CCI ConnectionFactory used by this template. + */ + @Nullable + public ConnectionFactory getConnectionFactory() { + return this.connectionFactory; + } + + private ConnectionFactory obtainConnectionFactory() { + ConnectionFactory connectionFactory = getConnectionFactory(); + Assert.state(connectionFactory != null, "No ConnectionFactory set"); + return connectionFactory; + } + + /** + * Set the CCI ConnectionSpec that this template instance is + * supposed to obtain Connections for. + */ + public void setConnectionSpec(@Nullable ConnectionSpec connectionSpec) { + this.connectionSpec = connectionSpec; + } + + /** + * Return the CCI ConnectionSpec used by this template, if any. + */ + @Nullable + public ConnectionSpec getConnectionSpec() { + return this.connectionSpec; + } + + /** + * Set a RecordCreator that should be used for creating default output Records. + *

Default is none: When no explicit output Record gets passed into an + * {@code execute} method, CCI's {@code Interaction.execute} variant + * that returns an output Record will be called. + *

Specify a RecordCreator here if you always need to call CCI's + * {@code Interaction.execute} variant with a passed-in output Record. + * Unless there is an explicitly specified output Record, CciTemplate will + * then invoke this RecordCreator to create a default output Record instance. + * @see javax.resource.cci.Interaction#execute(javax.resource.cci.InteractionSpec, Record) + * @see javax.resource.cci.Interaction#execute(javax.resource.cci.InteractionSpec, Record, Record) + */ + public void setOutputRecordCreator(@Nullable RecordCreator creator) { + this.outputRecordCreator = creator; + } + + /** + * Return a RecordCreator that should be used for creating default output Records. + */ + @Nullable + public RecordCreator getOutputRecordCreator() { + return this.outputRecordCreator; + } + + public void afterPropertiesSet() { + if (getConnectionFactory() == null) { + throw new IllegalArgumentException("Property 'connectionFactory' is required"); + } + } + + + /** + * Create a template derived from this template instance, + * inheriting the ConnectionFactory and other settings but + * overriding the ConnectionSpec used for obtaining Connections. + * @param connectionSpec the CCI ConnectionSpec that the derived template + * instance is supposed to obtain Connections for + * @return the derived template instance + * @see #setConnectionSpec + */ + public CciTemplate getDerivedTemplate(ConnectionSpec connectionSpec) { + CciTemplate derived = new CciTemplate(obtainConnectionFactory(), connectionSpec); + RecordCreator recordCreator = getOutputRecordCreator(); + if (recordCreator != null) { + derived.setOutputRecordCreator(recordCreator); + } + return derived; + } + + + @Override + @Nullable + public T execute(ConnectionCallback action) throws DataAccessException { + Assert.notNull(action, "Callback object must not be null"); + ConnectionFactory connectionFactory = obtainConnectionFactory(); + Connection con = ConnectionFactoryUtils.getConnection(connectionFactory, getConnectionSpec()); + try { + return action.doInConnection(con, connectionFactory); + } + catch (NotSupportedException ex) { + throw new CciOperationNotSupportedException("CCI operation not supported by connector", ex); + } + catch (ResourceException ex) { + throw new DataAccessResourceFailureException("CCI operation failed", ex); + } + catch (SQLException ex) { + throw new InvalidResultSetAccessException("Parsing of CCI ResultSet failed", ex); + } + finally { + ConnectionFactoryUtils.releaseConnection(con, getConnectionFactory()); + } + } + + @Override + @Nullable + public T execute(final InteractionCallback action) throws DataAccessException { + Assert.notNull(action, "Callback object must not be null"); + return execute((ConnectionCallback) (connection, connectionFactory) -> { + Interaction interaction = connection.createInteraction(); + try { + return action.doInInteraction(interaction, connectionFactory); + } + finally { + closeInteraction(interaction); + } + }); + } + + @Override + @Nullable + public Record execute(InteractionSpec spec, Record inputRecord) throws DataAccessException { + return doExecute(spec, inputRecord, null, new SimpleRecordExtractor()); + } + + @Override + public void execute(InteractionSpec spec, Record inputRecord, Record outputRecord) throws DataAccessException { + doExecute(spec, inputRecord, outputRecord, null); + } + + @Override + public Record execute(InteractionSpec spec, RecordCreator inputCreator) throws DataAccessException { + Record output = doExecute(spec, createRecord(inputCreator), null, new SimpleRecordExtractor()); + Assert.state(output != null, "Invalid output record"); + return output; + } + + @Override + public T execute(InteractionSpec spec, Record inputRecord, RecordExtractor outputExtractor) + throws DataAccessException { + + return doExecute(spec, inputRecord, null, outputExtractor); + } + + @Override + public T execute(InteractionSpec spec, RecordCreator inputCreator, RecordExtractor outputExtractor) + throws DataAccessException { + + return doExecute(spec, createRecord(inputCreator), null, outputExtractor); + } + + /** + * Execute the specified interaction on an EIS with CCI. + * All other interaction execution methods go through this. + * @param spec the CCI InteractionSpec instance that defines + * the interaction (connector-specific) + * @param inputRecord the input record + * @param outputRecord output record (can be {@code null}) + * @param outputExtractor object to convert the output record to a result object + * @return the output data extracted with the RecordExtractor object + * @throws DataAccessException if there is any problem + */ + @Nullable + protected T doExecute( + final InteractionSpec spec, final Record inputRecord, @Nullable final Record outputRecord, + @Nullable final RecordExtractor outputExtractor) throws DataAccessException { + + return execute((InteractionCallback) (interaction, connectionFactory) -> { + Record outputRecordToUse = outputRecord; + try { + if (outputRecord != null || getOutputRecordCreator() != null) { + // Use the CCI execute method with output record as parameter. + if (outputRecord == null) { + RecordFactory recordFactory = getRecordFactory(connectionFactory); + outputRecordToUse = getOutputRecordCreator().createRecord(recordFactory); + } + interaction.execute(spec, inputRecord, outputRecordToUse); + } + else { + outputRecordToUse = interaction.execute(spec, inputRecord); + } + return (outputExtractor != null ? outputExtractor.extractData(outputRecordToUse) : null); + } + finally { + if (outputRecordToUse instanceof ResultSet) { + closeResultSet((ResultSet) outputRecordToUse); + } + } + }); + } + + + /** + * Create an indexed Record through the ConnectionFactory's RecordFactory. + * @param name the name of the record + * @return the Record + * @throws DataAccessException if creation of the Record failed + * @see #getRecordFactory(javax.resource.cci.ConnectionFactory) + * @see javax.resource.cci.RecordFactory#createIndexedRecord(String) + */ + public IndexedRecord createIndexedRecord(String name) throws DataAccessException { + try { + RecordFactory recordFactory = getRecordFactory(obtainConnectionFactory()); + return recordFactory.createIndexedRecord(name); + } + catch (NotSupportedException ex) { + throw new RecordTypeNotSupportedException("Creation of indexed Record not supported by connector", ex); + } + catch (ResourceException ex) { + throw new CannotCreateRecordException("Creation of indexed Record failed", ex); + } + } + + /** + * Create a mapped Record from the ConnectionFactory's RecordFactory. + * @param name record name + * @return the Record + * @throws DataAccessException if creation of the Record failed + * @see #getRecordFactory(javax.resource.cci.ConnectionFactory) + * @see javax.resource.cci.RecordFactory#createMappedRecord(String) + */ + public MappedRecord createMappedRecord(String name) throws DataAccessException { + try { + RecordFactory recordFactory = getRecordFactory(obtainConnectionFactory()); + return recordFactory.createMappedRecord(name); + } + catch (NotSupportedException ex) { + throw new RecordTypeNotSupportedException("Creation of mapped Record not supported by connector", ex); + } + catch (ResourceException ex) { + throw new CannotCreateRecordException("Creation of mapped Record failed", ex); + } + } + + /** + * Invoke the given RecordCreator, converting JCA ResourceExceptions + * to Spring's DataAccessException hierarchy. + * @param recordCreator the RecordCreator to invoke + * @return the created Record + * @throws DataAccessException if creation of the Record failed + * @see #getRecordFactory(javax.resource.cci.ConnectionFactory) + * @see RecordCreator#createRecord(javax.resource.cci.RecordFactory) + */ + protected Record createRecord(RecordCreator recordCreator) throws DataAccessException { + try { + RecordFactory recordFactory = getRecordFactory(obtainConnectionFactory()); + return recordCreator.createRecord(recordFactory); + } + catch (NotSupportedException ex) { + throw new RecordTypeNotSupportedException( + "Creation of the desired Record type not supported by connector", ex); + } + catch (ResourceException ex) { + throw new CannotCreateRecordException("Creation of the desired Record failed", ex); + } + } + + /** + * Return a RecordFactory for the given ConnectionFactory. + *

Default implementation returns the connector's RecordFactory if + * available, falling back to a NotSupportedRecordFactory placeholder. + * This allows to invoke a RecordCreator callback with a non-null + * RecordFactory reference in any case. + * @param connectionFactory the CCI ConnectionFactory + * @return the CCI RecordFactory for the ConnectionFactory + * @throws ResourceException if thrown by CCI methods + * @see org.springframework.jca.cci.connection.NotSupportedRecordFactory + */ + protected RecordFactory getRecordFactory(ConnectionFactory connectionFactory) throws ResourceException { + try { + return connectionFactory.getRecordFactory(); + } + catch (NotSupportedException ex) { + return new NotSupportedRecordFactory(); + } + } + + + /** + * Close the given CCI Interaction and ignore any thrown exception. + * This is useful for typical finally blocks in manual CCI code. + * @param interaction the CCI Interaction to close + * @see javax.resource.cci.Interaction#close() + */ + private void closeInteraction(@Nullable Interaction interaction) { + if (interaction != null) { + try { + interaction.close(); + } + catch (ResourceException ex) { + logger.trace("Could not close CCI Interaction", ex); + } + catch (Throwable ex) { + // We don't trust the CCI driver: It might throw RuntimeException or Error. + logger.trace("Unexpected exception on closing CCI Interaction", ex); + } + } + } + + /** + * Close the given CCI ResultSet and ignore any thrown exception. + * This is useful for typical finally blocks in manual CCI code. + * @param resultSet the CCI ResultSet to close + * @see javax.resource.cci.ResultSet#close() + */ + private void closeResultSet(@Nullable ResultSet resultSet) { + if (resultSet != null) { + try { + resultSet.close(); + } + catch (SQLException ex) { + logger.trace("Could not close CCI ResultSet", ex); + } + catch (Throwable ex) { + // We don't trust the CCI driver: It might throw RuntimeException or Error. + logger.trace("Unexpected exception on closing CCI ResultSet", ex); + } + } + } + + + private static class SimpleRecordExtractor implements RecordExtractor { + + @Override + public Record extractData(Record record) { + return record; + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/ConnectionCallback.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/ConnectionCallback.java new file mode 100644 index 0000000000000000000000000000000000000000..8893341e21ae3224f6f4de28fa08b840b7ecbf12 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/ConnectionCallback.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core; + +import java.sql.SQLException; + +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; + +/** + * Generic callback interface for code that operates on a CCI Connection. + * Allows to execute any number of operations on a single Connection, + * using any type and number of Interaction. + * + *

This is particularly useful for delegating to existing data access code + * that expects a Connection to work on and throws ResourceException. For newly + * written code, it is strongly recommended to use CciTemplate's more specific + * {@code execute} variants. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @param the result type + * @see CciTemplate#execute(ConnectionCallback) + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, javax.resource.cci.Record) + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, RecordCreator, RecordExtractor) + */ +@FunctionalInterface +public interface ConnectionCallback { + + /** + * Gets called by {@code CciTemplate.execute} with an active CCI Connection. + * Does not need to care about activating or closing the Connection, or handling + * transactions. + *

If called without a thread-bound CCI transaction (initiated by + * CciLocalTransactionManager), the code will simply get executed on the CCI + * Connection with its transactional semantics. If CciTemplate is configured + * to use a JTA-aware ConnectionFactory, the CCI Connection and thus the callback + * code will be transactional if a JTA transaction is active. + *

Allows for returning a result object created within the callback, i.e. + * a domain object or a collection of domain objects. Note that there's special + * support for single step actions: see the {@code CciTemplate.execute} + * variants. A thrown RuntimeException is treated as application exception: + * it gets propagated to the caller of the template. + * @param connection active CCI Connection + * @param connectionFactory the CCI ConnectionFactory that the Connection was + * created with (gives access to RecordFactory and ResourceAdapterMetaData) + * @return a result object, or {@code null} if none + * @throws ResourceException if thrown by a CCI method, to be auto-converted + * to a DataAccessException + * @throws SQLException if thrown by a ResultSet method, to be auto-converted + * to a DataAccessException + * @throws DataAccessException in case of custom exceptions + * @see javax.resource.cci.ConnectionFactory#getRecordFactory() + * @see javax.resource.cci.ConnectionFactory#getMetaData() + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, RecordCreator, RecordExtractor) + */ + @Nullable + T doInConnection(Connection connection, ConnectionFactory connectionFactory) + throws ResourceException, SQLException, DataAccessException; + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/InteractionCallback.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/InteractionCallback.java new file mode 100644 index 0000000000000000000000000000000000000000..ef7f14add5311f7939cd01f0422cd0ac5f564a01 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/InteractionCallback.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core; + +import java.sql.SQLException; + +import javax.resource.ResourceException; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.Interaction; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; + +/** + * Generic callback interface for code that operates on a CCI Interaction. + * Allows to execute any number of operations on a single Interaction, for + * example a single execute call or repeated execute calls with varying + * parameters. + * + *

This is particularly useful for delegating to existing data access code + * that expects an Interaction to work on and throws ResourceException. For newly + * written code, it is strongly recommended to use CciTemplate's more specific + * {@code execute} variants. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @param the result type + * @see CciTemplate#execute(InteractionCallback) + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, javax.resource.cci.Record) + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, RecordCreator, RecordExtractor) + */ +@FunctionalInterface +public interface InteractionCallback { + + /** + * Gets called by {@code CciTemplate.execute} with an active CCI Interaction. + * Does not need to care about activating or closing the Interaction, or + * handling transactions. + *

If called without a thread-bound CCI transaction (initiated by + * CciLocalTransactionManager), the code will simply get executed on the CCI + * Interaction with its transactional semantics. If CciTemplate is configured + * to use a JTA-aware ConnectionFactory, the CCI Interaction and thus the callback + * code will be transactional if a JTA transaction is active. + *

Allows for returning a result object created within the callback, i.e. + * a domain object or a collection of domain objects. Note that there's special + * support for single step actions: see the {@code CciTemplate.execute} + * variants. A thrown RuntimeException is treated as application exception: + * it gets propagated to the caller of the template. + * @param interaction active CCI Interaction + * @param connectionFactory the CCI ConnectionFactory that the Connection was + * created with (gives access to RecordFactory and ResourceAdapterMetaData) + * @return a result object, or {@code null} if none + * @throws ResourceException if thrown by a CCI method, to be auto-converted + * to a DataAccessException + * @throws SQLException if thrown by a ResultSet method, to be auto-converted + * to a DataAccessException + * @throws DataAccessException in case of custom exceptions + * @see javax.resource.cci.ConnectionFactory#getRecordFactory() + * @see javax.resource.cci.ConnectionFactory#getMetaData() + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, RecordCreator, RecordExtractor) + */ + @Nullable + T doInInteraction(Interaction interaction, ConnectionFactory connectionFactory) + throws ResourceException, SQLException, DataAccessException; + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/RecordCreator.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/RecordCreator.java new file mode 100644 index 0000000000000000000000000000000000000000..77bf91a21b057ab39b0a3c525e3a7dbbeda135e3 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/RecordCreator.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core; + +import javax.resource.ResourceException; +import javax.resource.cci.Record; +import javax.resource.cci.RecordFactory; + +import org.springframework.dao.DataAccessException; + +/** + * Callback interface for creating a CCI Record instance, + * usually based on the passed-in CCI RecordFactory. + * + *

Used for input Record creation in CciTemplate. Alternatively, + * Record instances can be passed into CciTemplate's corresponding + * {@code execute} methods directly, either instantiated manually + * or created through CciTemplate's Record factory methods. + * + *

Also used for creating default output Records in CciTemplate. + * This is useful when the JCA connector needs an explicit output Record + * instance, but no output Records should be passed into CciTemplate's + * {@code execute} methods. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, RecordCreator) + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, RecordCreator, RecordExtractor) + * @see CciTemplate#createIndexedRecord(String) + * @see CciTemplate#createMappedRecord(String) + * @see CciTemplate#setOutputRecordCreator(RecordCreator) + */ +@FunctionalInterface +public interface RecordCreator { + + /** + * Create a CCI Record instance, usually based on the passed-in CCI RecordFactory. + *

For use as input creator with CciTemplate's {@code execute} methods, + * this method should create a populated Record instance. For use as + * output Record creator, it should return an empty Record instance. + * @param recordFactory the CCI RecordFactory (never {@code null}, but not guaranteed to be + * supported by the connector: its create methods might throw NotSupportedException) + * @return the Record instance + * @throws ResourceException if thrown by a CCI method, to be auto-converted + * to a DataAccessException + * @throws DataAccessException in case of custom exceptions + */ + Record createRecord(RecordFactory recordFactory) throws ResourceException, DataAccessException; + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/RecordExtractor.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/RecordExtractor.java new file mode 100644 index 0000000000000000000000000000000000000000..f0344b4f22cbba35d82f8139a876055ad26631b6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/RecordExtractor.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core; + +import java.sql.SQLException; + +import javax.resource.ResourceException; +import javax.resource.cci.Record; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; + +/** + * Callback interface for extracting a result object from a CCI Record instance. + * + *

Used for output object creation in CciTemplate. Alternatively, output + * Records can also be returned to client code as-is. In case of a CCI ResultSet + * as execution result, you will almost always want to implement a RecordExtractor, + * to be able to read the ResultSet in a managed fashion, with the CCI Connection + * still open while reading the ResultSet. + * + *

Implementations of this interface perform the actual work of extracting + * results, but don't need to worry about exception handling. ResourceExceptions + * will be caught and handled correctly by the CciTemplate class. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @param the result type + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, Record, RecordExtractor) + * @see CciTemplate#execute(javax.resource.cci.InteractionSpec, RecordCreator, RecordExtractor) + * @see javax.resource.cci.ResultSet + */ +@FunctionalInterface +public interface RecordExtractor { + + /** + * Process the data in the given Record, creating a corresponding result object. + * @param record the Record to extract data from + * (possibly a CCI ResultSet) + * @return an arbitrary result object, or {@code null} if none + * (the extractor will typically be stateful in the latter case) + * @throws ResourceException if thrown by a CCI method, to be auto-converted + * to a DataAccessException + * @throws SQLException if thrown by a ResultSet method, to be auto-converted + * to a DataAccessException + * @throws DataAccessException in case of custom exceptions + * @see javax.resource.cci.ResultSet + */ + @Nullable + T extractData(Record record) throws ResourceException, SQLException, DataAccessException; + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/package-info.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..33d36edf29d88cec9adc27cfd49321033436d4ba --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/package-info.java @@ -0,0 +1,10 @@ +/** + * Provides the core JCA CCI support, based on CciTemplate + * and its associated callback interfaces. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.cci.core; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/support/CciDaoSupport.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/support/CciDaoSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..2f523c38fdfe54f57f101eec29ecfb3e095dcb2d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/support/CciDaoSupport.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core.support; + +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.ConnectionSpec; + +import org.springframework.dao.support.DaoSupport; +import org.springframework.jca.cci.CannotGetCciConnectionException; +import org.springframework.jca.cci.connection.ConnectionFactoryUtils; +import org.springframework.jca.cci.core.CciTemplate; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Convenient super class for CCI-based data access objects. + * + *

Requires a {@link javax.resource.cci.ConnectionFactory} to be set, + * providing a {@link org.springframework.jca.cci.core.CciTemplate} based + * on it to subclasses through the {@link #getCciTemplate()} method. + * + *

This base class is mainly intended for CciTemplate usage but can + * also be used when working with a Connection directly or when using + * {@code org.springframework.jca.cci.object} classes. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see #setConnectionFactory + * @see #getCciTemplate + * @see org.springframework.jca.cci.core.CciTemplate + */ +public abstract class CciDaoSupport extends DaoSupport { + + @Nullable + private CciTemplate cciTemplate; + + + /** + * Set the ConnectionFactory to be used by this DAO. + */ + public final void setConnectionFactory(ConnectionFactory connectionFactory) { + if (this.cciTemplate == null || connectionFactory != this.cciTemplate.getConnectionFactory()) { + this.cciTemplate = createCciTemplate(connectionFactory); + } + } + + /** + * Create a CciTemplate for the given ConnectionFactory. + * Only invoked if populating the DAO with a ConnectionFactory reference! + *

Can be overridden in subclasses to provide a CciTemplate instance + * with different configuration, or a custom CciTemplate subclass. + * @param connectionFactory the CCI ConnectionFactory to create a CciTemplate for + * @return the new CciTemplate instance + * @see #setConnectionFactory(javax.resource.cci.ConnectionFactory) + */ + protected CciTemplate createCciTemplate(ConnectionFactory connectionFactory) { + return new CciTemplate(connectionFactory); + } + + /** + * Return the ConnectionFactory used by this DAO. + */ + @Nullable + public final ConnectionFactory getConnectionFactory() { + return (this.cciTemplate != null ? this.cciTemplate.getConnectionFactory() : null); + } + + /** + * Set the CciTemplate for this DAO explicitly, + * as an alternative to specifying a ConnectionFactory. + */ + public final void setCciTemplate(CciTemplate cciTemplate) { + this.cciTemplate = cciTemplate; + } + + /** + * Return the CciTemplate for this DAO, + * pre-initialized with the ConnectionFactory or set explicitly. + */ + @Nullable + public final CciTemplate getCciTemplate() { + return this.cciTemplate; + } + + @Override + protected final void checkDaoConfig() { + if (this.cciTemplate == null) { + throw new IllegalArgumentException("'connectionFactory' or 'cciTemplate' is required"); + } + } + + + /** + * Obtain a CciTemplate derived from the main template instance, + * inheriting the ConnectionFactory and other settings but + * overriding the ConnectionSpec used for obtaining Connections. + * @param connectionSpec the CCI ConnectionSpec that the returned + * template instance is supposed to obtain Connections for + * @return the derived template instance + * @see org.springframework.jca.cci.core.CciTemplate#getDerivedTemplate(javax.resource.cci.ConnectionSpec) + */ + protected final CciTemplate getCciTemplate(ConnectionSpec connectionSpec) { + CciTemplate cciTemplate = getCciTemplate(); + Assert.state(cciTemplate != null, "No CciTemplate set"); + return cciTemplate.getDerivedTemplate(connectionSpec); + } + + /** + * Get a CCI Connection, either from the current transaction or a new one. + * @return the CCI Connection + * @throws org.springframework.jca.cci.CannotGetCciConnectionException + * if the attempt to get a Connection failed + * @see org.springframework.jca.cci.connection.ConnectionFactoryUtils#getConnection(javax.resource.cci.ConnectionFactory) + */ + protected final Connection getConnection() throws CannotGetCciConnectionException { + ConnectionFactory connectionFactory = getConnectionFactory(); + Assert.state(connectionFactory != null, "No ConnectionFactory set"); + return ConnectionFactoryUtils.getConnection(connectionFactory); + } + + /** + * Close the given CCI Connection, created via this bean's ConnectionFactory, + * if it isn't bound to the thread. + * @param con the Connection to close + * @see org.springframework.jca.cci.connection.ConnectionFactoryUtils#releaseConnection + */ + protected final void releaseConnection(Connection con) { + ConnectionFactoryUtils.releaseConnection(con, getConnectionFactory()); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/support/CommAreaRecord.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/support/CommAreaRecord.java new file mode 100644 index 0000000000000000000000000000000000000000..7011eb1924739d87e18ad2fd1fdffb51a9532978 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/support/CommAreaRecord.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.core.support; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import javax.resource.cci.Record; +import javax.resource.cci.Streamable; + +import org.springframework.util.FileCopyUtils; + +/** + * CCI Record implementation for a COMMAREA, holding a byte array. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see org.springframework.jca.cci.object.MappingCommAreaOperation + */ +@SuppressWarnings("serial") +public class CommAreaRecord implements Record, Streamable { + + private byte[] bytes = new byte[0]; + + private String recordName = ""; + + private String recordShortDescription = ""; + + + /** + * Create a new CommAreaRecord. + * @see #read(java.io.InputStream) + */ + public CommAreaRecord() { + } + + /** + * Create a new CommAreaRecord. + * @param bytes the bytes to fill the record with + */ + public CommAreaRecord(byte[] bytes) { + this.bytes = bytes; + } + + + @Override + public void setRecordName(String recordName) { + this.recordName = recordName; + } + + @Override + public String getRecordName() { + return this.recordName; + } + + @Override + public void setRecordShortDescription(String recordShortDescription) { + this.recordShortDescription = recordShortDescription; + } + + @Override + public String getRecordShortDescription() { + return this.recordShortDescription; + } + + + @Override + public void read(InputStream in) throws IOException { + this.bytes = FileCopyUtils.copyToByteArray(in); + } + + @Override + public void write(OutputStream out) throws IOException { + out.write(this.bytes); + out.flush(); + } + + public byte[] toByteArray() { + return this.bytes; + } + + + @Override + public Object clone() { + return new CommAreaRecord(this.bytes); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/core/support/package-info.java b/spring-tx/src/main/java/org/springframework/jca/cci/core/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..d7bfd6bc0f223e3e0ee590c16f1200aed54de558 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/core/support/package-info.java @@ -0,0 +1,10 @@ +/** + * Classes supporting the {@code org.springframework.jca.cci.core} package. + * Contains a DAO base class for CciTemplate usage. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.cci.core.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/object/EisOperation.java b/spring-tx/src/main/java/org/springframework/jca/cci/object/EisOperation.java new file mode 100644 index 0000000000000000000000000000000000000000..d0d522c115c545085f3947cefe881816192e5836 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/object/EisOperation.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.object; + +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.InteractionSpec; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.jca.cci.core.CciTemplate; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Base class for EIS operation objects that work with the CCI API. + * Encapsulates a CCI ConnectionFactory and a CCI InteractionSpec. + * + *

Works with a CciTemplate instance underneath. EIS operation objects + * are an alternative to working with a CciTemplate directly. + * + * @author Juergen Hoeller + * @since 1.2 + * @see #setConnectionFactory + * @see #setInteractionSpec + */ +public abstract class EisOperation implements InitializingBean { + + private CciTemplate cciTemplate = new CciTemplate(); + + @Nullable + private InteractionSpec interactionSpec; + + + /** + * Set the CciTemplate to be used by this operation. + * Alternatively, specify a CCI ConnectionFactory. + * @see #setConnectionFactory + */ + public void setCciTemplate(CciTemplate cciTemplate) { + Assert.notNull(cciTemplate, "CciTemplate must not be null"); + this.cciTemplate = cciTemplate; + } + + /** + * Return the CciTemplate used by this operation. + */ + public CciTemplate getCciTemplate() { + return this.cciTemplate; + } + + /** + * Set the CCI ConnectionFactory to be used by this operation. + */ + public void setConnectionFactory(ConnectionFactory connectionFactory) { + this.cciTemplate.setConnectionFactory(connectionFactory); + } + + /** + * Set the CCI InteractionSpec for this operation. + */ + public void setInteractionSpec(@Nullable InteractionSpec interactionSpec) { + this.interactionSpec = interactionSpec; + } + + /** + * Return the CCI InteractionSpec for this operation. + */ + @Nullable + public InteractionSpec getInteractionSpec() { + return this.interactionSpec; + } + + + @Override + public void afterPropertiesSet() { + this.cciTemplate.afterPropertiesSet(); + + if (this.interactionSpec == null) { + throw new IllegalArgumentException("InteractionSpec is required"); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/object/MappingCommAreaOperation.java b/spring-tx/src/main/java/org/springframework/jca/cci/object/MappingCommAreaOperation.java new file mode 100644 index 0000000000000000000000000000000000000000..a6f886cdd78fd0858d8c55bb05b0070ad99c7c3d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/object/MappingCommAreaOperation.java @@ -0,0 +1,97 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.object; + +import java.io.IOException; + +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.Record; +import javax.resource.cci.RecordFactory; + +import org.springframework.dao.DataAccessException; +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.jca.cci.core.support.CommAreaRecord; + +/** + * EIS operation object for access to COMMAREA records. + * Subclass of the generic MappingRecordOperation class. + * + * @author Thierry Templier + * @since 1.2 + */ +public abstract class MappingCommAreaOperation extends MappingRecordOperation { + + /** + * Create a new MappingCommAreaQuery. + * @see #setConnectionFactory + * @see #setInteractionSpec + */ + public MappingCommAreaOperation() { + } + + /** + * Create a new MappingCommAreaQuery. + * @param connectionFactory the ConnectionFactory to use to obtain connections + * @param interactionSpec specification to configure the interaction + */ + public MappingCommAreaOperation(ConnectionFactory connectionFactory, InteractionSpec interactionSpec) { + super(connectionFactory, interactionSpec); + } + + + @Override + protected final Record createInputRecord(RecordFactory recordFactory, Object inObject) { + try { + return new CommAreaRecord(objectToBytes(inObject)); + } + catch (IOException ex) { + throw new DataRetrievalFailureException("I/O exception during bytes conversion", ex); + } + } + + @Override + protected final Object extractOutputData(Record record) throws DataAccessException { + CommAreaRecord commAreaRecord = (CommAreaRecord) record; + try { + return bytesToObject(commAreaRecord.toByteArray()); + } + catch (IOException ex) { + throw new DataRetrievalFailureException("I/O exception during bytes conversion", ex); + } + } + + + /** + * Method used to convert an object into COMMAREA bytes. + * @param inObject the input data + * @return the COMMAREA's bytes + * @throws IOException if thrown by I/O methods + * @throws DataAccessException if conversion failed + */ + protected abstract byte[] objectToBytes(Object inObject) throws IOException, DataAccessException; + + /** + * Method used to convert the COMMAREA's bytes to an object. + * @param bytes the COMMAREA's bytes + * @return the output data + * @throws IOException if thrown by I/O methods + * @throws DataAccessException if conversion failed + */ + protected abstract Object bytesToObject(byte[] bytes) throws IOException, DataAccessException; + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/object/MappingRecordOperation.java b/spring-tx/src/main/java/org/springframework/jca/cci/object/MappingRecordOperation.java new file mode 100644 index 0000000000000000000000000000000000000000..b80630f156d8b503cf21def27d24b9ac7cef0dc3 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/object/MappingRecordOperation.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.object; + +import java.sql.SQLException; + +import javax.resource.ResourceException; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.Record; +import javax.resource.cci.RecordFactory; + +import org.springframework.dao.DataAccessException; +import org.springframework.jca.cci.core.RecordCreator; +import org.springframework.jca.cci.core.RecordExtractor; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * EIS operation object that expects mapped input and output objects, + * converting to and from CCI Records. + * + *

Concrete subclasses must implement the abstract + * {@code createInputRecord(RecordFactory, Object)} and + * {@code extractOutputData(Record)} methods, to create an input + * Record from an object and to convert an output Record into an object, + * respectively. + * + * @author Thierry Templier + * @author Juergen Hoeller + * @since 1.2 + * @see #createInputRecord(javax.resource.cci.RecordFactory, Object) + * @see #extractOutputData(javax.resource.cci.Record) + */ +public abstract class MappingRecordOperation extends EisOperation { + + /** + * Constructor that allows use as a JavaBean. + */ + public MappingRecordOperation() { + } + + /** + * Convenient constructor with ConnectionFactory and specifications + * (connection and interaction). + * @param connectionFactory the ConnectionFactory to use to obtain connections + */ + public MappingRecordOperation(ConnectionFactory connectionFactory, InteractionSpec interactionSpec) { + getCciTemplate().setConnectionFactory(connectionFactory); + setInteractionSpec(interactionSpec); + } + + /** + * Set a RecordCreator that should be used for creating default output Records. + *

Default is none: CCI's {@code Interaction.execute} variant + * that returns an output Record will be called. + *

Specify a RecordCreator here if you always need to call CCI's + * {@code Interaction.execute} variant with a passed-in output Record. + * This RecordCreator will then be invoked to create a default output Record instance. + * @see javax.resource.cci.Interaction#execute(javax.resource.cci.InteractionSpec, Record) + * @see javax.resource.cci.Interaction#execute(javax.resource.cci.InteractionSpec, Record, Record) + * @see org.springframework.jca.cci.core.CciTemplate#setOutputRecordCreator + */ + public void setOutputRecordCreator(RecordCreator creator) { + getCciTemplate().setOutputRecordCreator(creator); + } + + /** + * Execute the interaction encapsulated by this operation object. + * @param inputObject the input data, to be converted to a Record + * by the {@code createInputRecord} method + * @return the output data extracted with the {@code extractOutputData} method + * @throws DataAccessException if there is any problem + * @see #createInputRecord + * @see #extractOutputData + */ + @Nullable + public Object execute(Object inputObject) throws DataAccessException { + InteractionSpec interactionSpec = getInteractionSpec(); + Assert.state(interactionSpec != null, "No InteractionSpec set"); + return getCciTemplate().execute( + interactionSpec, new RecordCreatorImpl(inputObject), new RecordExtractorImpl()); + } + + + /** + * Subclasses must implement this method to generate an input Record + * from an input object passed into the {@code execute} method. + * @param inputObject the passed-in input object + * @return the CCI input Record + * @throws ResourceException if thrown by a CCI method, to be auto-converted + * to a DataAccessException + * @see #execute(Object) + */ + protected abstract Record createInputRecord(RecordFactory recordFactory, Object inputObject) + throws ResourceException, DataAccessException; + + /** + * Subclasses must implement this method to convert the Record returned + * by CCI execution into a result object for the {@code execute} method. + * @param outputRecord the Record returned by CCI execution + * @return the result object + * @throws ResourceException if thrown by a CCI method, to be auto-converted + * to a DataAccessException + * @see #execute(Object) + */ + protected abstract Object extractOutputData(Record outputRecord) + throws ResourceException, SQLException, DataAccessException; + + + /** + * Implementation of RecordCreator that calls the enclosing + * class's {@code createInputRecord} method. + */ + protected class RecordCreatorImpl implements RecordCreator { + + private final Object inputObject; + + public RecordCreatorImpl(Object inObject) { + this.inputObject = inObject; + } + + @Override + public Record createRecord(RecordFactory recordFactory) throws ResourceException, DataAccessException { + return createInputRecord(recordFactory, this.inputObject); + } + } + + + /** + * Implementation of RecordExtractor that calls the enclosing + * class's {@code extractOutputData} method. + */ + protected class RecordExtractorImpl implements RecordExtractor { + + @Override + public Object extractData(Record record) throws ResourceException, SQLException, DataAccessException { + return extractOutputData(record); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/object/SimpleRecordOperation.java b/spring-tx/src/main/java/org/springframework/jca/cci/object/SimpleRecordOperation.java new file mode 100644 index 0000000000000000000000000000000000000000..519891dca75372a2179e3a048ba4ccb09efd1dc5 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/object/SimpleRecordOperation.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci.object; + +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.Record; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * EIS operation object that accepts a passed-in CCI input Record + * and returns a corresponding CCI output Record. + * + * @author Juergen Hoeller + * @since 1.2 + */ +public class SimpleRecordOperation extends EisOperation { + + /** + * Constructor that allows use as a JavaBean. + */ + public SimpleRecordOperation() { + } + + /** + * Convenient constructor with ConnectionFactory and specifications + * (connection and interaction). + * @param connectionFactory the ConnectionFactory to use to obtain connections + */ + public SimpleRecordOperation(ConnectionFactory connectionFactory, InteractionSpec interactionSpec) { + getCciTemplate().setConnectionFactory(connectionFactory); + setInteractionSpec(interactionSpec); + } + + + /** + * Execute the CCI interaction encapsulated by this operation object. + *

This method will call CCI's {@code Interaction.execute} variant + * that returns an output Record. + * @param inputRecord the input record + * @return the output record + * @throws DataAccessException if there is any problem + * @see javax.resource.cci.Interaction#execute(javax.resource.cci.InteractionSpec, Record) + */ + @Nullable + public Record execute(Record inputRecord) throws DataAccessException { + InteractionSpec interactionSpec = getInteractionSpec(); + Assert.state(interactionSpec != null, "No InteractionSpec set"); + return getCciTemplate().execute(interactionSpec, inputRecord); + } + + /** + * Execute the CCI interaction encapsulated by this operation object. + *

This method will call CCI's {@code Interaction.execute} variant + * with a passed-in output Record. + * @param inputRecord the input record + * @param outputRecord the output record + * @throws DataAccessException if there is any problem + * @see javax.resource.cci.Interaction#execute(javax.resource.cci.InteractionSpec, Record, Record) + */ + public void execute(Record inputRecord, Record outputRecord) throws DataAccessException { + InteractionSpec interactionSpec = getInteractionSpec(); + Assert.state(interactionSpec != null, "No InteractionSpec set"); + getCciTemplate().execute(interactionSpec, inputRecord, outputRecord); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/object/package-info.java b/spring-tx/src/main/java/org/springframework/jca/cci/object/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..2c51b175d1054e0f075bac45e6db318af9b5509e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/object/package-info.java @@ -0,0 +1,13 @@ +/** + * The classes in this package represent EIS operations as threadsafe, + * reusable objects. This higher level of CCI abstraction depends on the + * lower-level abstraction in the {@code org.springframework.jca.cci.core} package. + * Exceptions thrown are as in the {@code org.springframework.dao} package, + * meaning that code using this package does not need to worry about error handling. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.cci.object; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/cci/package-info.java b/spring-tx/src/main/java/org/springframework/jca/cci/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..469cdc926b3b3756ec22dbc98af04702286f168c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/cci/package-info.java @@ -0,0 +1,12 @@ +/** + * This package contains Spring's support for the Common Client Interface (CCI), + * as defined by the J2EE Connector Architecture. It is conceptually similar + * to the {@code org.springframework.jdbc} package, providing the same + * levels of data access abstraction. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.cci; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/context/BootstrapContextAware.java b/spring-tx/src/main/java/org/springframework/jca/context/BootstrapContextAware.java new file mode 100644 index 0000000000000000000000000000000000000000..66c6b65c1d3862543e68e5687fd6200775ebb64b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/context/BootstrapContextAware.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.context; + +import javax.resource.spi.BootstrapContext; + +import org.springframework.beans.factory.Aware; + +/** + * Interface to be implemented by any object that wishes to be + * notified of the BootstrapContext (typically determined by the + * {@link ResourceAdapterApplicationContext}) that it runs in. + * + * @author Juergen Hoeller + * @author Chris Beams + * @since 2.5 + * @see javax.resource.spi.BootstrapContext + */ +public interface BootstrapContextAware extends Aware { + + /** + * Set the BootstrapContext that this object runs in. + *

Invoked after population of normal bean properties but before an init + * callback like InitializingBean's {@code afterPropertiesSet} or a + * custom init-method. Invoked after ApplicationContextAware's + * {@code setApplicationContext}. + * @param bootstrapContext the BootstrapContext object to be used by this object + * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet + * @see org.springframework.context.ApplicationContextAware#setApplicationContext + */ + void setBootstrapContext(BootstrapContext bootstrapContext); + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/context/BootstrapContextAwareProcessor.java b/spring-tx/src/main/java/org/springframework/jca/context/BootstrapContextAwareProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..1efb9d9a6915c3bfe04cab80bf65a717cdb6c45b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/context/BootstrapContextAwareProcessor.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.context; + +import javax.resource.spi.BootstrapContext; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.lang.Nullable; + +/** + * {@link org.springframework.beans.factory.config.BeanPostProcessor} + * implementation that passes the BootstrapContext to beans that implement + * the {@link BootstrapContextAware} interface. + * + *

{@link ResourceAdapterApplicationContext} automatically registers + * this processor with its underlying bean factory. + * + * @author Juergen Hoeller + * @since 2.5 + * @see BootstrapContextAware + */ +class BootstrapContextAwareProcessor implements BeanPostProcessor { + + @Nullable + private final BootstrapContext bootstrapContext; + + + /** + * Create a new BootstrapContextAwareProcessor for the given context. + */ + public BootstrapContextAwareProcessor(@Nullable BootstrapContext bootstrapContext) { + this.bootstrapContext = bootstrapContext; + } + + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + if (this.bootstrapContext != null && bean instanceof BootstrapContextAware) { + ((BootstrapContextAware) bean).setBootstrapContext(this.bootstrapContext); + } + return bean; + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + return bean; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/context/ResourceAdapterApplicationContext.java b/spring-tx/src/main/java/org/springframework/jca/context/ResourceAdapterApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..4154e223046962273833638d272dd06217e6a5b9 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/context/ResourceAdapterApplicationContext.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.context; + +import javax.resource.spi.BootstrapContext; +import javax.resource.spi.work.WorkManager; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.util.Assert; + +/** + * {@link org.springframework.context.ApplicationContext} implementation + * for a JCA ResourceAdapter. Needs to be initialized with the JCA + * {@link javax.resource.spi.BootstrapContext}, passing it on to + * Spring-managed beans that implement {@link BootstrapContextAware}. + * + * @author Juergen Hoeller + * @since 2.5 + * @see SpringContextResourceAdapter + * @see BootstrapContextAware + */ +public class ResourceAdapterApplicationContext extends GenericApplicationContext { + + private final BootstrapContext bootstrapContext; + + + /** + * Create a new ResourceAdapterApplicationContext for the given BootstrapContext. + * @param bootstrapContext the JCA BootstrapContext that the ResourceAdapter + * has been started with + */ + public ResourceAdapterApplicationContext(BootstrapContext bootstrapContext) { + Assert.notNull(bootstrapContext, "BootstrapContext must not be null"); + this.bootstrapContext = bootstrapContext; + } + + + @Override + protected void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + beanFactory.addBeanPostProcessor(new BootstrapContextAwareProcessor(this.bootstrapContext)); + beanFactory.ignoreDependencyInterface(BootstrapContextAware.class); + beanFactory.registerResolvableDependency(BootstrapContext.class, this.bootstrapContext); + + // JCA WorkManager resolved lazily - may not be available. + beanFactory.registerResolvableDependency(WorkManager.class, + (ObjectFactory) this.bootstrapContext::getWorkManager); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/context/SpringContextResourceAdapter.java b/spring-tx/src/main/java/org/springframework/jca/context/SpringContextResourceAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..20e37cd9b452ef3c3838f057998b979e7ddfa6d1 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/context/SpringContextResourceAdapter.java @@ -0,0 +1,260 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.context; + +import javax.resource.NotSupportedException; +import javax.resource.ResourceException; +import javax.resource.spi.ActivationSpec; +import javax.resource.spi.BootstrapContext; +import javax.resource.spi.ResourceAdapter; +import javax.resource.spi.ResourceAdapterInternalException; +import javax.resource.spi.endpoint.MessageEndpointFactory; +import javax.transaction.xa.XAResource; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * JCA 1.7 {@link javax.resource.spi.ResourceAdapter} implementation + * that loads a Spring {@link org.springframework.context.ApplicationContext}, + * starting and stopping Spring-managed beans as part of the ResourceAdapter's + * lifecycle. + * + *

Ideal for application contexts that do not need any HTTP entry points + * but rather just consist of message endpoints and scheduled jobs etc. + * Beans in such a context may use application server resources such as the + * JTA transaction manager and JNDI-bound JDBC DataSources and JMS + * ConnectionFactory instances, and may also register with the platform's + * JMX server - all through Spring's standard transaction management and + * JNDI and JMX support facilities. + * + *

If the need for scheduling asynchronous work arises, consider using + * Spring's {@link org.springframework.jca.work.WorkManagerTaskExecutor} + * as a standard bean definition, to be injected into application beans + * through dependency injection. This WorkManagerTaskExecutor will automatically + * use the JCA WorkManager from the BootstrapContext that has been provided + * to this ResourceAdapter. + * + *

The JCA {@link javax.resource.spi.BootstrapContext} may also be + * accessed directly, through application components that implement the + * {@link BootstrapContextAware} interface. When deployed using this + * ResourceAdapter, the BootstrapContext is guaranteed to be passed on + * to such components. + * + *

This ResourceAdapter is to be defined in a "META-INF/ra.xml" file + * within a Java EE ".rar" deployment unit like as follows: + * + *

+ * <?xml version="1.0" encoding="UTF-8"?>
+ * <connector xmlns="http://java.sun.com/xml/ns/j2ee"
+ *		 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ *		 xsi:schemaLocation="http://java.sun.com/xml/ns/j2ee https://java.sun.com/xml/ns/j2ee/connector_1_5.xsd"
+ *		 version="1.5">
+ *	 <vendor-name>Spring Framework</vendor-name>
+ *	 <eis-type>Spring Connector</eis-type>
+ *	 <resourceadapter-version>1.0</resourceadapter-version>
+ *	 <resourceadapter>
+ *		 <resourceadapter-class>org.springframework.jca.context.SpringContextResourceAdapter</resourceadapter-class>
+ *		 <config-property>
+ *			 <config-property-name>ContextConfigLocation</config-property-name>
+ *			 <config-property-type>java.lang.String</config-property-type>
+ *			 <config-property-value>META-INF/applicationContext.xml</config-property-value>
+ *		 </config-property>
+ *	 </resourceadapter>
+ * </connector>
+ * + * Note that "META-INF/applicationContext.xml" is the default context config + * location, so it doesn't have to specified unless you intend to specify + * different/additional config files. So in the default case, you may remove + * the entire {@code config-property} section above. + * + *

For simple deployment needs, all you need to do is the following: + * Package all application classes into a RAR file (which is just a standard + * JAR file with a different file extension), add all required library jars + * into the root of the RAR archive, add a "META-INF/ra.xml" deployment + * descriptor as shown above as well as the corresponding Spring XML bean + * definition file(s) (typically "META-INF/applicationContext.xml"), + * and drop the resulting RAR file into your application server's + * deployment directory! + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setContextConfigLocation + * @see #loadBeanDefinitions + * @see ResourceAdapterApplicationContext + */ +public class SpringContextResourceAdapter implements ResourceAdapter { + + /** + * Any number of these characters are considered delimiters between + * multiple context config paths in a single String value. + * @see #setContextConfigLocation + */ + public static final String CONFIG_LOCATION_DELIMITERS = ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS; + + /** + * The default {@code applicationContext.xml} location. + */ + public static final String DEFAULT_CONTEXT_CONFIG_LOCATION = "META-INF/applicationContext.xml"; + + + protected final Log logger = LogFactory.getLog(getClass()); + + private String contextConfigLocation = DEFAULT_CONTEXT_CONFIG_LOCATION; + + @Nullable + private ConfigurableApplicationContext applicationContext; + + + /** + * Set the location of the context configuration files, within the + * resource adapter's deployment unit. This can be a delimited + * String that consists of multiple resource location, separated + * by commas, semicolons, whitespace, or line breaks. + *

This can be specified as "ContextConfigLocation" config + * property in the {@code ra.xml} deployment descriptor. + *

The default is "classpath:META-INF/applicationContext.xml". + */ + public void setContextConfigLocation(String contextConfigLocation) { + this.contextConfigLocation = contextConfigLocation; + } + + /** + * Return the specified context configuration files. + */ + protected String getContextConfigLocation() { + return this.contextConfigLocation; + } + + /** + * Return a new {@link StandardEnvironment}. + *

Subclasses may override this method in order to supply + * a custom {@link ConfigurableEnvironment} implementation. + */ + protected ConfigurableEnvironment createEnvironment() { + return new StandardEnvironment(); + } + + /** + * This implementation loads a Spring ApplicationContext through the + * {@link #createApplicationContext} template method. + */ + @Override + public void start(BootstrapContext bootstrapContext) throws ResourceAdapterInternalException { + if (logger.isDebugEnabled()) { + logger.debug("Starting SpringContextResourceAdapter with BootstrapContext: " + bootstrapContext); + } + this.applicationContext = createApplicationContext(bootstrapContext); + } + + /** + * Build a Spring ApplicationContext for the given JCA BootstrapContext. + *

The default implementation builds a {@link ResourceAdapterApplicationContext} + * and delegates to {@link #loadBeanDefinitions} for actually parsing the + * specified configuration files. + * @param bootstrapContext this ResourceAdapter's BootstrapContext + * @return the Spring ApplicationContext instance + */ + protected ConfigurableApplicationContext createApplicationContext(BootstrapContext bootstrapContext) { + ResourceAdapterApplicationContext applicationContext = + new ResourceAdapterApplicationContext(bootstrapContext); + + // Set ResourceAdapter's ClassLoader as bean class loader. + applicationContext.setClassLoader(getClass().getClassLoader()); + + // Extract individual config locations. + String[] configLocations = + StringUtils.tokenizeToStringArray(getContextConfigLocation(), CONFIG_LOCATION_DELIMITERS); + + loadBeanDefinitions(applicationContext, configLocations); + applicationContext.refresh(); + + return applicationContext; + } + + /** + * Load the bean definitions into the given registry, + * based on the specified configuration files. + * @param registry the registry to load into + * @param configLocations the parsed config locations + * @see #setContextConfigLocation + */ + protected void loadBeanDefinitions(BeanDefinitionRegistry registry, String[] configLocations) { + new XmlBeanDefinitionReader(registry).loadBeanDefinitions(configLocations); + } + + /** + * This implementation closes the Spring ApplicationContext. + */ + @Override + public void stop() { + logger.debug("Stopping SpringContextResourceAdapter"); + if (this.applicationContext != null) { + this.applicationContext.close(); + } + } + + + /** + * This implementation always throws a NotSupportedException. + */ + @Override + public void endpointActivation(MessageEndpointFactory messageEndpointFactory, ActivationSpec activationSpec) + throws ResourceException { + + throw new NotSupportedException("SpringContextResourceAdapter does not support message endpoints"); + } + + /** + * This implementation does nothing. + */ + @Override + public void endpointDeactivation(MessageEndpointFactory messageEndpointFactory, ActivationSpec activationSpec) { + } + + /** + * This implementation always returns {@code null}. + */ + @Override + @Nullable + public XAResource[] getXAResources(ActivationSpec[] activationSpecs) throws ResourceException { + return null; + } + + + @Override + public boolean equals(Object other) { + return (this == other || (other instanceof SpringContextResourceAdapter && + ObjectUtils.nullSafeEquals(getContextConfigLocation(), + ((SpringContextResourceAdapter) other).getContextConfigLocation()))); + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHashCode(getContextConfigLocation()); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/context/package-info.java b/spring-tx/src/main/java/org/springframework/jca/context/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..172bca88efe2e9b76fc612c04802cce8d5c830c6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/context/package-info.java @@ -0,0 +1,10 @@ +/** + * Integration package that allows for deploying a Spring application context + * as a JCA 1.7 compliant RAR file. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.context; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/endpoint/AbstractMessageEndpointFactory.java b/spring-tx/src/main/java/org/springframework/jca/endpoint/AbstractMessageEndpointFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..46ee354a773d56229427c3cd2f3064191e2097ad --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/endpoint/AbstractMessageEndpointFactory.java @@ -0,0 +1,370 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.endpoint; + +import java.lang.reflect.Method; + +import javax.resource.ResourceException; +import javax.resource.spi.ApplicationServerInternalException; +import javax.resource.spi.UnavailableException; +import javax.resource.spi.endpoint.MessageEndpoint; +import javax.resource.spi.endpoint.MessageEndpointFactory; +import javax.transaction.Transaction; +import javax.transaction.TransactionManager; +import javax.transaction.xa.XAResource; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.lang.Nullable; +import org.springframework.transaction.jta.SimpleTransactionFactory; +import org.springframework.transaction.jta.TransactionFactory; +import org.springframework.util.Assert; + +/** + * Abstract base implementation of the JCA 1.7 + * {@link javax.resource.spi.endpoint.MessageEndpointFactory} interface, + * providing transaction management capabilities as well as ClassLoader + * exposure for endpoint invocations. + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setTransactionManager + */ +public abstract class AbstractMessageEndpointFactory implements MessageEndpointFactory, BeanNameAware { + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private TransactionFactory transactionFactory; + + @Nullable + private String transactionName; + + private int transactionTimeout = -1; + + @Nullable + private String beanName; + + + /** + * Set the XA transaction manager to use for wrapping endpoint + * invocations, enlisting the endpoint resource in each such transaction. + *

The passed-in object may be a transaction manager which implements + * Spring's {@link org.springframework.transaction.jta.TransactionFactory} + * interface, or a plain {@link javax.transaction.TransactionManager}. + *

If no transaction manager is specified, the endpoint invocation + * will simply not be wrapped in an XA transaction. Check out your + * resource provider's ActivationSpec documentation for local + * transaction options of your particular provider. + * @see #setTransactionName + * @see #setTransactionTimeout + */ + public void setTransactionManager(Object transactionManager) { + if (transactionManager instanceof TransactionFactory) { + this.transactionFactory = (TransactionFactory) transactionManager; + } + else if (transactionManager instanceof TransactionManager) { + this.transactionFactory = new SimpleTransactionFactory((TransactionManager) transactionManager); + } + else { + throw new IllegalArgumentException("Transaction manager [" + transactionManager + + "] is neither a [org.springframework.transaction.jta.TransactionFactory} nor a " + + "[javax.transaction.TransactionManager]"); + } + } + + /** + * Set the Spring TransactionFactory to use for wrapping endpoint + * invocations, enlisting the endpoint resource in each such transaction. + *

Alternatively, specify an appropriate transaction manager through + * the {@link #setTransactionManager "transactionManager"} property. + *

If no transaction factory is specified, the endpoint invocation + * will simply not be wrapped in an XA transaction. Check out your + * resource provider's ActivationSpec documentation for local + * transaction options of your particular provider. + * @see #setTransactionName + * @see #setTransactionTimeout + */ + public void setTransactionFactory(TransactionFactory transactionFactory) { + this.transactionFactory = transactionFactory; + } + + /** + * Specify the name of the transaction, if any. + *

Default is none. A specified name will be passed on to the transaction + * manager, allowing to identify the transaction in a transaction monitor. + */ + public void setTransactionName(String transactionName) { + this.transactionName = transactionName; + } + + /** + * Specify the transaction timeout, if any. + *

Default is -1: rely on the transaction manager's default timeout. + * Specify a concrete timeout to restrict the maximum duration of each + * endpoint invocation. + */ + public void setTransactionTimeout(int transactionTimeout) { + this.transactionTimeout = transactionTimeout; + } + + /** + * Set the name of this message endpoint. Populated with the bean name + * automatically when defined within Spring's bean factory. + */ + @Override + public void setBeanName(String beanName) { + this.beanName = beanName; + } + + + /** + * Implementation of the JCA 1.7 {@code #getActivationName()} method, + * returning the bean name as set on this MessageEndpointFactory. + * @see #setBeanName + */ + @Override + @Nullable + public String getActivationName() { + return this.beanName; + } + + /** + * Implementation of the JCA 1.7 {@code #getEndpointClass()} method, + * returning {@code} null in order to indicate a synthetic endpoint type. + */ + @Override + @Nullable + public Class getEndpointClass() { + return null; + } + + /** + * This implementation returns {@code true} if a transaction manager + * has been specified; {@code false} otherwise. + * @see #setTransactionManager + * @see #setTransactionFactory + */ + @Override + public boolean isDeliveryTransacted(Method method) throws NoSuchMethodException { + return (this.transactionFactory != null); + } + + /** + * The standard JCA 1.5 version of {@code createEndpoint}. + *

This implementation delegates to {@link #createEndpointInternal()}, + * initializing the endpoint's XAResource before the endpoint gets invoked. + */ + @Override + public MessageEndpoint createEndpoint(XAResource xaResource) throws UnavailableException { + AbstractMessageEndpoint endpoint = createEndpointInternal(); + endpoint.initXAResource(xaResource); + return endpoint; + } + + /** + * The alternative JCA 1.6 version of {@code createEndpoint}. + *

This implementation delegates to {@link #createEndpointInternal()}, + * ignoring the specified timeout. It is only here for JCA 1.6 compliance. + */ + @Override + public MessageEndpoint createEndpoint(XAResource xaResource, long timeout) throws UnavailableException { + AbstractMessageEndpoint endpoint = createEndpointInternal(); + endpoint.initXAResource(xaResource); + return endpoint; + } + + /** + * Create the actual endpoint instance, as a subclass of the + * {@link AbstractMessageEndpoint} inner class of this factory. + * @return the actual endpoint instance (never {@code null}) + * @throws UnavailableException if no endpoint is available at present + */ + protected abstract AbstractMessageEndpoint createEndpointInternal() throws UnavailableException; + + + /** + * Inner class for actual endpoint implementations, based on template + * method to allow for any kind of concrete endpoint implementation. + */ + protected abstract class AbstractMessageEndpoint implements MessageEndpoint { + + @Nullable + private TransactionDelegate transactionDelegate; + + private boolean beforeDeliveryCalled = false; + + @Nullable + private ClassLoader previousContextClassLoader; + + /** + * Initialize this endpoint's TransactionDelegate. + * @param xaResource the XAResource for this endpoint + */ + void initXAResource(XAResource xaResource) { + this.transactionDelegate = new TransactionDelegate(xaResource); + } + + /** + * This {@code beforeDelivery} implementation starts a transaction, + * if necessary, and exposes the endpoint ClassLoader as current + * thread context ClassLoader. + *

Note that the JCA 1.7 specification does not require a ResourceAdapter + * to call this method before invoking the concrete endpoint. If this method + * has not been called (check {@link #hasBeforeDeliveryBeenCalled()}), the + * concrete endpoint method should call {@code beforeDelivery} and its + * sibling {@link #afterDelivery()} explicitly, as part of its own processing. + */ + @Override + public void beforeDelivery(@Nullable Method method) throws ResourceException { + this.beforeDeliveryCalled = true; + Assert.state(this.transactionDelegate != null, "Not initialized"); + try { + this.transactionDelegate.beginTransaction(); + } + catch (Throwable ex) { + throw new ApplicationServerInternalException("Failed to begin transaction", ex); + } + Thread currentThread = Thread.currentThread(); + this.previousContextClassLoader = currentThread.getContextClassLoader(); + currentThread.setContextClassLoader(getEndpointClassLoader()); + } + + /** + * Template method for exposing the endpoint's ClassLoader + * (typically the ClassLoader that the message listener class + * has been loaded with). + * @return the endpoint ClassLoader (never {@code null}) + */ + protected abstract ClassLoader getEndpointClassLoader(); + + /** + * Return whether the {@link #beforeDelivery} method of this endpoint + * has already been called. + */ + protected final boolean hasBeforeDeliveryBeenCalled() { + return this.beforeDeliveryCalled; + } + + /** + * Callback method for notifying the endpoint base class + * that the concrete endpoint invocation led to an exception. + *

To be invoked by subclasses in case of the concrete + * endpoint throwing an exception. + * @param ex the exception thrown from the concrete endpoint + */ + protected void onEndpointException(Throwable ex) { + Assert.state(this.transactionDelegate != null, "Not initialized"); + this.transactionDelegate.setRollbackOnly(); + logger.debug("Transaction marked as rollback-only after endpoint exception", ex); + } + + /** + * This {@code afterDelivery} implementation resets the thread context + * ClassLoader and completes the transaction, if any. + *

Note that the JCA 1.7 specification does not require a ResourceAdapter + * to call this method after invoking the concrete endpoint. See the + * explanation in {@link #beforeDelivery}'s javadoc. + */ + @Override + public void afterDelivery() throws ResourceException { + Assert.state(this.transactionDelegate != null, "Not initialized"); + this.beforeDeliveryCalled = false; + Thread.currentThread().setContextClassLoader(this.previousContextClassLoader); + this.previousContextClassLoader = null; + try { + this.transactionDelegate.endTransaction(); + } + catch (Throwable ex) { + logger.warn("Failed to complete transaction after endpoint delivery", ex); + throw new ApplicationServerInternalException("Failed to complete transaction", ex); + } + } + + @Override + public void release() { + if (this.transactionDelegate != null) { + try { + this.transactionDelegate.setRollbackOnly(); + this.transactionDelegate.endTransaction(); + } + catch (Throwable ex) { + logger.warn("Could not complete unfinished transaction on endpoint release", ex); + } + } + } + } + + + /** + * Private inner class that performs the actual transaction handling, + * including enlistment of the endpoint's XAResource. + */ + private class TransactionDelegate { + + @Nullable + private final XAResource xaResource; + + @Nullable + private Transaction transaction; + + private boolean rollbackOnly; + + public TransactionDelegate(@Nullable XAResource xaResource) { + if (xaResource == null && transactionFactory != null && + !transactionFactory.supportsResourceAdapterManagedTransactions()) { + throw new IllegalStateException("ResourceAdapter-provided XAResource is required for " + + "transaction management. Check your ResourceAdapter's configuration."); + } + this.xaResource = xaResource; + } + + public void beginTransaction() throws Exception { + if (transactionFactory != null && this.xaResource != null) { + this.transaction = transactionFactory.createTransaction(transactionName, transactionTimeout); + this.transaction.enlistResource(this.xaResource); + } + } + + public void setRollbackOnly() { + if (this.transaction != null) { + this.rollbackOnly = true; + } + } + + public void endTransaction() throws Exception { + if (this.transaction != null) { + try { + if (this.rollbackOnly) { + this.transaction.rollback(); + } + else { + this.transaction.commit(); + } + } + finally { + this.transaction = null; + this.rollbackOnly = false; + } + } + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/endpoint/GenericMessageEndpointFactory.java b/spring-tx/src/main/java/org/springframework/jca/endpoint/GenericMessageEndpointFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..af42a2c837c3fc169dd096f28e13434909116105 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/endpoint/GenericMessageEndpointFactory.java @@ -0,0 +1,175 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.endpoint; + +import javax.resource.ResourceException; +import javax.resource.spi.UnavailableException; +import javax.resource.spi.endpoint.MessageEndpoint; +import javax.transaction.xa.XAResource; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.aop.support.DelegatingIntroductionInterceptor; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * Generic implementation of the JCA 1.7 + * {@link javax.resource.spi.endpoint.MessageEndpointFactory} interface, + * providing transaction management capabilities for any kind of message + * listener object (e.g. {@link javax.jms.MessageListener} objects or + * {@link javax.resource.cci.MessageListener} objects. + * + *

Uses AOP proxies for concrete endpoint instances, simply wrapping + * the specified message listener object and exposing all of its implemented + * interfaces on the endpoint instance. + * + *

Typically used with Spring's {@link GenericMessageEndpointManager}, + * but not tied to it. As a consequence, this endpoint factory could + * also be used with programmatic endpoint management on a native + * {@link javax.resource.spi.ResourceAdapter} instance. + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setMessageListener + * @see #setTransactionManager + * @see GenericMessageEndpointManager + */ +public class GenericMessageEndpointFactory extends AbstractMessageEndpointFactory { + + @Nullable + private Object messageListener; + + + /** + * Specify the message listener object that the endpoint should expose + * (e.g. a {@link javax.jms.MessageListener} objects or + * {@link javax.resource.cci.MessageListener} implementation). + */ + public void setMessageListener(Object messageListener) { + this.messageListener = messageListener; + } + + /** + * Return the message listener object for this endpoint. + * @since 5.0 + */ + protected Object getMessageListener() { + Assert.state(this.messageListener != null, "No message listener set"); + return this.messageListener; + } + + /** + * Wrap each concrete endpoint instance with an AOP proxy, + * exposing the message listener's interfaces as well as the + * endpoint SPI through an AOP introduction. + */ + @Override + public MessageEndpoint createEndpoint(XAResource xaResource) throws UnavailableException { + GenericMessageEndpoint endpoint = (GenericMessageEndpoint) super.createEndpoint(xaResource); + ProxyFactory proxyFactory = new ProxyFactory(getMessageListener()); + DelegatingIntroductionInterceptor introduction = new DelegatingIntroductionInterceptor(endpoint); + introduction.suppressInterface(MethodInterceptor.class); + proxyFactory.addAdvice(introduction); + return (MessageEndpoint) proxyFactory.getProxy(); + } + + /** + * Creates a concrete generic message endpoint, internal to this factory. + */ + @Override + protected AbstractMessageEndpoint createEndpointInternal() throws UnavailableException { + return new GenericMessageEndpoint(); + } + + + /** + * Private inner class that implements the concrete generic message endpoint, + * as an AOP Alliance MethodInterceptor that will be invoked by a proxy. + */ + private class GenericMessageEndpoint extends AbstractMessageEndpoint implements MethodInterceptor { + + @Override + public Object invoke(MethodInvocation methodInvocation) throws Throwable { + Throwable endpointEx = null; + boolean applyDeliveryCalls = !hasBeforeDeliveryBeenCalled(); + if (applyDeliveryCalls) { + try { + beforeDelivery(null); + } + catch (ResourceException ex) { + throw adaptExceptionIfNecessary(methodInvocation, ex); + } + } + try { + return methodInvocation.proceed(); + } + catch (Throwable ex) { + endpointEx = ex; + onEndpointException(ex); + throw ex; + } + finally { + if (applyDeliveryCalls) { + try { + afterDelivery(); + } + catch (ResourceException ex) { + if (endpointEx == null) { + throw adaptExceptionIfNecessary(methodInvocation, ex); + } + } + } + } + } + + private Exception adaptExceptionIfNecessary(MethodInvocation methodInvocation, ResourceException ex) { + if (ReflectionUtils.declaresException(methodInvocation.getMethod(), ex.getClass())) { + return ex; + } + else { + return new InternalResourceException(ex); + } + } + + @Override + protected ClassLoader getEndpointClassLoader() { + return getMessageListener().getClass().getClassLoader(); + } + } + + + /** + * Internal exception thrown when a ResourceException has been encountered + * during the endpoint invocation. + *

Will only be used if the ResourceAdapter does not invoke the + * endpoint's {@code beforeDelivery} and {@code afterDelivery} + * directly, leaving it up to the concrete endpoint to apply those - + * and to handle any ResourceExceptions thrown from them. + */ + @SuppressWarnings("serial") + public static class InternalResourceException extends RuntimeException { + + public InternalResourceException(ResourceException cause) { + super(cause); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/endpoint/GenericMessageEndpointManager.java b/spring-tx/src/main/java/org/springframework/jca/endpoint/GenericMessageEndpointManager.java new file mode 100644 index 0000000000000000000000000000000000000000..268846383ada3640cbb8edfdd1bc5335e7b9642e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/endpoint/GenericMessageEndpointManager.java @@ -0,0 +1,344 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.endpoint; + +import javax.resource.ResourceException; +import javax.resource.spi.ActivationSpec; +import javax.resource.spi.ResourceAdapter; +import javax.resource.spi.endpoint.MessageEndpointFactory; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.SmartLifecycle; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Generic bean that manages JCA 1.7 message endpoints within a Spring + * application context, activating and deactivating the endpoint as part + * of the application context's lifecycle. + * + *

This class is completely generic in that it may work with any + * ResourceAdapter, any MessageEndpointFactory, and any ActivationSpec. + * It can be configured in standard bean style, for example through + * Spring's XML bean definition format, as follows: + * + *

+ * <bean class="org.springframework.jca.endpoint.GenericMessageEndpointManager">
+ * 	 <property name="resourceAdapter" ref="resourceAdapter"/>
+ * 	 <property name="messageEndpointFactory">
+ *     <bean class="org.springframework.jca.endpoint.GenericMessageEndpointFactory">
+ *       <property name="messageListener" ref="messageListener"/>
+ *     </bean>
+ * 	 </property>
+ * 	 <property name="activationSpec">
+ *     <bean class="org.apache.activemq.ra.ActiveMQActivationSpec">
+ *       <property name="destination" value="myQueue"/>
+ *       <property name="destinationType" value="javax.jms.Queue"/>
+ *     </bean>
+ *   </property>
+ * </bean>
+ * + * In this example, Spring's own {@link GenericMessageEndpointFactory} is used + * to point to a standard message listener object that happens to be supported + * by the specified target ResourceAdapter: in this case, a JMS + * {@link javax.jms.MessageListener} object as supported by the ActiveMQ + * message broker, defined as a Spring bean: + * + *
+ * <bean id="messageListener" class="com.myorg.messaging.myMessageListener">
+ *   ...
+ * </bean>
+ * + * The target ResourceAdapter may be configured as a local Spring bean as well + * (the typical case) or obtained from JNDI (e.g. on WebLogic). For the + * example above, a local ResourceAdapter bean could be defined as follows + * (matching the "resourceAdapter" bean reference above): + * + *
+ * <bean id="resourceAdapter" class="org.springframework.jca.support.ResourceAdapterFactoryBean">
+ *   <property name="resourceAdapter">
+ *     <bean class="org.apache.activemq.ra.ActiveMQResourceAdapter">
+ *       <property name="serverUrl" value="tcp://localhost:61616"/>
+ *     </bean>
+ *   </property>
+ *   <property name="workManager">
+ *     <bean class="org.springframework.jca.work.SimpleTaskWorkManager"/>
+ *   </property>
+ * </bean>
+ * + * For a different target resource, the configuration would simply point to a + * different ResourceAdapter and a different ActivationSpec object (which are + * both specific to the resource provider), and possibly a different message + * listener (e.g. a CCI {@link javax.resource.cci.MessageListener} for a + * resource adapter which is based on the JCA Common Client Interface). + * + *

The asynchronous execution strategy can be customized through the + * "workManager" property on the ResourceAdapterFactoryBean (as shown above). + * Check out {@link org.springframework.jca.work.SimpleTaskWorkManager}'s + * javadoc for its configuration options; alternatively, any other + * JCA-compliant WorkManager can be used (e.g. Geronimo's). + * + *

Transactional execution is a responsibility of the concrete message endpoint, + * as built by the specified MessageEndpointFactory. {@link GenericMessageEndpointFactory} + * supports XA transaction participation through its "transactionManager" property, + * typically with a Spring {@link org.springframework.transaction.jta.JtaTransactionManager} + * or a plain {@link javax.transaction.TransactionManager} implementation specified there. + * + *

+ * <bean class="org.springframework.jca.endpoint.GenericMessageEndpointManager">
+ * 	 <property name="resourceAdapter" ref="resourceAdapter"/>
+ * 	 <property name="messageEndpointFactory">
+ *     <bean class="org.springframework.jca.endpoint.GenericMessageEndpointFactory">
+ *       <property name="messageListener" ref="messageListener"/>
+ *       <property name="transactionManager" ref="transactionManager"/>
+ *     </bean>
+ * 	 </property>
+ * 	 <property name="activationSpec">
+ *     <bean class="org.apache.activemq.ra.ActiveMQActivationSpec">
+ *       <property name="destination" value="myQueue"/>
+ *       <property name="destinationType" value="javax.jms.Queue"/>
+ *     </bean>
+ *   </property>
+ * </bean>
+ *
+ * <bean id="transactionManager" class="org.springframework.transaction.jta.JtaTransactionManager"/>
+ * + * Alternatively, check out your resource provider's ActivationSpec object, + * which should support local transactions through a provider-specific config flag, + * e.g. ActiveMQActivationSpec's "useRAManagedTransaction" bean property. + * + *
+ * <bean class="org.springframework.jca.endpoint.GenericMessageEndpointManager">
+ * 	 <property name="resourceAdapter" ref="resourceAdapter"/>
+ * 	 <property name="messageEndpointFactory">
+ *     <bean class="org.springframework.jca.endpoint.GenericMessageEndpointFactory">
+ *       <property name="messageListener" ref="messageListener"/>
+ *     </bean>
+ * 	 </property>
+ * 	 <property name="activationSpec">
+ *     <bean class="org.apache.activemq.ra.ActiveMQActivationSpec">
+ *       <property name="destination" value="myQueue"/>
+ *       <property name="destinationType" value="javax.jms.Queue"/>
+ *       <property name="useRAManagedTransaction" value="true"/>
+ *     </bean>
+ *   </property>
+ * </bean>
+ * + * @author Juergen Hoeller + * @since 2.5 + * @see javax.resource.spi.ResourceAdapter#endpointActivation + * @see javax.resource.spi.ResourceAdapter#endpointDeactivation + * @see javax.resource.spi.endpoint.MessageEndpointFactory + * @see javax.resource.spi.ActivationSpec + */ +public class GenericMessageEndpointManager implements SmartLifecycle, InitializingBean, DisposableBean { + + @Nullable + private ResourceAdapter resourceAdapter; + + @Nullable + private MessageEndpointFactory messageEndpointFactory; + + @Nullable + private ActivationSpec activationSpec; + + private boolean autoStartup = true; + + private int phase = DEFAULT_PHASE; + + private volatile boolean running = false; + + private final Object lifecycleMonitor = new Object(); + + + /** + * Set the JCA ResourceAdapter to manage endpoints for. + */ + public void setResourceAdapter(@Nullable ResourceAdapter resourceAdapter) { + this.resourceAdapter = resourceAdapter; + } + + /** + * Return the JCA ResourceAdapter to manage endpoints for. + */ + @Nullable + public ResourceAdapter getResourceAdapter() { + return this.resourceAdapter; + } + + /** + * Set the JCA MessageEndpointFactory to activate, pointing to a + * MessageListener object that the endpoints will delegate to. + *

A MessageEndpointFactory instance may be shared across multiple + * endpoints (i.e. multiple GenericMessageEndpointManager instances), + * with different {@link #setActivationSpec ActivationSpec} objects applied. + * @see GenericMessageEndpointFactory#setMessageListener + */ + public void setMessageEndpointFactory(@Nullable MessageEndpointFactory messageEndpointFactory) { + this.messageEndpointFactory = messageEndpointFactory; + } + + /** + * Return the JCA MessageEndpointFactory to activate. + */ + @Nullable + public MessageEndpointFactory getMessageEndpointFactory() { + return this.messageEndpointFactory; + } + + /** + * Set the JCA ActivationSpec to use for activating the endpoint. + *

Note that this ActivationSpec instance should not be shared + * across multiple ResourceAdapter instances. + */ + public void setActivationSpec(@Nullable ActivationSpec activationSpec) { + this.activationSpec = activationSpec; + } + + /** + * Return the JCA ActivationSpec to use for activating the endpoint. + */ + @Nullable + public ActivationSpec getActivationSpec() { + return this.activationSpec; + } + + /** + * Set whether to auto-start the endpoint activation after this endpoint + * manager has been initialized and the context has been refreshed. + *

Default is "true". Turn this flag off to defer the endpoint + * activation until an explicit {@link #start()} call. + */ + public void setAutoStartup(boolean autoStartup) { + this.autoStartup = autoStartup; + } + + /** + * Return the value for the 'autoStartup' property. If "true", this + * endpoint manager will start upon a ContextRefreshedEvent. + */ + @Override + public boolean isAutoStartup() { + return this.autoStartup; + } + + /** + * Specify the phase in which this endpoint manager should be started + * and stopped. The startup order proceeds from lowest to highest, and + * the shutdown order is the reverse of that. By default this value is + * Integer.MAX_VALUE meaning that this endpoint manager starts as late + * as possible and stops as soon as possible. + */ + public void setPhase(int phase) { + this.phase = phase; + } + + /** + * Return the phase in which this endpoint manager will be started and stopped. + */ + @Override + public int getPhase() { + return this.phase; + } + + /** + * Prepares the message endpoint, and automatically activates it + * if the "autoStartup" flag is set to "true". + */ + @Override + public void afterPropertiesSet() throws ResourceException { + if (getResourceAdapter() == null) { + throw new IllegalArgumentException("Property 'resourceAdapter' is required"); + } + if (getMessageEndpointFactory() == null) { + throw new IllegalArgumentException("Property 'messageEndpointFactory' is required"); + } + ActivationSpec activationSpec = getActivationSpec(); + if (activationSpec == null) { + throw new IllegalArgumentException("Property 'activationSpec' is required"); + } + + if (activationSpec.getResourceAdapter() == null) { + activationSpec.setResourceAdapter(getResourceAdapter()); + } + else if (activationSpec.getResourceAdapter() != getResourceAdapter()) { + throw new IllegalArgumentException("ActivationSpec [" + activationSpec + + "] is associated with a different ResourceAdapter: " + activationSpec.getResourceAdapter()); + } + } + + /** + * Activates the configured message endpoint. + */ + @Override + public void start() { + synchronized (this.lifecycleMonitor) { + if (!this.running) { + ResourceAdapter resourceAdapter = getResourceAdapter(); + Assert.state(resourceAdapter != null, "No ResourceAdapter set"); + try { + resourceAdapter.endpointActivation(getMessageEndpointFactory(), getActivationSpec()); + } + catch (ResourceException ex) { + throw new IllegalStateException("Could not activate message endpoint", ex); + } + this.running = true; + } + } + } + + /** + * Deactivates the configured message endpoint. + */ + @Override + public void stop() { + synchronized (this.lifecycleMonitor) { + if (this.running) { + ResourceAdapter resourceAdapter = getResourceAdapter(); + Assert.state(resourceAdapter != null, "No ResourceAdapter set"); + resourceAdapter.endpointDeactivation(getMessageEndpointFactory(), getActivationSpec()); + this.running = false; + } + } + } + + @Override + public void stop(Runnable callback) { + synchronized (this.lifecycleMonitor) { + stop(); + callback.run(); + } + } + + /** + * Return whether the configured message endpoint is currently active. + */ + @Override + public boolean isRunning() { + return this.running; + } + + /** + * Deactivates the message endpoint, preparing it for shutdown. + */ + @Override + public void destroy() { + stop(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/endpoint/package-info.java b/spring-tx/src/main/java/org/springframework/jca/endpoint/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..45f3610a9e7f15988cb6fc4f6b1fd7c3bb023330 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/endpoint/package-info.java @@ -0,0 +1,9 @@ +/** + * This package provides a facility for generic JCA message endpoint management. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.endpoint; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/support/LocalConnectionFactoryBean.java b/spring-tx/src/main/java/org/springframework/jca/support/LocalConnectionFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..fd7d1bd05417ae5372cadf468f7e062f43d1c540 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/support/LocalConnectionFactoryBean.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.support; + +import javax.resource.ResourceException; +import javax.resource.spi.ConnectionManager; +import javax.resource.spi.ManagedConnectionFactory; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; + +/** + * {@link org.springframework.beans.factory.FactoryBean} that creates + * a local JCA connection factory in "non-managed" mode (as defined by the + * Java Connector Architecture specification). This is a direct alternative + * to a {@link org.springframework.jndi.JndiObjectFactoryBean} definition that + * obtains a connection factory handle from a Java EE server's naming environment. + * + *

The type of the connection factory is dependent on the actual connector: + * the connector can either expose its native API (such as a JDBC + * {@link javax.sql.DataSource} or a JMS {@link javax.jms.ConnectionFactory}) + * or follow the standard Common Client Interface (CCI), as defined by the JCA spec. + * The exposed interface in the CCI case is {@link javax.resource.cci.ConnectionFactory}. + * + *

In order to use this FactoryBean, you must specify the connector's + * {@link #setManagedConnectionFactory "managedConnectionFactory"} (usually + * configured as separate JavaBean), which will be used to create the actual + * connection factory reference as exposed to the application. Optionally, + * you can also specify a {@link #setConnectionManager "connectionManager"}, + * in order to use a custom ConnectionManager instead of the connector's default. + * + *

NOTE: In non-managed mode, a connector is not deployed on an + * application server, or more specifically not interacting with an application + * server. Consequently, it cannot use a Java EE server's system contracts: + * connection management, transaction management, and security management. + * A custom ConnectionManager implementation has to be used for applying those + * services in conjunction with a standalone transaction coordinator etc. + * + *

The connector will use a local ConnectionManager (included in the connector) + * by default, which cannot participate in global transactions due to the lack + * of XA enlistment. You need to specify an XA-capable ConnectionManager in + * order to make the connector interact with an XA transaction coordinator. + * Alternatively, simply use the native local transaction facilities of the + * exposed API (e.g. CCI local transactions), or use a corresponding + * implementation of Spring's PlatformTransactionManager SPI + * (e.g. {@link org.springframework.jca.cci.connection.CciLocalTransactionManager}) + * to drive local transactions. + * + * @author Juergen Hoeller + * @since 1.2 + * @see #setManagedConnectionFactory + * @see #setConnectionManager + * @see javax.resource.cci.ConnectionFactory + * @see javax.resource.cci.Connection#getLocalTransaction + * @see org.springframework.jca.cci.connection.CciLocalTransactionManager + */ +public class LocalConnectionFactoryBean implements FactoryBean, InitializingBean { + + @Nullable + private ManagedConnectionFactory managedConnectionFactory; + + @Nullable + private ConnectionManager connectionManager; + + @Nullable + private Object connectionFactory; + + + /** + * Set the JCA ManagerConnectionFactory that should be used to create + * the desired connection factory. + *

The ManagerConnectionFactory will usually be set up as separate bean + * (potentially as inner bean), populated with JavaBean properties: + * a ManagerConnectionFactory is encouraged to follow the JavaBean pattern + * by the JCA specification, analogous to a JDBC DataSource and a JPA + * EntityManagerFactory. + *

Note that the ManagerConnectionFactory implementation might expect + * a reference to its JCA 1.7 ResourceAdapter, expressed through the + * {@link javax.resource.spi.ResourceAdapterAssociation} interface. + * Simply inject the corresponding ResourceAdapter instance into its + * "resourceAdapter" bean property in this case, before passing the + * ManagerConnectionFactory into this LocalConnectionFactoryBean. + * @see javax.resource.spi.ManagedConnectionFactory#createConnectionFactory() + */ + public void setManagedConnectionFactory(ManagedConnectionFactory managedConnectionFactory) { + this.managedConnectionFactory = managedConnectionFactory; + } + + /** + * Set the JCA ConnectionManager that should be used to create the + * desired connection factory. + *

A ConnectionManager implementation for local usage is often + * included with a JCA connector. Such an included ConnectionManager + * might be set as default, with no need to explicitly specify one. + * @see javax.resource.spi.ManagedConnectionFactory#createConnectionFactory(javax.resource.spi.ConnectionManager) + */ + public void setConnectionManager(ConnectionManager connectionManager) { + this.connectionManager = connectionManager; + } + + @Override + public void afterPropertiesSet() throws ResourceException { + if (this.managedConnectionFactory == null) { + throw new IllegalArgumentException("Property 'managedConnectionFactory' is required"); + } + if (this.connectionManager != null) { + this.connectionFactory = this.managedConnectionFactory.createConnectionFactory(this.connectionManager); + } + else { + this.connectionFactory = this.managedConnectionFactory.createConnectionFactory(); + } + } + + + @Override + @Nullable + public Object getObject() { + return this.connectionFactory; + } + + @Override + public Class getObjectType() { + return (this.connectionFactory != null ? this.connectionFactory.getClass() : null); + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/support/ResourceAdapterFactoryBean.java b/spring-tx/src/main/java/org/springframework/jca/support/ResourceAdapterFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..7904e0bc614a908279d8e11c847d5603d9992d8a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/support/ResourceAdapterFactoryBean.java @@ -0,0 +1,159 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.support; + +import javax.resource.ResourceException; +import javax.resource.spi.BootstrapContext; +import javax.resource.spi.ResourceAdapter; +import javax.resource.spi.XATerminator; +import javax.resource.spi.work.WorkManager; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; + +/** + * {@link org.springframework.beans.factory.FactoryBean} that bootstraps + * the specified JCA 1.7 {@link javax.resource.spi.ResourceAdapter}, + * starting it with a local {@link javax.resource.spi.BootstrapContext} + * and exposing it for bean references. It will also stop the ResourceAdapter + * on context shutdown. This corresponds to 'non-managed' bootstrap in a + * local environment, according to the JCA 1.7 specification. + * + *

This is essentially an adapter for bean-style bootstrapping of a + * JCA ResourceAdapter, allowing the BootstrapContext or its elements + * (such as the JCA WorkManager) to be specified through bean properties. + * + * @author Juergen Hoeller + * @since 2.0.3 + * @see #setResourceAdapter + * @see #setBootstrapContext + * @see #setWorkManager + * @see javax.resource.spi.ResourceAdapter#start(javax.resource.spi.BootstrapContext) + * @see javax.resource.spi.ResourceAdapter#stop() + */ +public class ResourceAdapterFactoryBean implements FactoryBean, InitializingBean, DisposableBean { + + @Nullable + private ResourceAdapter resourceAdapter; + + @Nullable + private BootstrapContext bootstrapContext; + + @Nullable + private WorkManager workManager; + + @Nullable + private XATerminator xaTerminator; + + + /** + * Specify the target JCA ResourceAdapter as class, to be instantiated + * with its default configuration. + *

Alternatively, specify a pre-configured ResourceAdapter instance + * through the "resourceAdapter" property. + * @see #setResourceAdapter + */ + public void setResourceAdapterClass(Class resourceAdapterClass) { + this.resourceAdapter = BeanUtils.instantiateClass(resourceAdapterClass); + } + + /** + * Specify the target JCA ResourceAdapter, passed in as configured instance + * which hasn't been started yet. This will typically happen as an + * inner bean definition, configuring the ResourceAdapter instance + * through its vendor-specific bean properties. + */ + public void setResourceAdapter(ResourceAdapter resourceAdapter) { + this.resourceAdapter = resourceAdapter; + } + + /** + * Specify the JCA BootstrapContext to use for starting the ResourceAdapter. + *

Alternatively, you can specify the individual parts (such as the + * JCA WorkManager) as individual references. + * @see #setWorkManager + * @see #setXaTerminator + */ + public void setBootstrapContext(BootstrapContext bootstrapContext) { + this.bootstrapContext = bootstrapContext; + } + + /** + * Specify the JCA WorkManager to use for bootstrapping the ResourceAdapter. + * @see #setBootstrapContext + */ + public void setWorkManager(WorkManager workManager) { + this.workManager = workManager; + } + + /** + * Specify the JCA XATerminator to use for bootstrapping the ResourceAdapter. + * @see #setBootstrapContext + */ + public void setXaTerminator(XATerminator xaTerminator) { + this.xaTerminator = xaTerminator; + } + + + /** + * Builds the BootstrapContext and starts the ResourceAdapter with it. + * @see javax.resource.spi.ResourceAdapter#start(javax.resource.spi.BootstrapContext) + */ + @Override + public void afterPropertiesSet() throws ResourceException { + if (this.resourceAdapter == null) { + throw new IllegalArgumentException("'resourceAdapter' or 'resourceAdapterClass' is required"); + } + if (this.bootstrapContext == null) { + this.bootstrapContext = new SimpleBootstrapContext(this.workManager, this.xaTerminator); + } + this.resourceAdapter.start(this.bootstrapContext); + } + + + @Override + @Nullable + public ResourceAdapter getObject() { + return this.resourceAdapter; + } + + @Override + public Class getObjectType() { + return (this.resourceAdapter != null ? this.resourceAdapter.getClass() : ResourceAdapter.class); + } + + @Override + public boolean isSingleton() { + return true; + } + + + /** + * Stops the ResourceAdapter. + * @see javax.resource.spi.ResourceAdapter#stop() + */ + @Override + public void destroy() { + if (this.resourceAdapter != null) { + this.resourceAdapter.stop(); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/support/SimpleBootstrapContext.java b/spring-tx/src/main/java/org/springframework/jca/support/SimpleBootstrapContext.java new file mode 100644 index 0000000000000000000000000000000000000000..8ca15b9d901b0b7a9ba40680d1b078c1c0d5421e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/support/SimpleBootstrapContext.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.support; + +import java.util.Timer; + +import javax.resource.spi.BootstrapContext; +import javax.resource.spi.UnavailableException; +import javax.resource.spi.XATerminator; +import javax.resource.spi.work.WorkContext; +import javax.resource.spi.work.WorkManager; +import javax.transaction.TransactionSynchronizationRegistry; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Simple implementation of the JCA 1.7 {@link javax.resource.spi.BootstrapContext} + * interface, used for bootstrapping a JCA ResourceAdapter in a local environment. + * + *

Delegates to the given WorkManager and XATerminator, if any. Creates simple + * local instances of {@code java.util.Timer}. + * + * @author Juergen Hoeller + * @since 2.0.3 + * @see javax.resource.spi.ResourceAdapter#start(javax.resource.spi.BootstrapContext) + * @see ResourceAdapterFactoryBean + */ +public class SimpleBootstrapContext implements BootstrapContext { + + @Nullable + private WorkManager workManager; + + @Nullable + private XATerminator xaTerminator; + + @Nullable + private TransactionSynchronizationRegistry transactionSynchronizationRegistry; + + + /** + * Create a new SimpleBootstrapContext for the given WorkManager, + * with no XATerminator available. + * @param workManager the JCA WorkManager to use (may be {@code null}) + */ + public SimpleBootstrapContext(@Nullable WorkManager workManager) { + this.workManager = workManager; + } + + /** + * Create a new SimpleBootstrapContext for the given WorkManager and XATerminator. + * @param workManager the JCA WorkManager to use (may be {@code null}) + * @param xaTerminator the JCA XATerminator to use (may be {@code null}) + */ + public SimpleBootstrapContext(@Nullable WorkManager workManager, @Nullable XATerminator xaTerminator) { + this.workManager = workManager; + this.xaTerminator = xaTerminator; + } + + /** + * Create a new SimpleBootstrapContext for the given WorkManager, XATerminator + * and TransactionSynchronizationRegistry. + * @param workManager the JCA WorkManager to use (may be {@code null}) + * @param xaTerminator the JCA XATerminator to use (may be {@code null}) + * @param transactionSynchronizationRegistry the TransactionSynchronizationRegistry + * to use (may be {@code null}) + * @since 5.0 + */ + public SimpleBootstrapContext(@Nullable WorkManager workManager, @Nullable XATerminator xaTerminator, + @Nullable TransactionSynchronizationRegistry transactionSynchronizationRegistry) { + + this.workManager = workManager; + this.xaTerminator = xaTerminator; + this.transactionSynchronizationRegistry = transactionSynchronizationRegistry; + } + + + @Override + public WorkManager getWorkManager() { + Assert.state(this.workManager != null, "No WorkManager available"); + return this.workManager; + } + + @Override + @Nullable + public XATerminator getXATerminator() { + return this.xaTerminator; + } + + @Override + public Timer createTimer() throws UnavailableException { + return new Timer(); + } + + @Override + public boolean isContextSupported(Class workContextClass) { + return false; + } + + @Override + @Nullable + public TransactionSynchronizationRegistry getTransactionSynchronizationRegistry() { + return this.transactionSynchronizationRegistry; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/support/package-info.java b/spring-tx/src/main/java/org/springframework/jca/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..547844092eb0dca8c1c92caa8a6f02add6af7da9 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/support/package-info.java @@ -0,0 +1,10 @@ +/** + * Provides generic support classes for JCA usage within Spring, + * mainly for local setup of a JCA ResourceAdapter and/or ConnectionFactory. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/jca/work/DelegatingWork.java b/spring-tx/src/main/java/org/springframework/jca/work/DelegatingWork.java new file mode 100644 index 0000000000000000000000000000000000000000..da52a8536506a93b737992f57412c21d21eb6ed6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/work/DelegatingWork.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.work; + +import javax.resource.spi.work.Work; + +import org.springframework.util.Assert; + +/** + * Simple Work adapter that delegates to a given Runnable. + * + * @author Juergen Hoeller + * @since 2.0.3 + * @see javax.resource.spi.work.Work + * @see Runnable + */ +public class DelegatingWork implements Work { + + private final Runnable delegate; + + + /** + * Create a new DelegatingWork. + * @param delegate the Runnable implementation to delegate to + */ + public DelegatingWork(Runnable delegate) { + Assert.notNull(delegate, "Delegate must not be null"); + this.delegate = delegate; + } + + /** + * Return the wrapped Runnable implementation. + */ + public final Runnable getDelegate() { + return this.delegate; + } + + + /** + * Delegates execution to the underlying Runnable. + */ + @Override + public void run() { + this.delegate.run(); + } + + /** + * This implementation is empty, since we expect the Runnable + * to terminate based on some specific shutdown signal. + */ + @Override + public void release() { + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/work/SimpleTaskWorkManager.java b/spring-tx/src/main/java/org/springframework/jca/work/SimpleTaskWorkManager.java new file mode 100644 index 0000000000000000000000000000000000000000..90b8214df3441c338beb1afa3ea1df09ba136db3 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/work/SimpleTaskWorkManager.java @@ -0,0 +1,257 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.work; + +import javax.resource.spi.work.ExecutionContext; +import javax.resource.spi.work.Work; +import javax.resource.spi.work.WorkAdapter; +import javax.resource.spi.work.WorkCompletedException; +import javax.resource.spi.work.WorkEvent; +import javax.resource.spi.work.WorkException; +import javax.resource.spi.work.WorkListener; +import javax.resource.spi.work.WorkManager; +import javax.resource.spi.work.WorkRejectedException; + +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.core.task.TaskExecutor; +import org.springframework.core.task.TaskRejectedException; +import org.springframework.core.task.TaskTimeoutException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Simple JCA 1.7 {@link javax.resource.spi.work.WorkManager} implementation that + * delegates to a Spring {@link org.springframework.core.task.TaskExecutor}. + * Provides simple task execution including start timeouts, but without support + * for a JCA ExecutionContext (i.e. without support for imported transactions). + * + *

Uses a {@link org.springframework.core.task.SyncTaskExecutor} for {@link #doWork} + * calls and a {@link org.springframework.core.task.SimpleAsyncTaskExecutor} + * for {@link #startWork} and {@link #scheduleWork} calls, by default. + * These default task executors can be overridden through configuration. + * + *

NOTE: This WorkManager does not provide thread pooling by default! + * Specify a {@link org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor} + * (or any other thread-pooling TaskExecutor) as "asyncTaskExecutor" in order to + * achieve actual thread pooling. + * + *

This WorkManager automatically detects a specified + * {@link org.springframework.core.task.AsyncTaskExecutor} implementation + * and uses its extended timeout functionality where appropriate. + * JCA WorkListeners are fully supported in any case. + * + * @author Juergen Hoeller + * @since 2.0.3 + * @see #setSyncTaskExecutor + * @see #setAsyncTaskExecutor + */ +public class SimpleTaskWorkManager implements WorkManager { + + @Nullable + private TaskExecutor syncTaskExecutor = new SyncTaskExecutor(); + + @Nullable + private AsyncTaskExecutor asyncTaskExecutor = new SimpleAsyncTaskExecutor(); + + + /** + * Specify the TaskExecutor to use for synchronous work execution + * (i.e. {@link #doWork} calls). + *

Default is a {@link org.springframework.core.task.SyncTaskExecutor}. + */ + public void setSyncTaskExecutor(TaskExecutor syncTaskExecutor) { + this.syncTaskExecutor = syncTaskExecutor; + } + + /** + * Specify the TaskExecutor to use for asynchronous work execution + * (i.e. {@link #startWork} and {@link #scheduleWork} calls). + *

This will typically (but not necessarily) be an + * {@link org.springframework.core.task.AsyncTaskExecutor} implementation. + * Default is a {@link org.springframework.core.task.SimpleAsyncTaskExecutor}. + */ + public void setAsyncTaskExecutor(AsyncTaskExecutor asyncTaskExecutor) { + this.asyncTaskExecutor = asyncTaskExecutor; + } + + + @Override + public void doWork(Work work) throws WorkException { + doWork(work, WorkManager.INDEFINITE, null, null); + } + + @Override + public void doWork(Work work, long startTimeout, @Nullable ExecutionContext executionContext, @Nullable WorkListener workListener) + throws WorkException { + + Assert.state(this.syncTaskExecutor != null, "No 'syncTaskExecutor' set"); + executeWork(this.syncTaskExecutor, work, startTimeout, false, executionContext, workListener); + } + + @Override + public long startWork(Work work) throws WorkException { + return startWork(work, WorkManager.INDEFINITE, null, null); + } + + @Override + public long startWork(Work work, long startTimeout, @Nullable ExecutionContext executionContext, @Nullable WorkListener workListener) + throws WorkException { + + Assert.state(this.asyncTaskExecutor != null, "No 'asyncTaskExecutor' set"); + return executeWork(this.asyncTaskExecutor, work, startTimeout, true, executionContext, workListener); + } + + @Override + public void scheduleWork(Work work) throws WorkException { + scheduleWork(work, WorkManager.INDEFINITE, null, null); + } + + @Override + public void scheduleWork(Work work, long startTimeout, @Nullable ExecutionContext executionContext, @Nullable WorkListener workListener) + throws WorkException { + + Assert.state(this.asyncTaskExecutor != null, "No 'asyncTaskExecutor' set"); + executeWork(this.asyncTaskExecutor, work, startTimeout, false, executionContext, workListener); + } + + + /** + * Execute the given Work on the specified TaskExecutor. + * @param taskExecutor the TaskExecutor to use + * @param work the Work to execute + * @param startTimeout the time duration within which the Work is supposed to start + * @param blockUntilStarted whether to block until the Work has started + * @param executionContext the JCA ExecutionContext for the given Work + * @param workListener the WorkListener to clal for the given Work + * @return the time elapsed from Work acceptance until start of execution + * (or -1 if not applicable or not known) + * @throws WorkException if the TaskExecutor did not accept the Work + */ + protected long executeWork(TaskExecutor taskExecutor, Work work, long startTimeout, boolean blockUntilStarted, + @Nullable ExecutionContext executionContext, @Nullable WorkListener workListener) throws WorkException { + + if (executionContext != null && executionContext.getXid() != null) { + throw new WorkException("SimpleTaskWorkManager does not supported imported XIDs: " + executionContext.getXid()); + } + WorkListener workListenerToUse = workListener; + if (workListenerToUse == null) { + workListenerToUse = new WorkAdapter(); + } + + boolean isAsync = (taskExecutor instanceof AsyncTaskExecutor); + DelegatingWorkAdapter workHandle = new DelegatingWorkAdapter(work, workListenerToUse, !isAsync); + try { + if (isAsync) { + ((AsyncTaskExecutor) taskExecutor).execute(workHandle, startTimeout); + } + else { + taskExecutor.execute(workHandle); + } + } + catch (TaskTimeoutException ex) { + WorkException wex = new WorkRejectedException("TaskExecutor rejected Work because of timeout: " + work, ex); + wex.setErrorCode(WorkException.START_TIMED_OUT); + workListenerToUse.workRejected(new WorkEvent(this, WorkEvent.WORK_REJECTED, work, wex)); + throw wex; + } + catch (TaskRejectedException ex) { + WorkException wex = new WorkRejectedException("TaskExecutor rejected Work: " + work, ex); + wex.setErrorCode(WorkException.INTERNAL); + workListenerToUse.workRejected(new WorkEvent(this, WorkEvent.WORK_REJECTED, work, wex)); + throw wex; + } + catch (Throwable ex) { + WorkException wex = new WorkException("TaskExecutor failed to execute Work: " + work, ex); + wex.setErrorCode(WorkException.INTERNAL); + throw wex; + } + if (isAsync) { + workListenerToUse.workAccepted(new WorkEvent(this, WorkEvent.WORK_ACCEPTED, work, null)); + } + + if (blockUntilStarted) { + long acceptanceTime = System.currentTimeMillis(); + synchronized (workHandle.monitor) { + try { + while (!workHandle.started) { + workHandle.monitor.wait(); + } + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + } + return (System.currentTimeMillis() - acceptanceTime); + } + else { + return WorkManager.UNKNOWN; + } + } + + + /** + * Work adapter that supports start timeouts and WorkListener callbacks + * for a given Work that it delegates to. + */ + private static class DelegatingWorkAdapter implements Work { + + private final Work work; + + private final WorkListener workListener; + + private final boolean acceptOnExecution; + + public final Object monitor = new Object(); + + public boolean started = false; + + public DelegatingWorkAdapter(Work work, WorkListener workListener, boolean acceptOnExecution) { + this.work = work; + this.workListener = workListener; + this.acceptOnExecution = acceptOnExecution; + } + + @Override + public void run() { + if (this.acceptOnExecution) { + this.workListener.workAccepted(new WorkEvent(this, WorkEvent.WORK_ACCEPTED, this.work, null)); + } + synchronized (this.monitor) { + this.started = true; + this.monitor.notify(); + } + this.workListener.workStarted(new WorkEvent(this, WorkEvent.WORK_STARTED, this.work, null)); + try { + this.work.run(); + } + catch (RuntimeException | Error ex) { + this.workListener.workCompleted( + new WorkEvent(this, WorkEvent.WORK_COMPLETED, this.work, new WorkCompletedException(ex))); + throw ex; + } + this.workListener.workCompleted(new WorkEvent(this, WorkEvent.WORK_COMPLETED, this.work, null)); + } + + @Override + public void release() { + this.work.release(); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/work/WorkManagerTaskExecutor.java b/spring-tx/src/main/java/org/springframework/jca/work/WorkManagerTaskExecutor.java new file mode 100644 index 0000000000000000000000000000000000000000..e4fe06bfeb35ac12d74f1a9b009b1eb157b58af5 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/work/WorkManagerTaskExecutor.java @@ -0,0 +1,337 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.work; + +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; + +import javax.naming.NamingException; +import javax.resource.spi.BootstrapContext; +import javax.resource.spi.work.ExecutionContext; +import javax.resource.spi.work.Work; +import javax.resource.spi.work.WorkException; +import javax.resource.spi.work.WorkListener; +import javax.resource.spi.work.WorkManager; +import javax.resource.spi.work.WorkRejectedException; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.TaskDecorator; +import org.springframework.core.task.TaskRejectedException; +import org.springframework.core.task.TaskTimeoutException; +import org.springframework.jca.context.BootstrapContextAware; +import org.springframework.jndi.JndiLocatorSupport; +import org.springframework.lang.Nullable; +import org.springframework.scheduling.SchedulingException; +import org.springframework.scheduling.SchedulingTaskExecutor; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureTask; + +/** + * {@link org.springframework.core.task.TaskExecutor} implementation + * that delegates to a JCA 1.7 WorkManager, implementing the + * {@link javax.resource.spi.work.WorkManager} interface. + * + *

This is mainly intended for use within a JCA ResourceAdapter implementation, + * but may also be used in a standalone environment, delegating to a locally + * embedded WorkManager implementation (such as Geronimo's). + * + *

Also implements the JCA 1.7 WorkManager interface itself, delegating all + * calls to the target WorkManager. Hence, a caller can choose whether it wants + * to talk to this executor through the Spring TaskExecutor interface or the + * WorkManager interface. + * + *

This adapter is also capable of obtaining a JCA WorkManager from JNDI. + * This is for example appropriate on the Geronimo application server, where + * WorkManager GBeans (e.g. Geronimo's default "DefaultWorkManager" GBean) + * can be linked into the Java EE environment through "gbean-ref" entries + * in the {@code geronimo-web.xml} deployment descriptor. + * + * @author Juergen Hoeller + * @since 2.0.3 + * @see #setWorkManager + * @see javax.resource.spi.work.WorkManager#scheduleWork + */ +public class WorkManagerTaskExecutor extends JndiLocatorSupport + implements AsyncListenableTaskExecutor, SchedulingTaskExecutor, WorkManager, BootstrapContextAware, InitializingBean { + + @Nullable + private WorkManager workManager; + + @Nullable + private String workManagerName; + + private boolean blockUntilStarted = false; + + private boolean blockUntilCompleted = false; + + @Nullable + private WorkListener workListener; + + @Nullable + private TaskDecorator taskDecorator; + + + /** + * Create a new WorkManagerTaskExecutor, expecting bean-style configuration. + * @see #setWorkManager + */ + public WorkManagerTaskExecutor() { + } + + /** + * Create a new WorkManagerTaskExecutor for the given WorkManager. + * @param workManager the JCA WorkManager to delegate to + */ + public WorkManagerTaskExecutor(WorkManager workManager) { + setWorkManager(workManager); + } + + + /** + * Specify the JCA WorkManager instance to delegate to. + */ + public void setWorkManager(WorkManager workManager) { + Assert.notNull(workManager, "WorkManager must not be null"); + this.workManager = workManager; + } + + /** + * Set the JNDI name of the JCA WorkManager. + *

This can either be a fully qualified JNDI name, + * or the JNDI name relative to the current environment + * naming context if "resourceRef" is set to "true". + * @see #setWorkManager + * @see #setResourceRef + */ + public void setWorkManagerName(String workManagerName) { + this.workManagerName = workManagerName; + } + + /** + * Specify the JCA BootstrapContext that contains the + * WorkManager to delegate to. + */ + @Override + public void setBootstrapContext(BootstrapContext bootstrapContext) { + Assert.notNull(bootstrapContext, "BootstrapContext must not be null"); + this.workManager = bootstrapContext.getWorkManager(); + } + + /** + * Set whether to let {@link #execute} block until the work + * has been actually started. + *

Uses the JCA {@code startWork} operation underneath, + * instead of the default {@code scheduleWork}. + * @see javax.resource.spi.work.WorkManager#startWork + * @see javax.resource.spi.work.WorkManager#scheduleWork + */ + public void setBlockUntilStarted(boolean blockUntilStarted) { + this.blockUntilStarted = blockUntilStarted; + } + + /** + * Set whether to let {@link #execute} block until the work + * has been completed. + *

Uses the JCA {@code doWork} operation underneath, + * instead of the default {@code scheduleWork}. + * @see javax.resource.spi.work.WorkManager#doWork + * @see javax.resource.spi.work.WorkManager#scheduleWork + */ + public void setBlockUntilCompleted(boolean blockUntilCompleted) { + this.blockUntilCompleted = blockUntilCompleted; + } + + /** + * Specify a JCA WorkListener to apply, if any. + *

This shared WorkListener instance will be passed on to the + * WorkManager by all {@link #execute} calls on this TaskExecutor. + */ + public void setWorkListener(@Nullable WorkListener workListener) { + this.workListener = workListener; + } + + /** + * Specify a custom {@link TaskDecorator} to be applied to any {@link Runnable} + * about to be executed. + *

Note that such a decorator is not necessarily being applied to the + * user-supplied {@code Runnable}/{@code Callable} but rather to the actual + * execution callback (which may be a wrapper around the user-supplied task). + *

The primary use case is to set some execution context around the task's + * invocation, or to provide some monitoring/statistics for task execution. + *

NOTE: Exception handling in {@code TaskDecorator} implementations + * is limited to plain {@code Runnable} execution via {@code execute} calls. + * In case of {@code #submit} calls, the exposed {@code Runnable} will be a + * {@code FutureTask} which does not propagate any exceptions; you might + * have to cast it and call {@code Future#get} to evaluate exceptions. + * @since 4.3 + */ + public void setTaskDecorator(TaskDecorator taskDecorator) { + this.taskDecorator = taskDecorator; + } + + @Override + public void afterPropertiesSet() throws NamingException { + if (this.workManager == null) { + if (this.workManagerName != null) { + this.workManager = lookup(this.workManagerName, WorkManager.class); + } + else { + this.workManager = getDefaultWorkManager(); + } + } + } + + /** + * Obtain a default WorkManager to delegate to. + * Called if no explicit WorkManager or WorkManager JNDI name has been specified. + *

The default implementation returns a {@link SimpleTaskWorkManager}. + * Can be overridden in subclasses. + */ + protected WorkManager getDefaultWorkManager() { + return new SimpleTaskWorkManager(); + } + + private WorkManager obtainWorkManager() { + Assert.state(this.workManager != null, "No WorkManager specified"); + return this.workManager; + } + + + //------------------------------------------------------------------------- + // Implementation of the Spring SchedulingTaskExecutor interface + //------------------------------------------------------------------------- + + @Override + public void execute(Runnable task) { + execute(task, TIMEOUT_INDEFINITE); + } + + @Override + public void execute(Runnable task, long startTimeout) { + Work work = new DelegatingWork(this.taskDecorator != null ? this.taskDecorator.decorate(task) : task); + try { + if (this.blockUntilCompleted) { + if (startTimeout != TIMEOUT_INDEFINITE || this.workListener != null) { + obtainWorkManager().doWork(work, startTimeout, null, this.workListener); + } + else { + obtainWorkManager().doWork(work); + } + } + else if (this.blockUntilStarted) { + if (startTimeout != TIMEOUT_INDEFINITE || this.workListener != null) { + obtainWorkManager().startWork(work, startTimeout, null, this.workListener); + } + else { + obtainWorkManager().startWork(work); + } + } + else { + if (startTimeout != TIMEOUT_INDEFINITE || this.workListener != null) { + obtainWorkManager().scheduleWork(work, startTimeout, null, this.workListener); + } + else { + obtainWorkManager().scheduleWork(work); + } + } + } + catch (WorkRejectedException ex) { + if (WorkException.START_TIMED_OUT.equals(ex.getErrorCode())) { + throw new TaskTimeoutException("JCA WorkManager rejected task because of timeout: " + task, ex); + } + else { + throw new TaskRejectedException("JCA WorkManager rejected task: " + task, ex); + } + } + catch (WorkException ex) { + throw new SchedulingException("Could not schedule task on JCA WorkManager", ex); + } + } + + @Override + public Future submit(Runnable task) { + FutureTask future = new FutureTask<>(task, null); + execute(future, TIMEOUT_INDEFINITE); + return future; + } + + @Override + public Future submit(Callable task) { + FutureTask future = new FutureTask<>(task); + execute(future, TIMEOUT_INDEFINITE); + return future; + } + + @Override + public ListenableFuture submitListenable(Runnable task) { + ListenableFutureTask future = new ListenableFutureTask<>(task, null); + execute(future, TIMEOUT_INDEFINITE); + return future; + } + + @Override + public ListenableFuture submitListenable(Callable task) { + ListenableFutureTask future = new ListenableFutureTask<>(task); + execute(future, TIMEOUT_INDEFINITE); + return future; + } + + + //------------------------------------------------------------------------- + // Implementation of the JCA WorkManager interface + //------------------------------------------------------------------------- + + @Override + public void doWork(Work work) throws WorkException { + obtainWorkManager().doWork(work); + } + + @Override + public void doWork(Work work, long delay, ExecutionContext executionContext, WorkListener workListener) + throws WorkException { + + obtainWorkManager().doWork(work, delay, executionContext, workListener); + } + + @Override + public long startWork(Work work) throws WorkException { + return obtainWorkManager().startWork(work); + } + + @Override + public long startWork(Work work, long delay, ExecutionContext executionContext, WorkListener workListener) + throws WorkException { + + return obtainWorkManager().startWork(work, delay, executionContext, workListener); + } + + @Override + public void scheduleWork(Work work) throws WorkException { + obtainWorkManager().scheduleWork(work); + } + + @Override + public void scheduleWork(Work work, long delay, ExecutionContext executionContext, WorkListener workListener) + throws WorkException { + + obtainWorkManager().scheduleWork(work, delay, executionContext, workListener); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/jca/work/package-info.java b/spring-tx/src/main/java/org/springframework/jca/work/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..a797e7215078fda7a08da9af54fff04b5a4a6a15 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/jca/work/package-info.java @@ -0,0 +1,10 @@ +/** + * Convenience classes for scheduling based on the JCA WorkManager facility, + * as supported within ResourceAdapters. + */ +@NonNullApi +@NonNullFields +package org.springframework.jca.work; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/transaction/CannotCreateTransactionException.java b/spring-tx/src/main/java/org/springframework/transaction/CannotCreateTransactionException.java new file mode 100644 index 0000000000000000000000000000000000000000..e710f2f8cad365c7e9f1adbe32dd4c8a4893d979 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/CannotCreateTransactionException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception thrown when a transaction can't be created using an + * underlying transaction API such as JTA. + * + * @author Rod Johnson + * @since 17.03.2003 + */ +@SuppressWarnings("serial") +public class CannotCreateTransactionException extends TransactionException { + + /** + * Constructor for CannotCreateTransactionException. + * @param msg the detail message + */ + public CannotCreateTransactionException(String msg) { + super(msg); + } + + /** + * Constructor for CannotCreateTransactionException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public CannotCreateTransactionException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/HeuristicCompletionException.java b/spring-tx/src/main/java/org/springframework/transaction/HeuristicCompletionException.java new file mode 100644 index 0000000000000000000000000000000000000000..c5b6b71f1853b2a91203c26be7976c304e206a9d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/HeuristicCompletionException.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception that represents a transaction failure caused by a heuristic + * decision on the side of the transaction coordinator. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 17.03.2003 + */ +@SuppressWarnings("serial") +public class HeuristicCompletionException extends TransactionException { + + /** + * Unknown outcome state. + */ + public static final int STATE_UNKNOWN = 0; + + /** + * Committed outcome state. + */ + public static final int STATE_COMMITTED = 1; + + /** + * Rolledback outcome state. + */ + public static final int STATE_ROLLED_BACK = 2; + + /** + * Mixed outcome state. + */ + public static final int STATE_MIXED = 3; + + + public static String getStateString(int state) { + switch (state) { + case STATE_COMMITTED: + return "committed"; + case STATE_ROLLED_BACK: + return "rolled back"; + case STATE_MIXED: + return "mixed"; + default: + return "unknown"; + } + } + + + /** + * The outcome state of the transaction: have some or all resources been committed? + */ + private final int outcomeState; + + + /** + * Constructor for HeuristicCompletionException. + * @param outcomeState the outcome state of the transaction + * @param cause the root cause from the transaction API in use + */ + public HeuristicCompletionException(int outcomeState, Throwable cause) { + super("Heuristic completion: outcome state is " + getStateString(outcomeState), cause); + this.outcomeState = outcomeState; + } + + /** + * Return the outcome state of the transaction state, + * as one of the constants in this class. + * @see #STATE_UNKNOWN + * @see #STATE_COMMITTED + * @see #STATE_ROLLED_BACK + * @see #STATE_MIXED + */ + public int getOutcomeState() { + return this.outcomeState; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/IllegalTransactionStateException.java b/spring-tx/src/main/java/org/springframework/transaction/IllegalTransactionStateException.java new file mode 100644 index 0000000000000000000000000000000000000000..ebe3edbea67dc1f2a26530e7d00f69d711b92d14 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/IllegalTransactionStateException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception thrown when the existence or non-existence of a transaction + * amounts to an illegal state according to the transaction propagation + * behavior that applies. + * + * @author Juergen Hoeller + * @since 21.01.2004 + */ +@SuppressWarnings("serial") +public class IllegalTransactionStateException extends TransactionUsageException { + + /** + * Constructor for IllegalTransactionStateException. + * @param msg the detail message + */ + public IllegalTransactionStateException(String msg) { + super(msg); + } + + /** + * Constructor for IllegalTransactionStateException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public IllegalTransactionStateException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/InvalidIsolationLevelException.java b/spring-tx/src/main/java/org/springframework/transaction/InvalidIsolationLevelException.java new file mode 100644 index 0000000000000000000000000000000000000000..28f2593f53b4bbf317a4cf148ac213062ea8797d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/InvalidIsolationLevelException.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception that gets thrown when an invalid isolation level is specified, + * i.e. an isolation level that the transaction manager implementation + * doesn't support. + * + * @author Juergen Hoeller + * @since 12.05.2003 + */ +@SuppressWarnings("serial") +public class InvalidIsolationLevelException extends TransactionUsageException { + + /** + * Constructor for InvalidIsolationLevelException. + * @param msg the detail message + */ + public InvalidIsolationLevelException(String msg) { + super(msg); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/InvalidTimeoutException.java b/spring-tx/src/main/java/org/springframework/transaction/InvalidTimeoutException.java new file mode 100644 index 0000000000000000000000000000000000000000..0b574ad7d73d0ff6df044b995cb982f10ed5db32 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/InvalidTimeoutException.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception that gets thrown when an invalid timeout is specified, + * that is, the specified timeout valid is out of range or the + * transaction manager implementation doesn't support timeouts. + * + * @author Juergen Hoeller + * @since 12.05.2003 + */ +@SuppressWarnings("serial") +public class InvalidTimeoutException extends TransactionUsageException { + + private final int timeout; + + + /** + * Constructor for InvalidTimeoutException. + * @param msg the detail message + * @param timeout the invalid timeout value + */ + public InvalidTimeoutException(String msg, int timeout) { + super(msg); + this.timeout = timeout; + } + + /** + * Return the invalid timeout value. + */ + public int getTimeout() { + return this.timeout; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/NestedTransactionNotSupportedException.java b/spring-tx/src/main/java/org/springframework/transaction/NestedTransactionNotSupportedException.java new file mode 100644 index 0000000000000000000000000000000000000000..97bd7d7ebf57ada998b601b8b9e0c8aba6a6ea0d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/NestedTransactionNotSupportedException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception thrown when attempting to work with a nested transaction + * but nested transactions are not supported by the underlying backend. + * + * @author Juergen Hoeller + * @since 1.1 + */ +@SuppressWarnings("serial") +public class NestedTransactionNotSupportedException extends CannotCreateTransactionException { + + /** + * Constructor for NestedTransactionNotSupportedException. + * @param msg the detail message + */ + public NestedTransactionNotSupportedException(String msg) { + super(msg); + } + + /** + * Constructor for NestedTransactionNotSupportedException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public NestedTransactionNotSupportedException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/NoTransactionException.java b/spring-tx/src/main/java/org/springframework/transaction/NoTransactionException.java new file mode 100644 index 0000000000000000000000000000000000000000..872334acf91c35a7bf516027153b2bc1e2271891 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/NoTransactionException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception thrown when an operation is attempted that + * relies on an existing transaction (such as setting + * rollback status) and there is no existing transaction. + * This represents an illegal usage of the transaction API. + * + * @author Rod Johnson + * @since 17.03.2003 + */ +@SuppressWarnings("serial") +public class NoTransactionException extends TransactionUsageException { + + /** + * Constructor for NoTransactionException. + * @param msg the detail message + */ + public NoTransactionException(String msg) { + super(msg); + } + + /** + * Constructor for NoTransactionException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public NoTransactionException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/PlatformTransactionManager.java b/spring-tx/src/main/java/org/springframework/transaction/PlatformTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..39ab82f665c40d26fb3369e0e67cbc9cf8a994ef --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/PlatformTransactionManager.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import org.springframework.lang.Nullable; + +/** + * This is the central interface in Spring's transaction infrastructure. + * Applications can use this directly, but it is not primarily meant as API: + * Typically, applications will work with either TransactionTemplate or + * declarative transaction demarcation through AOP. + * + *

For implementors, it is recommended to derive from the provided + * {@link org.springframework.transaction.support.AbstractPlatformTransactionManager} + * class, which pre-implements the defined propagation behavior and takes care + * of transaction synchronization handling. Subclasses have to implement + * template methods for specific states of the underlying transaction, + * for example: begin, suspend, resume, commit. + * + *

The default implementations of this strategy interface are + * {@link org.springframework.transaction.jta.JtaTransactionManager} and + * {@link org.springframework.jdbc.datasource.DataSourceTransactionManager}, + * which can serve as an implementation guide for other transaction strategies. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 16.05.2003 + * @see org.springframework.transaction.support.TransactionTemplate + * @see org.springframework.transaction.interceptor.TransactionInterceptor + */ +public interface PlatformTransactionManager { + + /** + * Return a currently active transaction or create a new one, according to + * the specified propagation behavior. + *

Note that parameters like isolation level or timeout will only be applied + * to new transactions, and thus be ignored when participating in active ones. + *

Furthermore, not all transaction definition settings will be supported + * by every transaction manager: A proper transaction manager implementation + * should throw an exception when unsupported settings are encountered. + *

An exception to the above rule is the read-only flag, which should be + * ignored if no explicit read-only mode is supported. Essentially, the + * read-only flag is just a hint for potential optimization. + * @param definition the TransactionDefinition instance (can be {@code null} for defaults), + * describing propagation behavior, isolation level, timeout etc. + * @return transaction status object representing the new or current transaction + * @throws TransactionException in case of lookup, creation, or system errors + * @throws IllegalTransactionStateException if the given transaction definition + * cannot be executed (for example, if a currently active transaction is in + * conflict with the specified propagation behavior) + * @see TransactionDefinition#getPropagationBehavior + * @see TransactionDefinition#getIsolationLevel + * @see TransactionDefinition#getTimeout + * @see TransactionDefinition#isReadOnly + */ + TransactionStatus getTransaction(@Nullable TransactionDefinition definition) + throws TransactionException; + + /** + * Commit the given transaction, with regard to its status. If the transaction + * has been marked rollback-only programmatically, perform a rollback. + *

If the transaction wasn't a new one, omit the commit for proper + * participation in the surrounding transaction. If a previous transaction + * has been suspended to be able to create a new one, resume the previous + * transaction after committing the new one. + *

Note that when the commit call completes, no matter if normally or + * throwing an exception, the transaction must be fully completed and + * cleaned up. No rollback call should be expected in such a case. + *

If this method throws an exception other than a TransactionException, + * then some before-commit error caused the commit attempt to fail. For + * example, an O/R Mapping tool might have tried to flush changes to the + * database right before commit, with the resulting DataAccessException + * causing the transaction to fail. The original exception will be + * propagated to the caller of this commit method in such a case. + * @param status object returned by the {@code getTransaction} method + * @throws UnexpectedRollbackException in case of an unexpected rollback + * that the transaction coordinator initiated + * @throws HeuristicCompletionException in case of a transaction failure + * caused by a heuristic decision on the side of the transaction coordinator + * @throws TransactionSystemException in case of commit or system errors + * (typically caused by fundamental resource failures) + * @throws IllegalTransactionStateException if the given transaction + * is already completed (that is, committed or rolled back) + * @see TransactionStatus#setRollbackOnly + */ + void commit(TransactionStatus status) throws TransactionException; + + /** + * Perform a rollback of the given transaction. + *

If the transaction wasn't a new one, just set it rollback-only for proper + * participation in the surrounding transaction. If a previous transaction + * has been suspended to be able to create a new one, resume the previous + * transaction after rolling back the new one. + *

Do not call rollback on a transaction if commit threw an exception. + * The transaction will already have been completed and cleaned up when commit + * returns, even in case of a commit exception. Consequently, a rollback call + * after commit failure will lead to an IllegalTransactionStateException. + * @param status object returned by the {@code getTransaction} method + * @throws TransactionSystemException in case of rollback or system errors + * (typically caused by fundamental resource failures) + * @throws IllegalTransactionStateException if the given transaction + * is already completed (that is, committed or rolled back) + */ + void rollback(TransactionStatus status) throws TransactionException; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/SavepointManager.java b/spring-tx/src/main/java/org/springframework/transaction/SavepointManager.java new file mode 100644 index 0000000000000000000000000000000000000000..ca12743cb4ef2ef6015cf1d221ee159e78b4319a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/SavepointManager.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Interface that specifies an API to programmatically manage transaction + * savepoints in a generic fashion. Extended by TransactionStatus to + * expose savepoint management functionality for a specific transaction. + * + *

Note that savepoints can only work within an active transaction. + * Just use this programmatic savepoint handling for advanced needs; + * else, a subtransaction with PROPAGATION_NESTED is preferable. + * + *

This interface is inspired by JDBC 3.0's Savepoint mechanism + * but is independent from any specific persistence technology. + * + * @author Juergen Hoeller + * @since 1.1 + * @see TransactionStatus + * @see TransactionDefinition#PROPAGATION_NESTED + * @see java.sql.Savepoint + */ +public interface SavepointManager { + + /** + * Create a new savepoint. You can roll back to a specific savepoint + * via {@code rollbackToSavepoint}, and explicitly release a savepoint + * that you don't need anymore via {@code releaseSavepoint}. + *

Note that most transaction managers will automatically release + * savepoints at transaction completion. + * @return a savepoint object, to be passed into + * {@link #rollbackToSavepoint} or {@link #releaseSavepoint} + * @throws NestedTransactionNotSupportedException if the underlying + * transaction does not support savepoints + * @throws TransactionException if the savepoint could not be created, + * for example because the transaction is not in an appropriate state + * @see java.sql.Connection#setSavepoint + */ + Object createSavepoint() throws TransactionException; + + /** + * Roll back to the given savepoint. + *

The savepoint will not be automatically released afterwards. + * You may explicitly call {@link #releaseSavepoint(Object)} or rely on + * automatic release on transaction completion. + * @param savepoint the savepoint to roll back to + * @throws NestedTransactionNotSupportedException if the underlying + * transaction does not support savepoints + * @throws TransactionException if the rollback failed + * @see java.sql.Connection#rollback(java.sql.Savepoint) + */ + void rollbackToSavepoint(Object savepoint) throws TransactionException; + + /** + * Explicitly release the given savepoint. + *

Note that most transaction managers will automatically release + * savepoints on transaction completion. + *

Implementations should fail as silently as possible if proper + * resource cleanup will eventually happen at transaction completion. + * @param savepoint the savepoint to release + * @throws NestedTransactionNotSupportedException if the underlying + * transaction does not support savepoints + * @throws TransactionException if the release failed + * @see java.sql.Connection#releaseSavepoint + */ + void releaseSavepoint(Object savepoint) throws TransactionException; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/TransactionDefinition.java b/spring-tx/src/main/java/org/springframework/transaction/TransactionDefinition.java new file mode 100644 index 0000000000000000000000000000000000000000..bcc4f076ee563b5e01b4c16f58934ce2dd672397 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/TransactionDefinition.java @@ -0,0 +1,268 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import java.sql.Connection; + +import org.springframework.lang.Nullable; + +/** + * Interface that defines Spring-compliant transaction properties. + * Based on the propagation behavior definitions analogous to EJB CMT attributes. + * + *

Note that isolation level and timeout settings will not get applied unless + * an actual new transaction gets started. As only {@link #PROPAGATION_REQUIRED}, + * {@link #PROPAGATION_REQUIRES_NEW} and {@link #PROPAGATION_NESTED} can cause + * that, it usually doesn't make sense to specify those settings in other cases. + * Furthermore, be aware that not all transaction managers will support those + * advanced features and thus might throw corresponding exceptions when given + * non-default values. + * + *

The {@link #isReadOnly() read-only flag} applies to any transaction context, + * whether backed by an actual resource transaction or operating non-transactionally + * at the resource level. In the latter case, the flag will only apply to managed + * resources within the application, such as a Hibernate {@code Session}. + * + * @author Juergen Hoeller + * @since 08.05.2003 + * @see PlatformTransactionManager#getTransaction(TransactionDefinition) + * @see org.springframework.transaction.support.DefaultTransactionDefinition + * @see org.springframework.transaction.interceptor.TransactionAttribute + */ +public interface TransactionDefinition { + + /** + * Support a current transaction; create a new one if none exists. + * Analogous to the EJB transaction attribute of the same name. + *

This is typically the default setting of a transaction definition, + * and typically defines a transaction synchronization scope. + */ + int PROPAGATION_REQUIRED = 0; + + /** + * Support a current transaction; execute non-transactionally if none exists. + * Analogous to the EJB transaction attribute of the same name. + *

NOTE: For transaction managers with transaction synchronization, + * {@code PROPAGATION_SUPPORTS} is slightly different from no transaction + * at all, as it defines a transaction scope that synchronization might apply to. + * As a consequence, the same resources (a JDBC {@code Connection}, a + * Hibernate {@code Session}, etc) will be shared for the entire specified + * scope. Note that the exact behavior depends on the actual synchronization + * configuration of the transaction manager! + *

In general, use {@code PROPAGATION_SUPPORTS} with care! In particular, do + * not rely on {@code PROPAGATION_REQUIRED} or {@code PROPAGATION_REQUIRES_NEW} + * within a {@code PROPAGATION_SUPPORTS} scope (which may lead to + * synchronization conflicts at runtime). If such nesting is unavoidable, make sure + * to configure your transaction manager appropriately (typically switching to + * "synchronization on actual transaction"). + * @see org.springframework.transaction.support.AbstractPlatformTransactionManager#setTransactionSynchronization + * @see org.springframework.transaction.support.AbstractPlatformTransactionManager#SYNCHRONIZATION_ON_ACTUAL_TRANSACTION + */ + int PROPAGATION_SUPPORTS = 1; + + /** + * Support a current transaction; throw an exception if no current transaction + * exists. Analogous to the EJB transaction attribute of the same name. + *

Note that transaction synchronization within a {@code PROPAGATION_MANDATORY} + * scope will always be driven by the surrounding transaction. + */ + int PROPAGATION_MANDATORY = 2; + + /** + * Create a new transaction, suspending the current transaction if one exists. + * Analogous to the EJB transaction attribute of the same name. + *

NOTE: Actual transaction suspension will not work out-of-the-box + * on all transaction managers. This in particular applies to + * {@link org.springframework.transaction.jta.JtaTransactionManager}, + * which requires the {@code javax.transaction.TransactionManager} to be + * made available it to it (which is server-specific in standard Java EE). + *

A {@code PROPAGATION_REQUIRES_NEW} scope always defines its own + * transaction synchronizations. Existing synchronizations will be suspended + * and resumed appropriately. + * @see org.springframework.transaction.jta.JtaTransactionManager#setTransactionManager + */ + int PROPAGATION_REQUIRES_NEW = 3; + + /** + * Do not support a current transaction; rather always execute non-transactionally. + * Analogous to the EJB transaction attribute of the same name. + *

NOTE: Actual transaction suspension will not work out-of-the-box + * on all transaction managers. This in particular applies to + * {@link org.springframework.transaction.jta.JtaTransactionManager}, + * which requires the {@code javax.transaction.TransactionManager} to be + * made available it to it (which is server-specific in standard Java EE). + *

Note that transaction synchronization is not available within a + * {@code PROPAGATION_NOT_SUPPORTED} scope. Existing synchronizations + * will be suspended and resumed appropriately. + * @see org.springframework.transaction.jta.JtaTransactionManager#setTransactionManager + */ + int PROPAGATION_NOT_SUPPORTED = 4; + + /** + * Do not support a current transaction; throw an exception if a current transaction + * exists. Analogous to the EJB transaction attribute of the same name. + *

Note that transaction synchronization is not available within a + * {@code PROPAGATION_NEVER} scope. + */ + int PROPAGATION_NEVER = 5; + + /** + * Execute within a nested transaction if a current transaction exists, + * behave like {@link #PROPAGATION_REQUIRED} otherwise. There is no + * analogous feature in EJB. + *

NOTE: Actual creation of a nested transaction will only work on + * specific transaction managers. Out of the box, this only applies to the JDBC + * {@link org.springframework.jdbc.datasource.DataSourceTransactionManager} + * when working on a JDBC 3.0 driver. Some JTA providers might support + * nested transactions as well. + * @see org.springframework.jdbc.datasource.DataSourceTransactionManager + */ + int PROPAGATION_NESTED = 6; + + + /** + * Use the default isolation level of the underlying datastore. + * All other levels correspond to the JDBC isolation levels. + * @see java.sql.Connection + */ + int ISOLATION_DEFAULT = -1; + + /** + * Indicates that dirty reads, non-repeatable reads and phantom reads + * can occur. + *

This level allows a row changed by one transaction to be read by another + * transaction before any changes in that row have been committed (a "dirty read"). + * If any of the changes are rolled back, the second transaction will have + * retrieved an invalid row. + * @see java.sql.Connection#TRANSACTION_READ_UNCOMMITTED + */ + int ISOLATION_READ_UNCOMMITTED = Connection.TRANSACTION_READ_UNCOMMITTED; + + /** + * Indicates that dirty reads are prevented; non-repeatable reads and + * phantom reads can occur. + *

This level only prohibits a transaction from reading a row + * with uncommitted changes in it. + * @see java.sql.Connection#TRANSACTION_READ_COMMITTED + */ + int ISOLATION_READ_COMMITTED = Connection.TRANSACTION_READ_COMMITTED; + + /** + * Indicates that dirty reads and non-repeatable reads are prevented; + * phantom reads can occur. + *

This level prohibits a transaction from reading a row with uncommitted changes + * in it, and it also prohibits the situation where one transaction reads a row, + * a second transaction alters the row, and the first transaction re-reads the row, + * getting different values the second time (a "non-repeatable read"). + * @see java.sql.Connection#TRANSACTION_REPEATABLE_READ + */ + int ISOLATION_REPEATABLE_READ = Connection.TRANSACTION_REPEATABLE_READ; + + /** + * Indicates that dirty reads, non-repeatable reads and phantom reads + * are prevented. + *

This level includes the prohibitions in {@link #ISOLATION_REPEATABLE_READ} + * and further prohibits the situation where one transaction reads all rows that + * satisfy a {@code WHERE} condition, a second transaction inserts a row + * that satisfies that {@code WHERE} condition, and the first transaction + * re-reads for the same condition, retrieving the additional "phantom" row + * in the second read. + * @see java.sql.Connection#TRANSACTION_SERIALIZABLE + */ + int ISOLATION_SERIALIZABLE = Connection.TRANSACTION_SERIALIZABLE; + + + /** + * Use the default timeout of the underlying transaction system, + * or none if timeouts are not supported. + */ + int TIMEOUT_DEFAULT = -1; + + + /** + * Return the propagation behavior. + *

Must return one of the {@code PROPAGATION_XXX} constants + * defined on {@link TransactionDefinition this interface}. + * @return the propagation behavior + * @see #PROPAGATION_REQUIRED + * @see org.springframework.transaction.support.TransactionSynchronizationManager#isActualTransactionActive() + */ + int getPropagationBehavior(); + + /** + * Return the isolation level. + *

Must return one of the {@code ISOLATION_XXX} constants defined on + * {@link TransactionDefinition this interface}. Those constants are designed + * to match the values of the same constants on {@link java.sql.Connection}. + *

Exclusively designed for use with {@link #PROPAGATION_REQUIRED} or + * {@link #PROPAGATION_REQUIRES_NEW} since it only applies to newly started + * transactions. Consider switching the "validateExistingTransactions" flag to + * "true" on your transaction manager if you'd like isolation level declarations + * to get rejected when participating in an existing transaction with a different + * isolation level. + *

Note that a transaction manager that does not support custom isolation levels + * will throw an exception when given any other level than {@link #ISOLATION_DEFAULT}. + * @return the isolation level + * @see #ISOLATION_DEFAULT + * @see org.springframework.transaction.support.AbstractPlatformTransactionManager#setValidateExistingTransaction + */ + int getIsolationLevel(); + + /** + * Return the transaction timeout. + *

Must return a number of seconds, or {@link #TIMEOUT_DEFAULT}. + *

Exclusively designed for use with {@link #PROPAGATION_REQUIRED} or + * {@link #PROPAGATION_REQUIRES_NEW} since it only applies to newly started + * transactions. + *

Note that a transaction manager that does not support timeouts will throw + * an exception when given any other timeout than {@link #TIMEOUT_DEFAULT}. + * @return the transaction timeout + */ + int getTimeout(); + + /** + * Return whether to optimize as a read-only transaction. + *

The read-only flag applies to any transaction context, whether backed + * by an actual resource transaction ({@link #PROPAGATION_REQUIRED}/ + * {@link #PROPAGATION_REQUIRES_NEW}) or operating non-transactionally at + * the resource level ({@link #PROPAGATION_SUPPORTS}). In the latter case, + * the flag will only apply to managed resources within the application, + * such as a Hibernate {@code Session}. + *

This just serves as a hint for the actual transaction subsystem; + * it will not necessarily cause failure of write access attempts. + * A transaction manager which cannot interpret the read-only hint will + * not throw an exception when asked for a read-only transaction. + * @return {@code true} if the transaction is to be optimized as read-only + * @see org.springframework.transaction.support.TransactionSynchronization#beforeCommit(boolean) + * @see org.springframework.transaction.support.TransactionSynchronizationManager#isCurrentTransactionReadOnly() + */ + boolean isReadOnly(); + + /** + * Return the name of this transaction. Can be {@code null}. + *

This will be used as the transaction name to be shown in a + * transaction monitor, if applicable (for example, WebLogic's). + *

In case of Spring's declarative transactions, the exposed name will be + * the {@code fully-qualified class name + "." + method name} (by default). + * @return the name of this transaction + * @see org.springframework.transaction.interceptor.TransactionAspectSupport + * @see org.springframework.transaction.support.TransactionSynchronizationManager#getCurrentTransactionName() + */ + @Nullable + String getName(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/TransactionException.java b/spring-tx/src/main/java/org/springframework/transaction/TransactionException.java new file mode 100644 index 0000000000000000000000000000000000000000..85a0bf21d5038ef8b5db626ed667f0a13fa31716 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/TransactionException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import org.springframework.core.NestedRuntimeException; + +/** + * Superclass for all transaction exceptions. + * + * @author Rod Johnson + * @since 17.03.2003 + */ +@SuppressWarnings("serial") +public abstract class TransactionException extends NestedRuntimeException { + + /** + * Constructor for TransactionException. + * @param msg the detail message + */ + public TransactionException(String msg) { + super(msg); + } + + /** + * Constructor for TransactionException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public TransactionException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/TransactionStatus.java b/spring-tx/src/main/java/org/springframework/transaction/TransactionStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..557cd77204480e3bd33bf4559aeaf3c0d0667b8c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/TransactionStatus.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import java.io.Flushable; + +/** + * Representation of the status of a transaction. + * + *

Transactional code can use this to retrieve status information, + * and to programmatically request a rollback (instead of throwing + * an exception that causes an implicit rollback). + * + *

Includes the {@link SavepointManager} interface to provide access + * to savepoint management facilities. Note that savepoint management + * is only available if supported by the underlying transaction manager. + * + * @author Juergen Hoeller + * @since 27.03.2003 + * @see #setRollbackOnly() + * @see PlatformTransactionManager#getTransaction + * @see org.springframework.transaction.support.TransactionCallback#doInTransaction + * @see org.springframework.transaction.interceptor.TransactionInterceptor#currentTransactionStatus() + */ +public interface TransactionStatus extends SavepointManager, Flushable { + + /** + * Return whether the present transaction is new; otherwise participating + * in an existing transaction, or potentially not running in an actual + * transaction in the first place. + */ + boolean isNewTransaction(); + + /** + * Return whether this transaction internally carries a savepoint, + * that is, has been created as nested transaction based on a savepoint. + *

This method is mainly here for diagnostic purposes, alongside + * {@link #isNewTransaction()}. For programmatic handling of custom + * savepoints, use the operations provided by {@link SavepointManager}. + * @see #isNewTransaction() + * @see #createSavepoint() + * @see #rollbackToSavepoint(Object) + * @see #releaseSavepoint(Object) + */ + boolean hasSavepoint(); + + /** + * Set the transaction rollback-only. This instructs the transaction manager + * that the only possible outcome of the transaction may be a rollback, as + * alternative to throwing an exception which would in turn trigger a rollback. + *

This is mainly intended for transactions managed by + * {@link org.springframework.transaction.support.TransactionTemplate} or + * {@link org.springframework.transaction.interceptor.TransactionInterceptor}, + * where the actual commit/rollback decision is made by the container. + * @see org.springframework.transaction.support.TransactionCallback#doInTransaction + * @see org.springframework.transaction.interceptor.TransactionAttribute#rollbackOn + */ + void setRollbackOnly(); + + /** + * Return whether the transaction has been marked as rollback-only + * (either by the application or by the transaction infrastructure). + */ + boolean isRollbackOnly(); + + /** + * Flush the underlying session to the datastore, if applicable: + * for example, all affected Hibernate/JPA sessions. + *

This is effectively just a hint and may be a no-op if the underlying + * transaction manager does not have a flush concept. A flush signal may + * get applied to the primary resource or to transaction synchronizations, + * depending on the underlying resource. + */ + @Override + void flush(); + + /** + * Return whether this transaction is completed, that is, + * whether it has already been committed or rolled back. + * @see PlatformTransactionManager#commit + * @see PlatformTransactionManager#rollback + */ + boolean isCompleted(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/TransactionSuspensionNotSupportedException.java b/spring-tx/src/main/java/org/springframework/transaction/TransactionSuspensionNotSupportedException.java new file mode 100644 index 0000000000000000000000000000000000000000..47d48e3ea79987762d88a69832a11125317281fa --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/TransactionSuspensionNotSupportedException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception thrown when attempting to suspend an existing transaction + * but transaction suspension is not supported by the underlying backend. + * + * @author Juergen Hoeller + * @since 1.1 + */ +@SuppressWarnings("serial") +public class TransactionSuspensionNotSupportedException extends CannotCreateTransactionException { + + /** + * Constructor for TransactionSuspensionNotSupportedException. + * @param msg the detail message + */ + public TransactionSuspensionNotSupportedException(String msg) { + super(msg); + } + + /** + * Constructor for TransactionSuspensionNotSupportedException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public TransactionSuspensionNotSupportedException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/TransactionSystemException.java b/spring-tx/src/main/java/org/springframework/transaction/TransactionSystemException.java new file mode 100644 index 0000000000000000000000000000000000000000..eb1e833ccf6436e988da0a58530ea10a6434f629 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/TransactionSystemException.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Exception thrown when a general transaction system error is encountered, + * like on commit or rollback. + * + * @author Juergen Hoeller + * @since 24.03.2003 + */ +@SuppressWarnings("serial") +public class TransactionSystemException extends TransactionException { + + @Nullable + private Throwable applicationException; + + + /** + * Constructor for TransactionSystemException. + * @param msg the detail message + */ + public TransactionSystemException(String msg) { + super(msg); + } + + /** + * Constructor for TransactionSystemException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public TransactionSystemException(String msg, Throwable cause) { + super(msg, cause); + } + + + /** + * Set an application exception that was thrown before this transaction exception, + * preserving the original exception despite the overriding TransactionSystemException. + * @param ex the application exception + * @throws IllegalStateException if this TransactionSystemException already holds an + * application exception + */ + public void initApplicationException(Throwable ex) { + Assert.notNull(ex, "Application exception must not be null"); + if (this.applicationException != null) { + throw new IllegalStateException("Already holding an application exception: " + this.applicationException); + } + this.applicationException = ex; + } + + /** + * Return the application exception that was thrown before this transaction exception, + * if any. + * @return the application exception, or {@code null} if none set + */ + @Nullable + public final Throwable getApplicationException() { + return this.applicationException; + } + + /** + * Return the exception that was the first to be thrown within the failed transaction: + * i.e. the application exception, if any, or the TransactionSystemException's own cause. + * @return the original exception, or {@code null} if there was none + */ + @Nullable + public Throwable getOriginalException() { + return (this.applicationException != null ? this.applicationException : getCause()); + } + + @Override + public boolean contains(@Nullable Class exType) { + return super.contains(exType) || (exType != null && exType.isInstance(this.applicationException)); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/TransactionTimedOutException.java b/spring-tx/src/main/java/org/springframework/transaction/TransactionTimedOutException.java new file mode 100644 index 0000000000000000000000000000000000000000..60f41a22c511f78a41689741d64bb6f02cb234dd --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/TransactionTimedOutException.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Exception to be thrown when a transaction has timed out. + * + *

Thrown by Spring's local transaction strategies if the deadline + * for a transaction has been reached when an operation is attempted, + * according to the timeout specified for the given transaction. + * + *

Beyond such checks before each transactional operation, Spring's + * local transaction strategies will also pass appropriate timeout values + * to resource operations (for example to JDBC Statements, letting the JDBC + * driver respect the timeout). Such operations will usually throw native + * resource exceptions (for example, JDBC SQLExceptions) if their operation + * timeout has been exceeded, to be converted to Spring's DataAccessException + * in the respective DAO (which might use Spring's JdbcTemplate, for example). + * + *

In a JTA environment, it is up to the JTA transaction coordinator + * to apply transaction timeouts. Usually, the corresponding JTA-aware + * connection pool will perform timeout checks and throw corresponding + * native resource exceptions (for example, JDBC SQLExceptions). + * + * @author Juergen Hoeller + * @since 1.1.5 + * @see org.springframework.transaction.support.ResourceHolderSupport#getTimeToLiveInMillis + * @see java.sql.Statement#setQueryTimeout + * @see java.sql.SQLException + */ +@SuppressWarnings("serial") +public class TransactionTimedOutException extends TransactionException { + + /** + * Constructor for TransactionTimedOutException. + * @param msg the detail message + */ + public TransactionTimedOutException(String msg) { + super(msg); + } + + /** + * Constructor for TransactionTimedOutException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public TransactionTimedOutException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/TransactionUsageException.java b/spring-tx/src/main/java/org/springframework/transaction/TransactionUsageException.java new file mode 100644 index 0000000000000000000000000000000000000000..6308c701acc54bacf27c11f5b52aeee869332d0b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/TransactionUsageException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Superclass for exceptions caused by inappropriate usage of + * a Spring transaction API. + * + * @author Rod Johnson + * @since 22.03.2003 + */ +@SuppressWarnings("serial") +public class TransactionUsageException extends TransactionException { + + /** + * Constructor for TransactionUsageException. + * @param msg the detail message + */ + public TransactionUsageException(String msg) { + super(msg); + } + + /** + * Constructor for TransactionUsageException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public TransactionUsageException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/UnexpectedRollbackException.java b/spring-tx/src/main/java/org/springframework/transaction/UnexpectedRollbackException.java new file mode 100644 index 0000000000000000000000000000000000000000..530beee4ac6dfbe87a8fd0ff1e1b2456be9512e7 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/UnexpectedRollbackException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +/** + * Thrown when an attempt to commit a transaction resulted + * in an unexpected rollback. + * + * @author Rod Johnson + * @since 17.03.2003 + */ +@SuppressWarnings("serial") +public class UnexpectedRollbackException extends TransactionException { + + /** + * Constructor for UnexpectedRollbackException. + * @param msg the detail message + */ + public UnexpectedRollbackException(String msg) { + super(msg); + } + + /** + * Constructor for UnexpectedRollbackException. + * @param msg the detail message + * @param cause the root cause from the transaction API in use + */ + public UnexpectedRollbackException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/AbstractTransactionManagementConfiguration.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/AbstractTransactionManagementConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..93a88a80f9d1aa0e41f01864abce53d9ec42b2dc --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/AbstractTransactionManagementConfiguration.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.util.Collection; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportAware; +import org.springframework.context.annotation.Role; +import org.springframework.core.annotation.AnnotationAttributes; +import org.springframework.core.type.AnnotationMetadata; +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.config.TransactionManagementConfigUtils; +import org.springframework.transaction.event.TransactionalEventListenerFactory; +import org.springframework.util.CollectionUtils; + +/** + * Abstract base {@code @Configuration} class providing common structure for enabling + * Spring's annotation-driven transaction management capability. + * + * @author Chris Beams + * @author Stephane Nicoll + * @since 3.1 + * @see EnableTransactionManagement + */ +@Configuration +public abstract class AbstractTransactionManagementConfiguration implements ImportAware { + + @Nullable + protected AnnotationAttributes enableTx; + + /** + * Default transaction manager, as configured through a {@link TransactionManagementConfigurer}. + */ + @Nullable + protected PlatformTransactionManager txManager; + + + @Override + public void setImportMetadata(AnnotationMetadata importMetadata) { + this.enableTx = AnnotationAttributes.fromMap( + importMetadata.getAnnotationAttributes(EnableTransactionManagement.class.getName(), false)); + if (this.enableTx == null) { + throw new IllegalArgumentException( + "@EnableTransactionManagement is not present on importing class " + importMetadata.getClassName()); + } + } + + @Autowired(required = false) + void setConfigurers(Collection configurers) { + if (CollectionUtils.isEmpty(configurers)) { + return; + } + if (configurers.size() > 1) { + throw new IllegalStateException("Only one TransactionManagementConfigurer may exist"); + } + TransactionManagementConfigurer configurer = configurers.iterator().next(); + this.txManager = configurer.annotationDrivenTransactionManager(); + } + + + @Bean(name = TransactionManagementConfigUtils.TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME) + @Role(BeanDefinition.ROLE_INFRASTRUCTURE) + public static TransactionalEventListenerFactory transactionalEventListenerFactory() { + return new TransactionalEventListenerFactory(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/AnnotationTransactionAttributeSource.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/AnnotationTransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..e4f008ca0407c0a6d0ef66fa8f9aef50883f803e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/AnnotationTransactionAttributeSource.java @@ -0,0 +1,200 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.io.Serializable; +import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.interceptor.AbstractFallbackTransactionAttributeSource; +import org.springframework.transaction.interceptor.TransactionAttribute; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + +/** + * Implementation of the + * {@link org.springframework.transaction.interceptor.TransactionAttributeSource} + * interface for working with transaction metadata in JDK 1.5+ annotation format. + * + *

This class reads Spring's JDK 1.5+ {@link Transactional} annotation and + * exposes corresponding transaction attributes to Spring's transaction infrastructure. + * Also supports JTA 1.2's {@link javax.transaction.Transactional} and EJB3's + * {@link javax.ejb.TransactionAttribute} annotation (if present). + * This class may also serve as base class for a custom TransactionAttributeSource, + * or get customized through {@link TransactionAnnotationParser} strategies. + * + * @author Colin Sampaleanu + * @author Juergen Hoeller + * @since 1.2 + * @see Transactional + * @see TransactionAnnotationParser + * @see SpringTransactionAnnotationParser + * @see Ejb3TransactionAnnotationParser + * @see org.springframework.transaction.interceptor.TransactionInterceptor#setTransactionAttributeSource + * @see org.springframework.transaction.interceptor.TransactionProxyFactoryBean#setTransactionAttributeSource + */ +@SuppressWarnings("serial") +public class AnnotationTransactionAttributeSource extends AbstractFallbackTransactionAttributeSource + implements Serializable { + + private static final boolean jta12Present; + + private static final boolean ejb3Present; + + static { + ClassLoader classLoader = AnnotationTransactionAttributeSource.class.getClassLoader(); + jta12Present = ClassUtils.isPresent("javax.transaction.Transactional", classLoader); + ejb3Present = ClassUtils.isPresent("javax.ejb.TransactionAttribute", classLoader); + } + + private final boolean publicMethodsOnly; + + private final Set annotationParsers; + + + /** + * Create a default AnnotationTransactionAttributeSource, supporting + * public methods that carry the {@code Transactional} annotation + * or the EJB3 {@link javax.ejb.TransactionAttribute} annotation. + */ + public AnnotationTransactionAttributeSource() { + this(true); + } + + /** + * Create a custom AnnotationTransactionAttributeSource, supporting + * public methods that carry the {@code Transactional} annotation + * or the EJB3 {@link javax.ejb.TransactionAttribute} annotation. + * @param publicMethodsOnly whether to support public methods that carry + * the {@code Transactional} annotation only (typically for use + * with proxy-based AOP), or protected/private methods as well + * (typically used with AspectJ class weaving) + */ + public AnnotationTransactionAttributeSource(boolean publicMethodsOnly) { + this.publicMethodsOnly = publicMethodsOnly; + if (jta12Present || ejb3Present) { + this.annotationParsers = new LinkedHashSet<>(4); + this.annotationParsers.add(new SpringTransactionAnnotationParser()); + if (jta12Present) { + this.annotationParsers.add(new JtaTransactionAnnotationParser()); + } + if (ejb3Present) { + this.annotationParsers.add(new Ejb3TransactionAnnotationParser()); + } + } + else { + this.annotationParsers = Collections.singleton(new SpringTransactionAnnotationParser()); + } + } + + /** + * Create a custom AnnotationTransactionAttributeSource. + * @param annotationParser the TransactionAnnotationParser to use + */ + public AnnotationTransactionAttributeSource(TransactionAnnotationParser annotationParser) { + this.publicMethodsOnly = true; + Assert.notNull(annotationParser, "TransactionAnnotationParser must not be null"); + this.annotationParsers = Collections.singleton(annotationParser); + } + + /** + * Create a custom AnnotationTransactionAttributeSource. + * @param annotationParsers the TransactionAnnotationParsers to use + */ + public AnnotationTransactionAttributeSource(TransactionAnnotationParser... annotationParsers) { + this.publicMethodsOnly = true; + Assert.notEmpty(annotationParsers, "At least one TransactionAnnotationParser needs to be specified"); + this.annotationParsers = new LinkedHashSet<>(Arrays.asList(annotationParsers)); + } + + /** + * Create a custom AnnotationTransactionAttributeSource. + * @param annotationParsers the TransactionAnnotationParsers to use + */ + public AnnotationTransactionAttributeSource(Set annotationParsers) { + this.publicMethodsOnly = true; + Assert.notEmpty(annotationParsers, "At least one TransactionAnnotationParser needs to be specified"); + this.annotationParsers = annotationParsers; + } + + + @Override + @Nullable + protected TransactionAttribute findTransactionAttribute(Class clazz) { + return determineTransactionAttribute(clazz); + } + + @Override + @Nullable + protected TransactionAttribute findTransactionAttribute(Method method) { + return determineTransactionAttribute(method); + } + + /** + * Determine the transaction attribute for the given method or class. + *

This implementation delegates to configured + * {@link TransactionAnnotationParser TransactionAnnotationParsers} + * for parsing known annotations into Spring's metadata attribute class. + * Returns {@code null} if it's not transactional. + *

Can be overridden to support custom annotations that carry transaction metadata. + * @param element the annotated method or class + * @return the configured transaction attribute, or {@code null} if none was found + */ + @Nullable + protected TransactionAttribute determineTransactionAttribute(AnnotatedElement element) { + for (TransactionAnnotationParser annotationParser : this.annotationParsers) { + TransactionAttribute attr = annotationParser.parseTransactionAnnotation(element); + if (attr != null) { + return attr; + } + } + return null; + } + + /** + * By default, only public methods can be made transactional. + */ + @Override + protected boolean allowPublicMethodsOnly() { + return this.publicMethodsOnly; + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof AnnotationTransactionAttributeSource)) { + return false; + } + AnnotationTransactionAttributeSource otherTas = (AnnotationTransactionAttributeSource) other; + return (this.annotationParsers.equals(otherTas.annotationParsers) && + this.publicMethodsOnly == otherTas.publicMethodsOnly); + } + + @Override + public int hashCode() { + return this.annotationParsers.hashCode(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/Ejb3TransactionAnnotationParser.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/Ejb3TransactionAnnotationParser.java new file mode 100644 index 0000000000000000000000000000000000000000..f80630516d55585c60073105632507b52a03302d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/Ejb3TransactionAnnotationParser.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.io.Serializable; +import java.lang.reflect.AnnotatedElement; + +import javax.ejb.ApplicationException; +import javax.ejb.TransactionAttributeType; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.interceptor.DefaultTransactionAttribute; +import org.springframework.transaction.interceptor.TransactionAttribute; + +/** + * Strategy implementation for parsing EJB3's {@link javax.ejb.TransactionAttribute} + * annotation. + * + * @author Juergen Hoeller + * @since 2.5 + */ +@SuppressWarnings("serial") +public class Ejb3TransactionAnnotationParser implements TransactionAnnotationParser, Serializable { + + @Override + @Nullable + public TransactionAttribute parseTransactionAnnotation(AnnotatedElement element) { + javax.ejb.TransactionAttribute ann = element.getAnnotation(javax.ejb.TransactionAttribute.class); + if (ann != null) { + return parseTransactionAnnotation(ann); + } + else { + return null; + } + } + + public TransactionAttribute parseTransactionAnnotation(javax.ejb.TransactionAttribute ann) { + return new Ejb3TransactionAttribute(ann.value()); + } + + + @Override + public boolean equals(Object other) { + return (this == other || other instanceof Ejb3TransactionAnnotationParser); + } + + @Override + public int hashCode() { + return Ejb3TransactionAnnotationParser.class.hashCode(); + } + + + /** + * EJB3-specific TransactionAttribute, implementing EJB3's rollback rules + * which are based on annotated exceptions. + */ + private static class Ejb3TransactionAttribute extends DefaultTransactionAttribute { + + public Ejb3TransactionAttribute(TransactionAttributeType type) { + setPropagationBehaviorName(PREFIX_PROPAGATION + type.name()); + } + + @Override + public boolean rollbackOn(Throwable ex) { + ApplicationException ann = ex.getClass().getAnnotation(ApplicationException.class); + return (ann != null ? ann.rollback() : super.rollbackOn(ex)); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/EnableTransactionManagement.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/EnableTransactionManagement.java new file mode 100644 index 0000000000000000000000000000000000000000..1daef74c7320662fa92939f7eb00109fd608394a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/EnableTransactionManagement.java @@ -0,0 +1,191 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.annotation.AdviceMode; +import org.springframework.context.annotation.Import; +import org.springframework.core.Ordered; + +/** + * Enables Spring's annotation-driven transaction management capability, similar to + * the support found in Spring's {@code } XML namespace. To be used on + * {@link org.springframework.context.annotation.Configuration @Configuration} + * classes as follows: + * + *

+ * @Configuration
+ * @EnableTransactionManagement
+ * public class AppConfig {
+ *
+ *     @Bean
+ *     public FooRepository fooRepository() {
+ *         // configure and return a class having @Transactional methods
+ *         return new JdbcFooRepository(dataSource());
+ *     }
+ *
+ *     @Bean
+ *     public DataSource dataSource() {
+ *         // configure and return the necessary JDBC DataSource
+ *     }
+ *
+ *     @Bean
+ *     public PlatformTransactionManager txManager() {
+ *         return new DataSourceTransactionManager(dataSource());
+ *     }
+ * }
+ * + *

For reference, the example above can be compared to the following Spring XML + * configuration: + * + *

+ * <beans>
+ *
+ *     <tx:annotation-driven/>
+ *
+ *     <bean id="fooRepository" class="com.foo.JdbcFooRepository">
+ *         <constructor-arg ref="dataSource"/>
+ *     </bean>
+ *
+ *     <bean id="dataSource" class="com.vendor.VendorDataSource"/>
+ *
+ *     <bean id="transactionManager" class="org.sfwk...DataSourceTransactionManager">
+ *         <constructor-arg ref="dataSource"/>
+ *     </bean>
+ *
+ * </beans>
+ * 
+ * + * In both of the scenarios above, {@code @EnableTransactionManagement} and {@code + * } are responsible for registering the necessary Spring + * components that power annotation-driven transaction management, such as the + * TransactionInterceptor and the proxy- or AspectJ-based advice that weave the + * interceptor into the call stack when {@code JdbcFooRepository}'s {@code @Transactional} + * methods are invoked. + * + *

A minor difference between the two examples lies in the naming of the {@code + * PlatformTransactionManager} bean: In the {@code @Bean} case, the name is + * "txManager" (per the name of the method); in the XML case, the name is + * "transactionManager". The {@code } is hard-wired to + * look for a bean named "transactionManager" by default, however + * {@code @EnableTransactionManagement} is more flexible; it will fall back to a by-type + * lookup for any {@code PlatformTransactionManager} bean in the container. Thus the name + * can be "txManager", "transactionManager", or "tm": it simply does not matter. + * + *

For those that wish to establish a more direct relationship between + * {@code @EnableTransactionManagement} and the exact transaction manager bean to be used, + * the {@link TransactionManagementConfigurer} callback interface may be implemented - + * notice the {@code implements} clause and the {@code @Override}-annotated method below: + * + *

+ * @Configuration
+ * @EnableTransactionManagement
+ * public class AppConfig implements TransactionManagementConfigurer {
+ *
+ *     @Bean
+ *     public FooRepository fooRepository() {
+ *         // configure and return a class having @Transactional methods
+ *         return new JdbcFooRepository(dataSource());
+ *     }
+ *
+ *     @Bean
+ *     public DataSource dataSource() {
+ *         // configure and return the necessary JDBC DataSource
+ *     }
+ *
+ *     @Bean
+ *     public PlatformTransactionManager txManager() {
+ *         return new DataSourceTransactionManager(dataSource());
+ *     }
+ *
+ *     @Override
+ *     public PlatformTransactionManager annotationDrivenTransactionManager() {
+ *         return txManager();
+ *     }
+ * }
+ * + * This approach may be desirable simply because it is more explicit, or it may be + * necessary in order to distinguish between two {@code PlatformTransactionManager} beans + * present in the same container. As the name suggests, the + * {@code annotationDrivenTransactionManager()} will be the one used for processing + * {@code @Transactional} methods. See {@link TransactionManagementConfigurer} Javadoc + * for further details. + * + *

The {@link #mode} attribute controls how advice is applied: If the mode is + * {@link AdviceMode#PROXY} (the default), then the other attributes control the behavior + * of the proxying. Please note that proxy mode allows for interception of calls through + * the proxy only; local calls within the same class cannot get intercepted that way. + * + *

Note that if the {@linkplain #mode} is set to {@link AdviceMode#ASPECTJ}, then the + * value of the {@link #proxyTargetClass} attribute will be ignored. Note also that in + * this case the {@code spring-aspects} module JAR must be present on the classpath, with + * compile-time weaving or load-time weaving applying the aspect to the affected classes. + * There is no proxy involved in such a scenario; local calls will be intercepted as well. + * + * @author Chris Beams + * @author Juergen Hoeller + * @since 3.1 + * @see TransactionManagementConfigurer + * @see TransactionManagementConfigurationSelector + * @see ProxyTransactionManagementConfiguration + * @see org.springframework.transaction.aspectj.AspectJTransactionManagementConfiguration + */ +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Import(TransactionManagementConfigurationSelector.class) +public @interface EnableTransactionManagement { + + /** + * Indicate whether subclass-based (CGLIB) proxies are to be created ({@code true}) as + * opposed to standard Java interface-based proxies ({@code false}). The default is + * {@code false}. Applicable only if {@link #mode()} is set to + * {@link AdviceMode#PROXY}. + *

Note that setting this attribute to {@code true} will affect all + * Spring-managed beans requiring proxying, not just those marked with + * {@code @Transactional}. For example, other beans marked with Spring's + * {@code @Async} annotation will be upgraded to subclass proxying at the same + * time. This approach has no negative impact in practice unless one is explicitly + * expecting one type of proxy vs another, e.g. in tests. + */ + boolean proxyTargetClass() default false; + + /** + * Indicate how transactional advice should be applied. + *

The default is {@link AdviceMode#PROXY}. + * Please note that proxy mode allows for interception of calls through the proxy + * only. Local calls within the same class cannot get intercepted that way; an + * {@link Transactional} annotation on such a method within a local call will be + * ignored since Spring's interceptor does not even kick in for such a runtime + * scenario. For a more advanced mode of interception, consider switching this to + * {@link AdviceMode#ASPECTJ}. + */ + AdviceMode mode() default AdviceMode.PROXY; + + /** + * Indicate the ordering of the execution of the transaction advisor + * when multiple advices are applied at a specific joinpoint. + *

The default is {@link Ordered#LOWEST_PRECEDENCE}. + */ + int order() default Ordered.LOWEST_PRECEDENCE; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/Isolation.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/Isolation.java new file mode 100644 index 0000000000000000000000000000000000000000..6bc6e2125361100eb804ce058af5cff39ece7216 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/Isolation.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import org.springframework.transaction.TransactionDefinition; + +/** + * Enumeration that represents transaction isolation levels for use + * with the {@link Transactional} annotation, corresponding to the + * {@link TransactionDefinition} interface. + * + * @author Colin Sampaleanu + * @author Juergen Hoeller + * @since 1.2 + */ +public enum Isolation { + + /** + * Use the default isolation level of the underlying datastore. + * All other levels correspond to the JDBC isolation levels. + * @see java.sql.Connection + */ + DEFAULT(TransactionDefinition.ISOLATION_DEFAULT), + + /** + * A constant indicating that dirty reads, non-repeatable reads and phantom reads + * can occur. This level allows a row changed by one transaction to be read by + * another transaction before any changes in that row have been committed + * (a "dirty read"). If any of the changes are rolled back, the second + * transaction will have retrieved an invalid row. + * @see java.sql.Connection#TRANSACTION_READ_UNCOMMITTED + */ + READ_UNCOMMITTED(TransactionDefinition.ISOLATION_READ_UNCOMMITTED), + + /** + * A constant indicating that dirty reads are prevented; non-repeatable reads + * and phantom reads can occur. This level only prohibits a transaction + * from reading a row with uncommitted changes in it. + * @see java.sql.Connection#TRANSACTION_READ_COMMITTED + */ + READ_COMMITTED(TransactionDefinition.ISOLATION_READ_COMMITTED), + + /** + * A constant indicating that dirty reads and non-repeatable reads are + * prevented; phantom reads can occur. This level prohibits a transaction + * from reading a row with uncommitted changes in it, and it also prohibits + * the situation where one transaction reads a row, a second transaction + * alters the row, and the first transaction rereads the row, getting + * different values the second time (a "non-repeatable read"). + * @see java.sql.Connection#TRANSACTION_REPEATABLE_READ + */ + REPEATABLE_READ(TransactionDefinition.ISOLATION_REPEATABLE_READ), + + /** + * A constant indicating that dirty reads, non-repeatable reads and phantom + * reads are prevented. This level includes the prohibitions in + * {@code ISOLATION_REPEATABLE_READ} and further prohibits the situation + * where one transaction reads all rows that satisfy a {@code WHERE} + * condition, a second transaction inserts a row that satisfies that + * {@code WHERE} condition, and the first transaction rereads for the + * same condition, retrieving the additional "phantom" row in the second read. + * @see java.sql.Connection#TRANSACTION_SERIALIZABLE + */ + SERIALIZABLE(TransactionDefinition.ISOLATION_SERIALIZABLE); + + + private final int value; + + + Isolation(int value) { + this.value = value; + } + + public int value() { + return this.value; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/JtaTransactionAnnotationParser.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/JtaTransactionAnnotationParser.java new file mode 100644 index 0000000000000000000000000000000000000000..cc69944b8e584b14609e37d4619b8b35537ee005 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/JtaTransactionAnnotationParser.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.io.Serializable; +import java.lang.reflect.AnnotatedElement; +import java.util.ArrayList; +import java.util.List; + +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.AnnotationAttributes; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.lang.Nullable; +import org.springframework.transaction.interceptor.NoRollbackRuleAttribute; +import org.springframework.transaction.interceptor.RollbackRuleAttribute; +import org.springframework.transaction.interceptor.RuleBasedTransactionAttribute; +import org.springframework.transaction.interceptor.TransactionAttribute; + +/** + * Strategy implementation for parsing JTA 1.2's {@link javax.transaction.Transactional} annotation. + * + * @author Juergen Hoeller + * @since 4.0 + */ +@SuppressWarnings("serial") +public class JtaTransactionAnnotationParser implements TransactionAnnotationParser, Serializable { + + @Override + @Nullable + public TransactionAttribute parseTransactionAnnotation(AnnotatedElement element) { + AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes( + element, javax.transaction.Transactional.class); + if (attributes != null) { + return parseTransactionAnnotation(attributes); + } + else { + return null; + } + } + + public TransactionAttribute parseTransactionAnnotation(javax.transaction.Transactional ann) { + return parseTransactionAnnotation(AnnotationUtils.getAnnotationAttributes(ann, false, false)); + } + + protected TransactionAttribute parseTransactionAnnotation(AnnotationAttributes attributes) { + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + + rbta.setPropagationBehaviorName( + RuleBasedTransactionAttribute.PREFIX_PROPAGATION + attributes.getEnum("value").toString()); + + List rollbackRules = new ArrayList<>(); + for (Class rbRule : attributes.getClassArray("rollbackOn")) { + rollbackRules.add(new RollbackRuleAttribute(rbRule)); + } + for (Class rbRule : attributes.getClassArray("dontRollbackOn")) { + rollbackRules.add(new NoRollbackRuleAttribute(rbRule)); + } + rbta.setRollbackRules(rollbackRules); + + return rbta; + } + + + @Override + public boolean equals(Object other) { + return (this == other || other instanceof JtaTransactionAnnotationParser); + } + + @Override + public int hashCode() { + return JtaTransactionAnnotationParser.class.hashCode(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/Propagation.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/Propagation.java new file mode 100644 index 0000000000000000000000000000000000000000..1e235ce8770f496507d30bf6ffe14e9d52b5cb72 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/Propagation.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import org.springframework.transaction.TransactionDefinition; + +/** + * Enumeration that represents transaction propagation behaviors for use + * with the {@link Transactional} annotation, corresponding to the + * {@link TransactionDefinition} interface. + * + * @author Colin Sampaleanu + * @author Juergen Hoeller + * @since 1.2 + */ +public enum Propagation { + + /** + * Support a current transaction, create a new one if none exists. + * Analogous to EJB transaction attribute of the same name. + *

This is the default setting of a transaction annotation. + */ + REQUIRED(TransactionDefinition.PROPAGATION_REQUIRED), + + /** + * Support a current transaction, execute non-transactionally if none exists. + * Analogous to EJB transaction attribute of the same name. + *

Note: For transaction managers with transaction synchronization, + * {@code SUPPORTS} is slightly different from no transaction at all, + * as it defines a transaction scope that synchronization will apply for. + * As a consequence, the same resources (JDBC Connection, Hibernate Session, etc) + * will be shared for the entire specified scope. Note that this depends on + * the actual synchronization configuration of the transaction manager. + * @see org.springframework.transaction.support.AbstractPlatformTransactionManager#setTransactionSynchronization + */ + SUPPORTS(TransactionDefinition.PROPAGATION_SUPPORTS), + + /** + * Support a current transaction, throw an exception if none exists. + * Analogous to EJB transaction attribute of the same name. + */ + MANDATORY(TransactionDefinition.PROPAGATION_MANDATORY), + + /** + * Create a new transaction, and suspend the current transaction if one exists. + * Analogous to the EJB transaction attribute of the same name. + *

NOTE: Actual transaction suspension will not work out-of-the-box + * on all transaction managers. This in particular applies to + * {@link org.springframework.transaction.jta.JtaTransactionManager}, + * which requires the {@code javax.transaction.TransactionManager} to be + * made available to it (which is server-specific in standard Java EE). + * @see org.springframework.transaction.jta.JtaTransactionManager#setTransactionManager + */ + REQUIRES_NEW(TransactionDefinition.PROPAGATION_REQUIRES_NEW), + + /** + * Execute non-transactionally, suspend the current transaction if one exists. + * Analogous to EJB transaction attribute of the same name. + *

NOTE: Actual transaction suspension will not work out-of-the-box + * on all transaction managers. This in particular applies to + * {@link org.springframework.transaction.jta.JtaTransactionManager}, + * which requires the {@code javax.transaction.TransactionManager} to be + * made available to it (which is server-specific in standard Java EE). + * @see org.springframework.transaction.jta.JtaTransactionManager#setTransactionManager + */ + NOT_SUPPORTED(TransactionDefinition.PROPAGATION_NOT_SUPPORTED), + + /** + * Execute non-transactionally, throw an exception if a transaction exists. + * Analogous to EJB transaction attribute of the same name. + */ + NEVER(TransactionDefinition.PROPAGATION_NEVER), + + /** + * Execute within a nested transaction if a current transaction exists, + * behave like {@code REQUIRED} otherwise. There is no analogous feature in EJB. + *

Note: Actual creation of a nested transaction will only work on specific + * transaction managers. Out of the box, this only applies to the JDBC + * DataSourceTransactionManager. Some JTA providers might support nested + * transactions as well. + * @see org.springframework.jdbc.datasource.DataSourceTransactionManager + */ + NESTED(TransactionDefinition.PROPAGATION_NESTED); + + + private final int value; + + + Propagation(int value) { + this.value = value; + } + + public int value() { + return this.value; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/ProxyTransactionManagementConfiguration.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/ProxyTransactionManagementConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..f126d4d681f562df2485ca85da712db2a33e3175 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/ProxyTransactionManagementConfiguration.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Role; +import org.springframework.transaction.config.TransactionManagementConfigUtils; +import org.springframework.transaction.interceptor.BeanFactoryTransactionAttributeSourceAdvisor; +import org.springframework.transaction.interceptor.TransactionAttributeSource; +import org.springframework.transaction.interceptor.TransactionInterceptor; + +/** + * {@code @Configuration} class that registers the Spring infrastructure beans + * necessary to enable proxy-based annotation-driven transaction management. + * + * @author Chris Beams + * @since 3.1 + * @see EnableTransactionManagement + * @see TransactionManagementConfigurationSelector + */ +@Configuration +@Role(BeanDefinition.ROLE_INFRASTRUCTURE) +public class ProxyTransactionManagementConfiguration extends AbstractTransactionManagementConfiguration { + + @Bean(name = TransactionManagementConfigUtils.TRANSACTION_ADVISOR_BEAN_NAME) + @Role(BeanDefinition.ROLE_INFRASTRUCTURE) + public BeanFactoryTransactionAttributeSourceAdvisor transactionAdvisor() { + BeanFactoryTransactionAttributeSourceAdvisor advisor = new BeanFactoryTransactionAttributeSourceAdvisor(); + advisor.setTransactionAttributeSource(transactionAttributeSource()); + advisor.setAdvice(transactionInterceptor()); + if (this.enableTx != null) { + advisor.setOrder(this.enableTx.getNumber("order")); + } + return advisor; + } + + @Bean + @Role(BeanDefinition.ROLE_INFRASTRUCTURE) + public TransactionAttributeSource transactionAttributeSource() { + return new AnnotationTransactionAttributeSource(); + } + + @Bean + @Role(BeanDefinition.ROLE_INFRASTRUCTURE) + public TransactionInterceptor transactionInterceptor() { + TransactionInterceptor interceptor = new TransactionInterceptor(); + interceptor.setTransactionAttributeSource(transactionAttributeSource()); + if (this.txManager != null) { + interceptor.setTransactionManager(this.txManager); + } + return interceptor; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/SpringTransactionAnnotationParser.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/SpringTransactionAnnotationParser.java new file mode 100644 index 0000000000000000000000000000000000000000..d2ff739575f764a6f6125412d94a92ec508cc8c6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/SpringTransactionAnnotationParser.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.io.Serializable; +import java.lang.reflect.AnnotatedElement; +import java.util.ArrayList; +import java.util.List; + +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.AnnotationAttributes; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.lang.Nullable; +import org.springframework.transaction.interceptor.NoRollbackRuleAttribute; +import org.springframework.transaction.interceptor.RollbackRuleAttribute; +import org.springframework.transaction.interceptor.RuleBasedTransactionAttribute; +import org.springframework.transaction.interceptor.TransactionAttribute; + +/** + * Strategy implementation for parsing Spring's {@link Transactional} annotation. + * + * @author Juergen Hoeller + * @since 2.5 + */ +@SuppressWarnings("serial") +public class SpringTransactionAnnotationParser implements TransactionAnnotationParser, Serializable { + + @Override + @Nullable + public TransactionAttribute parseTransactionAnnotation(AnnotatedElement element) { + AnnotationAttributes attributes = AnnotatedElementUtils.findMergedAnnotationAttributes( + element, Transactional.class, false, false); + if (attributes != null) { + return parseTransactionAnnotation(attributes); + } + else { + return null; + } + } + + public TransactionAttribute parseTransactionAnnotation(Transactional ann) { + return parseTransactionAnnotation(AnnotationUtils.getAnnotationAttributes(ann, false, false)); + } + + protected TransactionAttribute parseTransactionAnnotation(AnnotationAttributes attributes) { + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + + Propagation propagation = attributes.getEnum("propagation"); + rbta.setPropagationBehavior(propagation.value()); + Isolation isolation = attributes.getEnum("isolation"); + rbta.setIsolationLevel(isolation.value()); + rbta.setTimeout(attributes.getNumber("timeout").intValue()); + rbta.setReadOnly(attributes.getBoolean("readOnly")); + rbta.setQualifier(attributes.getString("value")); + + List rollbackRules = new ArrayList<>(); + for (Class rbRule : attributes.getClassArray("rollbackFor")) { + rollbackRules.add(new RollbackRuleAttribute(rbRule)); + } + for (String rbRule : attributes.getStringArray("rollbackForClassName")) { + rollbackRules.add(new RollbackRuleAttribute(rbRule)); + } + for (Class rbRule : attributes.getClassArray("noRollbackFor")) { + rollbackRules.add(new NoRollbackRuleAttribute(rbRule)); + } + for (String rbRule : attributes.getStringArray("noRollbackForClassName")) { + rollbackRules.add(new NoRollbackRuleAttribute(rbRule)); + } + rbta.setRollbackRules(rollbackRules); + + return rbta; + } + + + @Override + public boolean equals(Object other) { + return (this == other || other instanceof SpringTransactionAnnotationParser); + } + + @Override + public int hashCode() { + return SpringTransactionAnnotationParser.class.hashCode(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionAnnotationParser.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionAnnotationParser.java new file mode 100644 index 0000000000000000000000000000000000000000..4561cbc2ce66db7e3485f2ec3caff2d2ef00577a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionAnnotationParser.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.lang.reflect.AnnotatedElement; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.interceptor.TransactionAttribute; + +/** + * Strategy interface for parsing known transaction annotation types. + * {@link AnnotationTransactionAttributeSource} delegates to such + * parsers for supporting specific annotation types such as Spring's own + * {@link Transactional}, JTA 1.2's {@link javax.transaction.Transactional} + * or EJB3's {@link javax.ejb.TransactionAttribute}. + * + * @author Juergen Hoeller + * @since 2.5 + * @see AnnotationTransactionAttributeSource + * @see SpringTransactionAnnotationParser + * @see Ejb3TransactionAnnotationParser + * @see JtaTransactionAnnotationParser + */ +public interface TransactionAnnotationParser { + + /** + * Parse the transaction attribute for the given method or class, + * based on an annotation type understood by this parser. + *

This essentially parses a known transaction annotation into Spring's metadata + * attribute class. Returns {@code null} if the method/class is not transactional. + * @param element the annotated method or class + * @return the configured transaction attribute, or {@code null} if none found + * @see AnnotationTransactionAttributeSource#determineTransactionAttribute + */ + @Nullable + TransactionAttribute parseTransactionAnnotation(AnnotatedElement element); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionManagementConfigurationSelector.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionManagementConfigurationSelector.java new file mode 100644 index 0000000000000000000000000000000000000000..5e16f915f81c112cfb202df99de65b5e53c7fb3e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionManagementConfigurationSelector.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import org.springframework.context.annotation.AdviceMode; +import org.springframework.context.annotation.AdviceModeImportSelector; +import org.springframework.context.annotation.AutoProxyRegistrar; +import org.springframework.transaction.config.TransactionManagementConfigUtils; +import org.springframework.util.ClassUtils; + +/** + * Selects which implementation of {@link AbstractTransactionManagementConfiguration} + * should be used based on the value of {@link EnableTransactionManagement#mode} on the + * importing {@code @Configuration} class. + * + * @author Chris Beams + * @author Juergen Hoeller + * @since 3.1 + * @see EnableTransactionManagement + * @see ProxyTransactionManagementConfiguration + * @see TransactionManagementConfigUtils#TRANSACTION_ASPECT_CONFIGURATION_CLASS_NAME + * @see TransactionManagementConfigUtils#JTA_TRANSACTION_ASPECT_CONFIGURATION_CLASS_NAME + */ +public class TransactionManagementConfigurationSelector extends AdviceModeImportSelector { + + /** + * Returns {@link ProxyTransactionManagementConfiguration} or + * {@code AspectJ(Jta)TransactionManagementConfiguration} for {@code PROXY} + * and {@code ASPECTJ} values of {@link EnableTransactionManagement#mode()}, + * respectively. + */ + @Override + protected String[] selectImports(AdviceMode adviceMode) { + switch (adviceMode) { + case PROXY: + return new String[] {AutoProxyRegistrar.class.getName(), + ProxyTransactionManagementConfiguration.class.getName()}; + case ASPECTJ: + return new String[] {determineTransactionAspectClass()}; + default: + return null; + } + } + + private String determineTransactionAspectClass() { + return (ClassUtils.isPresent("javax.transaction.Transactional", getClass().getClassLoader()) ? + TransactionManagementConfigUtils.JTA_TRANSACTION_ASPECT_CONFIGURATION_CLASS_NAME : + TransactionManagementConfigUtils.TRANSACTION_ASPECT_CONFIGURATION_CLASS_NAME); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionManagementConfigurer.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionManagementConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..c7d6a07f468b3a6db8441200c30ddfa2bfca71f7 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/TransactionManagementConfigurer.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import org.springframework.transaction.PlatformTransactionManager; + +/** + * Interface to be implemented by @{@link org.springframework.context.annotation.Configuration + * Configuration} classes annotated with @{@link EnableTransactionManagement} that wish to + * or need to explicitly specify the default {@link PlatformTransactionManager} bean to be + * used for annotation-driven transaction management, as opposed to the default approach + * of a by-type lookup. One reason this might be necessary is if there are two + * {@code PlatformTransactionManager} beans present in the container. + * + *

See @{@link EnableTransactionManagement} for general examples and context; + * see {@link #annotationDrivenTransactionManager()} for detailed instructions. + * + *

Note that in by-type lookup disambiguation cases, an alternative approach to + * implementing this interface is to simply mark one of the offending + * {@code PlatformTransactionManager} {@code @Bean} methods as + * {@link org.springframework.context.annotation.Primary @Primary}. + * This is even generally preferred since it doesn't lead to early initialization + * of the {@code PlatformTransactionManager} bean. + * + * @author Chris Beams + * @since 3.1 + * @see EnableTransactionManagement + * @see org.springframework.context.annotation.Primary + */ +public interface TransactionManagementConfigurer { + + /** + * Return the default transaction manager bean to use for annotation-driven database + * transaction management, i.e. when processing {@code @Transactional} methods. + *

There are two basic approaches to implementing this method: + *

1. Implement the method and annotate it with {@code @Bean}

+ * In this case, the implementing {@code @Configuration} class implements this method, + * marks it with {@code @Bean} and configures and returns the transaction manager + * directly within the method body: + *
+	 * @Bean
+	 * @Override
+	 * public PlatformTransactionManager annotationDrivenTransactionManager() {
+	 *     return new DataSourceTransactionManager(dataSource());
+	 * }
+ *

2. Implement the method without {@code @Bean} and delegate to another existing + * {@code @Bean} method

+ *
+	 * @Bean
+	 * public PlatformTransactionManager txManager() {
+	 *     return new DataSourceTransactionManager(dataSource());
+	 * }
+	 *
+	 * @Override
+	 * public PlatformTransactionManager annotationDrivenTransactionManager() {
+	 *     return txManager(); // reference the existing {@code @Bean} method above
+	 * }
+ * If taking approach #2, be sure that only one of the methods is marked + * with {@code @Bean}! + *

In either scenario #1 or #2, it is important that the + * {@code PlatformTransactionManager} instance is managed as a Spring bean within the + * container as all {@code PlatformTransactionManager} implementations take advantage + * of Spring lifecycle callbacks such as {@code InitializingBean} and + * {@code BeanFactoryAware}. + */ + PlatformTransactionManager annotationDrivenTransactionManager(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/Transactional.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/Transactional.java new file mode 100644 index 0000000000000000000000000000000000000000..87a5f56711cc1326bd057350fe461fa880de8c92 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/Transactional.java @@ -0,0 +1,188 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.transaction.TransactionDefinition; + +/** + * Describes a transaction attribute on an individual method or on a class. + * + *

At the class level, this annotation applies as a default to all methods of + * the declaring class and its subclasses. Note that it does not apply to ancestor + * classes up the class hierarchy; methods need to be locally redeclared in order + * to participate in a subclass-level annotation. + * + *

This annotation type is generally directly comparable to Spring's + * {@link org.springframework.transaction.interceptor.RuleBasedTransactionAttribute} + * class, and in fact {@link AnnotationTransactionAttributeSource} will directly + * convert the data to the latter class, so that Spring's transaction support code + * does not have to know about annotations. If no custom rollback rules apply, + * the transaction will roll back on {@link RuntimeException} and {@link Error} + * but not on checked exceptions. + * + *

For specific information about the semantics of this annotation's attributes, + * consult the {@link org.springframework.transaction.TransactionDefinition} and + * {@link org.springframework.transaction.interceptor.TransactionAttribute} javadocs. + * + * @author Colin Sampaleanu + * @author Juergen Hoeller + * @author Sam Brannen + * @since 1.2 + * @see org.springframework.transaction.interceptor.TransactionAttribute + * @see org.springframework.transaction.interceptor.DefaultTransactionAttribute + * @see org.springframework.transaction.interceptor.RuleBasedTransactionAttribute + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Inherited +@Documented +public @interface Transactional { + + /** + * Alias for {@link #transactionManager}. + * @see #transactionManager + */ + @AliasFor("transactionManager") + String value() default ""; + + /** + * A qualifier value for the specified transaction. + *

May be used to determine the target transaction manager, + * matching the qualifier value (or the bean name) of a specific + * {@link org.springframework.transaction.PlatformTransactionManager} + * bean definition. + * @since 4.2 + * @see #value + */ + @AliasFor("value") + String transactionManager() default ""; + + /** + * The transaction propagation type. + *

Defaults to {@link Propagation#REQUIRED}. + * @see org.springframework.transaction.interceptor.TransactionAttribute#getPropagationBehavior() + */ + Propagation propagation() default Propagation.REQUIRED; + + /** + * The transaction isolation level. + *

Defaults to {@link Isolation#DEFAULT}. + *

Exclusively designed for use with {@link Propagation#REQUIRED} or + * {@link Propagation#REQUIRES_NEW} since it only applies to newly started + * transactions. Consider switching the "validateExistingTransactions" flag to + * "true" on your transaction manager if you'd like isolation level declarations + * to get rejected when participating in an existing transaction with a different + * isolation level. + * @see org.springframework.transaction.interceptor.TransactionAttribute#getIsolationLevel() + * @see org.springframework.transaction.support.AbstractPlatformTransactionManager#setValidateExistingTransaction + */ + Isolation isolation() default Isolation.DEFAULT; + + /** + * The timeout for this transaction (in seconds). + *

Defaults to the default timeout of the underlying transaction system. + *

Exclusively designed for use with {@link Propagation#REQUIRED} or + * {@link Propagation#REQUIRES_NEW} since it only applies to newly started + * transactions. + * @see org.springframework.transaction.interceptor.TransactionAttribute#getTimeout() + */ + int timeout() default TransactionDefinition.TIMEOUT_DEFAULT; + + /** + * A boolean flag that can be set to {@code true} if the transaction is + * effectively read-only, allowing for corresponding optimizations at runtime. + *

Defaults to {@code false}. + *

This just serves as a hint for the actual transaction subsystem; + * it will not necessarily cause failure of write access attempts. + * A transaction manager which cannot interpret the read-only hint will + * not throw an exception when asked for a read-only transaction + * but rather silently ignore the hint. + * @see org.springframework.transaction.interceptor.TransactionAttribute#isReadOnly() + * @see org.springframework.transaction.support.TransactionSynchronizationManager#isCurrentTransactionReadOnly() + */ + boolean readOnly() default false; + + /** + * Defines zero (0) or more exception {@link Class classes}, which must be + * subclasses of {@link Throwable}, indicating which exception types must cause + * a transaction rollback. + *

By default, a transaction will be rolling back on {@link RuntimeException} + * and {@link Error} but not on checked exceptions (business exceptions). See + * {@link org.springframework.transaction.interceptor.DefaultTransactionAttribute#rollbackOn(Throwable)} + * for a detailed explanation. + *

This is the preferred way to construct a rollback rule (in contrast to + * {@link #rollbackForClassName}), matching the exception class and its subclasses. + *

Similar to {@link org.springframework.transaction.interceptor.RollbackRuleAttribute#RollbackRuleAttribute(Class clazz)}. + * @see #rollbackForClassName + * @see org.springframework.transaction.interceptor.DefaultTransactionAttribute#rollbackOn(Throwable) + */ + Class[] rollbackFor() default {}; + + /** + * Defines zero (0) or more exception names (for exceptions which must be a + * subclass of {@link Throwable}), indicating which exception types must cause + * a transaction rollback. + *

This can be a substring of a fully qualified class name, with no wildcard + * support at present. For example, a value of {@code "ServletException"} would + * match {@code javax.servlet.ServletException} and its subclasses. + *

NB: Consider carefully how specific the pattern is and whether + * to include package information (which isn't mandatory). For example, + * {@code "Exception"} will match nearly anything and will probably hide other + * rules. {@code "java.lang.Exception"} would be correct if {@code "Exception"} + * were meant to define a rule for all checked exceptions. With more unusual + * {@link Exception} names such as {@code "BaseBusinessException"} there is no + * need to use a FQN. + *

Similar to {@link org.springframework.transaction.interceptor.RollbackRuleAttribute#RollbackRuleAttribute(String exceptionName)}. + * @see #rollbackFor + * @see org.springframework.transaction.interceptor.DefaultTransactionAttribute#rollbackOn(Throwable) + */ + String[] rollbackForClassName() default {}; + + /** + * Defines zero (0) or more exception {@link Class Classes}, which must be + * subclasses of {@link Throwable}, indicating which exception types must + * not cause a transaction rollback. + *

This is the preferred way to construct a rollback rule (in contrast + * to {@link #noRollbackForClassName}), matching the exception class and + * its subclasses. + *

Similar to {@link org.springframework.transaction.interceptor.NoRollbackRuleAttribute#NoRollbackRuleAttribute(Class clazz)}. + * @see #noRollbackForClassName + * @see org.springframework.transaction.interceptor.DefaultTransactionAttribute#rollbackOn(Throwable) + */ + Class[] noRollbackFor() default {}; + + /** + * Defines zero (0) or more exception names (for exceptions which must be a + * subclass of {@link Throwable}) indicating which exception types must not + * cause a transaction rollback. + *

See the description of {@link #rollbackForClassName} for further + * information on how the specified names are treated. + *

Similar to {@link org.springframework.transaction.interceptor.NoRollbackRuleAttribute#NoRollbackRuleAttribute(String exceptionName)}. + * @see #noRollbackFor + * @see org.springframework.transaction.interceptor.DefaultTransactionAttribute#rollbackOn(Throwable) + */ + String[] noRollbackForClassName() default {}; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/package-info.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..fdef1fa4b563ea24da5f42b03d95f88c01ee56f2 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/package-info.java @@ -0,0 +1,11 @@ +/** + * Spring's support for annotation-based transaction demarcation. + * Hooked into Spring's transaction interception infrastructure + * via a special TransactionAttributeSource implementation. + */ +@NonNullApi +@NonNullFields +package org.springframework.transaction.annotation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/AnnotationDrivenBeanDefinitionParser.java b/spring-tx/src/main/java/org/springframework/transaction/config/AnnotationDrivenBeanDefinitionParser.java new file mode 100644 index 0000000000000000000000000000000000000000..8d6ba603deb1a232750cf782e02b8f412560f3ea --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/AnnotationDrivenBeanDefinitionParser.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import org.w3c.dom.Element; + +import org.springframework.aop.config.AopNamespaceUtils; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.parsing.BeanComponentDefinition; +import org.springframework.beans.factory.parsing.CompositeComponentDefinition; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.xml.BeanDefinitionParser; +import org.springframework.beans.factory.xml.ParserContext; +import org.springframework.lang.Nullable; +import org.springframework.transaction.event.TransactionalEventListenerFactory; +import org.springframework.transaction.interceptor.BeanFactoryTransactionAttributeSourceAdvisor; +import org.springframework.transaction.interceptor.TransactionInterceptor; +import org.springframework.util.ClassUtils; + +/** + * {@link org.springframework.beans.factory.xml.BeanDefinitionParser + * BeanDefinitionParser} implementation that allows users to easily configure + * all the infrastructure beans required to enable annotation-driven transaction + * demarcation. + * + *

By default, all proxies are created as JDK proxies. This may cause some + * problems if you are injecting objects as concrete classes rather than + * interfaces. To overcome this restriction you can set the + * '{@code proxy-target-class}' attribute to '{@code true}', which + * will result in class-based proxies being created. + * + * @author Juergen Hoeller + * @author Rob Harrop + * @author Chris Beams + * @author Stephane Nicoll + * @since 2.0 + */ +class AnnotationDrivenBeanDefinitionParser implements BeanDefinitionParser { + + /** + * Parses the {@code } tag. Will + * {@link AopNamespaceUtils#registerAutoProxyCreatorIfNecessary register an AutoProxyCreator} + * with the container as necessary. + */ + @Override + @Nullable + public BeanDefinition parse(Element element, ParserContext parserContext) { + registerTransactionalEventListenerFactory(parserContext); + String mode = element.getAttribute("mode"); + if ("aspectj".equals(mode)) { + // mode="aspectj" + registerTransactionAspect(element, parserContext); + if (ClassUtils.isPresent("javax.transaction.Transactional", getClass().getClassLoader())) { + registerJtaTransactionAspect(element, parserContext); + } + } + else { + // mode="proxy" + AopAutoProxyConfigurer.configureAutoProxyCreator(element, parserContext); + } + return null; + } + + private void registerTransactionAspect(Element element, ParserContext parserContext) { + String txAspectBeanName = TransactionManagementConfigUtils.TRANSACTION_ASPECT_BEAN_NAME; + String txAspectClassName = TransactionManagementConfigUtils.TRANSACTION_ASPECT_CLASS_NAME; + if (!parserContext.getRegistry().containsBeanDefinition(txAspectBeanName)) { + RootBeanDefinition def = new RootBeanDefinition(); + def.setBeanClassName(txAspectClassName); + def.setFactoryMethodName("aspectOf"); + registerTransactionManager(element, def); + parserContext.registerBeanComponent(new BeanComponentDefinition(def, txAspectBeanName)); + } + } + + private void registerJtaTransactionAspect(Element element, ParserContext parserContext) { + String txAspectBeanName = TransactionManagementConfigUtils.JTA_TRANSACTION_ASPECT_BEAN_NAME; + String txAspectClassName = TransactionManagementConfigUtils.JTA_TRANSACTION_ASPECT_CLASS_NAME; + if (!parserContext.getRegistry().containsBeanDefinition(txAspectBeanName)) { + RootBeanDefinition def = new RootBeanDefinition(); + def.setBeanClassName(txAspectClassName); + def.setFactoryMethodName("aspectOf"); + registerTransactionManager(element, def); + parserContext.registerBeanComponent(new BeanComponentDefinition(def, txAspectBeanName)); + } + } + + private static void registerTransactionManager(Element element, BeanDefinition def) { + def.getPropertyValues().add("transactionManagerBeanName", + TxNamespaceHandler.getTransactionManagerName(element)); + } + + private void registerTransactionalEventListenerFactory(ParserContext parserContext) { + RootBeanDefinition def = new RootBeanDefinition(); + def.setBeanClass(TransactionalEventListenerFactory.class); + parserContext.registerBeanComponent(new BeanComponentDefinition(def, + TransactionManagementConfigUtils.TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME)); + } + + + /** + * Inner class to just introduce an AOP framework dependency when actually in proxy mode. + */ + private static class AopAutoProxyConfigurer { + + public static void configureAutoProxyCreator(Element element, ParserContext parserContext) { + AopNamespaceUtils.registerAutoProxyCreatorIfNecessary(parserContext, element); + + String txAdvisorBeanName = TransactionManagementConfigUtils.TRANSACTION_ADVISOR_BEAN_NAME; + if (!parserContext.getRegistry().containsBeanDefinition(txAdvisorBeanName)) { + Object eleSource = parserContext.extractSource(element); + + // Create the TransactionAttributeSource definition. + RootBeanDefinition sourceDef = new RootBeanDefinition( + "org.springframework.transaction.annotation.AnnotationTransactionAttributeSource"); + sourceDef.setSource(eleSource); + sourceDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + String sourceName = parserContext.getReaderContext().registerWithGeneratedName(sourceDef); + + // Create the TransactionInterceptor definition. + RootBeanDefinition interceptorDef = new RootBeanDefinition(TransactionInterceptor.class); + interceptorDef.setSource(eleSource); + interceptorDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + registerTransactionManager(element, interceptorDef); + interceptorDef.getPropertyValues().add("transactionAttributeSource", new RuntimeBeanReference(sourceName)); + String interceptorName = parserContext.getReaderContext().registerWithGeneratedName(interceptorDef); + + // Create the TransactionAttributeSourceAdvisor definition. + RootBeanDefinition advisorDef = new RootBeanDefinition(BeanFactoryTransactionAttributeSourceAdvisor.class); + advisorDef.setSource(eleSource); + advisorDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + advisorDef.getPropertyValues().add("transactionAttributeSource", new RuntimeBeanReference(sourceName)); + advisorDef.getPropertyValues().add("adviceBeanName", interceptorName); + if (element.hasAttribute("order")) { + advisorDef.getPropertyValues().add("order", element.getAttribute("order")); + } + parserContext.getRegistry().registerBeanDefinition(txAdvisorBeanName, advisorDef); + + CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(), eleSource); + compositeDef.addNestedComponent(new BeanComponentDefinition(sourceDef, sourceName)); + compositeDef.addNestedComponent(new BeanComponentDefinition(interceptorDef, interceptorName)); + compositeDef.addNestedComponent(new BeanComponentDefinition(advisorDef, txAdvisorBeanName)); + parserContext.registerComponent(compositeDef); + } + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/JtaTransactionManagerBeanDefinitionParser.java b/spring-tx/src/main/java/org/springframework/transaction/config/JtaTransactionManagerBeanDefinitionParser.java new file mode 100644 index 0000000000000000000000000000000000000000..cae08558e4a18cf3260bf11925bcc7b5a9718d5b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/JtaTransactionManagerBeanDefinitionParser.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import org.w3c.dom.Element; + +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser; +import org.springframework.beans.factory.xml.ParserContext; + +/** + * Parser for the <tx:jta-transaction-manager/> XML configuration element, + * autodetecting WebLogic and WebSphere servers and exposing the corresponding + * {@link org.springframework.transaction.jta.JtaTransactionManager} subclass. + * + * @author Juergen Hoeller + * @author Christian Dupuis + * @since 2.5 + * @see org.springframework.transaction.jta.WebLogicJtaTransactionManager + * @see org.springframework.transaction.jta.WebSphereUowTransactionManager + */ +public class JtaTransactionManagerBeanDefinitionParser extends AbstractSingleBeanDefinitionParser { + + @Override + protected String getBeanClassName(Element element) { + return JtaTransactionManagerFactoryBean.resolveJtaTransactionManagerClassName(); + } + + @Override + protected String resolveId(Element element, AbstractBeanDefinition definition, ParserContext parserContext) { + return TxNamespaceHandler.DEFAULT_TRANSACTION_MANAGER_BEAN_NAME; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/JtaTransactionManagerFactoryBean.java b/spring-tx/src/main/java/org/springframework/transaction/config/JtaTransactionManagerFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..3c25b3ba6d323509f85b4aa8a7badc715c68ce88 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/JtaTransactionManagerFactoryBean.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.lang.Nullable; +import org.springframework.transaction.jta.JtaTransactionManager; +import org.springframework.util.ClassUtils; + +/** + * A {@link FactoryBean} equivalent to the <tx:jta-transaction-manager/> XML element, + * autodetecting WebLogic and WebSphere servers and exposing the corresponding + * {@link org.springframework.transaction.jta.JtaTransactionManager} subclass. + * + * @author Juergen Hoeller + * @since 4.1.1 + * @see org.springframework.transaction.jta.WebLogicJtaTransactionManager + * @see org.springframework.transaction.jta.WebSphereUowTransactionManager + */ +public class JtaTransactionManagerFactoryBean implements FactoryBean { + + private static final String WEBLOGIC_JTA_TRANSACTION_MANAGER_CLASS_NAME = + "org.springframework.transaction.jta.WebLogicJtaTransactionManager"; + + private static final String WEBSPHERE_TRANSACTION_MANAGER_CLASS_NAME = + "org.springframework.transaction.jta.WebSphereUowTransactionManager"; + + private static final String JTA_TRANSACTION_MANAGER_CLASS_NAME = + "org.springframework.transaction.jta.JtaTransactionManager"; + + + private static final boolean weblogicPresent; + + private static final boolean webspherePresent; + + static { + ClassLoader classLoader = JtaTransactionManagerFactoryBean.class.getClassLoader(); + weblogicPresent = ClassUtils.isPresent("weblogic.transaction.UserTransaction", classLoader); + webspherePresent = ClassUtils.isPresent("com.ibm.wsspi.uow.UOWManager", classLoader); + } + + + @Nullable + private final JtaTransactionManager transactionManager; + + + @SuppressWarnings("unchecked") + public JtaTransactionManagerFactoryBean() { + String className = resolveJtaTransactionManagerClassName(); + try { + Class clazz = (Class) + ClassUtils.forName(className, JtaTransactionManagerFactoryBean.class.getClassLoader()); + this.transactionManager = BeanUtils.instantiateClass(clazz); + } + catch (ClassNotFoundException ex) { + throw new IllegalStateException("Failed to load JtaTransactionManager class: " + className, ex); + } + } + + + @Override + @Nullable + public JtaTransactionManager getObject() { + return this.transactionManager; + } + + @Override + public Class getObjectType() { + return (this.transactionManager != null ? this.transactionManager.getClass() : JtaTransactionManager.class); + } + + @Override + public boolean isSingleton() { + return true; + } + + + static String resolveJtaTransactionManagerClassName() { + if (weblogicPresent) { + return WEBLOGIC_JTA_TRANSACTION_MANAGER_CLASS_NAME; + } + else if (webspherePresent) { + return WEBSPHERE_TRANSACTION_MANAGER_CLASS_NAME; + } + else { + return JTA_TRANSACTION_MANAGER_CLASS_NAME; + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/TransactionManagementConfigUtils.java b/spring-tx/src/main/java/org/springframework/transaction/config/TransactionManagementConfigUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..d54994b47f972b89b283b902c7ba20fd290aa0fd --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/TransactionManagementConfigUtils.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +/** + * Configuration constants for internal sharing across subpackages. + * + * @author Chris Beams + * @author Stephane Nicoll + * @since 3.1 + */ +public abstract class TransactionManagementConfigUtils { + + /** + * The bean name of the internally managed transaction advisor (used when mode == PROXY). + */ + public static final String TRANSACTION_ADVISOR_BEAN_NAME = + "org.springframework.transaction.config.internalTransactionAdvisor"; + + /** + * The bean name of the internally managed transaction aspect (used when mode == ASPECTJ). + */ + public static final String TRANSACTION_ASPECT_BEAN_NAME = + "org.springframework.transaction.config.internalTransactionAspect"; + + /** + * The class name of the AspectJ transaction management aspect. + */ + public static final String TRANSACTION_ASPECT_CLASS_NAME = + "org.springframework.transaction.aspectj.AnnotationTransactionAspect"; + + /** + * The name of the AspectJ transaction management @{@code Configuration} class. + */ + public static final String TRANSACTION_ASPECT_CONFIGURATION_CLASS_NAME = + "org.springframework.transaction.aspectj.AspectJTransactionManagementConfiguration"; + + /** + * The bean name of the internally managed JTA transaction aspect (used when mode == ASPECTJ). + * @since 5.1 + */ + public static final String JTA_TRANSACTION_ASPECT_BEAN_NAME = + "org.springframework.transaction.config.internalJtaTransactionAspect"; + + /** + * The class name of the AspectJ transaction management aspect. + * @since 5.1 + */ + public static final String JTA_TRANSACTION_ASPECT_CLASS_NAME = + "org.springframework.transaction.aspectj.JtaAnnotationTransactionAspect"; + + /** + * The name of the AspectJ transaction management @{@code Configuration} class for JTA. + * @since 5.1 + */ + public static final String JTA_TRANSACTION_ASPECT_CONFIGURATION_CLASS_NAME = + "org.springframework.transaction.aspectj.AspectJJtaTransactionManagementConfiguration"; + + /** + * The bean name of the internally managed TransactionalEventListenerFactory. + */ + public static final String TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME = + "org.springframework.transaction.config.internalTransactionalEventListenerFactory"; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/TxAdviceBeanDefinitionParser.java b/spring-tx/src/main/java/org/springframework/transaction/config/TxAdviceBeanDefinitionParser.java new file mode 100644 index 0000000000000000000000000000000000000000..31da50dd410f5d678113c011c51583595ae7136c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/TxAdviceBeanDefinitionParser.java @@ -0,0 +1,164 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import java.util.LinkedList; +import java.util.List; + +import org.w3c.dom.Element; + +import org.springframework.beans.factory.config.TypedStringValue; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.ManagedMap; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser; +import org.springframework.beans.factory.xml.ParserContext; +import org.springframework.transaction.interceptor.NameMatchTransactionAttributeSource; +import org.springframework.transaction.interceptor.NoRollbackRuleAttribute; +import org.springframework.transaction.interceptor.RollbackRuleAttribute; +import org.springframework.transaction.interceptor.RuleBasedTransactionAttribute; +import org.springframework.transaction.interceptor.TransactionInterceptor; +import org.springframework.util.StringUtils; +import org.springframework.util.xml.DomUtils; + +/** + * {@link org.springframework.beans.factory.xml.BeanDefinitionParser + * BeanDefinitionParser} for the {@code } tag. + * + * @author Rob Harrop + * @author Juergen Hoeller + * @author Adrian Colyer + * @author Chris Beams + * @since 2.0 + */ +class TxAdviceBeanDefinitionParser extends AbstractSingleBeanDefinitionParser { + + private static final String METHOD_ELEMENT = "method"; + + private static final String METHOD_NAME_ATTRIBUTE = "name"; + + private static final String ATTRIBUTES_ELEMENT = "attributes"; + + private static final String TIMEOUT_ATTRIBUTE = "timeout"; + + private static final String READ_ONLY_ATTRIBUTE = "read-only"; + + private static final String PROPAGATION_ATTRIBUTE = "propagation"; + + private static final String ISOLATION_ATTRIBUTE = "isolation"; + + private static final String ROLLBACK_FOR_ATTRIBUTE = "rollback-for"; + + private static final String NO_ROLLBACK_FOR_ATTRIBUTE = "no-rollback-for"; + + + @Override + protected Class getBeanClass(Element element) { + return TransactionInterceptor.class; + } + + @Override + protected void doParse(Element element, ParserContext parserContext, BeanDefinitionBuilder builder) { + builder.addPropertyReference("transactionManager", TxNamespaceHandler.getTransactionManagerName(element)); + + List txAttributes = DomUtils.getChildElementsByTagName(element, ATTRIBUTES_ELEMENT); + if (txAttributes.size() > 1) { + parserContext.getReaderContext().error( + "Element is allowed at most once inside element ", element); + } + else if (txAttributes.size() == 1) { + // Using attributes source. + Element attributeSourceElement = txAttributes.get(0); + RootBeanDefinition attributeSourceDefinition = parseAttributeSource(attributeSourceElement, parserContext); + builder.addPropertyValue("transactionAttributeSource", attributeSourceDefinition); + } + else { + // Assume annotations source. + builder.addPropertyValue("transactionAttributeSource", + new RootBeanDefinition("org.springframework.transaction.annotation.AnnotationTransactionAttributeSource")); + } + } + + private RootBeanDefinition parseAttributeSource(Element attrEle, ParserContext parserContext) { + List methods = DomUtils.getChildElementsByTagName(attrEle, METHOD_ELEMENT); + ManagedMap transactionAttributeMap = + new ManagedMap<>(methods.size()); + transactionAttributeMap.setSource(parserContext.extractSource(attrEle)); + + for (Element methodEle : methods) { + String name = methodEle.getAttribute(METHOD_NAME_ATTRIBUTE); + TypedStringValue nameHolder = new TypedStringValue(name); + nameHolder.setSource(parserContext.extractSource(methodEle)); + + RuleBasedTransactionAttribute attribute = new RuleBasedTransactionAttribute(); + String propagation = methodEle.getAttribute(PROPAGATION_ATTRIBUTE); + String isolation = methodEle.getAttribute(ISOLATION_ATTRIBUTE); + String timeout = methodEle.getAttribute(TIMEOUT_ATTRIBUTE); + String readOnly = methodEle.getAttribute(READ_ONLY_ATTRIBUTE); + if (StringUtils.hasText(propagation)) { + attribute.setPropagationBehaviorName(RuleBasedTransactionAttribute.PREFIX_PROPAGATION + propagation); + } + if (StringUtils.hasText(isolation)) { + attribute.setIsolationLevelName(RuleBasedTransactionAttribute.PREFIX_ISOLATION + isolation); + } + if (StringUtils.hasText(timeout)) { + try { + attribute.setTimeout(Integer.parseInt(timeout)); + } + catch (NumberFormatException ex) { + parserContext.getReaderContext().error("Timeout must be an integer value: [" + timeout + "]", methodEle); + } + } + if (StringUtils.hasText(readOnly)) { + attribute.setReadOnly(Boolean.parseBoolean(methodEle.getAttribute(READ_ONLY_ATTRIBUTE))); + } + + List rollbackRules = new LinkedList<>(); + if (methodEle.hasAttribute(ROLLBACK_FOR_ATTRIBUTE)) { + String rollbackForValue = methodEle.getAttribute(ROLLBACK_FOR_ATTRIBUTE); + addRollbackRuleAttributesTo(rollbackRules, rollbackForValue); + } + if (methodEle.hasAttribute(NO_ROLLBACK_FOR_ATTRIBUTE)) { + String noRollbackForValue = methodEle.getAttribute(NO_ROLLBACK_FOR_ATTRIBUTE); + addNoRollbackRuleAttributesTo(rollbackRules, noRollbackForValue); + } + attribute.setRollbackRules(rollbackRules); + + transactionAttributeMap.put(nameHolder, attribute); + } + + RootBeanDefinition attributeSourceDefinition = new RootBeanDefinition(NameMatchTransactionAttributeSource.class); + attributeSourceDefinition.setSource(parserContext.extractSource(attrEle)); + attributeSourceDefinition.getPropertyValues().add("nameMap", transactionAttributeMap); + return attributeSourceDefinition; + } + + private void addRollbackRuleAttributesTo(List rollbackRules, String rollbackForValue) { + String[] exceptionTypeNames = StringUtils.commaDelimitedListToStringArray(rollbackForValue); + for (String typeName : exceptionTypeNames) { + rollbackRules.add(new RollbackRuleAttribute(StringUtils.trimWhitespace(typeName))); + } + } + + private void addNoRollbackRuleAttributesTo(List rollbackRules, String noRollbackForValue) { + String[] exceptionTypeNames = StringUtils.commaDelimitedListToStringArray(noRollbackForValue); + for (String typeName : exceptionTypeNames) { + rollbackRules.add(new NoRollbackRuleAttribute(StringUtils.trimWhitespace(typeName))); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/TxNamespaceHandler.java b/spring-tx/src/main/java/org/springframework/transaction/config/TxNamespaceHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..206f3b6894cad8ae5c73235e0a65265100f693eb --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/TxNamespaceHandler.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import org.w3c.dom.Element; + +import org.springframework.beans.factory.xml.NamespaceHandlerSupport; + +/** + * {@code NamespaceHandler} allowing for the configuration of + * declarative transaction management using either XML or using annotations. + * + *

This namespace handler is the central piece of functionality in the + * Spring transaction management facilities and offers two approaches + * to declaratively manage transactions. + * + *

One approach uses transaction semantics defined in XML using the + * {@code } elements, the other uses annotations + * in combination with the {@code } element. + * Both approached are detailed to great extent in the Spring reference manual. + * + * @author Rob Harrop + * @author Juergen Hoeller + * @since 2.0 + */ +public class TxNamespaceHandler extends NamespaceHandlerSupport { + + static final String TRANSACTION_MANAGER_ATTRIBUTE = "transaction-manager"; + + static final String DEFAULT_TRANSACTION_MANAGER_BEAN_NAME = "transactionManager"; + + + static String getTransactionManagerName(Element element) { + return (element.hasAttribute(TRANSACTION_MANAGER_ATTRIBUTE) ? + element.getAttribute(TRANSACTION_MANAGER_ATTRIBUTE) : DEFAULT_TRANSACTION_MANAGER_BEAN_NAME); + } + + + @Override + public void init() { + registerBeanDefinitionParser("advice", new TxAdviceBeanDefinitionParser()); + registerBeanDefinitionParser("annotation-driven", new AnnotationDrivenBeanDefinitionParser()); + registerBeanDefinitionParser("jta-transaction-manager", new JtaTransactionManagerBeanDefinitionParser()); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/config/package-info.java b/spring-tx/src/main/java/org/springframework/transaction/config/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..120413df0a660878c6819bfa1c471ac5fae55857 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/config/package-info.java @@ -0,0 +1,10 @@ +/** + * Support package for declarative transaction configuration, + * with XML schema being the primary configuration format. + */ +@NonNullApi +@NonNullFields +package org.springframework.transaction.config; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/transaction/event/ApplicationListenerMethodTransactionalAdapter.java b/spring-tx/src/main/java/org/springframework/transaction/event/ApplicationListenerMethodTransactionalAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..7444ec5b395dacae6e35cccba5fe84b2cbc245b3 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/event/ApplicationListenerMethodTransactionalAdapter.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.event; + +import java.lang.reflect.Method; + +import org.springframework.context.ApplicationEvent; +import org.springframework.context.event.ApplicationListenerMethodAdapter; +import org.springframework.context.event.EventListener; +import org.springframework.context.event.GenericApplicationListener; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationAdapter; +import org.springframework.transaction.support.TransactionSynchronizationManager; + +/** + * {@link GenericApplicationListener} adapter that delegates the processing of + * an event to a {@link TransactionalEventListener} annotated method. Supports + * the exact same features as any regular {@link EventListener} annotated method + * but is aware of the transactional context of the event publisher. + * + *

Processing of {@link TransactionalEventListener} is enabled automatically + * when Spring's transaction management is enabled. For other cases, registering + * a bean of type {@link TransactionalEventListenerFactory} is required. + * + * @author Stephane Nicoll + * @author Juergen Hoeller + * @since 4.2 + * @see ApplicationListenerMethodAdapter + * @see TransactionalEventListener + */ +class ApplicationListenerMethodTransactionalAdapter extends ApplicationListenerMethodAdapter { + + private final TransactionalEventListener annotation; + + + public ApplicationListenerMethodTransactionalAdapter(String beanName, Class targetClass, Method method) { + super(beanName, targetClass, method); + TransactionalEventListener ann = AnnotatedElementUtils.findMergedAnnotation(method, TransactionalEventListener.class); + if (ann == null) { + throw new IllegalStateException("No TransactionalEventListener annotation found on method: " + method); + } + this.annotation = ann; + } + + + @Override + public void onApplicationEvent(ApplicationEvent event) { + if (TransactionSynchronizationManager.isSynchronizationActive()) { + TransactionSynchronization transactionSynchronization = createTransactionSynchronization(event); + TransactionSynchronizationManager.registerSynchronization(transactionSynchronization); + } + else if (this.annotation.fallbackExecution()) { + if (this.annotation.phase() == TransactionPhase.AFTER_ROLLBACK && logger.isWarnEnabled()) { + logger.warn("Processing " + event + " as a fallback execution on AFTER_ROLLBACK phase"); + } + processEvent(event); + } + else { + // No transactional event execution at all + if (logger.isDebugEnabled()) { + logger.debug("No transaction is active - skipping " + event); + } + } + } + + private TransactionSynchronization createTransactionSynchronization(ApplicationEvent event) { + return new TransactionSynchronizationEventAdapter(this, event, this.annotation.phase()); + } + + + private static class TransactionSynchronizationEventAdapter extends TransactionSynchronizationAdapter { + + private final ApplicationListenerMethodAdapter listener; + + private final ApplicationEvent event; + + private final TransactionPhase phase; + + public TransactionSynchronizationEventAdapter(ApplicationListenerMethodAdapter listener, + ApplicationEvent event, TransactionPhase phase) { + + this.listener = listener; + this.event = event; + this.phase = phase; + } + + @Override + public int getOrder() { + return this.listener.getOrder(); + } + + @Override + public void beforeCommit(boolean readOnly) { + if (this.phase == TransactionPhase.BEFORE_COMMIT) { + processEvent(); + } + } + + @Override + public void afterCompletion(int status) { + if (this.phase == TransactionPhase.AFTER_COMMIT && status == STATUS_COMMITTED) { + processEvent(); + } + else if (this.phase == TransactionPhase.AFTER_ROLLBACK && status == STATUS_ROLLED_BACK) { + processEvent(); + } + else if (this.phase == TransactionPhase.AFTER_COMPLETION) { + processEvent(); + } + } + + protected void processEvent() { + this.listener.processEvent(this.event); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/event/TransactionPhase.java b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionPhase.java new file mode 100644 index 0000000000000000000000000000000000000000..d8a9ae7a13a2c69de449064c6d3d72b88a044d8b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionPhase.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.event; + +import org.springframework.transaction.support.TransactionSynchronization; + +/** + * The phase at which a transactional event listener applies. + * + * @author Stephane Nicoll + * @author Juergen Hoeller + * @since 4.2 + * @see TransactionalEventListener + */ +public enum TransactionPhase { + + /** + * Fire the event before transaction commit. + * @see TransactionSynchronization#beforeCommit(boolean) + */ + BEFORE_COMMIT, + + /** + * Fire the event after the commit has completed successfully. + *

Note: This is a specialization of {@link #AFTER_COMPLETION} and + * therefore executes in the same after-completion sequence of events, + * (and not in {@link TransactionSynchronization#afterCommit()}). + * @see TransactionSynchronization#afterCompletion(int) + * @see TransactionSynchronization#STATUS_COMMITTED + */ + AFTER_COMMIT, + + /** + * Fire the event if the transaction has rolled back. + *

Note: This is a specialization of {@link #AFTER_COMPLETION} and + * therefore executes in the same after-completion sequence of events. + * @see TransactionSynchronization#afterCompletion(int) + * @see TransactionSynchronization#STATUS_ROLLED_BACK + */ + AFTER_ROLLBACK, + + /** + * Fire the event after the transaction has completed. + *

For more fine-grained events, use {@link #AFTER_COMMIT} or + * {@link #AFTER_ROLLBACK} to intercept transaction commit + * or rollback, respectively. + * @see TransactionSynchronization#afterCompletion(int) + */ + AFTER_COMPLETION + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListener.java b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListener.java new file mode 100644 index 0000000000000000000000000000000000000000..0fb6fc6047cf70be88d8c6854c318bccfbe7f620 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListener.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.event; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.event.EventListener; +import org.springframework.core.annotation.AliasFor; + +/** + * An {@link EventListener} that is invoked according to a {@link TransactionPhase}. + * + *

If the event is not published within the boundaries of a managed transaction, the + * event is discarded unless the {@link #fallbackExecution} flag is explicitly set. If a + * transaction is running, the event is processed according to its {@code TransactionPhase}. + * + *

Adding {@link org.springframework.core.annotation.Order @Order} to your annotated + * method allows you to prioritize that listener amongst other listeners running before + * or after transaction completion. + * + * @author Stephane Nicoll + * @author Sam Brannen + * @since 4.2 + */ +@Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@EventListener +public @interface TransactionalEventListener { + + /** + * Phase to bind the handling of an event to. + *

The default phase is {@link TransactionPhase#AFTER_COMMIT}. + *

If no transaction is in progress, the event is not processed at + * all unless {@link #fallbackExecution} has been enabled explicitly. + */ + TransactionPhase phase() default TransactionPhase.AFTER_COMMIT; + + /** + * Whether the event should be processed if no transaction is running. + */ + boolean fallbackExecution() default false; + + /** + * Alias for {@link #classes}. + */ + @AliasFor(annotation = EventListener.class, attribute = "classes") + Class[] value() default {}; + + /** + * The event classes that this listener handles. + *

If this attribute is specified with a single value, the annotated + * method may optionally accept a single parameter. However, if this + * attribute is specified with multiple values, the annotated method + * must not declare any parameters. + */ + @AliasFor(annotation = EventListener.class, attribute = "classes") + Class[] classes() default {}; + + /** + * Spring Expression Language (SpEL) attribute used for making the event + * handling conditional. + *

The default is {@code ""}, meaning the event is always handled. + * @see EventListener#condition + */ + String condition() default ""; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListenerFactory.java b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListenerFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..b912de33e80032a894c747305540681ac63b3960 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/event/TransactionalEventListenerFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.event; + +import java.lang.reflect.Method; + +import org.springframework.context.ApplicationListener; +import org.springframework.context.event.EventListenerFactory; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.AnnotatedElementUtils; + +/** + * {@link EventListenerFactory} implementation that handles {@link TransactionalEventListener} + * annotated methods. + * + * @author Stephane Nicoll + * @since 4.2 + */ +public class TransactionalEventListenerFactory implements EventListenerFactory, Ordered { + + private int order = 50; + + + public void setOrder(int order) { + this.order = order; + } + + @Override + public int getOrder() { + return this.order; + } + + + @Override + public boolean supportsMethod(Method method) { + return AnnotatedElementUtils.hasAnnotation(method, TransactionalEventListener.class); + } + + @Override + public ApplicationListener createApplicationListener(String beanName, Class type, Method method) { + return new ApplicationListenerMethodTransactionalAdapter(beanName, type, method); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/event/package-info.java b/spring-tx/src/main/java/org/springframework/transaction/event/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..507a1af2e8b02b009157e03a1b6f974c03d84f22 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/event/package-info.java @@ -0,0 +1,9 @@ +/** + * Spring's support for listening to transaction events. + */ +@NonNullApi +@NonNullFields +package org.springframework.transaction.event; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/AbstractFallbackTransactionAttributeSource.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/AbstractFallbackTransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..9ea4456bac944c586c6d22f13a2100c9f9c7ab73 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/AbstractFallbackTransactionAttributeSource.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.aop.support.AopUtils; +import org.springframework.core.MethodClassKey; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; + +/** + * Abstract implementation of {@link TransactionAttributeSource} that caches + * attributes for methods and implements a fallback policy: 1. specific target + * method; 2. target class; 3. declaring method; 4. declaring class/interface. + * + *

Defaults to using the target class's transaction attribute if none is + * associated with the target method. Any transaction attribute associated with + * the target method completely overrides a class transaction attribute. + * If none found on the target class, the interface that the invoked method + * has been called through (in case of a JDK proxy) will be checked. + * + *

This implementation caches attributes by method after they are first used. + * If it is ever desirable to allow dynamic changing of transaction attributes + * (which is very unlikely), caching could be made configurable. Caching is + * desirable because of the cost of evaluating rollback rules. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 1.1 + */ +public abstract class AbstractFallbackTransactionAttributeSource implements TransactionAttributeSource { + + /** + * Canonical value held in cache to indicate no transaction attribute was + * found for this method, and we don't need to look again. + */ + @SuppressWarnings("serial") + private static final TransactionAttribute NULL_TRANSACTION_ATTRIBUTE = new DefaultTransactionAttribute() { + @Override + public String toString() { + return "null"; + } + }; + + + /** + * Logger available to subclasses. + *

As this base class is not marked Serializable, the logger will be recreated + * after serialization - provided that the concrete subclass is Serializable. + */ + protected final Log logger = LogFactory.getLog(getClass()); + + /** + * Cache of TransactionAttributes, keyed by method on a specific target class. + *

As this base class is not marked Serializable, the cache will be recreated + * after serialization - provided that the concrete subclass is Serializable. + */ + private final Map attributeCache = new ConcurrentHashMap<>(1024); + + + /** + * Determine the transaction attribute for this method invocation. + *

Defaults to the class's transaction attribute if no method attribute is found. + * @param method the method for the current invocation (never {@code null}) + * @param targetClass the target class for this invocation (may be {@code null}) + * @return a TransactionAttribute for this method, or {@code null} if the method + * is not transactional + */ + @Override + @Nullable + public TransactionAttribute getTransactionAttribute(Method method, @Nullable Class targetClass) { + if (method.getDeclaringClass() == Object.class) { + return null; + } + + // First, see if we have a cached value. + Object cacheKey = getCacheKey(method, targetClass); + TransactionAttribute cached = this.attributeCache.get(cacheKey); + if (cached != null) { + // Value will either be canonical value indicating there is no transaction attribute, + // or an actual transaction attribute. + if (cached == NULL_TRANSACTION_ATTRIBUTE) { + return null; + } + else { + return cached; + } + } + else { + // We need to work it out. + TransactionAttribute txAttr = computeTransactionAttribute(method, targetClass); + // Put it in the cache. + if (txAttr == null) { + this.attributeCache.put(cacheKey, NULL_TRANSACTION_ATTRIBUTE); + } + else { + String methodIdentification = ClassUtils.getQualifiedMethodName(method, targetClass); + if (txAttr instanceof DefaultTransactionAttribute) { + ((DefaultTransactionAttribute) txAttr).setDescriptor(methodIdentification); + } + if (logger.isTraceEnabled()) { + logger.trace("Adding transactional method '" + methodIdentification + "' with attribute: " + txAttr); + } + this.attributeCache.put(cacheKey, txAttr); + } + return txAttr; + } + } + + /** + * Determine a cache key for the given method and target class. + *

Must not produce same key for overloaded methods. + * Must produce same key for different instances of the same method. + * @param method the method (never {@code null}) + * @param targetClass the target class (may be {@code null}) + * @return the cache key (never {@code null}) + */ + protected Object getCacheKey(Method method, @Nullable Class targetClass) { + return new MethodClassKey(method, targetClass); + } + + /** + * Same signature as {@link #getTransactionAttribute}, but doesn't cache the result. + * {@link #getTransactionAttribute} is effectively a caching decorator for this method. + *

As of 4.1.8, this method can be overridden. + * @since 4.1.8 + * @see #getTransactionAttribute + */ + @Nullable + protected TransactionAttribute computeTransactionAttribute(Method method, @Nullable Class targetClass) { + // Don't allow no-public methods as required. + if (allowPublicMethodsOnly() && !Modifier.isPublic(method.getModifiers())) { + return null; + } + + // The method may be on an interface, but we need attributes from the target class. + // If the target class is null, the method will be unchanged. + Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); + + // First try is the method in the target class. + TransactionAttribute txAttr = findTransactionAttribute(specificMethod); + if (txAttr != null) { + return txAttr; + } + + // Second try is the transaction attribute on the target class. + txAttr = findTransactionAttribute(specificMethod.getDeclaringClass()); + if (txAttr != null && ClassUtils.isUserLevelMethod(method)) { + return txAttr; + } + + if (specificMethod != method) { + // Fallback is to look at the original method. + txAttr = findTransactionAttribute(method); + if (txAttr != null) { + return txAttr; + } + // Last fallback is the class of the original method. + txAttr = findTransactionAttribute(method.getDeclaringClass()); + if (txAttr != null && ClassUtils.isUserLevelMethod(method)) { + return txAttr; + } + } + + return null; + } + + + /** + * Subclasses need to implement this to return the transaction attribute for the + * given class, if any. + * @param clazz the class to retrieve the attribute for + * @return all transaction attribute associated with this class, or {@code null} if none + */ + @Nullable + protected abstract TransactionAttribute findTransactionAttribute(Class clazz); + + /** + * Subclasses need to implement this to return the transaction attribute for the + * given method, if any. + * @param method the method to retrieve the attribute for + * @return all transaction attribute associated with this method, or {@code null} if none + */ + @Nullable + protected abstract TransactionAttribute findTransactionAttribute(Method method); + + /** + * Should only public methods be allowed to have transactional semantics? + *

The default implementation returns {@code false}. + */ + protected boolean allowPublicMethodsOnly() { + return false; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/BeanFactoryTransactionAttributeSourceAdvisor.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/BeanFactoryTransactionAttributeSourceAdvisor.java new file mode 100644 index 0000000000000000000000000000000000000000..f52410ed9eaa6daf96d78b6c805da8ca61438fd5 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/BeanFactoryTransactionAttributeSourceAdvisor.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.springframework.aop.ClassFilter; +import org.springframework.aop.Pointcut; +import org.springframework.aop.support.AbstractBeanFactoryPointcutAdvisor; +import org.springframework.lang.Nullable; + +/** + * Advisor driven by a {@link TransactionAttributeSource}, used to include + * a transaction advice bean for methods that are transactional. + * + * @author Juergen Hoeller + * @since 2.5.5 + * @see #setAdviceBeanName + * @see TransactionInterceptor + * @see TransactionAttributeSourceAdvisor + */ +@SuppressWarnings("serial") +public class BeanFactoryTransactionAttributeSourceAdvisor extends AbstractBeanFactoryPointcutAdvisor { + + @Nullable + private TransactionAttributeSource transactionAttributeSource; + + private final TransactionAttributeSourcePointcut pointcut = new TransactionAttributeSourcePointcut() { + @Override + @Nullable + protected TransactionAttributeSource getTransactionAttributeSource() { + return transactionAttributeSource; + } + }; + + + /** + * Set the transaction attribute source which is used to find transaction + * attributes. This should usually be identical to the source reference + * set on the transaction interceptor itself. + * @see TransactionInterceptor#setTransactionAttributeSource + */ + public void setTransactionAttributeSource(TransactionAttributeSource transactionAttributeSource) { + this.transactionAttributeSource = transactionAttributeSource; + } + + /** + * Set the {@link ClassFilter} to use for this pointcut. + * Default is {@link ClassFilter#TRUE}. + */ + public void setClassFilter(ClassFilter classFilter) { + this.pointcut.setClassFilter(classFilter); + } + + @Override + public Pointcut getPointcut() { + return this.pointcut; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/CompositeTransactionAttributeSource.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/CompositeTransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..1b9e28c9629c4d3bc090edc30fb2b7321acd63ab --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/CompositeTransactionAttributeSource.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; +import java.lang.reflect.Method; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Composite {@link TransactionAttributeSource} implementation that iterates + * over a given array of {@link TransactionAttributeSource} instances. + * + * @author Juergen Hoeller + * @since 2.0 + */ +@SuppressWarnings("serial") +public class CompositeTransactionAttributeSource implements TransactionAttributeSource, Serializable { + + private final TransactionAttributeSource[] transactionAttributeSources; + + + /** + * Create a new CompositeTransactionAttributeSource for the given sources. + * @param transactionAttributeSources the TransactionAttributeSource instances to combine + */ + public CompositeTransactionAttributeSource(TransactionAttributeSource... transactionAttributeSources) { + Assert.notNull(transactionAttributeSources, "TransactionAttributeSource array must not be null"); + this.transactionAttributeSources = transactionAttributeSources; + } + + /** + * Return the TransactionAttributeSource instances that this + * CompositeTransactionAttributeSource combines. + */ + public final TransactionAttributeSource[] getTransactionAttributeSources() { + return this.transactionAttributeSources; + } + + + @Override + @Nullable + public TransactionAttribute getTransactionAttribute(Method method, @Nullable Class targetClass) { + for (TransactionAttributeSource source : this.transactionAttributeSources) { + TransactionAttribute attr = source.getTransactionAttribute(method, targetClass); + if (attr != null) { + return attr; + } + } + return null; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/DefaultTransactionAttribute.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/DefaultTransactionAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..6efeabd08f568df1eaab699cde1318329055181d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/DefaultTransactionAttribute.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.support.DefaultTransactionDefinition; +import org.springframework.util.StringUtils; + +/** + * Spring's common transaction attribute implementation. + * Rolls back on runtime, but not checked, exceptions by default. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 16.03.2003 + */ +@SuppressWarnings("serial") +public class DefaultTransactionAttribute extends DefaultTransactionDefinition implements TransactionAttribute { + + @Nullable + private String qualifier; + + @Nullable + private String descriptor; + + + /** + * Create a new DefaultTransactionAttribute, with default settings. + * Can be modified through bean property setters. + * @see #setPropagationBehavior + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + * @see #setName + */ + public DefaultTransactionAttribute() { + super(); + } + + /** + * Copy constructor. Definition can be modified through bean property setters. + * @see #setPropagationBehavior + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + * @see #setName + */ + public DefaultTransactionAttribute(TransactionAttribute other) { + super(other); + } + + /** + * Create a new DefaultTransactionAttribute with the given + * propagation behavior. Can be modified through bean property setters. + * @param propagationBehavior one of the propagation constants in the + * TransactionDefinition interface + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + */ + public DefaultTransactionAttribute(int propagationBehavior) { + super(propagationBehavior); + } + + + /** + * Associate a qualifier value with this transaction attribute. + *

This may be used for choosing a corresponding transaction manager + * to process this specific transaction. + * @since 3.0 + */ + public void setQualifier(@Nullable String qualifier) { + this.qualifier = qualifier; + } + + /** + * Return a qualifier value associated with this transaction attribute. + * @since 3.0 + */ + @Override + @Nullable + public String getQualifier() { + return this.qualifier; + } + + /** + * Set a descriptor for this transaction attribute, + * e.g. indicating where the attribute is applying. + * @since 4.3.4 + */ + public void setDescriptor(@Nullable String descriptor) { + this.descriptor = descriptor; + } + + /** + * Return a descriptor for this transaction attribute, + * or {@code null} if none. + * @since 4.3.4 + */ + @Nullable + public String getDescriptor() { + return this.descriptor; + } + + /** + * The default behavior is as with EJB: rollback on unchecked exception + * ({@link RuntimeException}), assuming an unexpected outcome outside of any + * business rules. Additionally, we also attempt to rollback on {@link Error} which + * is clearly an unexpected outcome as well. By contrast, a checked exception is + * considered a business exception and therefore a regular expected outcome of the + * transactional business method, i.e. a kind of alternative return value which + * still allows for regular completion of resource operations. + *

This is largely consistent with TransactionTemplate's default behavior, + * except that TransactionTemplate also rolls back on undeclared checked exceptions + * (a corner case). For declarative transactions, we expect checked exceptions to be + * intentionally declared as business exceptions, leading to a commit by default. + * @see org.springframework.transaction.support.TransactionTemplate#execute + */ + @Override + public boolean rollbackOn(Throwable ex) { + return (ex instanceof RuntimeException || ex instanceof Error); + } + + + /** + * Return an identifying description for this transaction attribute. + *

Available to subclasses, for inclusion in their {@code toString()} result. + */ + protected final StringBuilder getAttributeDescription() { + StringBuilder result = getDefinitionDescription(); + if (StringUtils.hasText(this.qualifier)) { + result.append("; '").append(this.qualifier).append("'"); + } + return result; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/DelegatingTransactionAttribute.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/DelegatingTransactionAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..45870d0110932a7e80b819bc021842f2b9d3d2fc --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/DelegatingTransactionAttribute.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.support.DelegatingTransactionDefinition; + +/** + * {@link TransactionAttribute} implementation that delegates all calls to a given target + * {@link TransactionAttribute} instance. Abstract because it is meant to be subclassed, + * with subclasses overriding specific methods that are not supposed to simply delegate + * to the target instance. + * + * @author Juergen Hoeller + * @since 1.2 + */ +@SuppressWarnings("serial") +public abstract class DelegatingTransactionAttribute extends DelegatingTransactionDefinition + implements TransactionAttribute, Serializable { + + private final TransactionAttribute targetAttribute; + + + /** + * Create a DelegatingTransactionAttribute for the given target attribute. + * @param targetAttribute the target TransactionAttribute to delegate to + */ + public DelegatingTransactionAttribute(TransactionAttribute targetAttribute) { + super(targetAttribute); + this.targetAttribute = targetAttribute; + } + + + @Override + @Nullable + public String getQualifier() { + return this.targetAttribute.getQualifier(); + } + + @Override + public boolean rollbackOn(Throwable ex) { + return this.targetAttribute.rollbackOn(ex); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/MatchAlwaysTransactionAttributeSource.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/MatchAlwaysTransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..d095a58f36fbdf84f814d2365c68c96cfb95a63a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/MatchAlwaysTransactionAttributeSource.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; +import java.lang.reflect.Method; + +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; + +/** + * Very simple implementation of TransactionAttributeSource which will always return + * the same TransactionAttribute for all methods fed to it. The TransactionAttribute + * may be specified, but will otherwise default to PROPAGATION_REQUIRED. This may be + * used in the cases where you want to use the same transaction attribute with all + * methods being handled by a transaction interceptor. + * + * @author Colin Sampaleanu + * @since 15.10.2003 + * @see org.springframework.transaction.interceptor.TransactionProxyFactoryBean + * @see org.springframework.aop.framework.autoproxy.BeanNameAutoProxyCreator + */ +@SuppressWarnings("serial") +public class MatchAlwaysTransactionAttributeSource implements TransactionAttributeSource, Serializable { + + private TransactionAttribute transactionAttribute = new DefaultTransactionAttribute(); + + + /** + * Allows a transaction attribute to be specified, using the String form, for + * example, "PROPAGATION_REQUIRED". + * @param transactionAttribute the String form of the transactionAttribute to use. + * @see org.springframework.transaction.interceptor.TransactionAttributeEditor + */ + public void setTransactionAttribute(TransactionAttribute transactionAttribute) { + this.transactionAttribute = transactionAttribute; + } + + + @Override + @Nullable + public TransactionAttribute getTransactionAttribute(Method method, @Nullable Class targetClass) { + return (ClassUtils.isUserLevelMethod(method) ? this.transactionAttribute : null); + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof MatchAlwaysTransactionAttributeSource)) { + return false; + } + MatchAlwaysTransactionAttributeSource otherTas = (MatchAlwaysTransactionAttributeSource) other; + return ObjectUtils.nullSafeEquals(this.transactionAttribute, otherTas.transactionAttribute); + } + + @Override + public int hashCode() { + return MatchAlwaysTransactionAttributeSource.class.hashCode(); + } + + @Override + public String toString() { + return getClass().getName() + ": " + this.transactionAttribute; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/MethodMapTransactionAttributeSource.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/MethodMapTransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..02acd390add663554c8181e5bc442b91f7fbf04a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/MethodMapTransactionAttributeSource.java @@ -0,0 +1,249 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.BeanClassLoaderAware; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.PatternMatchUtils; + +/** + * Simple {@link TransactionAttributeSource} implementation that + * allows attributes to be stored per method in a {@link Map}. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 24.04.2003 + * @see #isMatch + * @see NameMatchTransactionAttributeSource + */ +public class MethodMapTransactionAttributeSource + implements TransactionAttributeSource, BeanClassLoaderAware, InitializingBean { + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + /** Map from method name to attribute value. */ + @Nullable + private Map methodMap; + + @Nullable + private ClassLoader beanClassLoader = ClassUtils.getDefaultClassLoader(); + + private boolean eagerlyInitialized = false; + + private boolean initialized = false; + + /** Map from Method to TransactionAttribute. */ + private final Map transactionAttributeMap = new HashMap<>(); + + /** Map from Method to name pattern used for registration. */ + private final Map methodNameMap = new HashMap<>(); + + + /** + * Set a name/attribute map, consisting of "FQCN.method" method names + * (e.g. "com.mycompany.mycode.MyClass.myMethod") and + * {@link TransactionAttribute} instances (or Strings to be converted + * to {@code TransactionAttribute} instances). + *

Intended for configuration via setter injection, typically within + * a Spring bean factory. Relies on {@link #afterPropertiesSet()} + * being called afterwards. + * @param methodMap said {@link Map} from method name to attribute value + * @see TransactionAttribute + * @see TransactionAttributeEditor + */ + public void setMethodMap(Map methodMap) { + this.methodMap = methodMap; + } + + @Override + public void setBeanClassLoader(ClassLoader beanClassLoader) { + this.beanClassLoader = beanClassLoader; + } + + + /** + * Eagerly initializes the specified + * {@link #setMethodMap(java.util.Map) "methodMap"}, if any. + * @see #initMethodMap(java.util.Map) + */ + @Override + public void afterPropertiesSet() { + initMethodMap(this.methodMap); + this.eagerlyInitialized = true; + this.initialized = true; + } + + /** + * Initialize the specified {@link #setMethodMap(java.util.Map) "methodMap"}, if any. + * @param methodMap a Map from method names to {@code TransactionAttribute} instances + * @see #setMethodMap + */ + protected void initMethodMap(@Nullable Map methodMap) { + if (methodMap != null) { + methodMap.forEach(this::addTransactionalMethod); + } + } + + + /** + * Add an attribute for a transactional method. + *

Method names can end or start with "*" for matching multiple methods. + * @param name class and method name, separated by a dot + * @param attr attribute associated with the method + * @throws IllegalArgumentException in case of an invalid name + */ + public void addTransactionalMethod(String name, TransactionAttribute attr) { + Assert.notNull(name, "Name must not be null"); + int lastDotIndex = name.lastIndexOf('.'); + if (lastDotIndex == -1) { + throw new IllegalArgumentException("'" + name + "' is not a valid method name: format is FQN.methodName"); + } + String className = name.substring(0, lastDotIndex); + String methodName = name.substring(lastDotIndex + 1); + Class clazz = ClassUtils.resolveClassName(className, this.beanClassLoader); + addTransactionalMethod(clazz, methodName, attr); + } + + /** + * Add an attribute for a transactional method. + * Method names can end or start with "*" for matching multiple methods. + * @param clazz target interface or class + * @param mappedName mapped method name + * @param attr attribute associated with the method + */ + public void addTransactionalMethod(Class clazz, String mappedName, TransactionAttribute attr) { + Assert.notNull(clazz, "Class must not be null"); + Assert.notNull(mappedName, "Mapped name must not be null"); + String name = clazz.getName() + '.' + mappedName; + + Method[] methods = clazz.getDeclaredMethods(); + List matchingMethods = new ArrayList<>(); + for (Method method : methods) { + if (isMatch(method.getName(), mappedName)) { + matchingMethods.add(method); + } + } + if (matchingMethods.isEmpty()) { + throw new IllegalArgumentException( + "Could not find method '" + mappedName + "' on class [" + clazz.getName() + "]"); + } + + // Register all matching methods + for (Method method : matchingMethods) { + String regMethodName = this.methodNameMap.get(method); + if (regMethodName == null || (!regMethodName.equals(name) && regMethodName.length() <= name.length())) { + // No already registered method name, or more specific + // method name specification now -> (re-)register method. + if (logger.isDebugEnabled() && regMethodName != null) { + logger.debug("Replacing attribute for transactional method [" + method + "]: current name '" + + name + "' is more specific than '" + regMethodName + "'"); + } + this.methodNameMap.put(method, name); + addTransactionalMethod(method, attr); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("Keeping attribute for transactional method [" + method + "]: current name '" + + name + "' is not more specific than '" + regMethodName + "'"); + } + } + } + } + + /** + * Add an attribute for a transactional method. + * @param method the method + * @param attr attribute associated with the method + */ + public void addTransactionalMethod(Method method, TransactionAttribute attr) { + Assert.notNull(method, "Method must not be null"); + Assert.notNull(attr, "TransactionAttribute must not be null"); + if (logger.isDebugEnabled()) { + logger.debug("Adding transactional method [" + method + "] with attribute [" + attr + "]"); + } + this.transactionAttributeMap.put(method, attr); + } + + /** + * Return if the given method name matches the mapped name. + *

The default implementation checks for "xxx*", "*xxx" and "*xxx*" + * matches, as well as direct equality. + * @param methodName the method name of the class + * @param mappedName the name in the descriptor + * @return if the names match + * @see org.springframework.util.PatternMatchUtils#simpleMatch(String, String) + */ + protected boolean isMatch(String methodName, String mappedName) { + return PatternMatchUtils.simpleMatch(mappedName, methodName); + } + + + @Override + @Nullable + public TransactionAttribute getTransactionAttribute(Method method, @Nullable Class targetClass) { + if (this.eagerlyInitialized) { + return this.transactionAttributeMap.get(method); + } + else { + synchronized (this.transactionAttributeMap) { + if (!this.initialized) { + initMethodMap(this.methodMap); + this.initialized = true; + } + return this.transactionAttributeMap.get(method); + } + } + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof MethodMapTransactionAttributeSource)) { + return false; + } + MethodMapTransactionAttributeSource otherTas = (MethodMapTransactionAttributeSource) other; + return ObjectUtils.nullSafeEquals(this.methodMap, otherTas.methodMap); + } + + @Override + public int hashCode() { + return MethodMapTransactionAttributeSource.class.hashCode(); + } + + @Override + public String toString() { + return getClass().getName() + ": " + this.methodMap; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/NameMatchTransactionAttributeSource.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/NameMatchTransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..0d54339296b34d18b7000a8447f2090e0b18a64b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/NameMatchTransactionAttributeSource.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; +import java.lang.reflect.Method; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.PatternMatchUtils; + +/** + * Simple {@link TransactionAttributeSource} implementation that + * allows attributes to be matched by registered name. + * + * @author Juergen Hoeller + * @since 21.08.2003 + * @see #isMatch + * @see MethodMapTransactionAttributeSource + */ +@SuppressWarnings("serial") +public class NameMatchTransactionAttributeSource implements TransactionAttributeSource, Serializable { + + /** + * Logger available to subclasses. + *

Static for optimal serialization. + */ + protected static final Log logger = LogFactory.getLog(NameMatchTransactionAttributeSource.class); + + /** Keys are method names; values are TransactionAttributes. */ + private Map nameMap = new HashMap<>(); + + + /** + * Set a name/attribute map, consisting of method names + * (e.g. "myMethod") and TransactionAttribute instances + * (or Strings to be converted to TransactionAttribute instances). + * @see TransactionAttribute + * @see TransactionAttributeEditor + */ + public void setNameMap(Map nameMap) { + nameMap.forEach(this::addTransactionalMethod); + } + + /** + * Parses the given properties into a name/attribute map. + * Expects method names as keys and String attributes definitions as values, + * parsable into TransactionAttribute instances via TransactionAttributeEditor. + * @see #setNameMap + * @see TransactionAttributeEditor + */ + public void setProperties(Properties transactionAttributes) { + TransactionAttributeEditor tae = new TransactionAttributeEditor(); + Enumeration propNames = transactionAttributes.propertyNames(); + while (propNames.hasMoreElements()) { + String methodName = (String) propNames.nextElement(); + String value = transactionAttributes.getProperty(methodName); + tae.setAsText(value); + TransactionAttribute attr = (TransactionAttribute) tae.getValue(); + addTransactionalMethod(methodName, attr); + } + } + + /** + * Add an attribute for a transactional method. + *

Method names can be exact matches, or of the pattern "xxx*", + * "*xxx" or "*xxx*" for matching multiple methods. + * @param methodName the name of the method + * @param attr attribute associated with the method + */ + public void addTransactionalMethod(String methodName, TransactionAttribute attr) { + if (logger.isDebugEnabled()) { + logger.debug("Adding transactional method [" + methodName + "] with attribute [" + attr + "]"); + } + this.nameMap.put(methodName, attr); + } + + + @Override + @Nullable + public TransactionAttribute getTransactionAttribute(Method method, @Nullable Class targetClass) { + if (!ClassUtils.isUserLevelMethod(method)) { + return null; + } + + // Look for direct name match. + String methodName = method.getName(); + TransactionAttribute attr = this.nameMap.get(methodName); + + if (attr == null) { + // Look for most specific name match. + String bestNameMatch = null; + for (String mappedName : this.nameMap.keySet()) { + if (isMatch(methodName, mappedName) && + (bestNameMatch == null || bestNameMatch.length() <= mappedName.length())) { + attr = this.nameMap.get(mappedName); + bestNameMatch = mappedName; + } + } + } + + return attr; + } + + /** + * Return if the given method name matches the mapped name. + *

The default implementation checks for "xxx*", "*xxx" and "*xxx*" matches, + * as well as direct equality. Can be overridden in subclasses. + * @param methodName the method name of the class + * @param mappedName the name in the descriptor + * @return if the names match + * @see org.springframework.util.PatternMatchUtils#simpleMatch(String, String) + */ + protected boolean isMatch(String methodName, String mappedName) { + return PatternMatchUtils.simpleMatch(mappedName, methodName); + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof NameMatchTransactionAttributeSource)) { + return false; + } + NameMatchTransactionAttributeSource otherTas = (NameMatchTransactionAttributeSource) other; + return ObjectUtils.nullSafeEquals(this.nameMap, otherTas.nameMap); + } + + @Override + public int hashCode() { + return NameMatchTransactionAttributeSource.class.hashCode(); + } + + @Override + public String toString() { + return getClass().getName() + ": " + this.nameMap; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/NoRollbackRuleAttribute.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/NoRollbackRuleAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..a92e14b9ffb490be092680671d231cee38552cff --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/NoRollbackRuleAttribute.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +/** + * Tag subclass of {@link RollbackRuleAttribute} that has the opposite behavior + * to the {@code RollbackRuleAttribute} superclass. + * + * @author Rod Johnson + * @since 09.04.2003 + */ +@SuppressWarnings("serial") +public class NoRollbackRuleAttribute extends RollbackRuleAttribute { + + /** + * Create a new instance of the {@code NoRollbackRuleAttribute} class + * for the supplied {@link Throwable} class. + * @param clazz the {@code Throwable} class + * @see RollbackRuleAttribute#RollbackRuleAttribute(Class) + */ + public NoRollbackRuleAttribute(Class clazz) { + super(clazz); + } + + /** + * Create a new instance of the {@code NoRollbackRuleAttribute} class + * for the supplied {@code exceptionName}. + * @param exceptionName the exception name pattern + * @see RollbackRuleAttribute#RollbackRuleAttribute(String) + */ + public NoRollbackRuleAttribute(String exceptionName) { + super(exceptionName); + } + + @Override + public String toString() { + return "No" + super.toString(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/RollbackRuleAttribute.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/RollbackRuleAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..dc8e3791b2605e5832e359f65bbd1a11e5570733 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/RollbackRuleAttribute.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; + +import org.springframework.util.Assert; + +/** + * Rule determining whether or not a given exception (and any subclasses) + * should cause a rollback. + * + *

Multiple such rules can be applied to determine whether a transaction + * should commit or rollback after an exception has been thrown. + * + * @author Rod Johnson + * @since 09.04.2003 + * @see NoRollbackRuleAttribute + */ +@SuppressWarnings("serial") +public class RollbackRuleAttribute implements Serializable{ + + /** + * The {@link RollbackRuleAttribute rollback rule} for + * {@link RuntimeException RuntimeExceptions}. + */ + public static final RollbackRuleAttribute ROLLBACK_ON_RUNTIME_EXCEPTIONS = + new RollbackRuleAttribute(RuntimeException.class); + + + /** + * Could hold exception, resolving class name but would always require FQN. + * This way does multiple string comparisons, but how often do we decide + * whether to roll back a transaction following an exception? + */ + private final String exceptionName; + + + /** + * Create a new instance of the {@code RollbackRuleAttribute} class. + *

This is the preferred way to construct a rollback rule that matches + * the supplied {@link Exception} class (and subclasses). + * @param clazz throwable class; must be {@link Throwable} or a subclass + * of {@code Throwable} + * @throws IllegalArgumentException if the supplied {@code clazz} is + * not a {@code Throwable} type or is {@code null} + */ + public RollbackRuleAttribute(Class clazz) { + Assert.notNull(clazz, "'clazz' cannot be null"); + if (!Throwable.class.isAssignableFrom(clazz)) { + throw new IllegalArgumentException( + "Cannot construct rollback rule from [" + clazz.getName() + "]: it's not a Throwable"); + } + this.exceptionName = clazz.getName(); + } + + /** + * Create a new instance of the {@code RollbackRuleAttribute} class + * for the given {@code exceptionName}. + *

This can be a substring, with no wildcard support at present. A value + * of "ServletException" would match + * {@code javax.servlet.ServletException} and subclasses, for example. + *

NB: Consider carefully how specific the pattern is, and + * whether to include package information (which is not mandatory). For + * example, "Exception" will match nearly anything, and will probably hide + * other rules. "java.lang.Exception" would be correct if "Exception" was + * meant to define a rule for all checked exceptions. With more unusual + * exception names such as "BaseBusinessException" there's no need to use a + * fully package-qualified name. + * @param exceptionName the exception name pattern; can also be a fully + * package-qualified class name + * @throws IllegalArgumentException if the supplied + * {@code exceptionName} is {@code null} or empty + */ + public RollbackRuleAttribute(String exceptionName) { + Assert.hasText(exceptionName, "'exceptionName' cannot be null or empty"); + this.exceptionName = exceptionName; + } + + + /** + * Return the pattern for the exception name. + */ + public String getExceptionName() { + return this.exceptionName; + } + + /** + * Return the depth of the superclass matching. + *

{@code 0} means {@code ex} matches exactly. Returns + * {@code -1} if there is no match. Otherwise, returns depth with the + * lowest depth winning. + */ + public int getDepth(Throwable ex) { + return getDepth(ex.getClass(), 0); + } + + + private int getDepth(Class exceptionClass, int depth) { + if (exceptionClass.getName().contains(this.exceptionName)) { + // Found it! + return depth; + } + // If we've gone as far as we can go and haven't found it... + if (exceptionClass == Throwable.class) { + return -1; + } + return getDepth(exceptionClass.getSuperclass(), depth + 1); + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof RollbackRuleAttribute)) { + return false; + } + RollbackRuleAttribute rhs = (RollbackRuleAttribute) other; + return this.exceptionName.equals(rhs.exceptionName); + } + + @Override + public int hashCode() { + return this.exceptionName.hashCode(); + } + + @Override + public String toString() { + return "RollbackRuleAttribute with pattern [" + this.exceptionName + "]"; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/RuleBasedTransactionAttribute.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/RuleBasedTransactionAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..604c8c6d7c7866062372f19f469428fe3c6da9e5 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/RuleBasedTransactionAttribute.java @@ -0,0 +1,176 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; + +/** + * TransactionAttribute implementation that works out whether a given exception + * should cause transaction rollback by applying a number of rollback rules, + * both positive and negative. If no custom rollback rules apply, this attribute + * behaves like DefaultTransactionAttribute (rolling back on runtime exceptions). + * + *

{@link TransactionAttributeEditor} creates objects of this class. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 09.04.2003 + * @see TransactionAttributeEditor + */ +@SuppressWarnings("serial") +public class RuleBasedTransactionAttribute extends DefaultTransactionAttribute implements Serializable { + + /** Prefix for rollback-on-exception rules in description strings. */ + public static final String PREFIX_ROLLBACK_RULE = "-"; + + /** Prefix for commit-on-exception rules in description strings. */ + public static final String PREFIX_COMMIT_RULE = "+"; + + + /** Static for optimal serializability. */ + private static final Log logger = LogFactory.getLog(RuleBasedTransactionAttribute.class); + + @Nullable + private List rollbackRules; + + + /** + * Create a new RuleBasedTransactionAttribute, with default settings. + * Can be modified through bean property setters. + * @see #setPropagationBehavior + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + * @see #setName + * @see #setRollbackRules + */ + public RuleBasedTransactionAttribute() { + super(); + } + + /** + * Copy constructor. Definition can be modified through bean property setters. + * @see #setPropagationBehavior + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + * @see #setName + * @see #setRollbackRules + */ + public RuleBasedTransactionAttribute(RuleBasedTransactionAttribute other) { + super(other); + this.rollbackRules = (other.rollbackRules != null ? new ArrayList<>(other.rollbackRules) : null); + } + + /** + * Create a new DefaultTransactionAttribute with the given + * propagation behavior. Can be modified through bean property setters. + * @param propagationBehavior one of the propagation constants in the + * TransactionDefinition interface + * @param rollbackRules the list of RollbackRuleAttributes to apply + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + */ + public RuleBasedTransactionAttribute(int propagationBehavior, List rollbackRules) { + super(propagationBehavior); + this.rollbackRules = rollbackRules; + } + + + /** + * Set the list of {@code RollbackRuleAttribute} objects + * (and/or {@code NoRollbackRuleAttribute} objects) to apply. + * @see RollbackRuleAttribute + * @see NoRollbackRuleAttribute + */ + public void setRollbackRules(List rollbackRules) { + this.rollbackRules = rollbackRules; + } + + /** + * Return the list of {@code RollbackRuleAttribute} objects + * (never {@code null}). + */ + public List getRollbackRules() { + if (this.rollbackRules == null) { + this.rollbackRules = new LinkedList<>(); + } + return this.rollbackRules; + } + + + /** + * Winning rule is the shallowest rule (that is, the closest in the + * inheritance hierarchy to the exception). If no rule applies (-1), + * return false. + * @see TransactionAttribute#rollbackOn(java.lang.Throwable) + */ + @Override + public boolean rollbackOn(Throwable ex) { + if (logger.isTraceEnabled()) { + logger.trace("Applying rules to determine whether transaction should rollback on " + ex); + } + + RollbackRuleAttribute winner = null; + int deepest = Integer.MAX_VALUE; + + if (this.rollbackRules != null) { + for (RollbackRuleAttribute rule : this.rollbackRules) { + int depth = rule.getDepth(ex); + if (depth >= 0 && depth < deepest) { + deepest = depth; + winner = rule; + } + } + } + + if (logger.isTraceEnabled()) { + logger.trace("Winning rollback rule is: " + winner); + } + + // User superclass behavior (rollback on unchecked) if no rule matches. + if (winner == null) { + logger.trace("No relevant rollback rule found: applying default rules"); + return super.rollbackOn(ex); + } + + return !(winner instanceof NoRollbackRuleAttribute); + } + + + @Override + public String toString() { + StringBuilder result = getAttributeDescription(); + if (this.rollbackRules != null) { + for (RollbackRuleAttribute rule : this.rollbackRules) { + String sign = (rule instanceof NoRollbackRuleAttribute ? PREFIX_COMMIT_RULE : PREFIX_ROLLBACK_RULE); + result.append(',').append(sign).append(rule.getExceptionName()); + } + } + return result.toString(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..20f27a86e4202b1dbdea945feee2087bc2eb28af --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java @@ -0,0 +1,718 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; +import java.util.Properties; +import java.util.concurrent.ConcurrentMap; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.BeanFactoryAnnotationUtils; +import org.springframework.core.NamedThreadLocal; +import org.springframework.lang.Nullable; +import org.springframework.transaction.NoTransactionException; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.transaction.support.CallbackPreferringPlatformTransactionManager; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ConcurrentReferenceHashMap; +import org.springframework.util.StringUtils; + +/** + * Base class for transactional aspects, such as the {@link TransactionInterceptor} + * or an AspectJ aspect. + * + *

This enables the underlying Spring transaction infrastructure to be used easily + * to implement an aspect for any aspect system. + * + *

Subclasses are responsible for calling methods in this class in the correct order. + * + *

If no transaction name has been specified in the {@code TransactionAttribute}, + * the exposed name will be the {@code fully-qualified class name + "." + method name} + * (by default). + * + *

Uses the Strategy design pattern. A {@code PlatformTransactionManager} + * implementation will perform the actual transaction management, and a + * {@code TransactionAttributeSource} is used for determining transaction definitions. + * + *

A transaction aspect is serializable if its {@code PlatformTransactionManager} + * and {@code TransactionAttributeSource} are serializable. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Stéphane Nicoll + * @author Sam Brannen + * @since 1.1 + * @see #setTransactionManager + * @see #setTransactionAttributes + * @see #setTransactionAttributeSource + */ +public abstract class TransactionAspectSupport implements BeanFactoryAware, InitializingBean { + + // NOTE: This class must not implement Serializable because it serves as base + // class for AspectJ aspects (which are not allowed to implement Serializable)! + + + /** + * Key to use to store the default transaction manager. + */ + private static final Object DEFAULT_TRANSACTION_MANAGER_KEY = new Object(); + + /** + * Holder to support the {@code currentTransactionStatus()} method, + * and to support communication between different cooperating advices + * (e.g. before and after advice) if the aspect involves more than a + * single method (as will be the case for around advice). + */ + private static final ThreadLocal transactionInfoHolder = + new NamedThreadLocal<>("Current aspect-driven transaction"); + + + /** + * Subclasses can use this to return the current TransactionInfo. + * Only subclasses that cannot handle all operations in one method, + * such as an AspectJ aspect involving distinct before and after advice, + * need to use this mechanism to get at the current TransactionInfo. + * An around advice such as an AOP Alliance MethodInterceptor can hold a + * reference to the TransactionInfo throughout the aspect method. + *

A TransactionInfo will be returned even if no transaction was created. + * The {@code TransactionInfo.hasTransaction()} method can be used to query this. + *

To find out about specific transaction characteristics, consider using + * TransactionSynchronizationManager's {@code isSynchronizationActive()} + * and/or {@code isActualTransactionActive()} methods. + * @return the TransactionInfo bound to this thread, or {@code null} if none + * @see TransactionInfo#hasTransaction() + * @see org.springframework.transaction.support.TransactionSynchronizationManager#isSynchronizationActive() + * @see org.springframework.transaction.support.TransactionSynchronizationManager#isActualTransactionActive() + */ + @Nullable + protected static TransactionInfo currentTransactionInfo() throws NoTransactionException { + return transactionInfoHolder.get(); + } + + /** + * Return the transaction status of the current method invocation. + * Mainly intended for code that wants to set the current transaction + * rollback-only but not throw an application exception. + * @throws NoTransactionException if the transaction info cannot be found, + * because the method was invoked outside an AOP invocation context + */ + public static TransactionStatus currentTransactionStatus() throws NoTransactionException { + TransactionInfo info = currentTransactionInfo(); + if (info == null || info.transactionStatus == null) { + throw new NoTransactionException("No transaction aspect-managed TransactionStatus in scope"); + } + return info.transactionStatus; + } + + + protected final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private String transactionManagerBeanName; + + @Nullable + private PlatformTransactionManager transactionManager; + + @Nullable + private TransactionAttributeSource transactionAttributeSource; + + @Nullable + private BeanFactory beanFactory; + + private final ConcurrentMap transactionManagerCache = + new ConcurrentReferenceHashMap<>(4); + + + /** + * Specify the name of the default transaction manager bean. + */ + public void setTransactionManagerBeanName(@Nullable String transactionManagerBeanName) { + this.transactionManagerBeanName = transactionManagerBeanName; + } + + /** + * Return the name of the default transaction manager bean. + */ + @Nullable + protected final String getTransactionManagerBeanName() { + return this.transactionManagerBeanName; + } + + /** + * Specify the default transaction manager to use to drive transactions. + *

The default transaction manager will be used if a qualifier + * has not been declared for a given transaction or if an explicit name for the + * default transaction manager bean has not been specified. + * @see #setTransactionManagerBeanName + */ + public void setTransactionManager(@Nullable PlatformTransactionManager transactionManager) { + this.transactionManager = transactionManager; + } + + /** + * Return the default transaction manager, or {@code null} if unknown. + */ + @Nullable + public PlatformTransactionManager getTransactionManager() { + return this.transactionManager; + } + + /** + * Set properties with method names as keys and transaction attribute + * descriptors (parsed via TransactionAttributeEditor) as values: + * e.g. key = "myMethod", value = "PROPAGATION_REQUIRED,readOnly". + *

Note: Method names are always applied to the target class, + * no matter if defined in an interface or the class itself. + *

Internally, a NameMatchTransactionAttributeSource will be + * created from the given properties. + * @see #setTransactionAttributeSource + * @see TransactionAttributeEditor + * @see NameMatchTransactionAttributeSource + */ + public void setTransactionAttributes(Properties transactionAttributes) { + NameMatchTransactionAttributeSource tas = new NameMatchTransactionAttributeSource(); + tas.setProperties(transactionAttributes); + this.transactionAttributeSource = tas; + } + + /** + * Set multiple transaction attribute sources which are used to find transaction + * attributes. Will build a CompositeTransactionAttributeSource for the given sources. + * @see CompositeTransactionAttributeSource + * @see MethodMapTransactionAttributeSource + * @see NameMatchTransactionAttributeSource + * @see org.springframework.transaction.annotation.AnnotationTransactionAttributeSource + */ + public void setTransactionAttributeSources(TransactionAttributeSource... transactionAttributeSources) { + this.transactionAttributeSource = new CompositeTransactionAttributeSource(transactionAttributeSources); + } + + /** + * Set the transaction attribute source which is used to find transaction + * attributes. If specifying a String property value, a PropertyEditor + * will create a MethodMapTransactionAttributeSource from the value. + * @see TransactionAttributeSourceEditor + * @see MethodMapTransactionAttributeSource + * @see NameMatchTransactionAttributeSource + * @see org.springframework.transaction.annotation.AnnotationTransactionAttributeSource + */ + public void setTransactionAttributeSource(@Nullable TransactionAttributeSource transactionAttributeSource) { + this.transactionAttributeSource = transactionAttributeSource; + } + + /** + * Return the transaction attribute source. + */ + @Nullable + public TransactionAttributeSource getTransactionAttributeSource() { + return this.transactionAttributeSource; + } + + /** + * Set the BeanFactory to use for retrieving PlatformTransactionManager beans. + */ + @Override + public void setBeanFactory(@Nullable BeanFactory beanFactory) { + this.beanFactory = beanFactory; + } + + /** + * Return the BeanFactory to use for retrieving PlatformTransactionManager beans. + */ + @Nullable + protected final BeanFactory getBeanFactory() { + return this.beanFactory; + } + + /** + * Check that required properties were set. + */ + @Override + public void afterPropertiesSet() { + if (getTransactionManager() == null && this.beanFactory == null) { + throw new IllegalStateException( + "Set the 'transactionManager' property or make sure to run within a BeanFactory " + + "containing a PlatformTransactionManager bean!"); + } + if (getTransactionAttributeSource() == null) { + throw new IllegalStateException( + "Either 'transactionAttributeSource' or 'transactionAttributes' is required: " + + "If there are no transactional methods, then don't use a transaction aspect."); + } + } + + + /** + * General delegate for around-advice-based subclasses, delegating to several other template + * methods on this class. Able to handle {@link CallbackPreferringPlatformTransactionManager} + * as well as regular {@link PlatformTransactionManager} implementations. + * @param method the Method being invoked + * @param targetClass the target class that we're invoking the method on + * @param invocation the callback to use for proceeding with the target invocation + * @return the return value of the method, if any + * @throws Throwable propagated from the target invocation + */ + @Nullable + protected Object invokeWithinTransaction(Method method, @Nullable Class targetClass, + final InvocationCallback invocation) throws Throwable { + + // If the transaction attribute is null, the method is non-transactional. + TransactionAttributeSource tas = getTransactionAttributeSource(); + final TransactionAttribute txAttr = (tas != null ? tas.getTransactionAttribute(method, targetClass) : null); + final PlatformTransactionManager tm = determineTransactionManager(txAttr); + final String joinpointIdentification = methodIdentification(method, targetClass, txAttr); + + if (txAttr == null || !(tm instanceof CallbackPreferringPlatformTransactionManager)) { + // Standard transaction demarcation with getTransaction and commit/rollback calls. + TransactionInfo txInfo = createTransactionIfNecessary(tm, txAttr, joinpointIdentification); + + Object retVal; + try { + // This is an around advice: Invoke the next interceptor in the chain. + // This will normally result in a target object being invoked. + retVal = invocation.proceedWithInvocation(); + } + catch (Throwable ex) { + // target invocation exception + completeTransactionAfterThrowing(txInfo, ex); + throw ex; + } + finally { + cleanupTransactionInfo(txInfo); + } + commitTransactionAfterReturning(txInfo); + return retVal; + } + + else { + Object result; + final ThrowableHolder throwableHolder = new ThrowableHolder(); + + // It's a CallbackPreferringPlatformTransactionManager: pass a TransactionCallback in. + try { + result = ((CallbackPreferringPlatformTransactionManager) tm).execute(txAttr, status -> { + TransactionInfo txInfo = prepareTransactionInfo(tm, txAttr, joinpointIdentification, status); + try { + return invocation.proceedWithInvocation(); + } + catch (Throwable ex) { + if (txAttr.rollbackOn(ex)) { + // A RuntimeException: will lead to a rollback. + if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } + else { + throw new ThrowableHolderException(ex); + } + } + else { + // A normal return value: will lead to a commit. + throwableHolder.throwable = ex; + return null; + } + } + finally { + cleanupTransactionInfo(txInfo); + } + }); + } + catch (ThrowableHolderException ex) { + throw ex.getCause(); + } + catch (TransactionSystemException ex2) { + if (throwableHolder.throwable != null) { + logger.error("Application exception overridden by commit exception", throwableHolder.throwable); + ex2.initApplicationException(throwableHolder.throwable); + } + throw ex2; + } + catch (Throwable ex2) { + if (throwableHolder.throwable != null) { + logger.error("Application exception overridden by commit exception", throwableHolder.throwable); + } + throw ex2; + } + + // Check result state: It might indicate a Throwable to rethrow. + if (throwableHolder.throwable != null) { + throw throwableHolder.throwable; + } + return result; + } + } + + /** + * Clear the transaction manager cache. + */ + protected void clearTransactionManagerCache() { + this.transactionManagerCache.clear(); + this.beanFactory = null; + } + + /** + * Determine the specific transaction manager to use for the given transaction. + */ + @Nullable + protected PlatformTransactionManager determineTransactionManager(@Nullable TransactionAttribute txAttr) { + // Do not attempt to lookup tx manager if no tx attributes are set + if (txAttr == null || this.beanFactory == null) { + return getTransactionManager(); + } + + String qualifier = txAttr.getQualifier(); + if (StringUtils.hasText(qualifier)) { + return determineQualifiedTransactionManager(this.beanFactory, qualifier); + } + else if (StringUtils.hasText(this.transactionManagerBeanName)) { + return determineQualifiedTransactionManager(this.beanFactory, this.transactionManagerBeanName); + } + else { + PlatformTransactionManager defaultTransactionManager = getTransactionManager(); + if (defaultTransactionManager == null) { + defaultTransactionManager = this.transactionManagerCache.get(DEFAULT_TRANSACTION_MANAGER_KEY); + if (defaultTransactionManager == null) { + defaultTransactionManager = this.beanFactory.getBean(PlatformTransactionManager.class); + this.transactionManagerCache.putIfAbsent( + DEFAULT_TRANSACTION_MANAGER_KEY, defaultTransactionManager); + } + } + return defaultTransactionManager; + } + } + + private PlatformTransactionManager determineQualifiedTransactionManager(BeanFactory beanFactory, String qualifier) { + PlatformTransactionManager txManager = this.transactionManagerCache.get(qualifier); + if (txManager == null) { + txManager = BeanFactoryAnnotationUtils.qualifiedBeanOfType( + beanFactory, PlatformTransactionManager.class, qualifier); + this.transactionManagerCache.putIfAbsent(qualifier, txManager); + } + return txManager; + } + + private String methodIdentification(Method method, @Nullable Class targetClass, + @Nullable TransactionAttribute txAttr) { + + String methodIdentification = methodIdentification(method, targetClass); + if (methodIdentification == null) { + if (txAttr instanceof DefaultTransactionAttribute) { + methodIdentification = ((DefaultTransactionAttribute) txAttr).getDescriptor(); + } + if (methodIdentification == null) { + methodIdentification = ClassUtils.getQualifiedMethodName(method, targetClass); + } + } + return methodIdentification; + } + + /** + * Convenience method to return a String representation of this Method + * for use in logging. Can be overridden in subclasses to provide a + * different identifier for the given method. + *

The default implementation returns {@code null}, indicating the + * use of {@link DefaultTransactionAttribute#getDescriptor()} instead, + * ending up as {@link ClassUtils#getQualifiedMethodName(Method, Class)}. + * @param method the method we're interested in + * @param targetClass the class that the method is being invoked on + * @return a String representation identifying this method + * @see org.springframework.util.ClassUtils#getQualifiedMethodName + */ + @Nullable + protected String methodIdentification(Method method, @Nullable Class targetClass) { + return null; + } + + /** + * Create a transaction if necessary based on the given TransactionAttribute. + *

Allows callers to perform custom TransactionAttribute lookups through + * the TransactionAttributeSource. + * @param txAttr the TransactionAttribute (may be {@code null}) + * @param joinpointIdentification the fully qualified method name + * (used for monitoring and logging purposes) + * @return a TransactionInfo object, whether or not a transaction was created. + * The {@code hasTransaction()} method on TransactionInfo can be used to + * tell if there was a transaction created. + * @see #getTransactionAttributeSource() + */ + @SuppressWarnings("serial") + protected TransactionInfo createTransactionIfNecessary(@Nullable PlatformTransactionManager tm, + @Nullable TransactionAttribute txAttr, final String joinpointIdentification) { + + // If no name specified, apply method identification as transaction name. + if (txAttr != null && txAttr.getName() == null) { + txAttr = new DelegatingTransactionAttribute(txAttr) { + @Override + public String getName() { + return joinpointIdentification; + } + }; + } + + TransactionStatus status = null; + if (txAttr != null) { + if (tm != null) { + status = tm.getTransaction(txAttr); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("Skipping transactional joinpoint [" + joinpointIdentification + + "] because no transaction manager has been configured"); + } + } + } + return prepareTransactionInfo(tm, txAttr, joinpointIdentification, status); + } + + /** + * Prepare a TransactionInfo for the given attribute and status object. + * @param txAttr the TransactionAttribute (may be {@code null}) + * @param joinpointIdentification the fully qualified method name + * (used for monitoring and logging purposes) + * @param status the TransactionStatus for the current transaction + * @return the prepared TransactionInfo object + */ + protected TransactionInfo prepareTransactionInfo(@Nullable PlatformTransactionManager tm, + @Nullable TransactionAttribute txAttr, String joinpointIdentification, + @Nullable TransactionStatus status) { + + TransactionInfo txInfo = new TransactionInfo(tm, txAttr, joinpointIdentification); + if (txAttr != null) { + // We need a transaction for this method... + if (logger.isTraceEnabled()) { + logger.trace("Getting transaction for [" + txInfo.getJoinpointIdentification() + "]"); + } + // The transaction manager will flag an error if an incompatible tx already exists. + txInfo.newTransactionStatus(status); + } + else { + // The TransactionInfo.hasTransaction() method will return false. We created it only + // to preserve the integrity of the ThreadLocal stack maintained in this class. + if (logger.isTraceEnabled()) { + logger.trace("No need to create transaction for [" + joinpointIdentification + + "]: This method is not transactional."); + } + } + + // We always bind the TransactionInfo to the thread, even if we didn't create + // a new transaction here. This guarantees that the TransactionInfo stack + // will be managed correctly even if no transaction was created by this aspect. + txInfo.bindToThread(); + return txInfo; + } + + /** + * Execute after successful completion of call, but not after an exception was handled. + * Do nothing if we didn't create a transaction. + * @param txInfo information about the current transaction + */ + protected void commitTransactionAfterReturning(@Nullable TransactionInfo txInfo) { + if (txInfo != null && txInfo.getTransactionStatus() != null) { + if (logger.isTraceEnabled()) { + logger.trace("Completing transaction for [" + txInfo.getJoinpointIdentification() + "]"); + } + txInfo.getTransactionManager().commit(txInfo.getTransactionStatus()); + } + } + + /** + * Handle a throwable, completing the transaction. + * We may commit or roll back, depending on the configuration. + * @param txInfo information about the current transaction + * @param ex throwable encountered + */ + protected void completeTransactionAfterThrowing(@Nullable TransactionInfo txInfo, Throwable ex) { + if (txInfo != null && txInfo.getTransactionStatus() != null) { + if (logger.isTraceEnabled()) { + logger.trace("Completing transaction for [" + txInfo.getJoinpointIdentification() + + "] after exception: " + ex); + } + if (txInfo.transactionAttribute != null && txInfo.transactionAttribute.rollbackOn(ex)) { + try { + txInfo.getTransactionManager().rollback(txInfo.getTransactionStatus()); + } + catch (TransactionSystemException ex2) { + logger.error("Application exception overridden by rollback exception", ex); + ex2.initApplicationException(ex); + throw ex2; + } + catch (RuntimeException | Error ex2) { + logger.error("Application exception overridden by rollback exception", ex); + throw ex2; + } + } + else { + // We don't roll back on this exception. + // Will still roll back if TransactionStatus.isRollbackOnly() is true. + try { + txInfo.getTransactionManager().commit(txInfo.getTransactionStatus()); + } + catch (TransactionSystemException ex2) { + logger.error("Application exception overridden by commit exception", ex); + ex2.initApplicationException(ex); + throw ex2; + } + catch (RuntimeException | Error ex2) { + logger.error("Application exception overridden by commit exception", ex); + throw ex2; + } + } + } + } + + /** + * Reset the TransactionInfo ThreadLocal. + *

Call this in all cases: exception or normal return! + * @param txInfo information about the current transaction (may be {@code null}) + */ + protected void cleanupTransactionInfo(@Nullable TransactionInfo txInfo) { + if (txInfo != null) { + txInfo.restoreThreadLocalStatus(); + } + } + + + /** + * Opaque object used to hold transaction information. Subclasses + * must pass it back to methods on this class, but not see its internals. + */ + protected final class TransactionInfo { + + @Nullable + private final PlatformTransactionManager transactionManager; + + @Nullable + private final TransactionAttribute transactionAttribute; + + private final String joinpointIdentification; + + @Nullable + private TransactionStatus transactionStatus; + + @Nullable + private TransactionInfo oldTransactionInfo; + + public TransactionInfo(@Nullable PlatformTransactionManager transactionManager, + @Nullable TransactionAttribute transactionAttribute, String joinpointIdentification) { + + this.transactionManager = transactionManager; + this.transactionAttribute = transactionAttribute; + this.joinpointIdentification = joinpointIdentification; + } + + public PlatformTransactionManager getTransactionManager() { + Assert.state(this.transactionManager != null, "No PlatformTransactionManager set"); + return this.transactionManager; + } + + @Nullable + public TransactionAttribute getTransactionAttribute() { + return this.transactionAttribute; + } + + /** + * Return a String representation of this joinpoint (usually a Method call) + * for use in logging. + */ + public String getJoinpointIdentification() { + return this.joinpointIdentification; + } + + public void newTransactionStatus(@Nullable TransactionStatus status) { + this.transactionStatus = status; + } + + @Nullable + public TransactionStatus getTransactionStatus() { + return this.transactionStatus; + } + + /** + * Return whether a transaction was created by this aspect, + * or whether we just have a placeholder to keep ThreadLocal stack integrity. + */ + public boolean hasTransaction() { + return (this.transactionStatus != null); + } + + private void bindToThread() { + // Expose current TransactionStatus, preserving any existing TransactionStatus + // for restoration after this transaction is complete. + this.oldTransactionInfo = transactionInfoHolder.get(); + transactionInfoHolder.set(this); + } + + private void restoreThreadLocalStatus() { + // Use stack to restore old transaction TransactionInfo. + // Will be null if none was set. + transactionInfoHolder.set(this.oldTransactionInfo); + } + + @Override + public String toString() { + return (this.transactionAttribute != null ? this.transactionAttribute.toString() : "No transaction"); + } + } + + + /** + * Simple callback interface for proceeding with the target invocation. + * Concrete interceptors/aspects adapt this to their invocation mechanism. + */ + @FunctionalInterface + protected interface InvocationCallback { + + @Nullable + Object proceedWithInvocation() throws Throwable; + } + + + /** + * Internal holder class for a Throwable in a callback transaction model. + */ + private static class ThrowableHolder { + + @Nullable + public Throwable throwable; + } + + + /** + * Internal holder class for a Throwable, used as a RuntimeException to be + * thrown from a TransactionCallback (and subsequently unwrapped again). + */ + @SuppressWarnings("serial") + private static class ThrowableHolderException extends RuntimeException { + + public ThrowableHolderException(Throwable throwable) { + super(throwable); + } + + @Override + public String toString() { + return getCause().toString(); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttribute.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..25ad7658482b58d94ce386ce75e117d77ef528b6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttribute.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionDefinition; + +/** + * This interface adds a {@code rollbackOn} specification to {@link TransactionDefinition}. + * As custom {@code rollbackOn} is only possible with AOP, it resides in the AOP-related + * transaction subpackage. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 16.03.2003 + * @see DefaultTransactionAttribute + * @see RuleBasedTransactionAttribute + */ +public interface TransactionAttribute extends TransactionDefinition { + + /** + * Return a qualifier value associated with this transaction attribute. + *

This may be used for choosing a corresponding transaction manager + * to process this specific transaction. + * @since 3.0 + */ + @Nullable + String getQualifier(); + + /** + * Should we roll back on the given exception? + * @param ex the exception to evaluate + * @return whether to perform a rollback or not + */ + boolean rollbackOn(Throwable ex); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeEditor.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeEditor.java new file mode 100644 index 0000000000000000000000000000000000000000..f89a83bd7404306ab958c4fa11ed9c66946e3045 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeEditor.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.beans.PropertyEditorSupport; + +import org.springframework.util.StringUtils; + +/** + * PropertyEditor for {@link TransactionAttribute} objects. Accepts a String of form + *

{@code PROPAGATION_NAME, ISOLATION_NAME, readOnly, timeout_NNNN,+Exception1,-Exception2} + *

where only propagation code is required. For example: + *

{@code PROPAGATION_MANDATORY, ISOLATION_DEFAULT} + * + *

The tokens can be in any order. Propagation and isolation codes + * must use the names of the constants in the TransactionDefinition class. Timeout values + * are in seconds. If no timeout is specified, the transaction manager will apply a default + * timeout specific to the particular transaction manager. + * + *

A "+" before an exception name substring indicates that transactions should commit + * even if this exception is thrown; a "-" that they should roll back. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 24.04.2003 + * @see org.springframework.transaction.TransactionDefinition + * @see org.springframework.core.Constants + */ +public class TransactionAttributeEditor extends PropertyEditorSupport { + + /** + * Format is PROPAGATION_NAME,ISOLATION_NAME,readOnly,timeout_NNNN,+Exception1,-Exception2. + * Null or the empty string means that the method is non transactional. + */ + @Override + public void setAsText(String text) throws IllegalArgumentException { + if (StringUtils.hasLength(text)) { + // tokenize it with "," + String[] tokens = StringUtils.commaDelimitedListToStringArray(text); + RuleBasedTransactionAttribute attr = new RuleBasedTransactionAttribute(); + for (String token : tokens) { + // Trim leading and trailing whitespace. + String trimmedToken = StringUtils.trimWhitespace(token.trim()); + // Check whether token contains illegal whitespace within text. + if (StringUtils.containsWhitespace(trimmedToken)) { + throw new IllegalArgumentException( + "Transaction attribute token contains illegal whitespace: [" + trimmedToken + "]"); + } + // Check token type. + if (trimmedToken.startsWith(RuleBasedTransactionAttribute.PREFIX_PROPAGATION)) { + attr.setPropagationBehaviorName(trimmedToken); + } + else if (trimmedToken.startsWith(RuleBasedTransactionAttribute.PREFIX_ISOLATION)) { + attr.setIsolationLevelName(trimmedToken); + } + else if (trimmedToken.startsWith(RuleBasedTransactionAttribute.PREFIX_TIMEOUT)) { + String value = trimmedToken.substring(DefaultTransactionAttribute.PREFIX_TIMEOUT.length()); + attr.setTimeout(Integer.parseInt(value)); + } + else if (trimmedToken.equals(RuleBasedTransactionAttribute.READ_ONLY_MARKER)) { + attr.setReadOnly(true); + } + else if (trimmedToken.startsWith(RuleBasedTransactionAttribute.PREFIX_COMMIT_RULE)) { + attr.getRollbackRules().add(new NoRollbackRuleAttribute(trimmedToken.substring(1))); + } + else if (trimmedToken.startsWith(RuleBasedTransactionAttribute.PREFIX_ROLLBACK_RULE)) { + attr.getRollbackRules().add(new RollbackRuleAttribute(trimmedToken.substring(1))); + } + else { + throw new IllegalArgumentException("Invalid transaction attribute token: [" + trimmedToken + "]"); + } + } + setValue(attr); + } + else { + setValue(null); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSource.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..254832dbc25a39c467757914c27131f6cef13ef5 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSource.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; + +import org.springframework.lang.Nullable; + +/** + * Strategy interface used by {@link TransactionInterceptor} for metadata retrieval. + * + *

Implementations know how to source transaction attributes, whether from configuration, + * metadata attributes at source level (such as Java 5 annotations), or anywhere else. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 15.04.2003 + * @see TransactionInterceptor#setTransactionAttributeSource + * @see TransactionProxyFactoryBean#setTransactionAttributeSource + * @see org.springframework.transaction.annotation.AnnotationTransactionAttributeSource + */ +public interface TransactionAttributeSource { + + /** + * Return the transaction attribute for the given method, + * or {@code null} if the method is non-transactional. + * @param method the method to introspect + * @param targetClass the target class (may be {@code null}, + * in which case the declaring class of the method must be used) + * @return the matching transaction attribute, or {@code null} if none found + */ + @Nullable + TransactionAttribute getTransactionAttribute(Method method, @Nullable Class targetClass); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourceAdvisor.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourceAdvisor.java new file mode 100644 index 0000000000000000000000000000000000000000..5e65816fdba39b8528970932ff9e91dea4bd49e6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourceAdvisor.java @@ -0,0 +1,97 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.aopalliance.aop.Advice; + +import org.springframework.aop.ClassFilter; +import org.springframework.aop.Pointcut; +import org.springframework.aop.support.AbstractPointcutAdvisor; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Advisor driven by a {@link TransactionAttributeSource}, used to include + * a {@link TransactionInterceptor} only for methods that are transactional. + * + *

Because the AOP framework caches advice calculations, this is normally + * faster than just letting the TransactionInterceptor run and find out + * itself that it has no work to do. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @see #setTransactionInterceptor + * @see TransactionProxyFactoryBean + */ +@SuppressWarnings("serial") +public class TransactionAttributeSourceAdvisor extends AbstractPointcutAdvisor { + + @Nullable + private TransactionInterceptor transactionInterceptor; + + private final TransactionAttributeSourcePointcut pointcut = new TransactionAttributeSourcePointcut() { + @Override + @Nullable + protected TransactionAttributeSource getTransactionAttributeSource() { + return (transactionInterceptor != null ? transactionInterceptor.getTransactionAttributeSource() : null); + } + }; + + + /** + * Create a new TransactionAttributeSourceAdvisor. + */ + public TransactionAttributeSourceAdvisor() { + } + + /** + * Create a new TransactionAttributeSourceAdvisor. + * @param interceptor the transaction interceptor to use for this advisor + */ + public TransactionAttributeSourceAdvisor(TransactionInterceptor interceptor) { + setTransactionInterceptor(interceptor); + } + + + /** + * Set the transaction interceptor to use for this advisor. + */ + public void setTransactionInterceptor(TransactionInterceptor interceptor) { + this.transactionInterceptor = interceptor; + } + + /** + * Set the {@link ClassFilter} to use for this pointcut. + * Default is {@link ClassFilter#TRUE}. + */ + public void setClassFilter(ClassFilter classFilter) { + this.pointcut.setClassFilter(classFilter); + } + + + @Override + public Advice getAdvice() { + Assert.state(this.transactionInterceptor != null, "No TransactionInterceptor set"); + return this.transactionInterceptor; + } + + @Override + public Pointcut getPointcut() { + return this.pointcut; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourceEditor.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourceEditor.java new file mode 100644 index 0000000000000000000000000000000000000000..a1c9cbac8361d6d7c2924880a589fc35364ce034 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourceEditor.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.beans.PropertyEditorSupport; +import java.util.Enumeration; +import java.util.Properties; + +import org.springframework.beans.propertyeditors.PropertiesEditor; +import org.springframework.util.StringUtils; + +/** + * Property editor that converts a String into a {@link TransactionAttributeSource}. + * The transaction attribute string must be parseable by the + * {@link TransactionAttributeEditor} in this package. + * + *

Strings are in property syntax, with the form:
+ * {@code FQCN.methodName=<transaction attribute string>} + * + *

For example:
+ * {@code com.mycompany.mycode.MyClass.myMethod=PROPAGATION_MANDATORY,ISOLATION_DEFAULT} + * + *

NOTE: The specified class must be the one where the methods are + * defined; in case of implementing an interface, the interface class name. + * + *

Note: Will register all overloaded methods for a given name. + * Does not support explicit registration of certain overloaded methods. + * Supports "xxx*" mappings, e.g. "notify*" for "notify" and "notifyAll". + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 26.04.2003 + * @see TransactionAttributeEditor + */ +public class TransactionAttributeSourceEditor extends PropertyEditorSupport { + + @Override + public void setAsText(String text) throws IllegalArgumentException { + MethodMapTransactionAttributeSource source = new MethodMapTransactionAttributeSource(); + if (StringUtils.hasLength(text)) { + // Use properties editor to tokenize the hold string. + PropertiesEditor propertiesEditor = new PropertiesEditor(); + propertiesEditor.setAsText(text); + Properties props = (Properties) propertiesEditor.getValue(); + + // Now we have properties, process each one individually. + TransactionAttributeEditor tae = new TransactionAttributeEditor(); + Enumeration propNames = props.propertyNames(); + while (propNames.hasMoreElements()) { + String name = (String) propNames.nextElement(); + String value = props.getProperty(name); + // Convert value to a transaction attribute. + tae.setAsText(value); + TransactionAttribute attr = (TransactionAttribute) tae.getValue(); + // Register name and attribute. + source.addTransactionalMethod(name, attr); + } + } + setValue(source); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourcePointcut.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourcePointcut.java new file mode 100644 index 0000000000000000000000000000000000000000..f578768212e94572c4f7cffb30a21d23c4382d8d --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAttributeSourcePointcut.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; +import java.lang.reflect.Method; + +import org.springframework.aop.support.StaticMethodMatcherPointcut; +import org.springframework.dao.support.PersistenceExceptionTranslator; +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.util.ObjectUtils; + +/** + * Inner class that implements a Pointcut that matches if the underlying + * {@link TransactionAttributeSource} has an attribute for a given method. + * + * @author Juergen Hoeller + * @since 2.5.5 + */ +@SuppressWarnings("serial") +abstract class TransactionAttributeSourcePointcut extends StaticMethodMatcherPointcut implements Serializable { + + @Override + public boolean matches(Method method, Class targetClass) { + if (TransactionalProxy.class.isAssignableFrom(targetClass) || + PlatformTransactionManager.class.isAssignableFrom(targetClass) || + PersistenceExceptionTranslator.class.isAssignableFrom(targetClass)) { + return false; + } + TransactionAttributeSource tas = getTransactionAttributeSource(); + return (tas == null || tas.getTransactionAttribute(method, targetClass) != null); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof TransactionAttributeSourcePointcut)) { + return false; + } + TransactionAttributeSourcePointcut otherPc = (TransactionAttributeSourcePointcut) other; + return ObjectUtils.nullSafeEquals(getTransactionAttributeSource(), otherPc.getTransactionAttributeSource()); + } + + @Override + public int hashCode() { + return TransactionAttributeSourcePointcut.class.hashCode(); + } + + @Override + public String toString() { + return getClass().getName() + ": " + getTransactionAttributeSource(); + } + + + /** + * Obtain the underlying TransactionAttributeSource (may be {@code null}). + * To be implemented by subclasses. + */ + @Nullable + protected abstract TransactionAttributeSource getTransactionAttributeSource(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionInterceptor.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..c7a93c1af0ea53e478e78ae54d9cd8233ae29b0b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionInterceptor.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.Properties; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; + +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; + +/** + * AOP Alliance MethodInterceptor for declarative transaction + * management using the common Spring transaction infrastructure + * ({@link org.springframework.transaction.PlatformTransactionManager}). + * + *

Derives from the {@link TransactionAspectSupport} class which + * contains the integration with Spring's underlying transaction API. + * TransactionInterceptor simply calls the relevant superclass methods + * such as {@link #invokeWithinTransaction} in the correct order. + * + *

TransactionInterceptors are thread-safe. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @see TransactionProxyFactoryBean + * @see org.springframework.aop.framework.ProxyFactoryBean + * @see org.springframework.aop.framework.ProxyFactory + */ +@SuppressWarnings("serial") +public class TransactionInterceptor extends TransactionAspectSupport implements MethodInterceptor, Serializable { + + /** + * Create a new TransactionInterceptor. + *

Transaction manager and transaction attributes still need to be set. + * @see #setTransactionManager + * @see #setTransactionAttributes(java.util.Properties) + * @see #setTransactionAttributeSource(TransactionAttributeSource) + */ + public TransactionInterceptor() { + } + + /** + * Create a new TransactionInterceptor. + * @param ptm the default transaction manager to perform the actual transaction management + * @param attributes the transaction attributes in properties format + * @see #setTransactionManager + * @see #setTransactionAttributes(java.util.Properties) + */ + public TransactionInterceptor(PlatformTransactionManager ptm, Properties attributes) { + setTransactionManager(ptm); + setTransactionAttributes(attributes); + } + + /** + * Create a new TransactionInterceptor. + * @param ptm the default transaction manager to perform the actual transaction management + * @param tas the attribute source to be used to find transaction attributes + * @see #setTransactionManager + * @see #setTransactionAttributeSource(TransactionAttributeSource) + */ + public TransactionInterceptor(PlatformTransactionManager ptm, TransactionAttributeSource tas) { + setTransactionManager(ptm); + setTransactionAttributeSource(tas); + } + + + @Override + @Nullable + public Object invoke(MethodInvocation invocation) throws Throwable { + // Work out the target class: may be {@code null}. + // The TransactionAttributeSource should be passed the target class + // as well as the method, which may be from an interface. + Class targetClass = (invocation.getThis() != null ? AopUtils.getTargetClass(invocation.getThis()) : null); + + // Adapt to TransactionAspectSupport's invokeWithinTransaction... + return invokeWithinTransaction(invocation.getMethod(), targetClass, invocation::proceed); + } + + + //--------------------------------------------------------------------- + // Serialization support + //--------------------------------------------------------------------- + + private void writeObject(ObjectOutputStream oos) throws IOException { + // Rely on default serialization, although this class itself doesn't carry state anyway... + oos.defaultWriteObject(); + + // Deserialize superclass fields. + oos.writeObject(getTransactionManagerBeanName()); + oos.writeObject(getTransactionManager()); + oos.writeObject(getTransactionAttributeSource()); + oos.writeObject(getBeanFactory()); + } + + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + // Rely on default serialization, although this class itself doesn't carry state anyway... + ois.defaultReadObject(); + + // Serialize all relevant superclass fields. + // Superclass can't implement Serializable because it also serves as base class + // for AspectJ aspects (which are not allowed to implement Serializable)! + setTransactionManagerBeanName((String) ois.readObject()); + setTransactionManager((PlatformTransactionManager) ois.readObject()); + setTransactionAttributeSource((TransactionAttributeSource) ois.readObject()); + setBeanFactory((BeanFactory) ois.readObject()); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionProxyFactoryBean.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionProxyFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..e6ebbe1441a05b2172f22ae9d115dd00a6195a1b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionProxyFactoryBean.java @@ -0,0 +1,213 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.util.Properties; + +import org.springframework.aop.Pointcut; +import org.springframework.aop.framework.AbstractSingletonProxyFactoryBean; +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.aop.support.DefaultPointcutAdvisor; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; + +/** + * Proxy factory bean for simplified declarative transaction handling. + * This is a convenient alternative to a standard AOP + * {@link org.springframework.aop.framework.ProxyFactoryBean} + * with a separate {@link TransactionInterceptor} definition. + * + *

HISTORICAL NOTE: This class was originally designed to cover the + * typical case of declarative transaction demarcation: namely, wrapping a singleton + * target object with a transactional proxy, proxying all the interfaces that the target + * implements. However, in Spring versions 2.0 and beyond, the functionality provided here + * is superseded by the more convenient {@code tx:} XML namespace. See the declarative transaction management section of the + * Spring reference documentation to understand the modern options for managing + * transactions in Spring applications. For these reasons, users should favor of + * the {@code tx:} XML namespace as well as + * the @{@link org.springframework.transaction.annotation.Transactional Transactional} + * and @{@link org.springframework.transaction.annotation.EnableTransactionManagement + * EnableTransactionManagement} annotations. + * + *

There are three main properties that need to be specified: + *

    + *
  • "transactionManager": the {@link PlatformTransactionManager} implementation to use + * (for example, a {@link org.springframework.transaction.jta.JtaTransactionManager} instance) + *
  • "target": the target object that a transactional proxy should be created for + *
  • "transactionAttributes": the transaction attributes (for example, propagation + * behavior and "readOnly" flag) per target method name (or method name pattern) + *
+ * + *

If the "transactionManager" property is not set explicitly and this {@link FactoryBean} + * is running in a {@link ListableBeanFactory}, a single matching bean of type + * {@link PlatformTransactionManager} will be fetched from the {@link BeanFactory}. + * + *

In contrast to {@link TransactionInterceptor}, the transaction attributes are + * specified as properties, with method names as keys and transaction attribute + * descriptors as values. Method names are always applied to the target class. + * + *

Internally, a {@link TransactionInterceptor} instance is used, but the user of this + * class does not have to care. Optionally, a method pointcut can be specified + * to cause conditional invocation of the underlying {@link TransactionInterceptor}. + * + *

The "preInterceptors" and "postInterceptors" properties can be set to add + * additional interceptors to the mix, like + * {@link org.springframework.aop.interceptor.PerformanceMonitorInterceptor}. + * + *

HINT: This class is often used with parent / child bean definitions. + * Typically, you will define the transaction manager and default transaction + * attributes (for method name patterns) in an abstract parent bean definition, + * deriving concrete child bean definitions for specific target objects. + * This reduces the per-bean definition effort to a minimum. + * + *

+ * <bean id="baseTransactionProxy" class="org.springframework.transaction.interceptor.TransactionProxyFactoryBean"
+ *     abstract="true">
+ *   <property name="transactionManager" ref="transactionManager"/>
+ *   <property name="transactionAttributes">
+ *     <props>
+ *       <prop key="insert*">PROPAGATION_REQUIRED</prop>
+ *       <prop key="update*">PROPAGATION_REQUIRED</prop>
+ *       <prop key="*">PROPAGATION_REQUIRED,readOnly</prop>
+ *     </props>
+ *   </property>
+ * </bean>
+ *
+ * <bean id="myProxy" parent="baseTransactionProxy">
+ *   <property name="target" ref="myTarget"/>
+ * </bean>
+ *
+ * <bean id="yourProxy" parent="baseTransactionProxy">
+ *   <property name="target" ref="yourTarget"/>
+ * </bean>
+ * + * @author Juergen Hoeller + * @author Dmitriy Kopylenko + * @author Rod Johnson + * @author Chris Beams + * @since 21.08.2003 + * @see #setTransactionManager + * @see #setTarget + * @see #setTransactionAttributes + * @see TransactionInterceptor + * @see org.springframework.aop.framework.ProxyFactoryBean + */ +@SuppressWarnings("serial") +public class TransactionProxyFactoryBean extends AbstractSingletonProxyFactoryBean + implements BeanFactoryAware { + + private final TransactionInterceptor transactionInterceptor = new TransactionInterceptor(); + + @Nullable + private Pointcut pointcut; + + + /** + * Set the default transaction manager. This will perform actual + * transaction management: This class is just a way of invoking it. + * @see TransactionInterceptor#setTransactionManager + */ + public void setTransactionManager(PlatformTransactionManager transactionManager) { + this.transactionInterceptor.setTransactionManager(transactionManager); + } + + /** + * Set properties with method names as keys and transaction attribute + * descriptors (parsed via TransactionAttributeEditor) as values: + * e.g. key = "myMethod", value = "PROPAGATION_REQUIRED,readOnly". + *

Note: Method names are always applied to the target class, + * no matter if defined in an interface or the class itself. + *

Internally, a NameMatchTransactionAttributeSource will be + * created from the given properties. + * @see #setTransactionAttributeSource + * @see TransactionInterceptor#setTransactionAttributes + * @see TransactionAttributeEditor + * @see NameMatchTransactionAttributeSource + */ + public void setTransactionAttributes(Properties transactionAttributes) { + this.transactionInterceptor.setTransactionAttributes(transactionAttributes); + } + + /** + * Set the transaction attribute source which is used to find transaction + * attributes. If specifying a String property value, a PropertyEditor + * will create a MethodMapTransactionAttributeSource from the value. + * @see #setTransactionAttributes + * @see TransactionInterceptor#setTransactionAttributeSource + * @see TransactionAttributeSourceEditor + * @see MethodMapTransactionAttributeSource + * @see NameMatchTransactionAttributeSource + * @see org.springframework.transaction.annotation.AnnotationTransactionAttributeSource + */ + public void setTransactionAttributeSource(TransactionAttributeSource transactionAttributeSource) { + this.transactionInterceptor.setTransactionAttributeSource(transactionAttributeSource); + } + + /** + * Set a pointcut, i.e a bean that can cause conditional invocation + * of the TransactionInterceptor depending on method and attributes passed. + * Note: Additional interceptors are always invoked. + * @see #setPreInterceptors + * @see #setPostInterceptors + */ + public void setPointcut(Pointcut pointcut) { + this.pointcut = pointcut; + } + + /** + * This callback is optional: If running in a BeanFactory and no transaction + * manager has been set explicitly, a single matching bean of type + * {@link PlatformTransactionManager} will be fetched from the BeanFactory. + * @see org.springframework.beans.factory.BeanFactory#getBean(Class) + * @see org.springframework.transaction.PlatformTransactionManager + */ + @Override + public void setBeanFactory(BeanFactory beanFactory) { + this.transactionInterceptor.setBeanFactory(beanFactory); + } + + + /** + * Creates an advisor for this FactoryBean's TransactionInterceptor. + */ + @Override + protected Object createMainInterceptor() { + this.transactionInterceptor.afterPropertiesSet(); + if (this.pointcut != null) { + return new DefaultPointcutAdvisor(this.pointcut, this.transactionInterceptor); + } + else { + // Rely on default pointcut. + return new TransactionAttributeSourceAdvisor(this.transactionInterceptor); + } + } + + /** + * As of 4.2, this method adds {@link TransactionalProxy} to the set of + * proxy interfaces in order to avoid re-processing of transaction metadata. + */ + @Override + protected void postProcessProxyFactory(ProxyFactory proxyFactory) { + proxyFactory.addInterface(TransactionalProxy.class); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionalProxy.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionalProxy.java new file mode 100644 index 0000000000000000000000000000000000000000..f14664498c55452aca68d7c2472f7912b3f4b947 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionalProxy.java @@ -0,0 +1,33 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.springframework.aop.SpringProxy; + +/** + * A marker interface for manually created transactional proxies. + * + *

{@link TransactionAttributeSourcePointcut} will ignore such existing + * transactional proxies during AOP auto-proxying and therefore avoid + * re-processing transaction metadata on them. + * + * @author Juergen Hoeller + * @since 4.1.7 + */ +public interface TransactionalProxy extends SpringProxy { + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/package-info.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..44e8ff8b6ce0f39a2c9ed073936e944874655362 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/package-info.java @@ -0,0 +1,19 @@ +/** + * AOP-based solution for declarative transaction demarcation. + * Builds on the AOP infrastructure in org.springframework.aop.framework. + * Any POJO can be transactionally advised with Spring. + * + *

The TransactionFactoryProxyBean can be used to create transactional + * AOP proxies transparently to code that uses them. + * + *

The TransactionInterceptor is the AOP Alliance MethodInterceptor that + * delivers transactional advice, based on the Spring transaction abstraction. + * This allows declarative transaction management in any environment, + * even without JTA if an application uses only a single database. + */ +@NonNullApi +@NonNullFields +package org.springframework.transaction.interceptor; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/JtaAfterCompletionSynchronization.java b/spring-tx/src/main/java/org/springframework/transaction/jta/JtaAfterCompletionSynchronization.java new file mode 100644 index 0000000000000000000000000000000000000000..807a19fcf7128393c5ac0d378ea9b1bfd4af418e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/JtaAfterCompletionSynchronization.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import java.util.List; + +import javax.transaction.Status; +import javax.transaction.Synchronization; + +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationUtils; + +/** + * Adapter for a JTA Synchronization, invoking the {@code afterCommit} / + * {@code afterCompletion} callbacks of Spring {@link TransactionSynchronization} + * objects callbacks after the outer JTA transaction has completed. + * Applied when participating in an existing (non-Spring) JTA transaction. + * + * @author Juergen Hoeller + * @since 2.0 + * @see TransactionSynchronization#afterCommit + * @see TransactionSynchronization#afterCompletion + */ +public class JtaAfterCompletionSynchronization implements Synchronization { + + private final List synchronizations; + + + /** + * Create a new JtaAfterCompletionSynchronization for the given synchronization objects. + * @param synchronizations the List of TransactionSynchronization objects + * @see org.springframework.transaction.support.TransactionSynchronization + */ + public JtaAfterCompletionSynchronization(List synchronizations) { + this.synchronizations = synchronizations; + } + + + @Override + public void beforeCompletion() { + } + + @Override + public void afterCompletion(int status) { + switch (status) { + case Status.STATUS_COMMITTED: + try { + TransactionSynchronizationUtils.invokeAfterCommit(this.synchronizations); + } + finally { + TransactionSynchronizationUtils.invokeAfterCompletion( + this.synchronizations, TransactionSynchronization.STATUS_COMMITTED); + } + break; + case Status.STATUS_ROLLEDBACK: + TransactionSynchronizationUtils.invokeAfterCompletion( + this.synchronizations, TransactionSynchronization.STATUS_ROLLED_BACK); + break; + default: + TransactionSynchronizationUtils.invokeAfterCompletion( + this.synchronizations, TransactionSynchronization.STATUS_UNKNOWN); + } + } +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/JtaTransactionManager.java b/spring-tx/src/main/java/org/springframework/transaction/jta/JtaTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..ca6e99173cfd0881ba2298774501bddaa7805bee --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/JtaTransactionManager.java @@ -0,0 +1,1237 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; +import java.util.List; +import java.util.Properties; + +import javax.naming.NamingException; +import javax.transaction.HeuristicMixedException; +import javax.transaction.HeuristicRollbackException; +import javax.transaction.InvalidTransactionException; +import javax.transaction.NotSupportedException; +import javax.transaction.RollbackException; +import javax.transaction.Status; +import javax.transaction.SystemException; +import javax.transaction.Transaction; +import javax.transaction.TransactionManager; +import javax.transaction.TransactionSynchronizationRegistry; +import javax.transaction.UserTransaction; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.jndi.JndiTemplate; +import org.springframework.lang.Nullable; +import org.springframework.transaction.CannotCreateTransactionException; +import org.springframework.transaction.HeuristicCompletionException; +import org.springframework.transaction.IllegalTransactionStateException; +import org.springframework.transaction.InvalidIsolationLevelException; +import org.springframework.transaction.NestedTransactionNotSupportedException; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionSuspensionNotSupportedException; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.transaction.UnexpectedRollbackException; +import org.springframework.transaction.support.AbstractPlatformTransactionManager; +import org.springframework.transaction.support.DefaultTransactionStatus; +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * {@link org.springframework.transaction.PlatformTransactionManager} implementation + * for JTA, delegating to a backend JTA provider. This is typically used to delegate + * to a Java EE server's transaction coordinator, but may also be configured with a + * local JTA provider which is embedded within the application. + * + *

This transaction manager is appropriate for handling distributed transactions, + * i.e. transactions that span multiple resources, and for controlling transactions on + * application server resources (e.g. JDBC DataSources available in JNDI) in general. + * For a single JDBC DataSource, DataSourceTransactionManager is perfectly sufficient, + * and for accessing a single resource with Hibernate (including transactional cache), + * HibernateTransactionManager is appropriate, for example. + * + *

For typical JTA transactions (REQUIRED, SUPPORTS, MANDATORY, NEVER), a plain + * JtaTransactionManager definition is all you need, portable across all Java EE servers. + * This corresponds to the functionality of the JTA UserTransaction, for which Java EE + * specifies a standard JNDI name ("java:comp/UserTransaction"). There is no need to + * configure a server-specific TransactionManager lookup for this kind of JTA usage. + * + *

Transaction suspension (REQUIRES_NEW, NOT_SUPPORTED) is just available with a + * JTA TransactionManager being registered. Common TransactionManager locations are + * autodetected by JtaTransactionManager, provided that the "autodetectTransactionManager" + * flag is set to "true" (which it is by default). + * + *

Note: Support for the JTA TransactionManager interface is not required by Java EE. + * Almost all Java EE servers expose it, but do so as extension to EE. There might be some + * issues with compatibility, despite the TransactionManager interface being part of JTA. + * As a consequence, Spring provides various vendor-specific PlatformTransactionManagers, + * which are recommended to be used if appropriate: {@link WebLogicJtaTransactionManager} + * and {@link WebSphereUowTransactionManager}. For all other Java EE servers, the + * standard JtaTransactionManager is sufficient. + * + *

This pure JtaTransactionManager class supports timeouts but not per-transaction + * isolation levels. Custom subclasses may override the {@link #doJtaBegin} method for + * specific JTA extensions in order to provide this functionality; Spring includes a + * corresponding {@link WebLogicJtaTransactionManager} class for WebLogic Server. Such + * adapters for specific Java EE transaction coordinators may also expose transaction + * names for monitoring; with standard JTA, transaction names will simply be ignored. + * + *

Consider using Spring's {@code tx:jta-transaction-manager} configuration + * element for automatically picking the appropriate JTA platform transaction manager + * (automatically detecting WebLogic and WebSphere). + * + *

JTA 1.1 adds the TransactionSynchronizationRegistry facility, as public Java EE 5 + * API in addition to the standard JTA UserTransaction handle. As of Spring 2.5, this + * JtaTransactionManager autodetects the TransactionSynchronizationRegistry and uses + * it for registering Spring-managed synchronizations when participating in an existing + * JTA transaction (e.g. controlled by EJB CMT). If no TransactionSynchronizationRegistry + * is available, then such synchronizations will be registered via the (non-EE) JTA + * TransactionManager handle. + * + *

This class is serializable. However, active synchronizations do not survive serialization. + * + * @author Juergen Hoeller + * @since 24.03.2003 + * @see javax.transaction.UserTransaction + * @see javax.transaction.TransactionManager + * @see javax.transaction.TransactionSynchronizationRegistry + * @see #setUserTransactionName + * @see #setUserTransaction + * @see #setTransactionManagerName + * @see #setTransactionManager + * @see WebLogicJtaTransactionManager + */ +@SuppressWarnings("serial") +public class JtaTransactionManager extends AbstractPlatformTransactionManager + implements TransactionFactory, InitializingBean, Serializable { + + /** + * Default JNDI location for the JTA UserTransaction. Many Java EE servers + * also provide support for the JTA TransactionManager interface there. + * @see #setUserTransactionName + * @see #setAutodetectTransactionManager + */ + public static final String DEFAULT_USER_TRANSACTION_NAME = "java:comp/UserTransaction"; + + /** + * Fallback JNDI locations for the JTA TransactionManager. Applied if + * the JTA UserTransaction does not implement the JTA TransactionManager + * interface, provided that the "autodetectTransactionManager" flag is "true". + * @see #setTransactionManagerName + * @see #setAutodetectTransactionManager + */ + public static final String[] FALLBACK_TRANSACTION_MANAGER_NAMES = + new String[] {"java:comp/TransactionManager", "java:appserver/TransactionManager", + "java:pm/TransactionManager", "java:/TransactionManager"}; + + /** + * Standard Java EE 5 JNDI location for the JTA TransactionSynchronizationRegistry. + * Autodetected when available. + */ + public static final String DEFAULT_TRANSACTION_SYNCHRONIZATION_REGISTRY_NAME = + "java:comp/TransactionSynchronizationRegistry"; + + + private transient JndiTemplate jndiTemplate = new JndiTemplate(); + + @Nullable + private transient UserTransaction userTransaction; + + @Nullable + private String userTransactionName; + + private boolean autodetectUserTransaction = true; + + private boolean cacheUserTransaction = true; + + private boolean userTransactionObtainedFromJndi = false; + + @Nullable + private transient TransactionManager transactionManager; + + @Nullable + private String transactionManagerName; + + private boolean autodetectTransactionManager = true; + + @Nullable + private transient TransactionSynchronizationRegistry transactionSynchronizationRegistry; + + @Nullable + private String transactionSynchronizationRegistryName; + + private boolean autodetectTransactionSynchronizationRegistry = true; + + private boolean allowCustomIsolationLevels = false; + + + /** + * Create a new JtaTransactionManager instance, to be configured as bean. + * Invoke {@code afterPropertiesSet} to activate the configuration. + * @see #setUserTransactionName + * @see #setUserTransaction + * @see #setTransactionManagerName + * @see #setTransactionManager + * @see #afterPropertiesSet() + */ + public JtaTransactionManager() { + setNestedTransactionAllowed(true); + } + + /** + * Create a new JtaTransactionManager instance. + * @param userTransaction the JTA UserTransaction to use as direct reference + */ + public JtaTransactionManager(UserTransaction userTransaction) { + this(); + Assert.notNull(userTransaction, "UserTransaction must not be null"); + this.userTransaction = userTransaction; + } + + /** + * Create a new JtaTransactionManager instance. + * @param userTransaction the JTA UserTransaction to use as direct reference + * @param transactionManager the JTA TransactionManager to use as direct reference + */ + public JtaTransactionManager(UserTransaction userTransaction, TransactionManager transactionManager) { + this(); + Assert.notNull(userTransaction, "UserTransaction must not be null"); + Assert.notNull(transactionManager, "TransactionManager must not be null"); + this.userTransaction = userTransaction; + this.transactionManager = transactionManager; + } + + /** + * Create a new JtaTransactionManager instance. + * @param transactionManager the JTA TransactionManager to use as direct reference + */ + public JtaTransactionManager(TransactionManager transactionManager) { + this(); + Assert.notNull(transactionManager, "TransactionManager must not be null"); + this.transactionManager = transactionManager; + this.userTransaction = buildUserTransaction(transactionManager); + } + + + /** + * Set the JndiTemplate to use for JNDI lookups. + * A default one is used if not set. + */ + public void setJndiTemplate(JndiTemplate jndiTemplate) { + Assert.notNull(jndiTemplate, "JndiTemplate must not be null"); + this.jndiTemplate = jndiTemplate; + } + + /** + * Return the JndiTemplate used for JNDI lookups. + */ + public JndiTemplate getJndiTemplate() { + return this.jndiTemplate; + } + + /** + * Set the JNDI environment to use for JNDI lookups. + * Creates a JndiTemplate with the given environment settings. + * @see #setJndiTemplate + */ + public void setJndiEnvironment(@Nullable Properties jndiEnvironment) { + this.jndiTemplate = new JndiTemplate(jndiEnvironment); + } + + /** + * Return the JNDI environment to use for JNDI lookups. + */ + @Nullable + public Properties getJndiEnvironment() { + return this.jndiTemplate.getEnvironment(); + } + + + /** + * Set the JTA UserTransaction to use as direct reference. + *

Typically just used for local JTA setups; in a Java EE environment, + * the UserTransaction will always be fetched from JNDI. + * @see #setUserTransactionName + * @see #setAutodetectUserTransaction + */ + public void setUserTransaction(@Nullable UserTransaction userTransaction) { + this.userTransaction = userTransaction; + } + + /** + * Return the JTA UserTransaction that this transaction manager uses. + */ + @Nullable + public UserTransaction getUserTransaction() { + return this.userTransaction; + } + + /** + * Set the JNDI name of the JTA UserTransaction. + *

Note that the UserTransaction will be autodetected at the Java EE + * default location "java:comp/UserTransaction" if not specified explicitly. + * @see #DEFAULT_USER_TRANSACTION_NAME + * @see #setUserTransaction + * @see #setAutodetectUserTransaction + */ + public void setUserTransactionName(String userTransactionName) { + this.userTransactionName = userTransactionName; + } + + /** + * Set whether to autodetect the JTA UserTransaction at its default + * JNDI location "java:comp/UserTransaction", as specified by Java EE. + * Will proceed without UserTransaction if none found. + *

Default is "true", autodetecting the UserTransaction unless + * it has been specified explicitly. Turn this flag off to allow for + * JtaTransactionManager operating against the TransactionManager only, + * despite a default UserTransaction being available. + * @see #DEFAULT_USER_TRANSACTION_NAME + */ + public void setAutodetectUserTransaction(boolean autodetectUserTransaction) { + this.autodetectUserTransaction = autodetectUserTransaction; + } + + /** + * Set whether to cache the JTA UserTransaction object fetched from JNDI. + *

Default is "true": UserTransaction lookup will only happen at startup, + * reusing the same UserTransaction handle for all transactions of all threads. + * This is the most efficient choice for all application servers that provide + * a shared UserTransaction object (the typical case). + *

Turn this flag off to enforce a fresh lookup of the UserTransaction + * for every transaction. This is only necessary for application servers + * that return a new UserTransaction for every transaction, keeping state + * tied to the UserTransaction object itself rather than the current thread. + * @see #setUserTransactionName + */ + public void setCacheUserTransaction(boolean cacheUserTransaction) { + this.cacheUserTransaction = cacheUserTransaction; + } + + /** + * Set the JTA TransactionManager to use as direct reference. + *

A TransactionManager is necessary for suspending and resuming transactions, + * as this not supported by the UserTransaction interface. + *

Note that the TransactionManager will be autodetected if the JTA + * UserTransaction object implements the JTA TransactionManager interface too, + * as well as autodetected at various well-known fallback JNDI locations. + * @see #setTransactionManagerName + * @see #setAutodetectTransactionManager + */ + public void setTransactionManager(@Nullable TransactionManager transactionManager) { + this.transactionManager = transactionManager; + } + + /** + * Return the JTA TransactionManager that this transaction manager uses, if any. + */ + @Nullable + public TransactionManager getTransactionManager() { + return this.transactionManager; + } + + /** + * Set the JNDI name of the JTA TransactionManager. + *

A TransactionManager is necessary for suspending and resuming transactions, + * as this not supported by the UserTransaction interface. + *

Note that the TransactionManager will be autodetected if the JTA + * UserTransaction object implements the JTA TransactionManager interface too, + * as well as autodetected at various well-known fallback JNDI locations. + * @see #setTransactionManager + * @see #setAutodetectTransactionManager + */ + public void setTransactionManagerName(String transactionManagerName) { + this.transactionManagerName = transactionManagerName; + } + + /** + * Set whether to autodetect a JTA UserTransaction object that implements + * the JTA TransactionManager interface too (i.e. the JNDI location for the + * TransactionManager is "java:comp/UserTransaction", same as for the UserTransaction). + * Also checks the fallback JNDI locations "java:comp/TransactionManager" and + * "java:/TransactionManager". Will proceed without TransactionManager if none found. + *

Default is "true", autodetecting the TransactionManager unless it has been + * specified explicitly. Can be turned off to deliberately ignore an available + * TransactionManager, for example when there are known issues with suspend/resume + * and any attempt to use REQUIRES_NEW or NOT_SUPPORTED should fail fast. + * @see #FALLBACK_TRANSACTION_MANAGER_NAMES + */ + public void setAutodetectTransactionManager(boolean autodetectTransactionManager) { + this.autodetectTransactionManager = autodetectTransactionManager; + } + + /** + * Set the JTA 1.1 TransactionSynchronizationRegistry to use as direct reference. + *

A TransactionSynchronizationRegistry allows for interposed registration + * of transaction synchronizations, as an alternative to the regular registration + * methods on the JTA TransactionManager API. Also, it is an official part of the + * Java EE 5 platform, in contrast to the JTA TransactionManager itself. + *

Note that the TransactionSynchronizationRegistry will be autodetected in JNDI and + * also from the UserTransaction/TransactionManager object if implemented there as well. + * @see #setTransactionSynchronizationRegistryName + * @see #setAutodetectTransactionSynchronizationRegistry + */ + public void setTransactionSynchronizationRegistry(@Nullable TransactionSynchronizationRegistry transactionSynchronizationRegistry) { + this.transactionSynchronizationRegistry = transactionSynchronizationRegistry; + } + + /** + * Return the JTA 1.1 TransactionSynchronizationRegistry that this transaction manager uses, if any. + */ + @Nullable + public TransactionSynchronizationRegistry getTransactionSynchronizationRegistry() { + return this.transactionSynchronizationRegistry; + } + + /** + * Set the JNDI name of the JTA 1.1 TransactionSynchronizationRegistry. + *

Note that the TransactionSynchronizationRegistry will be autodetected + * at the Java EE 5 default location "java:comp/TransactionSynchronizationRegistry" + * if not specified explicitly. + * @see #DEFAULT_TRANSACTION_SYNCHRONIZATION_REGISTRY_NAME + */ + public void setTransactionSynchronizationRegistryName(String transactionSynchronizationRegistryName) { + this.transactionSynchronizationRegistryName = transactionSynchronizationRegistryName; + } + + /** + * Set whether to autodetect a JTA 1.1 TransactionSynchronizationRegistry object + * at its default JDNI location ("java:comp/TransactionSynchronizationRegistry") + * if the UserTransaction has also been obtained from JNDI, and also whether + * to fall back to checking whether the JTA UserTransaction/TransactionManager + * object implements the JTA TransactionSynchronizationRegistry interface too. + *

Default is "true", autodetecting the TransactionSynchronizationRegistry + * unless it has been specified explicitly. Can be turned off to delegate + * synchronization registration to the regular JTA TransactionManager API. + */ + public void setAutodetectTransactionSynchronizationRegistry(boolean autodetectTransactionSynchronizationRegistry) { + this.autodetectTransactionSynchronizationRegistry = autodetectTransactionSynchronizationRegistry; + } + + /** + * Set whether to allow custom isolation levels to be specified. + *

Default is "false", throwing an exception if a non-default isolation level + * is specified for a transaction. Turn this flag on if affected resource adapters + * check the thread-bound transaction context and apply the specified isolation + * levels individually (e.g. through an IsolationLevelDataSourceAdapter). + * @see org.springframework.jdbc.datasource.IsolationLevelDataSourceAdapter + * @see org.springframework.jdbc.datasource.lookup.IsolationLevelDataSourceRouter + */ + public void setAllowCustomIsolationLevels(boolean allowCustomIsolationLevels) { + this.allowCustomIsolationLevels = allowCustomIsolationLevels; + } + + + /** + * Initialize the UserTransaction as well as the TransactionManager handle. + * @see #initUserTransactionAndTransactionManager() + */ + @Override + public void afterPropertiesSet() throws TransactionSystemException { + initUserTransactionAndTransactionManager(); + checkUserTransactionAndTransactionManager(); + initTransactionSynchronizationRegistry(); + } + + /** + * Initialize the UserTransaction as well as the TransactionManager handle. + * @throws TransactionSystemException if initialization failed + */ + protected void initUserTransactionAndTransactionManager() throws TransactionSystemException { + if (this.userTransaction == null) { + // Fetch JTA UserTransaction from JNDI, if necessary. + if (StringUtils.hasLength(this.userTransactionName)) { + this.userTransaction = lookupUserTransaction(this.userTransactionName); + this.userTransactionObtainedFromJndi = true; + } + else { + this.userTransaction = retrieveUserTransaction(); + if (this.userTransaction == null && this.autodetectUserTransaction) { + // Autodetect UserTransaction at its default JNDI location. + this.userTransaction = findUserTransaction(); + } + } + } + + if (this.transactionManager == null) { + // Fetch JTA TransactionManager from JNDI, if necessary. + if (StringUtils.hasLength(this.transactionManagerName)) { + this.transactionManager = lookupTransactionManager(this.transactionManagerName); + } + else { + this.transactionManager = retrieveTransactionManager(); + if (this.transactionManager == null && this.autodetectTransactionManager) { + // Autodetect UserTransaction object that implements TransactionManager, + // and check fallback JNDI locations otherwise. + this.transactionManager = findTransactionManager(this.userTransaction); + } + } + } + + // If only JTA TransactionManager specified, create UserTransaction handle for it. + if (this.userTransaction == null && this.transactionManager != null) { + this.userTransaction = buildUserTransaction(this.transactionManager); + } + } + + /** + * Check the UserTransaction as well as the TransactionManager handle, + * assuming standard JTA requirements. + * @throws IllegalStateException if no sufficient handles are available + */ + protected void checkUserTransactionAndTransactionManager() throws IllegalStateException { + // We at least need the JTA UserTransaction. + if (this.userTransaction != null) { + if (logger.isDebugEnabled()) { + logger.debug("Using JTA UserTransaction: " + this.userTransaction); + } + } + else { + throw new IllegalStateException("No JTA UserTransaction available - specify either " + + "'userTransaction' or 'userTransactionName' or 'transactionManager' or 'transactionManagerName'"); + } + + // For transaction suspension, the JTA TransactionManager is necessary too. + if (this.transactionManager != null) { + if (logger.isDebugEnabled()) { + logger.debug("Using JTA TransactionManager: " + this.transactionManager); + } + } + else { + logger.warn("No JTA TransactionManager found: transaction suspension not available"); + } + } + + /** + * Initialize the JTA 1.1 TransactionSynchronizationRegistry, if available. + *

To be called after {@link #initUserTransactionAndTransactionManager()}, + * since it may check the UserTransaction and TransactionManager handles. + * @throws TransactionSystemException if initialization failed + */ + protected void initTransactionSynchronizationRegistry() { + if (this.transactionSynchronizationRegistry == null) { + // Fetch JTA TransactionSynchronizationRegistry from JNDI, if necessary. + if (StringUtils.hasLength(this.transactionSynchronizationRegistryName)) { + this.transactionSynchronizationRegistry = + lookupTransactionSynchronizationRegistry(this.transactionSynchronizationRegistryName); + } + else { + this.transactionSynchronizationRegistry = retrieveTransactionSynchronizationRegistry(); + if (this.transactionSynchronizationRegistry == null && this.autodetectTransactionSynchronizationRegistry) { + // Autodetect in JNDI if applicable, and check UserTransaction/TransactionManager + // object that implements TransactionSynchronizationRegistry otherwise. + this.transactionSynchronizationRegistry = + findTransactionSynchronizationRegistry(this.userTransaction, this.transactionManager); + } + } + } + + if (this.transactionSynchronizationRegistry != null) { + if (logger.isDebugEnabled()) { + logger.debug("Using JTA TransactionSynchronizationRegistry: " + this.transactionSynchronizationRegistry); + } + } + } + + + /** + * Build a UserTransaction handle based on the given TransactionManager. + * @param transactionManager the TransactionManager + * @return a corresponding UserTransaction handle + */ + protected UserTransaction buildUserTransaction(TransactionManager transactionManager) { + if (transactionManager instanceof UserTransaction) { + return (UserTransaction) transactionManager; + } + else { + return new UserTransactionAdapter(transactionManager); + } + } + + /** + * Look up the JTA UserTransaction in JNDI via the configured name. + *

Called by {@code afterPropertiesSet} if no direct UserTransaction reference was set. + * Can be overridden in subclasses to provide a different UserTransaction object. + * @param userTransactionName the JNDI name of the UserTransaction + * @return the UserTransaction object + * @throws TransactionSystemException if the JNDI lookup failed + * @see #setJndiTemplate + * @see #setUserTransactionName + */ + protected UserTransaction lookupUserTransaction(String userTransactionName) + throws TransactionSystemException { + try { + if (logger.isDebugEnabled()) { + logger.debug("Retrieving JTA UserTransaction from JNDI location [" + userTransactionName + "]"); + } + return getJndiTemplate().lookup(userTransactionName, UserTransaction.class); + } + catch (NamingException ex) { + throw new TransactionSystemException( + "JTA UserTransaction is not available at JNDI location [" + userTransactionName + "]", ex); + } + } + + /** + * Look up the JTA TransactionManager in JNDI via the configured name. + *

Called by {@code afterPropertiesSet} if no direct TransactionManager reference was set. + * Can be overridden in subclasses to provide a different TransactionManager object. + * @param transactionManagerName the JNDI name of the TransactionManager + * @return the UserTransaction object + * @throws TransactionSystemException if the JNDI lookup failed + * @see #setJndiTemplate + * @see #setTransactionManagerName + */ + protected TransactionManager lookupTransactionManager(String transactionManagerName) + throws TransactionSystemException { + try { + if (logger.isDebugEnabled()) { + logger.debug("Retrieving JTA TransactionManager from JNDI location [" + transactionManagerName + "]"); + } + return getJndiTemplate().lookup(transactionManagerName, TransactionManager.class); + } + catch (NamingException ex) { + throw new TransactionSystemException( + "JTA TransactionManager is not available at JNDI location [" + transactionManagerName + "]", ex); + } + } + + /** + * Look up the JTA 1.1 TransactionSynchronizationRegistry in JNDI via the configured name. + *

Can be overridden in subclasses to provide a different TransactionManager object. + * @param registryName the JNDI name of the + * TransactionSynchronizationRegistry + * @return the TransactionSynchronizationRegistry object + * @throws TransactionSystemException if the JNDI lookup failed + * @see #setJndiTemplate + * @see #setTransactionSynchronizationRegistryName + */ + protected TransactionSynchronizationRegistry lookupTransactionSynchronizationRegistry(String registryName) throws TransactionSystemException { + try { + if (logger.isDebugEnabled()) { + logger.debug("Retrieving JTA TransactionSynchronizationRegistry from JNDI location [" + registryName + "]"); + } + return getJndiTemplate().lookup(registryName, TransactionSynchronizationRegistry.class); + } + catch (NamingException ex) { + throw new TransactionSystemException( + "JTA TransactionSynchronizationRegistry is not available at JNDI location [" + registryName + "]", ex); + } + } + + /** + * Allows subclasses to retrieve the JTA UserTransaction in a vendor-specific manner. + * Only called if no "userTransaction" or "userTransactionName" specified. + *

The default implementation simply returns {@code null}. + * @return the JTA UserTransaction handle to use, or {@code null} if none found + * @throws TransactionSystemException in case of errors + * @see #setUserTransaction + * @see #setUserTransactionName + */ + @Nullable + protected UserTransaction retrieveUserTransaction() throws TransactionSystemException { + return null; + } + + /** + * Allows subclasses to retrieve the JTA TransactionManager in a vendor-specific manner. + * Only called if no "transactionManager" or "transactionManagerName" specified. + *

The default implementation simply returns {@code null}. + * @return the JTA TransactionManager handle to use, or {@code null} if none found + * @throws TransactionSystemException in case of errors + * @see #setTransactionManager + * @see #setTransactionManagerName + */ + @Nullable + protected TransactionManager retrieveTransactionManager() throws TransactionSystemException { + return null; + } + + /** + * Allows subclasses to retrieve the JTA 1.1 TransactionSynchronizationRegistry + * in a vendor-specific manner. + *

The default implementation simply returns {@code null}. + * @return the JTA TransactionSynchronizationRegistry handle to use, + * or {@code null} if none found + * @throws TransactionSystemException in case of errors + */ + @Nullable + protected TransactionSynchronizationRegistry retrieveTransactionSynchronizationRegistry() throws TransactionSystemException { + return null; + } + + /** + * Find the JTA UserTransaction through a default JNDI lookup: + * "java:comp/UserTransaction". + * @return the JTA UserTransaction reference, or {@code null} if not found + * @see #DEFAULT_USER_TRANSACTION_NAME + */ + @Nullable + protected UserTransaction findUserTransaction() { + String jndiName = DEFAULT_USER_TRANSACTION_NAME; + try { + UserTransaction ut = getJndiTemplate().lookup(jndiName, UserTransaction.class); + if (logger.isDebugEnabled()) { + logger.debug("JTA UserTransaction found at default JNDI location [" + jndiName + "]"); + } + this.userTransactionObtainedFromJndi = true; + return ut; + } + catch (NamingException ex) { + if (logger.isDebugEnabled()) { + logger.debug("No JTA UserTransaction found at default JNDI location [" + jndiName + "]", ex); + } + return null; + } + } + + /** + * Find the JTA TransactionManager through autodetection: checking whether the + * UserTransaction object implements the TransactionManager, and checking the + * fallback JNDI locations. + * @param ut the JTA UserTransaction object + * @return the JTA TransactionManager reference, or {@code null} if not found + * @see #FALLBACK_TRANSACTION_MANAGER_NAMES + */ + @Nullable + protected TransactionManager findTransactionManager(@Nullable UserTransaction ut) { + if (ut instanceof TransactionManager) { + if (logger.isDebugEnabled()) { + logger.debug("JTA UserTransaction object [" + ut + "] implements TransactionManager"); + } + return (TransactionManager) ut; + } + + // Check fallback JNDI locations. + for (String jndiName : FALLBACK_TRANSACTION_MANAGER_NAMES) { + try { + TransactionManager tm = getJndiTemplate().lookup(jndiName, TransactionManager.class); + if (logger.isDebugEnabled()) { + logger.debug("JTA TransactionManager found at fallback JNDI location [" + jndiName + "]"); + } + return tm; + } + catch (NamingException ex) { + if (logger.isDebugEnabled()) { + logger.debug("No JTA TransactionManager found at fallback JNDI location [" + jndiName + "]", ex); + } + } + } + + // OK, so no JTA TransactionManager is available... + return null; + } + + /** + * Find the JTA 1.1 TransactionSynchronizationRegistry through autodetection: + * checking whether the UserTransaction object or TransactionManager object + * implements it, and checking Java EE 5's standard JNDI location. + *

The default implementation simply returns {@code null}. + * @param ut the JTA UserTransaction object + * @param tm the JTA TransactionManager object + * @return the JTA TransactionSynchronizationRegistry handle to use, + * or {@code null} if none found + * @throws TransactionSystemException in case of errors + */ + @Nullable + protected TransactionSynchronizationRegistry findTransactionSynchronizationRegistry( + @Nullable UserTransaction ut, @Nullable TransactionManager tm) throws TransactionSystemException { + + if (this.userTransactionObtainedFromJndi) { + // UserTransaction has already been obtained from JNDI, so the + // TransactionSynchronizationRegistry probably sits there as well. + String jndiName = DEFAULT_TRANSACTION_SYNCHRONIZATION_REGISTRY_NAME; + try { + TransactionSynchronizationRegistry tsr = getJndiTemplate().lookup(jndiName, TransactionSynchronizationRegistry.class); + if (logger.isDebugEnabled()) { + logger.debug("JTA TransactionSynchronizationRegistry found at default JNDI location [" + jndiName + "]"); + } + return tsr; + } + catch (NamingException ex) { + if (logger.isDebugEnabled()) { + logger.debug("No JTA TransactionSynchronizationRegistry found at default JNDI location [" + jndiName + "]", ex); + } + } + } + // Check whether the UserTransaction or TransactionManager implements it... + if (ut instanceof TransactionSynchronizationRegistry) { + return (TransactionSynchronizationRegistry) ut; + } + if (tm instanceof TransactionSynchronizationRegistry) { + return (TransactionSynchronizationRegistry) tm; + } + // OK, so no JTA 1.1 TransactionSynchronizationRegistry is available... + return null; + } + + + /** + * This implementation returns a JtaTransactionObject instance for the + * JTA UserTransaction. + *

The UserTransaction object will either be looked up freshly for the + * current transaction, or the cached one looked up at startup will be used. + * The latter is the default: Most application servers use a shared singleton + * UserTransaction that can be cached. Turn off the "cacheUserTransaction" + * flag to enforce a fresh lookup for every transaction. + * @see #setCacheUserTransaction + */ + @Override + protected Object doGetTransaction() { + UserTransaction ut = getUserTransaction(); + if (ut == null) { + throw new CannotCreateTransactionException("No JTA UserTransaction available - " + + "programmatic PlatformTransactionManager.getTransaction usage not supported"); + } + if (!this.cacheUserTransaction) { + ut = lookupUserTransaction( + this.userTransactionName != null ? this.userTransactionName : DEFAULT_USER_TRANSACTION_NAME); + } + return doGetJtaTransaction(ut); + } + + /** + * Get a JTA transaction object for the given current UserTransaction. + *

Subclasses can override this to provide a JtaTransactionObject + * subclass, for example holding some additional JTA handle needed. + * @param ut the UserTransaction handle to use for the current transaction + * @return the JtaTransactionObject holding the UserTransaction + */ + protected JtaTransactionObject doGetJtaTransaction(UserTransaction ut) { + return new JtaTransactionObject(ut); + } + + @Override + protected boolean isExistingTransaction(Object transaction) { + JtaTransactionObject txObject = (JtaTransactionObject) transaction; + try { + return (txObject.getUserTransaction().getStatus() != Status.STATUS_NO_TRANSACTION); + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on getStatus", ex); + } + } + + /** + * This implementation returns false to cause a further invocation + * of doBegin despite an already existing transaction. + *

JTA implementations might support nested transactions via further + * {@code UserTransaction.begin()} invocations, but never support savepoints. + * @see #doBegin + * @see javax.transaction.UserTransaction#begin() + */ + @Override + protected boolean useSavepointForNestedTransaction() { + return false; + } + + + @Override + protected void doBegin(Object transaction, TransactionDefinition definition) { + JtaTransactionObject txObject = (JtaTransactionObject) transaction; + try { + doJtaBegin(txObject, definition); + } + catch (NotSupportedException | UnsupportedOperationException ex) { + throw new NestedTransactionNotSupportedException( + "JTA implementation does not support nested transactions", ex); + } + catch (SystemException ex) { + throw new CannotCreateTransactionException("JTA failure on begin", ex); + } + } + + /** + * Perform a JTA begin on the JTA UserTransaction or TransactionManager. + *

This implementation only supports standard JTA functionality: + * that is, no per-transaction isolation levels and no transaction names. + * Can be overridden in subclasses, for specific JTA implementations. + *

Calls {@code applyIsolationLevel} and {@code applyTimeout} + * before invoking the UserTransaction's {@code begin} method. + * @param txObject the JtaTransactionObject containing the UserTransaction + * @param definition the TransactionDefinition instance, describing propagation + * behavior, isolation level, read-only flag, timeout, and transaction name + * @throws NotSupportedException if thrown by JTA methods + * @throws SystemException if thrown by JTA methods + * @see #getUserTransaction + * @see #getTransactionManager + * @see #applyIsolationLevel + * @see #applyTimeout + * @see JtaTransactionObject#getUserTransaction() + * @see javax.transaction.UserTransaction#setTransactionTimeout + * @see javax.transaction.UserTransaction#begin + */ + protected void doJtaBegin(JtaTransactionObject txObject, TransactionDefinition definition) + throws NotSupportedException, SystemException { + + applyIsolationLevel(txObject, definition.getIsolationLevel()); + int timeout = determineTimeout(definition); + applyTimeout(txObject, timeout); + txObject.getUserTransaction().begin(); + } + + /** + * Apply the given transaction isolation level. The default implementation + * will throw an exception for any level other than ISOLATION_DEFAULT. + *

To be overridden in subclasses for specific JTA implementations, + * as alternative to overriding the full {@link #doJtaBegin} method. + * @param txObject the JtaTransactionObject containing the UserTransaction + * @param isolationLevel isolation level taken from transaction definition + * @throws InvalidIsolationLevelException if the given isolation level + * cannot be applied + * @throws SystemException if thrown by the JTA implementation + * @see #doJtaBegin + * @see JtaTransactionObject#getUserTransaction() + * @see #getTransactionManager() + */ + protected void applyIsolationLevel(JtaTransactionObject txObject, int isolationLevel) + throws InvalidIsolationLevelException, SystemException { + + if (!this.allowCustomIsolationLevels && isolationLevel != TransactionDefinition.ISOLATION_DEFAULT) { + throw new InvalidIsolationLevelException( + "JtaTransactionManager does not support custom isolation levels by default - " + + "switch 'allowCustomIsolationLevels' to 'true'"); + } + } + + /** + * Apply the given transaction timeout. The default implementation will call + * {@code UserTransaction.setTransactionTimeout} for a non-default timeout value. + * @param txObject the JtaTransactionObject containing the UserTransaction + * @param timeout timeout value taken from transaction definition + * @throws SystemException if thrown by the JTA implementation + * @see #doJtaBegin + * @see JtaTransactionObject#getUserTransaction() + * @see javax.transaction.UserTransaction#setTransactionTimeout(int) + */ + protected void applyTimeout(JtaTransactionObject txObject, int timeout) throws SystemException { + if (timeout > TransactionDefinition.TIMEOUT_DEFAULT) { + txObject.getUserTransaction().setTransactionTimeout(timeout); + if (timeout > 0) { + txObject.resetTransactionTimeout = true; + } + } + } + + + @Override + protected Object doSuspend(Object transaction) { + JtaTransactionObject txObject = (JtaTransactionObject) transaction; + try { + return doJtaSuspend(txObject); + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on suspend", ex); + } + } + + /** + * Perform a JTA suspend on the JTA TransactionManager. + *

Can be overridden in subclasses, for specific JTA implementations. + * @param txObject the JtaTransactionObject containing the UserTransaction + * @return the suspended JTA Transaction object + * @throws SystemException if thrown by JTA methods + * @see #getTransactionManager() + * @see javax.transaction.TransactionManager#suspend() + */ + protected Object doJtaSuspend(JtaTransactionObject txObject) throws SystemException { + if (getTransactionManager() == null) { + throw new TransactionSuspensionNotSupportedException( + "JtaTransactionManager needs a JTA TransactionManager for suspending a transaction: " + + "specify the 'transactionManager' or 'transactionManagerName' property"); + } + return getTransactionManager().suspend(); + } + + @Override + protected void doResume(@Nullable Object transaction, Object suspendedResources) { + JtaTransactionObject txObject = (JtaTransactionObject) transaction; + try { + doJtaResume(txObject, suspendedResources); + } + catch (InvalidTransactionException ex) { + throw new IllegalTransactionStateException("Tried to resume invalid JTA transaction", ex); + } + catch (IllegalStateException ex) { + throw new TransactionSystemException("Unexpected internal transaction state", ex); + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on resume", ex); + } + } + + /** + * Perform a JTA resume on the JTA TransactionManager. + *

Can be overridden in subclasses, for specific JTA implementations. + * @param txObject the JtaTransactionObject containing the UserTransaction + * @param suspendedTransaction the suspended JTA Transaction object + * @throws InvalidTransactionException if thrown by JTA methods + * @throws SystemException if thrown by JTA methods + * @see #getTransactionManager() + * @see javax.transaction.TransactionManager#resume(javax.transaction.Transaction) + */ + protected void doJtaResume(@Nullable JtaTransactionObject txObject, Object suspendedTransaction) + throws InvalidTransactionException, SystemException { + + if (getTransactionManager() == null) { + throw new TransactionSuspensionNotSupportedException( + "JtaTransactionManager needs a JTA TransactionManager for suspending a transaction: " + + "specify the 'transactionManager' or 'transactionManagerName' property"); + } + getTransactionManager().resume((Transaction) suspendedTransaction); + } + + + /** + * This implementation returns "true": a JTA commit will properly handle + * transactions that have been marked rollback-only at a global level. + */ + @Override + protected boolean shouldCommitOnGlobalRollbackOnly() { + return true; + } + + @Override + protected void doCommit(DefaultTransactionStatus status) { + JtaTransactionObject txObject = (JtaTransactionObject) status.getTransaction(); + try { + int jtaStatus = txObject.getUserTransaction().getStatus(); + if (jtaStatus == Status.STATUS_NO_TRANSACTION) { + // Should never happen... would have thrown an exception before + // and as a consequence led to a rollback, not to a commit call. + // In any case, the transaction is already fully cleaned up. + throw new UnexpectedRollbackException("JTA transaction already completed - probably rolled back"); + } + if (jtaStatus == Status.STATUS_ROLLEDBACK) { + // Only really happens on JBoss 4.2 in case of an early timeout... + // Explicit rollback call necessary to clean up the transaction. + // IllegalStateException expected on JBoss; call still necessary. + try { + txObject.getUserTransaction().rollback(); + } + catch (IllegalStateException ex) { + if (logger.isDebugEnabled()) { + logger.debug("Rollback failure with transaction already marked as rolled back: " + ex); + } + } + throw new UnexpectedRollbackException("JTA transaction already rolled back (probably due to a timeout)"); + } + txObject.getUserTransaction().commit(); + } + catch (RollbackException ex) { + throw new UnexpectedRollbackException( + "JTA transaction unexpectedly rolled back (maybe due to a timeout)", ex); + } + catch (HeuristicMixedException ex) { + throw new HeuristicCompletionException(HeuristicCompletionException.STATE_MIXED, ex); + } + catch (HeuristicRollbackException ex) { + throw new HeuristicCompletionException(HeuristicCompletionException.STATE_ROLLED_BACK, ex); + } + catch (IllegalStateException ex) { + throw new TransactionSystemException("Unexpected internal transaction state", ex); + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on commit", ex); + } + } + + @Override + protected void doRollback(DefaultTransactionStatus status) { + JtaTransactionObject txObject = (JtaTransactionObject) status.getTransaction(); + try { + int jtaStatus = txObject.getUserTransaction().getStatus(); + if (jtaStatus != Status.STATUS_NO_TRANSACTION) { + try { + txObject.getUserTransaction().rollback(); + } + catch (IllegalStateException ex) { + if (jtaStatus == Status.STATUS_ROLLEDBACK) { + // Only really happens on JBoss 4.2 in case of an early timeout... + if (logger.isDebugEnabled()) { + logger.debug("Rollback failure with transaction already marked as rolled back: " + ex); + } + } + else { + throw new TransactionSystemException("Unexpected internal transaction state", ex); + } + } + } + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on rollback", ex); + } + } + + @Override + protected void doSetRollbackOnly(DefaultTransactionStatus status) { + JtaTransactionObject txObject = (JtaTransactionObject) status.getTransaction(); + if (status.isDebug()) { + logger.debug("Setting JTA transaction rollback-only"); + } + try { + int jtaStatus = txObject.getUserTransaction().getStatus(); + if (jtaStatus != Status.STATUS_NO_TRANSACTION && jtaStatus != Status.STATUS_ROLLEDBACK) { + txObject.getUserTransaction().setRollbackOnly(); + } + } + catch (IllegalStateException ex) { + throw new TransactionSystemException("Unexpected internal transaction state", ex); + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on setRollbackOnly", ex); + } + } + + + @Override + protected void registerAfterCompletionWithExistingTransaction( + Object transaction, List synchronizations) { + + JtaTransactionObject txObject = (JtaTransactionObject) transaction; + logger.debug("Registering after-completion synchronization with existing JTA transaction"); + try { + doRegisterAfterCompletionWithJtaTransaction(txObject, synchronizations); + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on registerSynchronization", ex); + } + catch (Exception ex) { + // Note: JBoss throws plain RuntimeException with RollbackException as cause. + if (ex instanceof RollbackException || ex.getCause() instanceof RollbackException) { + logger.debug("Participating in existing JTA transaction that has been marked for rollback: " + + "cannot register Spring after-completion callbacks with outer JTA transaction - " + + "immediately performing Spring after-completion callbacks with outcome status 'rollback'. " + + "Original exception: " + ex); + invokeAfterCompletion(synchronizations, TransactionSynchronization.STATUS_ROLLED_BACK); + } + else { + logger.debug("Participating in existing JTA transaction, but unexpected internal transaction " + + "state encountered: cannot register Spring after-completion callbacks with outer JTA " + + "transaction - processing Spring after-completion callbacks with outcome status 'unknown'" + + "Original exception: " + ex); + invokeAfterCompletion(synchronizations, TransactionSynchronization.STATUS_UNKNOWN); + } + } + } + + /** + * Register a JTA synchronization on the JTA TransactionManager, for calling + * {@code afterCompletion} on the given Spring TransactionSynchronizations. + *

The default implementation registers the synchronizations on the + * JTA 1.1 TransactionSynchronizationRegistry, if available, or on the + * JTA TransactionManager's current Transaction - again, if available. + * If none of the two is available, a warning will be logged. + *

Can be overridden in subclasses, for specific JTA implementations. + * @param txObject the current transaction object + * @param synchronizations a List of TransactionSynchronization objects + * @throws RollbackException if thrown by JTA methods + * @throws SystemException if thrown by JTA methods + * @see #getTransactionManager() + * @see javax.transaction.Transaction#registerSynchronization + * @see javax.transaction.TransactionSynchronizationRegistry#registerInterposedSynchronization + */ + protected void doRegisterAfterCompletionWithJtaTransaction( + JtaTransactionObject txObject, List synchronizations) + throws RollbackException, SystemException { + + int jtaStatus = txObject.getUserTransaction().getStatus(); + if (jtaStatus == Status.STATUS_NO_TRANSACTION) { + throw new RollbackException("JTA transaction already completed - probably rolled back"); + } + if (jtaStatus == Status.STATUS_ROLLEDBACK) { + throw new RollbackException("JTA transaction already rolled back (probably due to a timeout)"); + } + + if (this.transactionSynchronizationRegistry != null) { + // JTA 1.1 TransactionSynchronizationRegistry available - use it. + this.transactionSynchronizationRegistry.registerInterposedSynchronization( + new JtaAfterCompletionSynchronization(synchronizations)); + } + + else if (getTransactionManager() != null) { + // At least the JTA TransactionManager available - use that one. + Transaction transaction = getTransactionManager().getTransaction(); + if (transaction == null) { + throw new IllegalStateException("No JTA Transaction available"); + } + transaction.registerSynchronization(new JtaAfterCompletionSynchronization(synchronizations)); + } + + else { + // No JTA TransactionManager available - log a warning. + logger.warn("Participating in existing JTA transaction, but no JTA TransactionManager available: " + + "cannot register Spring after-completion callbacks with outer JTA transaction - " + + "processing Spring after-completion callbacks with outcome status 'unknown'"); + invokeAfterCompletion(synchronizations, TransactionSynchronization.STATUS_UNKNOWN); + } + } + + @Override + protected void doCleanupAfterCompletion(Object transaction) { + JtaTransactionObject txObject = (JtaTransactionObject) transaction; + if (txObject.resetTransactionTimeout) { + try { + txObject.getUserTransaction().setTransactionTimeout(0); + } + catch (SystemException ex) { + logger.debug("Failed to reset transaction timeout after JTA completion", ex); + } + } + } + + + //--------------------------------------------------------------------- + // Implementation of TransactionFactory interface + //--------------------------------------------------------------------- + + @Override + public Transaction createTransaction(@Nullable String name, int timeout) throws NotSupportedException, SystemException { + TransactionManager tm = getTransactionManager(); + Assert.state(tm != null, "No JTA TransactionManager available"); + if (timeout >= 0) { + tm.setTransactionTimeout(timeout); + } + tm.begin(); + return new ManagedTransactionAdapter(tm); + } + + @Override + public boolean supportsResourceAdapterManagedTransactions() { + return false; + } + + + //--------------------------------------------------------------------- + // Serialization support + //--------------------------------------------------------------------- + + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + // Rely on default serialization; just initialize state after deserialization. + ois.defaultReadObject(); + + // Create template for client-side JNDI lookup. + this.jndiTemplate = new JndiTemplate(); + + // Perform a fresh lookup for JTA handles. + initUserTransactionAndTransactionManager(); + initTransactionSynchronizationRegistry(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/JtaTransactionObject.java b/spring-tx/src/main/java/org/springframework/transaction/jta/JtaTransactionObject.java new file mode 100644 index 0000000000000000000000000000000000000000..4bbc8ff46ac032095c104fa468bcef049b6af65b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/JtaTransactionObject.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import javax.transaction.Status; +import javax.transaction.SystemException; +import javax.transaction.UserTransaction; + +import org.springframework.transaction.TransactionSystemException; +import org.springframework.transaction.support.SmartTransactionObject; +import org.springframework.transaction.support.TransactionSynchronizationUtils; + +/** + * JTA transaction object, representing a {@link javax.transaction.UserTransaction}. + * Used as transaction object by Spring's {@link JtaTransactionManager}. + * + *

Note: This is an SPI class, not intended to be used by applications. + * + * @author Juergen Hoeller + * @since 1.1 + * @see JtaTransactionManager + * @see javax.transaction.UserTransaction + */ +public class JtaTransactionObject implements SmartTransactionObject { + + private final UserTransaction userTransaction; + + boolean resetTransactionTimeout = false; + + + /** + * Create a new JtaTransactionObject for the given JTA UserTransaction. + * @param userTransaction the JTA UserTransaction for the current transaction + * (either a shared object or retrieved through a fresh per-transaction lookup) + */ + public JtaTransactionObject(UserTransaction userTransaction) { + this.userTransaction = userTransaction; + } + + /** + * Return the JTA UserTransaction object for the current transaction. + */ + public final UserTransaction getUserTransaction() { + return this.userTransaction; + } + + + /** + * This implementation checks the UserTransaction's rollback-only flag. + */ + @Override + public boolean isRollbackOnly() { + try { + int jtaStatus = this.userTransaction.getStatus(); + return (jtaStatus == Status.STATUS_MARKED_ROLLBACK || jtaStatus == Status.STATUS_ROLLEDBACK); + } + catch (SystemException ex) { + throw new TransactionSystemException("JTA failure on getStatus", ex); + } + } + + /** + * This implementation triggers flush callbacks, + * assuming that they will flush all affected ORM sessions. + * @see org.springframework.transaction.support.TransactionSynchronization#flush() + */ + @Override + public void flush() { + TransactionSynchronizationUtils.triggerFlush(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/ManagedTransactionAdapter.java b/spring-tx/src/main/java/org/springframework/transaction/jta/ManagedTransactionAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..47a804abc84e5c2114d592a87068c5c52166a482 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/ManagedTransactionAdapter.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import javax.transaction.HeuristicMixedException; +import javax.transaction.HeuristicRollbackException; +import javax.transaction.RollbackException; +import javax.transaction.Synchronization; +import javax.transaction.SystemException; +import javax.transaction.Transaction; +import javax.transaction.TransactionManager; +import javax.transaction.xa.XAResource; + +import org.springframework.util.Assert; + +/** + * Adapter for a managed JTA Transaction handle, taking a JTA + * {@link javax.transaction.TransactionManager} reference and creating + * a JTA {@link javax.transaction.Transaction} handle for it. + * + * @author Juergen Hoeller + * @since 3.0.2 + */ +public class ManagedTransactionAdapter implements Transaction { + + private final TransactionManager transactionManager; + + + /** + * Create a new ManagedTransactionAdapter for the given TransactionManager. + * @param transactionManager the JTA TransactionManager to wrap + */ + public ManagedTransactionAdapter(TransactionManager transactionManager) throws SystemException { + Assert.notNull(transactionManager, "TransactionManager must not be null"); + this.transactionManager = transactionManager; + } + + /** + * Return the JTA TransactionManager that this adapter delegates to. + */ + public final TransactionManager getTransactionManager() { + return this.transactionManager; + } + + + @Override + public void commit() throws RollbackException, HeuristicMixedException, HeuristicRollbackException, + SecurityException, SystemException { + this.transactionManager.commit(); + } + + @Override + public void rollback() throws SystemException { + this.transactionManager.rollback(); + } + + @Override + public void setRollbackOnly() throws SystemException { + this.transactionManager.setRollbackOnly(); + } + + @Override + public int getStatus() throws SystemException { + return this.transactionManager.getStatus(); + } + + @Override + public boolean enlistResource(XAResource xaRes) throws RollbackException, SystemException { + return this.transactionManager.getTransaction().enlistResource(xaRes); + } + + @Override + public boolean delistResource(XAResource xaRes, int flag) throws SystemException { + return this.transactionManager.getTransaction().delistResource(xaRes, flag); + } + + @Override + public void registerSynchronization(Synchronization sync) throws RollbackException, SystemException { + this.transactionManager.getTransaction().registerSynchronization(sync); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/SimpleTransactionFactory.java b/spring-tx/src/main/java/org/springframework/transaction/jta/SimpleTransactionFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..bd0828a9c370b8f4bcd74a2aaff73e95cb3ded89 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/SimpleTransactionFactory.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import javax.transaction.NotSupportedException; +import javax.transaction.SystemException; +import javax.transaction.Transaction; +import javax.transaction.TransactionManager; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default implementation of the {@link TransactionFactory} strategy interface, + * simply wrapping a standard JTA {@link javax.transaction.TransactionManager}. + * + *

Does not support transaction names; simply ignores any specified name. + * + * @author Juergen Hoeller + * @since 2.5 + * @see javax.transaction.TransactionManager#setTransactionTimeout(int) + * @see javax.transaction.TransactionManager#begin() + * @see javax.transaction.TransactionManager#getTransaction() + */ +public class SimpleTransactionFactory implements TransactionFactory { + + private final TransactionManager transactionManager; + + + /** + * Create a new SimpleTransactionFactory for the given TransactionManager. + * @param transactionManager the JTA TransactionManager to wrap + */ + public SimpleTransactionFactory(TransactionManager transactionManager) { + Assert.notNull(transactionManager, "TransactionManager must not be null"); + this.transactionManager = transactionManager; + } + + + @Override + public Transaction createTransaction(@Nullable String name, int timeout) throws NotSupportedException, SystemException { + if (timeout >= 0) { + this.transactionManager.setTransactionTimeout(timeout); + } + this.transactionManager.begin(); + return new ManagedTransactionAdapter(this.transactionManager); + } + + @Override + public boolean supportsResourceAdapterManagedTransactions() { + return false; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/SpringJtaSynchronizationAdapter.java b/spring-tx/src/main/java/org/springframework/transaction/jta/SpringJtaSynchronizationAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..d2f8282ef15aec692c37d82f37a57f3a60b7ef9f --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/SpringJtaSynchronizationAdapter.java @@ -0,0 +1,190 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import javax.transaction.Status; +import javax.transaction.Synchronization; +import javax.transaction.TransactionManager; +import javax.transaction.UserTransaction; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.util.Assert; + +/** + * Adapter that implements the JTA {@link javax.transaction.Synchronization} + * interface delegating to an underlying Spring + * {@link org.springframework.transaction.support.TransactionSynchronization}. + * + *

Useful for synchronizing Spring resource management code with plain + * JTA / EJB CMT transactions, despite the original code being built for + * Spring transaction synchronization. + * + * @author Juergen Hoeller + * @since 2.0 + * @see javax.transaction.Transaction#registerSynchronization + * @see org.springframework.transaction.support.TransactionSynchronization + */ +public class SpringJtaSynchronizationAdapter implements Synchronization { + + protected static final Log logger = LogFactory.getLog(SpringJtaSynchronizationAdapter.class); + + private final TransactionSynchronization springSynchronization; + + @Nullable + private UserTransaction jtaTransaction; + + private boolean beforeCompletionCalled = false; + + + /** + * Create a new SpringJtaSynchronizationAdapter for the given Spring + * TransactionSynchronization and JTA TransactionManager. + * @param springSynchronization the Spring TransactionSynchronization to delegate to + */ + public SpringJtaSynchronizationAdapter(TransactionSynchronization springSynchronization) { + Assert.notNull(springSynchronization, "TransactionSynchronization must not be null"); + this.springSynchronization = springSynchronization; + } + + /** + * Create a new SpringJtaSynchronizationAdapter for the given Spring + * TransactionSynchronization and JTA TransactionManager. + *

Note that this adapter will never perform a rollback-only call on WebLogic, + * since WebLogic Server is known to automatically mark the transaction as + * rollback-only in case of a {@code beforeCompletion} exception. Hence, + * on WLS, this constructor is equivalent to the single-arg constructor. + * @param springSynchronization the Spring TransactionSynchronization to delegate to + * @param jtaUserTransaction the JTA UserTransaction to use for rollback-only + * setting in case of an exception thrown in {@code beforeCompletion} + * (can be omitted if the JTA provider itself marks the transaction rollback-only + * in such a scenario, which is required by the JTA specification as of JTA 1.1). + */ + public SpringJtaSynchronizationAdapter(TransactionSynchronization springSynchronization, + @Nullable UserTransaction jtaUserTransaction) { + + this(springSynchronization); + if (jtaUserTransaction != null && !jtaUserTransaction.getClass().getName().startsWith("weblogic.")) { + this.jtaTransaction = jtaUserTransaction; + } + } + + /** + * Create a new SpringJtaSynchronizationAdapter for the given Spring + * TransactionSynchronization and JTA TransactionManager. + *

Note that this adapter will never perform a rollback-only call on WebLogic, + * since WebLogic Server is known to automatically mark the transaction as + * rollback-only in case of a {@code beforeCompletion} exception. Hence, + * on WLS, this constructor is equivalent to the single-arg constructor. + * @param springSynchronization the Spring TransactionSynchronization to delegate to + * @param jtaTransactionManager the JTA TransactionManager to use for rollback-only + * setting in case of an exception thrown in {@code beforeCompletion} + * (can be omitted if the JTA provider itself marks the transaction rollback-only + * in such a scenario, which is required by the JTA specification as of JTA 1.1) + */ + public SpringJtaSynchronizationAdapter( + TransactionSynchronization springSynchronization, @Nullable TransactionManager jtaTransactionManager) { + + this(springSynchronization); + if (jtaTransactionManager != null && !jtaTransactionManager.getClass().getName().startsWith("weblogic.")) { + this.jtaTransaction = new UserTransactionAdapter(jtaTransactionManager); + } + } + + + /** + * JTA {@code beforeCompletion} callback: just invoked before commit. + *

In case of an exception, the JTA transaction will be marked as rollback-only. + * @see org.springframework.transaction.support.TransactionSynchronization#beforeCommit + */ + @Override + public void beforeCompletion() { + try { + boolean readOnly = TransactionSynchronizationManager.isCurrentTransactionReadOnly(); + this.springSynchronization.beforeCommit(readOnly); + } + catch (RuntimeException | Error ex) { + setRollbackOnlyIfPossible(); + throw ex; + } + finally { + // Process Spring's beforeCompletion early, in order to avoid issues + // with strict JTA implementations that issue warnings when doing JDBC + // operations after transaction completion (e.g. Connection.getWarnings). + this.beforeCompletionCalled = true; + this.springSynchronization.beforeCompletion(); + } + } + + /** + * Set the underlying JTA transaction to rollback-only. + */ + private void setRollbackOnlyIfPossible() { + if (this.jtaTransaction != null) { + try { + this.jtaTransaction.setRollbackOnly(); + } + catch (UnsupportedOperationException ex) { + // Probably Hibernate's WebSphereExtendedJTATransactionLookup pseudo JTA stuff... + logger.debug("JTA transaction handle does not support setRollbackOnly method - " + + "relying on JTA provider to mark the transaction as rollback-only based on " + + "the exception thrown from beforeCompletion", ex); + } + catch (Throwable ex) { + logger.error("Could not set JTA transaction rollback-only", ex); + } + } + else { + logger.debug("No JTA transaction handle available and/or running on WebLogic - " + + "relying on JTA provider to mark the transaction as rollback-only based on " + + "the exception thrown from beforeCompletion"); + } + } + + /** + * JTA {@code afterCompletion} callback: invoked after commit/rollback. + *

Needs to invoke the Spring synchronization's {@code beforeCompletion} + * at this late stage in case of a rollback, since there is no corresponding + * callback with JTA. + * @see org.springframework.transaction.support.TransactionSynchronization#beforeCompletion + * @see org.springframework.transaction.support.TransactionSynchronization#afterCompletion + */ + @Override + public void afterCompletion(int status) { + if (!this.beforeCompletionCalled) { + // beforeCompletion not called before (probably because of JTA rollback). + // Perform the cleanup here. + this.springSynchronization.beforeCompletion(); + } + // Call afterCompletion with the appropriate status indication. + switch (status) { + case Status.STATUS_COMMITTED: + this.springSynchronization.afterCompletion(TransactionSynchronization.STATUS_COMMITTED); + break; + case Status.STATUS_ROLLEDBACK: + this.springSynchronization.afterCompletion(TransactionSynchronization.STATUS_ROLLED_BACK); + break; + default: + this.springSynchronization.afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/TransactionFactory.java b/spring-tx/src/main/java/org/springframework/transaction/jta/TransactionFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..9317e9bb8ebbec9b21a0d2417fb4f36dcca52d3b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/TransactionFactory.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import javax.transaction.NotSupportedException; +import javax.transaction.SystemException; +import javax.transaction.Transaction; + +import org.springframework.lang.Nullable; + +/** + * Strategy interface for creating JTA {@link javax.transaction.Transaction} + * objects based on specified transactional characteristics. + * + *

The default implementation, {@link SimpleTransactionFactory}, simply + * wraps a standard JTA {@link javax.transaction.TransactionManager}. + * This strategy interface allows for more sophisticated implementations + * that adapt to vendor-specific JTA extensions. + * + * @author Juergen Hoeller + * @since 2.5 + * @see javax.transaction.TransactionManager#getTransaction() + * @see SimpleTransactionFactory + * @see JtaTransactionManager + */ +public interface TransactionFactory { + + /** + * Create an active Transaction object based on the given name and timeout. + * @param name the transaction name (may be {@code null}) + * @param timeout the transaction timeout (may be -1 for the default timeout) + * @return the active Transaction object (never {@code null}) + * @throws NotSupportedException if the transaction manager does not support + * a transaction of the specified type + * @throws SystemException if the transaction manager failed to create the + * transaction + */ + Transaction createTransaction(@Nullable String name, int timeout) throws NotSupportedException, SystemException; + + /** + * Determine whether the underlying transaction manager supports XA transactions + * managed by a resource adapter (i.e. without explicit XA resource enlistment). + *

Typically {@code false}. Checked by + * {@link org.springframework.jca.endpoint.AbstractMessageEndpointFactory} + * in order to differentiate between invalid configuration and valid + * ResourceAdapter-managed transactions. + * @see javax.resource.spi.ResourceAdapter#endpointActivation + * @see javax.resource.spi.endpoint.MessageEndpointFactory#isDeliveryTransacted + */ + boolean supportsResourceAdapterManagedTransactions(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/UserTransactionAdapter.java b/spring-tx/src/main/java/org/springframework/transaction/jta/UserTransactionAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..8302071f0e32508919cb9c75ac3aa7ace3144868 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/UserTransactionAdapter.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import javax.transaction.HeuristicMixedException; +import javax.transaction.HeuristicRollbackException; +import javax.transaction.NotSupportedException; +import javax.transaction.RollbackException; +import javax.transaction.SystemException; +import javax.transaction.TransactionManager; +import javax.transaction.UserTransaction; + +import org.springframework.util.Assert; + +/** + * Adapter for a JTA UserTransaction handle, taking a JTA + * {@link javax.transaction.TransactionManager} reference and creating + * a JTA {@link javax.transaction.UserTransaction} handle for it. + * + *

The JTA UserTransaction interface is an exact subset of the JTA + * TransactionManager interface. Unfortunately, it does not serve as + * super-interface of TransactionManager, though, which requires an + * adapter such as this class to be used when intending to talk to + * a TransactionManager handle through the UserTransaction interface. + * + *

Used internally by Spring's {@link JtaTransactionManager} for certain + * scenarios. Not intended for direct use in application code. + * + * @author Juergen Hoeller + * @since 1.1.5 + */ +public class UserTransactionAdapter implements UserTransaction { + + private final TransactionManager transactionManager; + + + /** + * Create a new UserTransactionAdapter for the given TransactionManager. + * @param transactionManager the JTA TransactionManager to wrap + */ + public UserTransactionAdapter(TransactionManager transactionManager) { + Assert.notNull(transactionManager, "TransactionManager must not be null"); + this.transactionManager = transactionManager; + } + + /** + * Return the JTA TransactionManager that this adapter delegates to. + */ + public final TransactionManager getTransactionManager() { + return this.transactionManager; + } + + + @Override + public void setTransactionTimeout(int timeout) throws SystemException { + this.transactionManager.setTransactionTimeout(timeout); + } + + @Override + public void begin() throws NotSupportedException, SystemException { + this.transactionManager.begin(); + } + + @Override + public void commit() + throws RollbackException, HeuristicMixedException, HeuristicRollbackException, + SecurityException, SystemException { + this.transactionManager.commit(); + } + + @Override + public void rollback() throws SecurityException, SystemException { + this.transactionManager.rollback(); + } + + @Override + public void setRollbackOnly() throws SystemException { + this.transactionManager.setRollbackOnly(); + } + + @Override + public int getStatus() throws SystemException { + return this.transactionManager.getStatus(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/WebLogicJtaTransactionManager.java b/spring-tx/src/main/java/org/springframework/transaction/jta/WebLogicJtaTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..9d2cc38a723b8bd0fd330eff3a354e9319e950e0 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/WebLogicJtaTransactionManager.java @@ -0,0 +1,366 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import javax.transaction.InvalidTransactionException; +import javax.transaction.NotSupportedException; +import javax.transaction.SystemException; +import javax.transaction.Transaction; +import javax.transaction.TransactionManager; +import javax.transaction.UserTransaction; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.util.Assert; + +/** + * Special {@link JtaTransactionManager} variant for BEA WebLogic (9.0 and higher). + * Supports the full power of Spring's transaction definitions on WebLogic's + * transaction coordinator, beyond standard JTA: transaction names, + * per-transaction isolation levels, and proper resuming of transactions in all cases. + * + *

Uses WebLogic's special {@code begin(name)} method to start a JTA transaction, + * in order to make Spring-driven transactions visible in WebLogic's transaction + * monitor. In case of Spring's declarative transactions, the exposed name will + * (by default) be the fully-qualified class name + "." + method name. + * + *

Supports a per-transaction isolation level through WebLogic's corresponding + * JTA transaction property "ISOLATION LEVEL". This will apply the specified isolation + * level (e.g. ISOLATION_SERIALIZABLE) to all JDBC Connections that participate in the + * given transaction. + * + *

Invokes WebLogic's special {@code forceResume} method if standard JTA resume + * failed, to also resume if the target transaction was marked rollback-only. + * If you're not relying on this feature of transaction suspension in the first + * place, Spring's standard JtaTransactionManager will behave properly too. + * + *

By default, the JTA UserTransaction and TransactionManager handles are + * fetched directly from WebLogic's {@code TransactionHelper}. This can be + * overridden by specifying "userTransaction"/"userTransactionName" and + * "transactionManager"/"transactionManagerName", passing in existing handles + * or specifying corresponding JNDI locations to look up. + * + *

NOTE: This JtaTransactionManager is intended to refine specific transaction + * demarcation behavior on Spring's side. It will happily co-exist with independently + * configured WebLogic transaction strategies in your persistence provider, with no + * need to specifically connect those setups in any way. + * + * @author Juergen Hoeller + * @since 1.1 + * @see org.springframework.transaction.TransactionDefinition#getName + * @see org.springframework.transaction.TransactionDefinition#getIsolationLevel + * @see weblogic.transaction.UserTransaction#begin(String) + * @see weblogic.transaction.Transaction#setProperty + * @see weblogic.transaction.TransactionManager#forceResume + * @see weblogic.transaction.TransactionHelper + */ +@SuppressWarnings("serial") +public class WebLogicJtaTransactionManager extends JtaTransactionManager { + + private static final String USER_TRANSACTION_CLASS_NAME = "weblogic.transaction.UserTransaction"; + + private static final String CLIENT_TRANSACTION_MANAGER_CLASS_NAME = "weblogic.transaction.ClientTransactionManager"; + + private static final String TRANSACTION_CLASS_NAME = "weblogic.transaction.Transaction"; + + private static final String TRANSACTION_HELPER_CLASS_NAME = "weblogic.transaction.TransactionHelper"; + + private static final String ISOLATION_LEVEL_KEY = "ISOLATION LEVEL"; + + + private boolean weblogicUserTransactionAvailable; + + @Nullable + private Method beginWithNameMethod; + + @Nullable + private Method beginWithNameAndTimeoutMethod; + + private boolean weblogicTransactionManagerAvailable; + + @Nullable + private Method forceResumeMethod; + + @Nullable + private Method setPropertyMethod; + + @Nullable + private Object transactionHelper; + + + @Override + public void afterPropertiesSet() throws TransactionSystemException { + super.afterPropertiesSet(); + loadWebLogicTransactionClasses(); + } + + @Override + @Nullable + protected UserTransaction retrieveUserTransaction() throws TransactionSystemException { + Object helper = loadWebLogicTransactionHelper(); + try { + logger.trace("Retrieving JTA UserTransaction from WebLogic TransactionHelper"); + Method getUserTransactionMethod = helper.getClass().getMethod("getUserTransaction"); + return (UserTransaction) getUserTransactionMethod.invoke(this.transactionHelper); + } + catch (InvocationTargetException ex) { + throw new TransactionSystemException( + "WebLogic's TransactionHelper.getUserTransaction() method failed", ex.getTargetException()); + } + catch (Exception ex) { + throw new TransactionSystemException( + "Could not invoke WebLogic's TransactionHelper.getUserTransaction() method", ex); + } + } + + @Override + @Nullable + protected TransactionManager retrieveTransactionManager() throws TransactionSystemException { + Object helper = loadWebLogicTransactionHelper(); + try { + logger.trace("Retrieving JTA TransactionManager from WebLogic TransactionHelper"); + Method getTransactionManagerMethod = helper.getClass().getMethod("getTransactionManager"); + return (TransactionManager) getTransactionManagerMethod.invoke(this.transactionHelper); + } + catch (InvocationTargetException ex) { + throw new TransactionSystemException( + "WebLogic's TransactionHelper.getTransactionManager() method failed", ex.getTargetException()); + } + catch (Exception ex) { + throw new TransactionSystemException( + "Could not invoke WebLogic's TransactionHelper.getTransactionManager() method", ex); + } + } + + private Object loadWebLogicTransactionHelper() throws TransactionSystemException { + Object helper = this.transactionHelper; + if (helper == null) { + try { + Class transactionHelperClass = getClass().getClassLoader().loadClass(TRANSACTION_HELPER_CLASS_NAME); + Method getTransactionHelperMethod = transactionHelperClass.getMethod("getTransactionHelper"); + helper = getTransactionHelperMethod.invoke(null); + this.transactionHelper = helper; + logger.trace("WebLogic TransactionHelper found"); + } + catch (InvocationTargetException ex) { + throw new TransactionSystemException( + "WebLogic's TransactionHelper.getTransactionHelper() method failed", ex.getTargetException()); + } + catch (Exception ex) { + throw new TransactionSystemException( + "Could not initialize WebLogicJtaTransactionManager because WebLogic API classes are not available", + ex); + } + } + return helper; + } + + private void loadWebLogicTransactionClasses() throws TransactionSystemException { + try { + Class userTransactionClass = getClass().getClassLoader().loadClass(USER_TRANSACTION_CLASS_NAME); + this.weblogicUserTransactionAvailable = userTransactionClass.isInstance(getUserTransaction()); + if (this.weblogicUserTransactionAvailable) { + this.beginWithNameMethod = userTransactionClass.getMethod("begin", String.class); + this.beginWithNameAndTimeoutMethod = userTransactionClass.getMethod("begin", String.class, int.class); + logger.debug("Support for WebLogic transaction names available"); + } + else { + logger.debug("Support for WebLogic transaction names not available"); + } + + // Obtain WebLogic ClientTransactionManager interface. + Class transactionManagerClass = + getClass().getClassLoader().loadClass(CLIENT_TRANSACTION_MANAGER_CLASS_NAME); + logger.trace("WebLogic ClientTransactionManager found"); + + this.weblogicTransactionManagerAvailable = transactionManagerClass.isInstance(getTransactionManager()); + if (this.weblogicTransactionManagerAvailable) { + Class transactionClass = getClass().getClassLoader().loadClass(TRANSACTION_CLASS_NAME); + this.forceResumeMethod = transactionManagerClass.getMethod("forceResume", Transaction.class); + this.setPropertyMethod = transactionClass.getMethod("setProperty", String.class, Serializable.class); + logger.debug("Support for WebLogic forceResume available"); + } + else { + logger.debug("Support for WebLogic forceResume not available"); + } + } + catch (Exception ex) { + throw new TransactionSystemException( + "Could not initialize WebLogicJtaTransactionManager because WebLogic API classes are not available", + ex); + } + } + + private TransactionManager obtainTransactionManager() { + TransactionManager tm = getTransactionManager(); + Assert.state(tm != null, "No TransactionManager set"); + return tm; + } + + + @Override + protected void doJtaBegin(JtaTransactionObject txObject, TransactionDefinition definition) + throws NotSupportedException, SystemException { + + int timeout = determineTimeout(definition); + + // Apply transaction name (if any) to WebLogic transaction. + if (this.weblogicUserTransactionAvailable && definition.getName() != null) { + try { + if (timeout > TransactionDefinition.TIMEOUT_DEFAULT) { + /* + weblogic.transaction.UserTransaction wut = (weblogic.transaction.UserTransaction) ut; + wut.begin(definition.getName(), timeout); + */ + Assert.state(this.beginWithNameAndTimeoutMethod != null, "WebLogic JTA API not initialized"); + this.beginWithNameAndTimeoutMethod.invoke(txObject.getUserTransaction(), definition.getName(), timeout); + } + else { + /* + weblogic.transaction.UserTransaction wut = (weblogic.transaction.UserTransaction) ut; + wut.begin(definition.getName()); + */ + Assert.state(this.beginWithNameMethod != null, "WebLogic JTA API not initialized"); + this.beginWithNameMethod.invoke(txObject.getUserTransaction(), definition.getName()); + } + } + catch (InvocationTargetException ex) { + throw new TransactionSystemException( + "WebLogic's UserTransaction.begin() method failed", ex.getTargetException()); + } + catch (Exception ex) { + throw new TransactionSystemException( + "Could not invoke WebLogic's UserTransaction.begin() method", ex); + } + } + else { + // No WebLogic UserTransaction available or no transaction name specified + // -> standard JTA begin call. + applyTimeout(txObject, timeout); + txObject.getUserTransaction().begin(); + } + + // Specify isolation level, if any, through corresponding WebLogic transaction property. + if (this.weblogicTransactionManagerAvailable) { + if (definition.getIsolationLevel() != TransactionDefinition.ISOLATION_DEFAULT) { + try { + Transaction tx = obtainTransactionManager().getTransaction(); + Integer isolationLevel = definition.getIsolationLevel(); + /* + weblogic.transaction.Transaction wtx = (weblogic.transaction.Transaction) tx; + wtx.setProperty(ISOLATION_LEVEL_KEY, isolationLevel); + */ + Assert.state(this.setPropertyMethod != null, "WebLogic JTA API not initialized"); + this.setPropertyMethod.invoke(tx, ISOLATION_LEVEL_KEY, isolationLevel); + } + catch (InvocationTargetException ex) { + throw new TransactionSystemException( + "WebLogic's Transaction.setProperty(String, Serializable) method failed", ex.getTargetException()); + } + catch (Exception ex) { + throw new TransactionSystemException( + "Could not invoke WebLogic's Transaction.setProperty(String, Serializable) method", ex); + } + } + } + else { + applyIsolationLevel(txObject, definition.getIsolationLevel()); + } + } + + @Override + protected void doJtaResume(@Nullable JtaTransactionObject txObject, Object suspendedTransaction) + throws InvalidTransactionException, SystemException { + + try { + obtainTransactionManager().resume((Transaction) suspendedTransaction); + } + catch (InvalidTransactionException ex) { + if (!this.weblogicTransactionManagerAvailable) { + throw ex; + } + + if (logger.isDebugEnabled()) { + logger.debug("Standard JTA resume threw InvalidTransactionException: " + ex.getMessage() + + " - trying WebLogic JTA forceResume"); + } + /* + weblogic.transaction.TransactionManager wtm = + (weblogic.transaction.TransactionManager) getTransactionManager(); + wtm.forceResume(suspendedTransaction); + */ + try { + Assert.state(this.forceResumeMethod != null, "WebLogic JTA API not initialized"); + this.forceResumeMethod.invoke(getTransactionManager(), suspendedTransaction); + } + catch (InvocationTargetException ex2) { + throw new TransactionSystemException( + "WebLogic's TransactionManager.forceResume(Transaction) method failed", ex2.getTargetException()); + } + catch (Exception ex2) { + throw new TransactionSystemException( + "Could not access WebLogic's TransactionManager.forceResume(Transaction) method", ex2); + } + } + } + + @Override + public Transaction createTransaction(@Nullable String name, int timeout) throws NotSupportedException, SystemException { + if (this.weblogicUserTransactionAvailable && name != null) { + try { + if (timeout >= 0) { + Assert.state(this.beginWithNameAndTimeoutMethod != null, "WebLogic JTA API not initialized"); + this.beginWithNameAndTimeoutMethod.invoke(getUserTransaction(), name, timeout); + } + else { + Assert.state(this.beginWithNameMethod != null, "WebLogic JTA API not initialized"); + this.beginWithNameMethod.invoke(getUserTransaction(), name); + } + } + catch (InvocationTargetException ex) { + if (ex.getTargetException() instanceof NotSupportedException) { + throw (NotSupportedException) ex.getTargetException(); + } + else if (ex.getTargetException() instanceof SystemException) { + throw (SystemException) ex.getTargetException(); + } + else if (ex.getTargetException() instanceof RuntimeException) { + throw (RuntimeException) ex.getTargetException(); + } + else { + throw new SystemException( + "WebLogic's begin() method failed with an unexpected error: " + ex.getTargetException()); + } + } + catch (Exception ex) { + throw new SystemException("Could not invoke WebLogic's UserTransaction.begin() method: " + ex); + } + return new ManagedTransactionAdapter(obtainTransactionManager()); + } + + else { + // No name specified - standard JTA is sufficient. + return super.createTransaction(name, timeout); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/WebSphereUowTransactionManager.java b/spring-tx/src/main/java/org/springframework/transaction/jta/WebSphereUowTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..68929645a1c9b34ea7de9f77b88f99722e8285a2 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/WebSphereUowTransactionManager.java @@ -0,0 +1,426 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import java.util.List; + +import javax.naming.NamingException; + +import com.ibm.websphere.uow.UOWSynchronizationRegistry; +import com.ibm.wsspi.uow.UOWAction; +import com.ibm.wsspi.uow.UOWActionException; +import com.ibm.wsspi.uow.UOWException; +import com.ibm.wsspi.uow.UOWManager; +import com.ibm.wsspi.uow.UOWManagerFactory; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.IllegalTransactionStateException; +import org.springframework.transaction.InvalidTimeoutException; +import org.springframework.transaction.NestedTransactionNotSupportedException; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.transaction.support.CallbackPreferringPlatformTransactionManager; +import org.springframework.transaction.support.DefaultTransactionDefinition; +import org.springframework.transaction.support.DefaultTransactionStatus; +import org.springframework.transaction.support.SmartTransactionObject; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.transaction.support.TransactionSynchronizationUtils; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * WebSphere-specific PlatformTransactionManager implementation that delegates + * to a {@link com.ibm.wsspi.uow.UOWManager} instance, obtained from WebSphere's + * JNDI environment. This allows Spring to leverage the full power of the WebSphere + * transaction coordinator, including transaction suspension, in a manner that is + * perfectly compliant with officially supported WebSphere API. + * + *

The {@link CallbackPreferringPlatformTransactionManager} interface + * implemented by this class indicates that callers should preferably pass in + * a {@link TransactionCallback} through the {@link #execute} method, which + * will be handled through the callback-based WebSphere UOWManager API instead + * of through standard JTA API (UserTransaction / TransactionManager). This avoids + * the use of the non-public {@code javax.transaction.TransactionManager} + * API on WebSphere, staying within supported WebSphere API boundaries. + * + *

This transaction manager implementation derives from Spring's standard + * {@link JtaTransactionManager}, inheriting the capability to support programmatic + * transaction demarcation via {@code getTransaction} / {@code commit} / + * {@code rollback} calls through a JTA UserTransaction handle, for callers + * that do not use the TransactionCallback-based {@link #execute} method. However, + * transaction suspension is not supported in this {@code getTransaction} + * style (unless you explicitly specify a {@link #setTransactionManager} reference, + * despite the official WebSphere recommendations). Use the {@link #execute} style + * for any code that might require transaction suspension. + * + *

This transaction manager is compatible with WebSphere 6.1.0.9 and above. + * The default JNDI location for the UOWManager is "java:comp/websphere/UOWManager". + * If the location happens to differ according to your WebSphere documentation, + * simply specify the actual location through this transaction manager's + * "uowManagerName" bean property. + * + *

NOTE: This JtaTransactionManager is intended to refine specific transaction + * demarcation behavior on Spring's side. It will happily co-exist with independently + * configured WebSphere transaction strategies in your persistence provider, with no + * need to specifically connect those setups in any way. + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setUowManager + * @see #setUowManagerName + * @see com.ibm.wsspi.uow.UOWManager + */ +@SuppressWarnings("serial") +public class WebSphereUowTransactionManager extends JtaTransactionManager + implements CallbackPreferringPlatformTransactionManager { + + /** + * Default JNDI location for the WebSphere UOWManager. + * @see #setUowManagerName + */ + public static final String DEFAULT_UOW_MANAGER_NAME = "java:comp/websphere/UOWManager"; + + + @Nullable + private UOWManager uowManager; + + @Nullable + private String uowManagerName; + + + /** + * Create a new WebSphereUowTransactionManager. + */ + public WebSphereUowTransactionManager() { + setAutodetectTransactionManager(false); + } + + /** + * Create a new WebSphereUowTransactionManager for the given UOWManager. + * @param uowManager the WebSphere UOWManager to use as direct reference + */ + public WebSphereUowTransactionManager(UOWManager uowManager) { + this(); + this.uowManager = uowManager; + } + + + /** + * Set the WebSphere UOWManager to use as direct reference. + *

Typically just used for test setups; in a Java EE environment, + * the UOWManager will always be fetched from JNDI. + * @see #setUserTransactionName + */ + public void setUowManager(UOWManager uowManager) { + this.uowManager = uowManager; + } + + /** + * Set the JNDI name of the WebSphere UOWManager. + * The default "java:comp/websphere/UOWManager" is used if not set. + * @see #DEFAULT_USER_TRANSACTION_NAME + * @see #setUowManager + */ + public void setUowManagerName(String uowManagerName) { + this.uowManagerName = uowManagerName; + } + + + @Override + public void afterPropertiesSet() throws TransactionSystemException { + initUserTransactionAndTransactionManager(); + + // Fetch UOWManager handle from JNDI, if necessary. + if (this.uowManager == null) { + if (this.uowManagerName != null) { + this.uowManager = lookupUowManager(this.uowManagerName); + } + else { + this.uowManager = lookupDefaultUowManager(); + } + } + } + + /** + * Look up the WebSphere UOWManager in JNDI via the configured name. + * @param uowManagerName the JNDI name of the UOWManager + * @return the UOWManager object + * @throws TransactionSystemException if the JNDI lookup failed + * @see #setJndiTemplate + * @see #setUowManagerName + */ + protected UOWManager lookupUowManager(String uowManagerName) throws TransactionSystemException { + try { + if (logger.isDebugEnabled()) { + logger.debug("Retrieving WebSphere UOWManager from JNDI location [" + uowManagerName + "]"); + } + return getJndiTemplate().lookup(uowManagerName, UOWManager.class); + } + catch (NamingException ex) { + throw new TransactionSystemException( + "WebSphere UOWManager is not available at JNDI location [" + uowManagerName + "]", ex); + } + } + + /** + * Obtain the WebSphere UOWManager from the default JNDI location + * "java:comp/websphere/UOWManager". + * @return the UOWManager object + * @throws TransactionSystemException if the JNDI lookup failed + * @see #setJndiTemplate + */ + protected UOWManager lookupDefaultUowManager() throws TransactionSystemException { + try { + logger.debug("Retrieving WebSphere UOWManager from default JNDI location [" + DEFAULT_UOW_MANAGER_NAME + "]"); + return getJndiTemplate().lookup(DEFAULT_UOW_MANAGER_NAME, UOWManager.class); + } + catch (NamingException ex) { + logger.debug("WebSphere UOWManager is not available at default JNDI location [" + + DEFAULT_UOW_MANAGER_NAME + "] - falling back to UOWManagerFactory lookup"); + return UOWManagerFactory.getUOWManager(); + } + } + + private UOWManager obtainUOWManager() { + Assert.state(this.uowManager != null, "No UOWManager set"); + return this.uowManager; + } + + + /** + * Registers the synchronizations as interposed JTA Synchronization on the UOWManager. + */ + @Override + protected void doRegisterAfterCompletionWithJtaTransaction( + JtaTransactionObject txObject, List synchronizations) { + + obtainUOWManager().registerInterposedSynchronization(new JtaAfterCompletionSynchronization(synchronizations)); + } + + /** + * Returns {@code true} since WebSphere ResourceAdapters (as exposed in JNDI) + * implicitly perform transaction enlistment if the MessageEndpointFactory's + * {@code isDeliveryTransacted} method returns {@code true}. + * In that case we'll simply skip the {@link #createTransaction} call. + * @see javax.resource.spi.endpoint.MessageEndpointFactory#isDeliveryTransacted + * @see org.springframework.jca.endpoint.AbstractMessageEndpointFactory + * @see TransactionFactory#createTransaction + */ + @Override + public boolean supportsResourceAdapterManagedTransactions() { + return true; + } + + + @Override + @Nullable + public T execute(@Nullable TransactionDefinition definition, TransactionCallback callback) + throws TransactionException { + + if (definition == null) { + // Use defaults if no transaction definition given. + definition = new DefaultTransactionDefinition(); + } + + if (definition.getTimeout() < TransactionDefinition.TIMEOUT_DEFAULT) { + throw new InvalidTimeoutException("Invalid transaction timeout", definition.getTimeout()); + } + + UOWManager uowManager = obtainUOWManager(); + int pb = definition.getPropagationBehavior(); + boolean existingTx = (uowManager.getUOWStatus() != UOWSynchronizationRegistry.UOW_STATUS_NONE && + uowManager.getUOWType() != UOWSynchronizationRegistry.UOW_TYPE_LOCAL_TRANSACTION); + + int uowType = UOWSynchronizationRegistry.UOW_TYPE_GLOBAL_TRANSACTION; + boolean joinTx = false; + boolean newSynch = false; + + if (existingTx) { + if (pb == TransactionDefinition.PROPAGATION_NEVER) { + throw new IllegalTransactionStateException( + "Transaction propagation 'never' but existing transaction found"); + } + if (pb == TransactionDefinition.PROPAGATION_NESTED) { + throw new NestedTransactionNotSupportedException( + "Transaction propagation 'nested' not supported for WebSphere UOW transactions"); + } + if (pb == TransactionDefinition.PROPAGATION_SUPPORTS || + pb == TransactionDefinition.PROPAGATION_REQUIRED || + pb == TransactionDefinition.PROPAGATION_MANDATORY) { + joinTx = true; + newSynch = (getTransactionSynchronization() != SYNCHRONIZATION_NEVER); + } + else if (pb == TransactionDefinition.PROPAGATION_NOT_SUPPORTED) { + uowType = UOWSynchronizationRegistry.UOW_TYPE_LOCAL_TRANSACTION; + newSynch = (getTransactionSynchronization() == SYNCHRONIZATION_ALWAYS); + } + else { + newSynch = (getTransactionSynchronization() != SYNCHRONIZATION_NEVER); + } + } + else { + if (pb == TransactionDefinition.PROPAGATION_MANDATORY) { + throw new IllegalTransactionStateException( + "Transaction propagation 'mandatory' but no existing transaction found"); + } + if (pb == TransactionDefinition.PROPAGATION_SUPPORTS || + pb == TransactionDefinition.PROPAGATION_NOT_SUPPORTED || + pb == TransactionDefinition.PROPAGATION_NEVER) { + uowType = UOWSynchronizationRegistry.UOW_TYPE_LOCAL_TRANSACTION; + newSynch = (getTransactionSynchronization() == SYNCHRONIZATION_ALWAYS); + } + else { + newSynch = (getTransactionSynchronization() != SYNCHRONIZATION_NEVER); + } + } + + boolean debug = logger.isDebugEnabled(); + if (debug) { + logger.debug("Creating new transaction with name [" + definition.getName() + "]: " + definition); + } + SuspendedResourcesHolder suspendedResources = (!joinTx ? suspend(null) : null); + UOWActionAdapter action = null; + try { + boolean actualTransaction = (uowType == UOWManager.UOW_TYPE_GLOBAL_TRANSACTION); + if (actualTransaction && definition.getTimeout() > TransactionDefinition.TIMEOUT_DEFAULT) { + uowManager.setUOWTimeout(uowType, definition.getTimeout()); + } + if (debug) { + logger.debug("Invoking WebSphere UOW action: type=" + uowType + ", join=" + joinTx); + } + action = new UOWActionAdapter<>(definition, callback, actualTransaction, !joinTx, newSynch, debug); + uowManager.runUnderUOW(uowType, joinTx, action); + if (debug) { + logger.debug("Returned from WebSphere UOW action: type=" + uowType + ", join=" + joinTx); + } + return action.getResult(); + } + catch (UOWException | UOWActionException ex) { + TransactionSystemException tse = + new TransactionSystemException("UOWManager transaction processing failed", ex); + Throwable appEx = action.getException(); + if (appEx != null) { + logger.error("Application exception overridden by rollback exception", appEx); + tse.initApplicationException(appEx); + } + throw tse; + } + finally { + if (suspendedResources != null) { + resume(null, suspendedResources); + } + } + } + + + /** + * Adapter that executes the given Spring transaction within the WebSphere UOWAction shape. + */ + private class UOWActionAdapter implements UOWAction, SmartTransactionObject { + + private final TransactionDefinition definition; + + private final TransactionCallback callback; + + private final boolean actualTransaction; + + private final boolean newTransaction; + + private final boolean newSynchronization; + + private boolean debug; + + @Nullable + private T result; + + @Nullable + private Throwable exception; + + public UOWActionAdapter(TransactionDefinition definition, TransactionCallback callback, + boolean actualTransaction, boolean newTransaction, boolean newSynchronization, boolean debug) { + + this.definition = definition; + this.callback = callback; + this.actualTransaction = actualTransaction; + this.newTransaction = newTransaction; + this.newSynchronization = newSynchronization; + this.debug = debug; + } + + @Override + public void run() { + UOWManager uowManager = obtainUOWManager(); + DefaultTransactionStatus status = prepareTransactionStatus( + this.definition, (this.actualTransaction ? this : null), + this.newTransaction, this.newSynchronization, this.debug, null); + try { + this.result = this.callback.doInTransaction(status); + triggerBeforeCommit(status); + } + catch (Throwable ex) { + this.exception = ex; + if (status.isDebug()) { + logger.debug("Rolling back on application exception from transaction callback", ex); + } + uowManager.setRollbackOnly(); + } + finally { + if (status.isLocalRollbackOnly()) { + if (status.isDebug()) { + logger.debug("Transaction callback has explicitly requested rollback"); + } + uowManager.setRollbackOnly(); + } + triggerBeforeCompletion(status); + if (status.isNewSynchronization()) { + List synchronizations = TransactionSynchronizationManager.getSynchronizations(); + TransactionSynchronizationManager.clear(); + if (!synchronizations.isEmpty()) { + uowManager.registerInterposedSynchronization(new JtaAfterCompletionSynchronization(synchronizations)); + } + } + } + } + + @Nullable + public T getResult() { + if (this.exception != null) { + ReflectionUtils.rethrowRuntimeException(this.exception); + } + return this.result; + } + + @Nullable + public Throwable getException() { + return this.exception; + } + + @Override + public boolean isRollbackOnly() { + return obtainUOWManager().getRollbackOnly(); + } + + @Override + public void flush() { + TransactionSynchronizationUtils.triggerFlush(); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/jta/package-info.java b/spring-tx/src/main/java/org/springframework/transaction/jta/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..677749eb0a4c3f8d04d4ad0a08f51c50b7d4e891 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/jta/package-info.java @@ -0,0 +1,9 @@ +/** + * Transaction SPI implementation for JTA. + */ +@NonNullApi +@NonNullFields +package org.springframework.transaction.jta; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/transaction/package-info.java b/spring-tx/src/main/java/org/springframework/transaction/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..c6f2ed6104a155da231333804ad4051fa4e0b120 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/package-info.java @@ -0,0 +1,11 @@ +/** + * Exception hierarchy for Spring's transaction infrastructure, + * independent of any specific transaction management system. + * Contains transaction manager, definition, and status interfaces. + */ +@NonNullApi +@NonNullFields +package org.springframework.transaction; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/AbstractPlatformTransactionManager.java b/spring-tx/src/main/java/org/springframework/transaction/support/AbstractPlatformTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..809a854aad27352af1b49f4792b5deb25993343a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/AbstractPlatformTransactionManager.java @@ -0,0 +1,1313 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.Constants; +import org.springframework.lang.Nullable; +import org.springframework.transaction.IllegalTransactionStateException; +import org.springframework.transaction.InvalidTimeoutException; +import org.springframework.transaction.NestedTransactionNotSupportedException; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.TransactionSuspensionNotSupportedException; +import org.springframework.transaction.UnexpectedRollbackException; + +/** + * Abstract base class that implements Spring's standard transaction workflow, + * serving as basis for concrete platform transaction managers like + * {@link org.springframework.transaction.jta.JtaTransactionManager}. + * + *

This base class provides the following workflow handling: + *

    + *
  • determines if there is an existing transaction; + *
  • applies the appropriate propagation behavior; + *
  • suspends and resumes transactions if necessary; + *
  • checks the rollback-only flag on commit; + *
  • applies the appropriate modification on rollback + * (actual rollback or setting rollback-only); + *
  • triggers registered synchronization callbacks + * (if transaction synchronization is active). + *
+ * + *

Subclasses have to implement specific template methods for specific + * states of a transaction, e.g.: begin, suspend, resume, commit, rollback. + * The most important of them are abstract and must be provided by a concrete + * implementation; for the rest, defaults are provided, so overriding is optional. + * + *

Transaction synchronization is a generic mechanism for registering callbacks + * that get invoked at transaction completion time. This is mainly used internally + * by the data access support classes for JDBC, Hibernate, JPA, etc when running + * within a JTA transaction: They register resources that are opened within the + * transaction for closing at transaction completion time, allowing e.g. for reuse + * of the same Hibernate Session within the transaction. The same mechanism can + * also be leveraged for custom synchronization needs in an application. + * + *

The state of this class is serializable, to allow for serializing the + * transaction strategy along with proxies that carry a transaction interceptor. + * It is up to subclasses if they wish to make their state to be serializable too. + * They should implement the {@code java.io.Serializable} marker interface in + * that case, and potentially a private {@code readObject()} method (according + * to Java serialization rules) if they need to restore any transient state. + * + * @author Juergen Hoeller + * @since 28.03.2003 + * @see #setTransactionSynchronization + * @see TransactionSynchronizationManager + * @see org.springframework.transaction.jta.JtaTransactionManager + */ +@SuppressWarnings("serial") +public abstract class AbstractPlatformTransactionManager implements PlatformTransactionManager, Serializable { + + /** + * Always activate transaction synchronization, even for "empty" transactions + * that result from PROPAGATION_SUPPORTS with no existing backend transaction. + * @see org.springframework.transaction.TransactionDefinition#PROPAGATION_SUPPORTS + * @see org.springframework.transaction.TransactionDefinition#PROPAGATION_NOT_SUPPORTED + * @see org.springframework.transaction.TransactionDefinition#PROPAGATION_NEVER + */ + public static final int SYNCHRONIZATION_ALWAYS = 0; + + /** + * Activate transaction synchronization only for actual transactions, + * that is, not for empty ones that result from PROPAGATION_SUPPORTS with + * no existing backend transaction. + * @see org.springframework.transaction.TransactionDefinition#PROPAGATION_REQUIRED + * @see org.springframework.transaction.TransactionDefinition#PROPAGATION_MANDATORY + * @see org.springframework.transaction.TransactionDefinition#PROPAGATION_REQUIRES_NEW + */ + public static final int SYNCHRONIZATION_ON_ACTUAL_TRANSACTION = 1; + + /** + * Never active transaction synchronization, not even for actual transactions. + */ + public static final int SYNCHRONIZATION_NEVER = 2; + + + /** Constants instance for AbstractPlatformTransactionManager. */ + private static final Constants constants = new Constants(AbstractPlatformTransactionManager.class); + + + protected transient Log logger = LogFactory.getLog(getClass()); + + private int transactionSynchronization = SYNCHRONIZATION_ALWAYS; + + private int defaultTimeout = TransactionDefinition.TIMEOUT_DEFAULT; + + private boolean nestedTransactionAllowed = false; + + private boolean validateExistingTransaction = false; + + private boolean globalRollbackOnParticipationFailure = true; + + private boolean failEarlyOnGlobalRollbackOnly = false; + + private boolean rollbackOnCommitFailure = false; + + + /** + * Set the transaction synchronization by the name of the corresponding constant + * in this class, e.g. "SYNCHRONIZATION_ALWAYS". + * @param constantName name of the constant + * @see #SYNCHRONIZATION_ALWAYS + */ + public final void setTransactionSynchronizationName(String constantName) { + setTransactionSynchronization(constants.asNumber(constantName).intValue()); + } + + /** + * Set when this transaction manager should activate the thread-bound + * transaction synchronization support. Default is "always". + *

Note that transaction synchronization isn't supported for + * multiple concurrent transactions by different transaction managers. + * Only one transaction manager is allowed to activate it at any time. + * @see #SYNCHRONIZATION_ALWAYS + * @see #SYNCHRONIZATION_ON_ACTUAL_TRANSACTION + * @see #SYNCHRONIZATION_NEVER + * @see TransactionSynchronizationManager + * @see TransactionSynchronization + */ + public final void setTransactionSynchronization(int transactionSynchronization) { + this.transactionSynchronization = transactionSynchronization; + } + + /** + * Return if this transaction manager should activate the thread-bound + * transaction synchronization support. + */ + public final int getTransactionSynchronization() { + return this.transactionSynchronization; + } + + /** + * Specify the default timeout that this transaction manager should apply + * if there is no timeout specified at the transaction level, in seconds. + *

Default is the underlying transaction infrastructure's default timeout, + * e.g. typically 30 seconds in case of a JTA provider, indicated by the + * {@code TransactionDefinition.TIMEOUT_DEFAULT} value. + * @see org.springframework.transaction.TransactionDefinition#TIMEOUT_DEFAULT + */ + public final void setDefaultTimeout(int defaultTimeout) { + if (defaultTimeout < TransactionDefinition.TIMEOUT_DEFAULT) { + throw new InvalidTimeoutException("Invalid default timeout", defaultTimeout); + } + this.defaultTimeout = defaultTimeout; + } + + /** + * Return the default timeout that this transaction manager should apply + * if there is no timeout specified at the transaction level, in seconds. + *

Returns {@code TransactionDefinition.TIMEOUT_DEFAULT} to indicate + * the underlying transaction infrastructure's default timeout. + */ + public final int getDefaultTimeout() { + return this.defaultTimeout; + } + + /** + * Set whether nested transactions are allowed. Default is "false". + *

Typically initialized with an appropriate default by the + * concrete transaction manager subclass. + */ + public final void setNestedTransactionAllowed(boolean nestedTransactionAllowed) { + this.nestedTransactionAllowed = nestedTransactionAllowed; + } + + /** + * Return whether nested transactions are allowed. + */ + public final boolean isNestedTransactionAllowed() { + return this.nestedTransactionAllowed; + } + + /** + * Set whether existing transactions should be validated before participating + * in them. + *

When participating in an existing transaction (e.g. with + * PROPAGATION_REQUIRED or PROPAGATION_SUPPORTS encountering an existing + * transaction), this outer transaction's characteristics will apply even + * to the inner transaction scope. Validation will detect incompatible + * isolation level and read-only settings on the inner transaction definition + * and reject participation accordingly through throwing a corresponding exception. + *

Default is "false", leniently ignoring inner transaction settings, + * simply overriding them with the outer transaction's characteristics. + * Switch this flag to "true" in order to enforce strict validation. + * @since 2.5.1 + */ + public final void setValidateExistingTransaction(boolean validateExistingTransaction) { + this.validateExistingTransaction = validateExistingTransaction; + } + + /** + * Return whether existing transactions should be validated before participating + * in them. + * @since 2.5.1 + */ + public final boolean isValidateExistingTransaction() { + return this.validateExistingTransaction; + } + + /** + * Set whether to globally mark an existing transaction as rollback-only + * after a participating transaction failed. + *

Default is "true": If a participating transaction (e.g. with + * PROPAGATION_REQUIRED or PROPAGATION_SUPPORTS encountering an existing + * transaction) fails, the transaction will be globally marked as rollback-only. + * The only possible outcome of such a transaction is a rollback: The + * transaction originator cannot make the transaction commit anymore. + *

Switch this to "false" to let the transaction originator make the rollback + * decision. If a participating transaction fails with an exception, the caller + * can still decide to continue with a different path within the transaction. + * However, note that this will only work as long as all participating resources + * are capable of continuing towards a transaction commit even after a data access + * failure: This is generally not the case for a Hibernate Session, for example; + * neither is it for a sequence of JDBC insert/update/delete operations. + *

Note:This flag only applies to an explicit rollback attempt for a + * subtransaction, typically caused by an exception thrown by a data access operation + * (where TransactionInterceptor will trigger a {@code PlatformTransactionManager.rollback()} + * call according to a rollback rule). If the flag is off, the caller can handle the exception + * and decide on a rollback, independent of the rollback rules of the subtransaction. + * This flag does, however, not apply to explicit {@code setRollbackOnly} + * calls on a {@code TransactionStatus}, which will always cause an eventual + * global rollback (as it might not throw an exception after the rollback-only call). + *

The recommended solution for handling failure of a subtransaction + * is a "nested transaction", where the global transaction can be rolled + * back to a savepoint taken at the beginning of the subtransaction. + * PROPAGATION_NESTED provides exactly those semantics; however, it will + * only work when nested transaction support is available. This is the case + * with DataSourceTransactionManager, but not with JtaTransactionManager. + * @see #setNestedTransactionAllowed + * @see org.springframework.transaction.jta.JtaTransactionManager + */ + public final void setGlobalRollbackOnParticipationFailure(boolean globalRollbackOnParticipationFailure) { + this.globalRollbackOnParticipationFailure = globalRollbackOnParticipationFailure; + } + + /** + * Return whether to globally mark an existing transaction as rollback-only + * after a participating transaction failed. + */ + public final boolean isGlobalRollbackOnParticipationFailure() { + return this.globalRollbackOnParticipationFailure; + } + + /** + * Set whether to fail early in case of the transaction being globally marked + * as rollback-only. + *

Default is "false", only causing an UnexpectedRollbackException at the + * outermost transaction boundary. Switch this flag on to cause an + * UnexpectedRollbackException as early as the global rollback-only marker + * has been first detected, even from within an inner transaction boundary. + *

Note that, as of Spring 2.0, the fail-early behavior for global + * rollback-only markers has been unified: All transaction managers will by + * default only cause UnexpectedRollbackException at the outermost transaction + * boundary. This allows, for example, to continue unit tests even after an + * operation failed and the transaction will never be completed. All transaction + * managers will only fail earlier if this flag has explicitly been set to "true". + * @since 2.0 + * @see org.springframework.transaction.UnexpectedRollbackException + */ + public final void setFailEarlyOnGlobalRollbackOnly(boolean failEarlyOnGlobalRollbackOnly) { + this.failEarlyOnGlobalRollbackOnly = failEarlyOnGlobalRollbackOnly; + } + + /** + * Return whether to fail early in case of the transaction being globally marked + * as rollback-only. + * @since 2.0 + */ + public final boolean isFailEarlyOnGlobalRollbackOnly() { + return this.failEarlyOnGlobalRollbackOnly; + } + + /** + * Set whether {@code doRollback} should be performed on failure of the + * {@code doCommit} call. Typically not necessary and thus to be avoided, + * as it can potentially override the commit exception with a subsequent + * rollback exception. + *

Default is "false". + * @see #doCommit + * @see #doRollback + */ + public final void setRollbackOnCommitFailure(boolean rollbackOnCommitFailure) { + this.rollbackOnCommitFailure = rollbackOnCommitFailure; + } + + /** + * Return whether {@code doRollback} should be performed on failure of the + * {@code doCommit} call. + */ + public final boolean isRollbackOnCommitFailure() { + return this.rollbackOnCommitFailure; + } + + + //--------------------------------------------------------------------- + // Implementation of PlatformTransactionManager + //--------------------------------------------------------------------- + + /** + * This implementation handles propagation behavior. Delegates to + * {@code doGetTransaction}, {@code isExistingTransaction} + * and {@code doBegin}. + * @see #doGetTransaction + * @see #isExistingTransaction + * @see #doBegin + */ + @Override + public final TransactionStatus getTransaction(@Nullable TransactionDefinition definition) throws TransactionException { + Object transaction = doGetTransaction(); + + // Cache debug flag to avoid repeated checks. + boolean debugEnabled = logger.isDebugEnabled(); + + if (definition == null) { + // Use defaults if no transaction definition given. + definition = new DefaultTransactionDefinition(); + } + + if (isExistingTransaction(transaction)) { + // Existing transaction found -> check propagation behavior to find out how to behave. + return handleExistingTransaction(definition, transaction, debugEnabled); + } + + // Check definition settings for new transaction. + if (definition.getTimeout() < TransactionDefinition.TIMEOUT_DEFAULT) { + throw new InvalidTimeoutException("Invalid transaction timeout", definition.getTimeout()); + } + + // No existing transaction found -> check propagation behavior to find out how to proceed. + if (definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_MANDATORY) { + throw new IllegalTransactionStateException( + "No existing transaction found for transaction marked with propagation 'mandatory'"); + } + else if (definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRED || + definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRES_NEW || + definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_NESTED) { + SuspendedResourcesHolder suspendedResources = suspend(null); + if (debugEnabled) { + logger.debug("Creating new transaction with name [" + definition.getName() + "]: " + definition); + } + try { + boolean newSynchronization = (getTransactionSynchronization() != SYNCHRONIZATION_NEVER); + DefaultTransactionStatus status = newTransactionStatus( + definition, transaction, true, newSynchronization, debugEnabled, suspendedResources); + doBegin(transaction, definition); + prepareSynchronization(status, definition); + return status; + } + catch (RuntimeException | Error ex) { + resume(null, suspendedResources); + throw ex; + } + } + else { + // Create "empty" transaction: no actual transaction, but potentially synchronization. + if (definition.getIsolationLevel() != TransactionDefinition.ISOLATION_DEFAULT && logger.isWarnEnabled()) { + logger.warn("Custom isolation level specified but no actual transaction initiated; " + + "isolation level will effectively be ignored: " + definition); + } + boolean newSynchronization = (getTransactionSynchronization() == SYNCHRONIZATION_ALWAYS); + return prepareTransactionStatus(definition, null, true, newSynchronization, debugEnabled, null); + } + } + + /** + * Create a TransactionStatus for an existing transaction. + */ + private TransactionStatus handleExistingTransaction( + TransactionDefinition definition, Object transaction, boolean debugEnabled) + throws TransactionException { + + if (definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_NEVER) { + throw new IllegalTransactionStateException( + "Existing transaction found for transaction marked with propagation 'never'"); + } + + if (definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_NOT_SUPPORTED) { + if (debugEnabled) { + logger.debug("Suspending current transaction"); + } + Object suspendedResources = suspend(transaction); + boolean newSynchronization = (getTransactionSynchronization() == SYNCHRONIZATION_ALWAYS); + return prepareTransactionStatus( + definition, null, false, newSynchronization, debugEnabled, suspendedResources); + } + + if (definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRES_NEW) { + if (debugEnabled) { + logger.debug("Suspending current transaction, creating new transaction with name [" + + definition.getName() + "]"); + } + SuspendedResourcesHolder suspendedResources = suspend(transaction); + try { + boolean newSynchronization = (getTransactionSynchronization() != SYNCHRONIZATION_NEVER); + DefaultTransactionStatus status = newTransactionStatus( + definition, transaction, true, newSynchronization, debugEnabled, suspendedResources); + doBegin(transaction, definition); + prepareSynchronization(status, definition); + return status; + } + catch (RuntimeException | Error beginEx) { + resumeAfterBeginException(transaction, suspendedResources, beginEx); + throw beginEx; + } + } + + if (definition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_NESTED) { + if (!isNestedTransactionAllowed()) { + throw new NestedTransactionNotSupportedException( + "Transaction manager does not allow nested transactions by default - " + + "specify 'nestedTransactionAllowed' property with value 'true'"); + } + if (debugEnabled) { + logger.debug("Creating nested transaction with name [" + definition.getName() + "]"); + } + if (useSavepointForNestedTransaction()) { + // Create savepoint within existing Spring-managed transaction, + // through the SavepointManager API implemented by TransactionStatus. + // Usually uses JDBC 3.0 savepoints. Never activates Spring synchronization. + DefaultTransactionStatus status = + prepareTransactionStatus(definition, transaction, false, false, debugEnabled, null); + status.createAndHoldSavepoint(); + return status; + } + else { + // Nested transaction through nested begin and commit/rollback calls. + // Usually only for JTA: Spring synchronization might get activated here + // in case of a pre-existing JTA transaction. + boolean newSynchronization = (getTransactionSynchronization() != SYNCHRONIZATION_NEVER); + DefaultTransactionStatus status = newTransactionStatus( + definition, transaction, true, newSynchronization, debugEnabled, null); + doBegin(transaction, definition); + prepareSynchronization(status, definition); + return status; + } + } + + // Assumably PROPAGATION_SUPPORTS or PROPAGATION_REQUIRED. + if (debugEnabled) { + logger.debug("Participating in existing transaction"); + } + if (isValidateExistingTransaction()) { + if (definition.getIsolationLevel() != TransactionDefinition.ISOLATION_DEFAULT) { + Integer currentIsolationLevel = TransactionSynchronizationManager.getCurrentTransactionIsolationLevel(); + if (currentIsolationLevel == null || currentIsolationLevel != definition.getIsolationLevel()) { + Constants isoConstants = DefaultTransactionDefinition.constants; + throw new IllegalTransactionStateException("Participating transaction with definition [" + + definition + "] specifies isolation level which is incompatible with existing transaction: " + + (currentIsolationLevel != null ? + isoConstants.toCode(currentIsolationLevel, DefaultTransactionDefinition.PREFIX_ISOLATION) : + "(unknown)")); + } + } + if (!definition.isReadOnly()) { + if (TransactionSynchronizationManager.isCurrentTransactionReadOnly()) { + throw new IllegalTransactionStateException("Participating transaction with definition [" + + definition + "] is not marked as read-only but existing transaction is"); + } + } + } + boolean newSynchronization = (getTransactionSynchronization() != SYNCHRONIZATION_NEVER); + return prepareTransactionStatus(definition, transaction, false, newSynchronization, debugEnabled, null); + } + + /** + * Create a new TransactionStatus for the given arguments, + * also initializing transaction synchronization as appropriate. + * @see #newTransactionStatus + * @see #prepareTransactionStatus + */ + protected final DefaultTransactionStatus prepareTransactionStatus( + TransactionDefinition definition, @Nullable Object transaction, boolean newTransaction, + boolean newSynchronization, boolean debug, @Nullable Object suspendedResources) { + + DefaultTransactionStatus status = newTransactionStatus( + definition, transaction, newTransaction, newSynchronization, debug, suspendedResources); + prepareSynchronization(status, definition); + return status; + } + + /** + * Create a TransactionStatus instance for the given arguments. + */ + protected DefaultTransactionStatus newTransactionStatus( + TransactionDefinition definition, @Nullable Object transaction, boolean newTransaction, + boolean newSynchronization, boolean debug, @Nullable Object suspendedResources) { + + boolean actualNewSynchronization = newSynchronization && + !TransactionSynchronizationManager.isSynchronizationActive(); + return new DefaultTransactionStatus( + transaction, newTransaction, actualNewSynchronization, + definition.isReadOnly(), debug, suspendedResources); + } + + /** + * Initialize transaction synchronization as appropriate. + */ + protected void prepareSynchronization(DefaultTransactionStatus status, TransactionDefinition definition) { + if (status.isNewSynchronization()) { + TransactionSynchronizationManager.setActualTransactionActive(status.hasTransaction()); + TransactionSynchronizationManager.setCurrentTransactionIsolationLevel( + definition.getIsolationLevel() != TransactionDefinition.ISOLATION_DEFAULT ? + definition.getIsolationLevel() : null); + TransactionSynchronizationManager.setCurrentTransactionReadOnly(definition.isReadOnly()); + TransactionSynchronizationManager.setCurrentTransactionName(definition.getName()); + TransactionSynchronizationManager.initSynchronization(); + } + } + + /** + * Determine the actual timeout to use for the given definition. + * Will fall back to this manager's default timeout if the + * transaction definition doesn't specify a non-default value. + * @param definition the transaction definition + * @return the actual timeout to use + * @see org.springframework.transaction.TransactionDefinition#getTimeout() + * @see #setDefaultTimeout + */ + protected int determineTimeout(TransactionDefinition definition) { + if (definition.getTimeout() != TransactionDefinition.TIMEOUT_DEFAULT) { + return definition.getTimeout(); + } + return getDefaultTimeout(); + } + + + /** + * Suspend the given transaction. Suspends transaction synchronization first, + * then delegates to the {@code doSuspend} template method. + * @param transaction the current transaction object + * (or {@code null} to just suspend active synchronizations, if any) + * @return an object that holds suspended resources + * (or {@code null} if neither transaction nor synchronization active) + * @see #doSuspend + * @see #resume + */ + @Nullable + protected final SuspendedResourcesHolder suspend(@Nullable Object transaction) throws TransactionException { + if (TransactionSynchronizationManager.isSynchronizationActive()) { + List suspendedSynchronizations = doSuspendSynchronization(); + try { + Object suspendedResources = null; + if (transaction != null) { + suspendedResources = doSuspend(transaction); + } + String name = TransactionSynchronizationManager.getCurrentTransactionName(); + TransactionSynchronizationManager.setCurrentTransactionName(null); + boolean readOnly = TransactionSynchronizationManager.isCurrentTransactionReadOnly(); + TransactionSynchronizationManager.setCurrentTransactionReadOnly(false); + Integer isolationLevel = TransactionSynchronizationManager.getCurrentTransactionIsolationLevel(); + TransactionSynchronizationManager.setCurrentTransactionIsolationLevel(null); + boolean wasActive = TransactionSynchronizationManager.isActualTransactionActive(); + TransactionSynchronizationManager.setActualTransactionActive(false); + return new SuspendedResourcesHolder( + suspendedResources, suspendedSynchronizations, name, readOnly, isolationLevel, wasActive); + } + catch (RuntimeException | Error ex) { + // doSuspend failed - original transaction is still active... + doResumeSynchronization(suspendedSynchronizations); + throw ex; + } + } + else if (transaction != null) { + // Transaction active but no synchronization active. + Object suspendedResources = doSuspend(transaction); + return new SuspendedResourcesHolder(suspendedResources); + } + else { + // Neither transaction nor synchronization active. + return null; + } + } + + /** + * Resume the given transaction. Delegates to the {@code doResume} + * template method first, then resuming transaction synchronization. + * @param transaction the current transaction object + * @param resourcesHolder the object that holds suspended resources, + * as returned by {@code suspend} (or {@code null} to just + * resume synchronizations, if any) + * @see #doResume + * @see #suspend + */ + protected final void resume(@Nullable Object transaction, @Nullable SuspendedResourcesHolder resourcesHolder) + throws TransactionException { + + if (resourcesHolder != null) { + Object suspendedResources = resourcesHolder.suspendedResources; + if (suspendedResources != null) { + doResume(transaction, suspendedResources); + } + List suspendedSynchronizations = resourcesHolder.suspendedSynchronizations; + if (suspendedSynchronizations != null) { + TransactionSynchronizationManager.setActualTransactionActive(resourcesHolder.wasActive); + TransactionSynchronizationManager.setCurrentTransactionIsolationLevel(resourcesHolder.isolationLevel); + TransactionSynchronizationManager.setCurrentTransactionReadOnly(resourcesHolder.readOnly); + TransactionSynchronizationManager.setCurrentTransactionName(resourcesHolder.name); + doResumeSynchronization(suspendedSynchronizations); + } + } + } + + /** + * Resume outer transaction after inner transaction begin failed. + */ + private void resumeAfterBeginException( + Object transaction, @Nullable SuspendedResourcesHolder suspendedResources, Throwable beginEx) { + + String exMessage = "Inner transaction begin exception overridden by outer transaction resume exception"; + try { + resume(transaction, suspendedResources); + } + catch (RuntimeException | Error resumeEx) { + logger.error(exMessage, beginEx); + throw resumeEx; + } + } + + /** + * Suspend all current synchronizations and deactivate transaction + * synchronization for the current thread. + * @return the List of suspended TransactionSynchronization objects + */ + private List doSuspendSynchronization() { + List suspendedSynchronizations = + TransactionSynchronizationManager.getSynchronizations(); + for (TransactionSynchronization synchronization : suspendedSynchronizations) { + synchronization.suspend(); + } + TransactionSynchronizationManager.clearSynchronization(); + return suspendedSynchronizations; + } + + /** + * Reactivate transaction synchronization for the current thread + * and resume all given synchronizations. + * @param suspendedSynchronizations a List of TransactionSynchronization objects + */ + private void doResumeSynchronization(List suspendedSynchronizations) { + TransactionSynchronizationManager.initSynchronization(); + for (TransactionSynchronization synchronization : suspendedSynchronizations) { + synchronization.resume(); + TransactionSynchronizationManager.registerSynchronization(synchronization); + } + } + + + /** + * This implementation of commit handles participating in existing + * transactions and programmatic rollback requests. + * Delegates to {@code isRollbackOnly}, {@code doCommit} + * and {@code rollback}. + * @see org.springframework.transaction.TransactionStatus#isRollbackOnly() + * @see #doCommit + * @see #rollback + */ + @Override + public final void commit(TransactionStatus status) throws TransactionException { + if (status.isCompleted()) { + throw new IllegalTransactionStateException( + "Transaction is already completed - do not call commit or rollback more than once per transaction"); + } + + DefaultTransactionStatus defStatus = (DefaultTransactionStatus) status; + if (defStatus.isLocalRollbackOnly()) { + if (defStatus.isDebug()) { + logger.debug("Transactional code has requested rollback"); + } + processRollback(defStatus, false); + return; + } + + if (!shouldCommitOnGlobalRollbackOnly() && defStatus.isGlobalRollbackOnly()) { + if (defStatus.isDebug()) { + logger.debug("Global transaction is marked as rollback-only but transactional code requested commit"); + } + processRollback(defStatus, true); + return; + } + + processCommit(defStatus); + } + + /** + * Process an actual commit. + * Rollback-only flags have already been checked and applied. + * @param status object representing the transaction + * @throws TransactionException in case of commit failure + */ + private void processCommit(DefaultTransactionStatus status) throws TransactionException { + try { + boolean beforeCompletionInvoked = false; + + try { + boolean unexpectedRollback = false; + prepareForCommit(status); + triggerBeforeCommit(status); + triggerBeforeCompletion(status); + beforeCompletionInvoked = true; + + if (status.hasSavepoint()) { + if (status.isDebug()) { + logger.debug("Releasing transaction savepoint"); + } + unexpectedRollback = status.isGlobalRollbackOnly(); + status.releaseHeldSavepoint(); + } + else if (status.isNewTransaction()) { + if (status.isDebug()) { + logger.debug("Initiating transaction commit"); + } + unexpectedRollback = status.isGlobalRollbackOnly(); + doCommit(status); + } + else if (isFailEarlyOnGlobalRollbackOnly()) { + unexpectedRollback = status.isGlobalRollbackOnly(); + } + + // Throw UnexpectedRollbackException if we have a global rollback-only + // marker but still didn't get a corresponding exception from commit. + if (unexpectedRollback) { + throw new UnexpectedRollbackException( + "Transaction silently rolled back because it has been marked as rollback-only"); + } + } + catch (UnexpectedRollbackException ex) { + // can only be caused by doCommit + triggerAfterCompletion(status, TransactionSynchronization.STATUS_ROLLED_BACK); + throw ex; + } + catch (TransactionException ex) { + // can only be caused by doCommit + if (isRollbackOnCommitFailure()) { + doRollbackOnCommitException(status, ex); + } + else { + triggerAfterCompletion(status, TransactionSynchronization.STATUS_UNKNOWN); + } + throw ex; + } + catch (RuntimeException | Error ex) { + if (!beforeCompletionInvoked) { + triggerBeforeCompletion(status); + } + doRollbackOnCommitException(status, ex); + throw ex; + } + + // Trigger afterCommit callbacks, with an exception thrown there + // propagated to callers but the transaction still considered as committed. + try { + triggerAfterCommit(status); + } + finally { + triggerAfterCompletion(status, TransactionSynchronization.STATUS_COMMITTED); + } + + } + finally { + cleanupAfterCompletion(status); + } + } + + /** + * This implementation of rollback handles participating in existing + * transactions. Delegates to {@code doRollback} and + * {@code doSetRollbackOnly}. + * @see #doRollback + * @see #doSetRollbackOnly + */ + @Override + public final void rollback(TransactionStatus status) throws TransactionException { + if (status.isCompleted()) { + throw new IllegalTransactionStateException( + "Transaction is already completed - do not call commit or rollback more than once per transaction"); + } + + DefaultTransactionStatus defStatus = (DefaultTransactionStatus) status; + processRollback(defStatus, false); + } + + /** + * Process an actual rollback. + * The completed flag has already been checked. + * @param status object representing the transaction + * @throws TransactionException in case of rollback failure + */ + private void processRollback(DefaultTransactionStatus status, boolean unexpected) { + try { + boolean unexpectedRollback = unexpected; + + try { + triggerBeforeCompletion(status); + + if (status.hasSavepoint()) { + if (status.isDebug()) { + logger.debug("Rolling back transaction to savepoint"); + } + status.rollbackToHeldSavepoint(); + } + else if (status.isNewTransaction()) { + if (status.isDebug()) { + logger.debug("Initiating transaction rollback"); + } + doRollback(status); + } + else { + // Participating in larger transaction + if (status.hasTransaction()) { + if (status.isLocalRollbackOnly() || isGlobalRollbackOnParticipationFailure()) { + if (status.isDebug()) { + logger.debug("Participating transaction failed - marking existing transaction as rollback-only"); + } + doSetRollbackOnly(status); + } + else { + if (status.isDebug()) { + logger.debug("Participating transaction failed - letting transaction originator decide on rollback"); + } + } + } + else { + logger.debug("Should roll back transaction but cannot - no transaction available"); + } + // Unexpected rollback only matters here if we're asked to fail early + if (!isFailEarlyOnGlobalRollbackOnly()) { + unexpectedRollback = false; + } + } + } + catch (RuntimeException | Error ex) { + triggerAfterCompletion(status, TransactionSynchronization.STATUS_UNKNOWN); + throw ex; + } + + triggerAfterCompletion(status, TransactionSynchronization.STATUS_ROLLED_BACK); + + // Raise UnexpectedRollbackException if we had a global rollback-only marker + if (unexpectedRollback) { + throw new UnexpectedRollbackException( + "Transaction rolled back because it has been marked as rollback-only"); + } + } + finally { + cleanupAfterCompletion(status); + } + } + + /** + * Invoke {@code doRollback}, handling rollback exceptions properly. + * @param status object representing the transaction + * @param ex the thrown application exception or error + * @throws TransactionException in case of rollback failure + * @see #doRollback + */ + private void doRollbackOnCommitException(DefaultTransactionStatus status, Throwable ex) throws TransactionException { + try { + if (status.isNewTransaction()) { + if (status.isDebug()) { + logger.debug("Initiating transaction rollback after commit exception", ex); + } + doRollback(status); + } + else if (status.hasTransaction() && isGlobalRollbackOnParticipationFailure()) { + if (status.isDebug()) { + logger.debug("Marking existing transaction as rollback-only after commit exception", ex); + } + doSetRollbackOnly(status); + } + } + catch (RuntimeException | Error rbex) { + logger.error("Commit exception overridden by rollback exception", ex); + triggerAfterCompletion(status, TransactionSynchronization.STATUS_UNKNOWN); + throw rbex; + } + triggerAfterCompletion(status, TransactionSynchronization.STATUS_ROLLED_BACK); + } + + + /** + * Trigger {@code beforeCommit} callbacks. + * @param status object representing the transaction + */ + protected final void triggerBeforeCommit(DefaultTransactionStatus status) { + if (status.isNewSynchronization()) { + if (status.isDebug()) { + logger.trace("Triggering beforeCommit synchronization"); + } + TransactionSynchronizationUtils.triggerBeforeCommit(status.isReadOnly()); + } + } + + /** + * Trigger {@code beforeCompletion} callbacks. + * @param status object representing the transaction + */ + protected final void triggerBeforeCompletion(DefaultTransactionStatus status) { + if (status.isNewSynchronization()) { + if (status.isDebug()) { + logger.trace("Triggering beforeCompletion synchronization"); + } + TransactionSynchronizationUtils.triggerBeforeCompletion(); + } + } + + /** + * Trigger {@code afterCommit} callbacks. + * @param status object representing the transaction + */ + private void triggerAfterCommit(DefaultTransactionStatus status) { + if (status.isNewSynchronization()) { + if (status.isDebug()) { + logger.trace("Triggering afterCommit synchronization"); + } + TransactionSynchronizationUtils.triggerAfterCommit(); + } + } + + /** + * Trigger {@code afterCompletion} callbacks. + * @param status object representing the transaction + * @param completionStatus completion status according to TransactionSynchronization constants + */ + private void triggerAfterCompletion(DefaultTransactionStatus status, int completionStatus) { + if (status.isNewSynchronization()) { + List synchronizations = TransactionSynchronizationManager.getSynchronizations(); + TransactionSynchronizationManager.clearSynchronization(); + if (!status.hasTransaction() || status.isNewTransaction()) { + if (status.isDebug()) { + logger.trace("Triggering afterCompletion synchronization"); + } + // No transaction or new transaction for the current scope -> + // invoke the afterCompletion callbacks immediately + invokeAfterCompletion(synchronizations, completionStatus); + } + else if (!synchronizations.isEmpty()) { + // Existing transaction that we participate in, controlled outside + // of the scope of this Spring transaction manager -> try to register + // an afterCompletion callback with the existing (JTA) transaction. + registerAfterCompletionWithExistingTransaction(status.getTransaction(), synchronizations); + } + } + } + + /** + * Actually invoke the {@code afterCompletion} methods of the + * given Spring TransactionSynchronization objects. + *

To be called by this abstract manager itself, or by special implementations + * of the {@code registerAfterCompletionWithExistingTransaction} callback. + * @param synchronizations a List of TransactionSynchronization objects + * @param completionStatus the completion status according to the + * constants in the TransactionSynchronization interface + * @see #registerAfterCompletionWithExistingTransaction(Object, java.util.List) + * @see TransactionSynchronization#STATUS_COMMITTED + * @see TransactionSynchronization#STATUS_ROLLED_BACK + * @see TransactionSynchronization#STATUS_UNKNOWN + */ + protected final void invokeAfterCompletion(List synchronizations, int completionStatus) { + TransactionSynchronizationUtils.invokeAfterCompletion(synchronizations, completionStatus); + } + + /** + * Clean up after completion, clearing synchronization if necessary, + * and invoking doCleanupAfterCompletion. + * @param status object representing the transaction + * @see #doCleanupAfterCompletion + */ + private void cleanupAfterCompletion(DefaultTransactionStatus status) { + status.setCompleted(); + if (status.isNewSynchronization()) { + TransactionSynchronizationManager.clear(); + } + if (status.isNewTransaction()) { + doCleanupAfterCompletion(status.getTransaction()); + } + if (status.getSuspendedResources() != null) { + if (status.isDebug()) { + logger.debug("Resuming suspended transaction after completion of inner transaction"); + } + Object transaction = (status.hasTransaction() ? status.getTransaction() : null); + resume(transaction, (SuspendedResourcesHolder) status.getSuspendedResources()); + } + } + + + //--------------------------------------------------------------------- + // Template methods to be implemented in subclasses + //--------------------------------------------------------------------- + + /** + * Return a transaction object for the current transaction state. + *

The returned object will usually be specific to the concrete transaction + * manager implementation, carrying corresponding transaction state in a + * modifiable fashion. This object will be passed into the other template + * methods (e.g. doBegin and doCommit), either directly or as part of a + * DefaultTransactionStatus instance. + *

The returned object should contain information about any existing + * transaction, that is, a transaction that has already started before the + * current {@code getTransaction} call on the transaction manager. + * Consequently, a {@code doGetTransaction} implementation will usually + * look for an existing transaction and store corresponding state in the + * returned transaction object. + * @return the current transaction object + * @throws org.springframework.transaction.CannotCreateTransactionException + * if transaction support is not available + * @throws TransactionException in case of lookup or system errors + * @see #doBegin + * @see #doCommit + * @see #doRollback + * @see DefaultTransactionStatus#getTransaction + */ + protected abstract Object doGetTransaction() throws TransactionException; + + /** + * Check if the given transaction object indicates an existing transaction + * (that is, a transaction which has already started). + *

The result will be evaluated according to the specified propagation + * behavior for the new transaction. An existing transaction might get + * suspended (in case of PROPAGATION_REQUIRES_NEW), or the new transaction + * might participate in the existing one (in case of PROPAGATION_REQUIRED). + *

The default implementation returns {@code false}, assuming that + * participating in existing transactions is generally not supported. + * Subclasses are of course encouraged to provide such support. + * @param transaction transaction object returned by doGetTransaction + * @return if there is an existing transaction + * @throws TransactionException in case of system errors + * @see #doGetTransaction + */ + protected boolean isExistingTransaction(Object transaction) throws TransactionException { + return false; + } + + /** + * Return whether to use a savepoint for a nested transaction. + *

Default is {@code true}, which causes delegation to DefaultTransactionStatus + * for creating and holding a savepoint. If the transaction object does not implement + * the SavepointManager interface, a NestedTransactionNotSupportedException will be + * thrown. Else, the SavepointManager will be asked to create a new savepoint to + * demarcate the start of the nested transaction. + *

Subclasses can override this to return {@code false}, causing a further + * call to {@code doBegin} - within the context of an already existing transaction. + * The {@code doBegin} implementation needs to handle this accordingly in such + * a scenario. This is appropriate for JTA, for example. + * @see DefaultTransactionStatus#createAndHoldSavepoint + * @see DefaultTransactionStatus#rollbackToHeldSavepoint + * @see DefaultTransactionStatus#releaseHeldSavepoint + * @see #doBegin + */ + protected boolean useSavepointForNestedTransaction() { + return true; + } + + /** + * Begin a new transaction with semantics according to the given transaction + * definition. Does not have to care about applying the propagation behavior, + * as this has already been handled by this abstract manager. + *

This method gets called when the transaction manager has decided to actually + * start a new transaction. Either there wasn't any transaction before, or the + * previous transaction has been suspended. + *

A special scenario is a nested transaction without savepoint: If + * {@code useSavepointForNestedTransaction()} returns "false", this method + * will be called to start a nested transaction when necessary. In such a context, + * there will be an active transaction: The implementation of this method has + * to detect this and start an appropriate nested transaction. + * @param transaction transaction object returned by {@code doGetTransaction} + * @param definition a TransactionDefinition instance, describing propagation + * behavior, isolation level, read-only flag, timeout, and transaction name + * @throws TransactionException in case of creation or system errors + * @throws org.springframework.transaction.NestedTransactionNotSupportedException + * if the underlying transaction does not support nesting + */ + protected abstract void doBegin(Object transaction, TransactionDefinition definition) + throws TransactionException; + + /** + * Suspend the resources of the current transaction. + * Transaction synchronization will already have been suspended. + *

The default implementation throws a TransactionSuspensionNotSupportedException, + * assuming that transaction suspension is generally not supported. + * @param transaction transaction object returned by {@code doGetTransaction} + * @return an object that holds suspended resources + * (will be kept unexamined for passing it into doResume) + * @throws org.springframework.transaction.TransactionSuspensionNotSupportedException + * if suspending is not supported by the transaction manager implementation + * @throws TransactionException in case of system errors + * @see #doResume + */ + protected Object doSuspend(Object transaction) throws TransactionException { + throw new TransactionSuspensionNotSupportedException( + "Transaction manager [" + getClass().getName() + "] does not support transaction suspension"); + } + + /** + * Resume the resources of the current transaction. + * Transaction synchronization will be resumed afterwards. + *

The default implementation throws a TransactionSuspensionNotSupportedException, + * assuming that transaction suspension is generally not supported. + * @param transaction transaction object returned by {@code doGetTransaction} + * @param suspendedResources the object that holds suspended resources, + * as returned by doSuspend + * @throws org.springframework.transaction.TransactionSuspensionNotSupportedException + * if resuming is not supported by the transaction manager implementation + * @throws TransactionException in case of system errors + * @see #doSuspend + */ + protected void doResume(@Nullable Object transaction, Object suspendedResources) throws TransactionException { + throw new TransactionSuspensionNotSupportedException( + "Transaction manager [" + getClass().getName() + "] does not support transaction suspension"); + } + + /** + * Return whether to call {@code doCommit} on a transaction that has been + * marked as rollback-only in a global fashion. + *

Does not apply if an application locally sets the transaction to rollback-only + * via the TransactionStatus, but only to the transaction itself being marked as + * rollback-only by the transaction coordinator. + *

Default is "false": Local transaction strategies usually don't hold the rollback-only + * marker in the transaction itself, therefore they can't handle rollback-only transactions + * as part of transaction commit. Hence, AbstractPlatformTransactionManager will trigger + * a rollback in that case, throwing an UnexpectedRollbackException afterwards. + *

Override this to return "true" if the concrete transaction manager expects a + * {@code doCommit} call even for a rollback-only transaction, allowing for + * special handling there. This will, for example, be the case for JTA, where + * {@code UserTransaction.commit} will check the read-only flag itself and + * throw a corresponding RollbackException, which might include the specific reason + * (such as a transaction timeout). + *

If this method returns "true" but the {@code doCommit} implementation does not + * throw an exception, this transaction manager will throw an UnexpectedRollbackException + * itself. This should not be the typical case; it is mainly checked to cover misbehaving + * JTA providers that silently roll back even when the rollback has not been requested + * by the calling code. + * @see #doCommit + * @see DefaultTransactionStatus#isGlobalRollbackOnly() + * @see DefaultTransactionStatus#isLocalRollbackOnly() + * @see org.springframework.transaction.TransactionStatus#setRollbackOnly() + * @see org.springframework.transaction.UnexpectedRollbackException + * @see javax.transaction.UserTransaction#commit() + * @see javax.transaction.RollbackException + */ + protected boolean shouldCommitOnGlobalRollbackOnly() { + return false; + } + + /** + * Make preparations for commit, to be performed before the + * {@code beforeCommit} synchronization callbacks occur. + *

Note that exceptions will get propagated to the commit caller + * and cause a rollback of the transaction. + * @param status the status representation of the transaction + * @throws RuntimeException in case of errors; will be propagated to the caller + * (note: do not throw TransactionException subclasses here!) + */ + protected void prepareForCommit(DefaultTransactionStatus status) { + } + + /** + * Perform an actual commit of the given transaction. + *

An implementation does not need to check the "new transaction" flag + * or the rollback-only flag; this will already have been handled before. + * Usually, a straight commit will be performed on the transaction object + * contained in the passed-in status. + * @param status the status representation of the transaction + * @throws TransactionException in case of commit or system errors + * @see DefaultTransactionStatus#getTransaction + */ + protected abstract void doCommit(DefaultTransactionStatus status) throws TransactionException; + + /** + * Perform an actual rollback of the given transaction. + *

An implementation does not need to check the "new transaction" flag; + * this will already have been handled before. Usually, a straight rollback + * will be performed on the transaction object contained in the passed-in status. + * @param status the status representation of the transaction + * @throws TransactionException in case of system errors + * @see DefaultTransactionStatus#getTransaction + */ + protected abstract void doRollback(DefaultTransactionStatus status) throws TransactionException; + + /** + * Set the given transaction rollback-only. Only called on rollback + * if the current transaction participates in an existing one. + *

The default implementation throws an IllegalTransactionStateException, + * assuming that participating in existing transactions is generally not + * supported. Subclasses are of course encouraged to provide such support. + * @param status the status representation of the transaction + * @throws TransactionException in case of system errors + */ + protected void doSetRollbackOnly(DefaultTransactionStatus status) throws TransactionException { + throw new IllegalTransactionStateException( + "Participating in existing transactions is not supported - when 'isExistingTransaction' " + + "returns true, appropriate 'doSetRollbackOnly' behavior must be provided"); + } + + /** + * Register the given list of transaction synchronizations with the existing transaction. + *

Invoked when the control of the Spring transaction manager and thus all Spring + * transaction synchronizations end, without the transaction being completed yet. This + * is for example the case when participating in an existing JTA or EJB CMT transaction. + *

The default implementation simply invokes the {@code afterCompletion} methods + * immediately, passing in "STATUS_UNKNOWN". This is the best we can do if there's no + * chance to determine the actual outcome of the outer transaction. + * @param transaction transaction object returned by {@code doGetTransaction} + * @param synchronizations a List of TransactionSynchronization objects + * @throws TransactionException in case of system errors + * @see #invokeAfterCompletion(java.util.List, int) + * @see TransactionSynchronization#afterCompletion(int) + * @see TransactionSynchronization#STATUS_UNKNOWN + */ + protected void registerAfterCompletionWithExistingTransaction( + Object transaction, List synchronizations) throws TransactionException { + + logger.debug("Cannot register Spring after-completion synchronization with existing transaction - " + + "processing Spring after-completion callbacks immediately, with outcome status 'unknown'"); + invokeAfterCompletion(synchronizations, TransactionSynchronization.STATUS_UNKNOWN); + } + + /** + * Cleanup resources after transaction completion. + *

Called after {@code doCommit} and {@code doRollback} execution, + * on any outcome. The default implementation does nothing. + *

Should not throw any exceptions but just issue warnings on errors. + * @param transaction transaction object returned by {@code doGetTransaction} + */ + protected void doCleanupAfterCompletion(Object transaction) { + } + + + //--------------------------------------------------------------------- + // Serialization support + //--------------------------------------------------------------------- + + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + // Rely on default serialization; just initialize state after deserialization. + ois.defaultReadObject(); + + // Initialize transient fields. + this.logger = LogFactory.getLog(getClass()); + } + + + /** + * Holder for suspended resources. + * Used internally by {@code suspend} and {@code resume}. + */ + protected static final class SuspendedResourcesHolder { + + @Nullable + private final Object suspendedResources; + + @Nullable + private List suspendedSynchronizations; + + @Nullable + private String name; + + private boolean readOnly; + + @Nullable + private Integer isolationLevel; + + private boolean wasActive; + + private SuspendedResourcesHolder(Object suspendedResources) { + this.suspendedResources = suspendedResources; + } + + private SuspendedResourcesHolder( + @Nullable Object suspendedResources, List suspendedSynchronizations, + @Nullable String name, boolean readOnly, @Nullable Integer isolationLevel, boolean wasActive) { + + this.suspendedResources = suspendedResources; + this.suspendedSynchronizations = suspendedSynchronizations; + this.name = name; + this.readOnly = readOnly; + this.isolationLevel = isolationLevel; + this.wasActive = wasActive; + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/AbstractTransactionStatus.java b/spring-tx/src/main/java/org/springframework/transaction/support/AbstractTransactionStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..f82c96a56c41642a0688e95e654a3094937e1333 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/AbstractTransactionStatus.java @@ -0,0 +1,226 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.NestedTransactionNotSupportedException; +import org.springframework.transaction.SavepointManager; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.TransactionUsageException; + +/** + * Abstract base implementation of the + * {@link org.springframework.transaction.TransactionStatus} interface. + * + *

Pre-implements the handling of local rollback-only and completed flags, and + * delegation to an underlying {@link org.springframework.transaction.SavepointManager}. + * Also offers the option of a holding a savepoint within the transaction. + * + *

Does not assume any specific internal transaction handling, such as an + * underlying transaction object, and no transaction synchronization mechanism. + * + * @author Juergen Hoeller + * @since 1.2.3 + * @see #setRollbackOnly() + * @see #isRollbackOnly() + * @see #setCompleted() + * @see #isCompleted() + * @see #getSavepointManager() + * @see SimpleTransactionStatus + * @see DefaultTransactionStatus + */ +public abstract class AbstractTransactionStatus implements TransactionStatus { + + private boolean rollbackOnly = false; + + private boolean completed = false; + + @Nullable + private Object savepoint; + + + //--------------------------------------------------------------------- + // Handling of current transaction state + //--------------------------------------------------------------------- + + @Override + public void setRollbackOnly() { + this.rollbackOnly = true; + } + + /** + * Determine the rollback-only flag via checking both the local rollback-only flag + * of this TransactionStatus and the global rollback-only flag of the underlying + * transaction, if any. + * @see #isLocalRollbackOnly() + * @see #isGlobalRollbackOnly() + */ + @Override + public boolean isRollbackOnly() { + return (isLocalRollbackOnly() || isGlobalRollbackOnly()); + } + + /** + * Determine the rollback-only flag via checking this TransactionStatus. + *

Will only return "true" if the application called {@code setRollbackOnly} + * on this TransactionStatus object. + */ + public boolean isLocalRollbackOnly() { + return this.rollbackOnly; + } + + /** + * Template method for determining the global rollback-only flag of the + * underlying transaction, if any. + *

This implementation always returns {@code false}. + */ + public boolean isGlobalRollbackOnly() { + return false; + } + + /** + * This implementations is empty, considering flush as a no-op. + */ + @Override + public void flush() { + } + + /** + * Mark this transaction as completed, that is, committed or rolled back. + */ + public void setCompleted() { + this.completed = true; + } + + @Override + public boolean isCompleted() { + return this.completed; + } + + + //--------------------------------------------------------------------- + // Handling of current savepoint state + //--------------------------------------------------------------------- + + /** + * Set a savepoint for this transaction. Useful for PROPAGATION_NESTED. + * @see org.springframework.transaction.TransactionDefinition#PROPAGATION_NESTED + */ + protected void setSavepoint(@Nullable Object savepoint) { + this.savepoint = savepoint; + } + + /** + * Get the savepoint for this transaction, if any. + */ + @Nullable + protected Object getSavepoint() { + return this.savepoint; + } + + @Override + public boolean hasSavepoint() { + return (this.savepoint != null); + } + + /** + * Create a savepoint and hold it for the transaction. + * @throws org.springframework.transaction.NestedTransactionNotSupportedException + * if the underlying transaction does not support savepoints + */ + public void createAndHoldSavepoint() throws TransactionException { + setSavepoint(getSavepointManager().createSavepoint()); + } + + /** + * Roll back to the savepoint that is held for the transaction + * and release the savepoint right afterwards. + */ + public void rollbackToHeldSavepoint() throws TransactionException { + Object savepoint = getSavepoint(); + if (savepoint == null) { + throw new TransactionUsageException( + "Cannot roll back to savepoint - no savepoint associated with current transaction"); + } + getSavepointManager().rollbackToSavepoint(savepoint); + getSavepointManager().releaseSavepoint(savepoint); + setSavepoint(null); + } + + /** + * Release the savepoint that is held for the transaction. + */ + public void releaseHeldSavepoint() throws TransactionException { + Object savepoint = getSavepoint(); + if (savepoint == null) { + throw new TransactionUsageException( + "Cannot release savepoint - no savepoint associated with current transaction"); + } + getSavepointManager().releaseSavepoint(savepoint); + setSavepoint(null); + } + + + //--------------------------------------------------------------------- + // Implementation of SavepointManager + //--------------------------------------------------------------------- + + /** + * This implementation delegates to a SavepointManager for the + * underlying transaction, if possible. + * @see #getSavepointManager() + * @see SavepointManager#createSavepoint() + */ + @Override + public Object createSavepoint() throws TransactionException { + return getSavepointManager().createSavepoint(); + } + + /** + * This implementation delegates to a SavepointManager for the + * underlying transaction, if possible. + * @see #getSavepointManager() + * @see SavepointManager#rollbackToSavepoint(Object) + */ + @Override + public void rollbackToSavepoint(Object savepoint) throws TransactionException { + getSavepointManager().rollbackToSavepoint(savepoint); + } + + /** + * This implementation delegates to a SavepointManager for the + * underlying transaction, if possible. + * @see #getSavepointManager() + * @see SavepointManager#releaseSavepoint(Object) + */ + @Override + public void releaseSavepoint(Object savepoint) throws TransactionException { + getSavepointManager().releaseSavepoint(savepoint); + } + + /** + * Return a SavepointManager for the underlying transaction, if possible. + *

Default implementation always throws a NestedTransactionNotSupportedException. + * @throws org.springframework.transaction.NestedTransactionNotSupportedException + * if the underlying transaction does not support savepoints + */ + protected SavepointManager getSavepointManager() { + throw new NestedTransactionNotSupportedException("This transaction does not support savepoints"); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/CallbackPreferringPlatformTransactionManager.java b/spring-tx/src/main/java/org/springframework/transaction/support/CallbackPreferringPlatformTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..5cc9ee59f5fb5eb60969b81d0cd14e08d1601f96 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/CallbackPreferringPlatformTransactionManager.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; + +/** + * Extension of the {@link org.springframework.transaction.PlatformTransactionManager} + * interface, exposing a method for executing a given callback within a transaction. + * + *

Implementors of this interface automatically express a preference for + * callbacks over programmatic {@code getTransaction}, {@code commit} + * and {@code rollback} calls. Calling code may check whether a given + * transaction manager implements this interface to choose to prepare a + * callback instead of explicit transaction demarcation control. + * + *

Spring's {@link TransactionTemplate} and + * {@link org.springframework.transaction.interceptor.TransactionInterceptor} + * detect and use this PlatformTransactionManager variant automatically. + * + * @author Juergen Hoeller + * @since 2.0 + * @see TransactionTemplate + * @see org.springframework.transaction.interceptor.TransactionInterceptor + */ +public interface CallbackPreferringPlatformTransactionManager extends PlatformTransactionManager { + + /** + * Execute the action specified by the given callback object within a transaction. + *

Allows for returning a result object created within the transaction, that is, + * a domain object or a collection of domain objects. A RuntimeException thrown + * by the callback is treated as a fatal exception that enforces a rollback. + * Such an exception gets propagated to the caller of the template. + * @param definition the definition for the transaction to wrap the callback in + * @param callback the callback object that specifies the transactional action + * @return a result object returned by the callback, or {@code null} if none + * @throws TransactionException in case of initialization, rollback, or system errors + * @throws RuntimeException if thrown by the TransactionCallback + */ + @Nullable + T execute(@Nullable TransactionDefinition definition, TransactionCallback callback) + throws TransactionException; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/DefaultTransactionDefinition.java b/spring-tx/src/main/java/org/springframework/transaction/support/DefaultTransactionDefinition.java new file mode 100644 index 0000000000000000000000000000000000000000..973ae911839ef57566878ee637e4c7026a1b3796 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/DefaultTransactionDefinition.java @@ -0,0 +1,311 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.io.Serializable; + +import org.springframework.core.Constants; +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionDefinition; + +/** + * Default implementation of the {@link TransactionDefinition} interface, + * offering bean-style configuration and sensible default values + * (PROPAGATION_REQUIRED, ISOLATION_DEFAULT, TIMEOUT_DEFAULT, readOnly=false). + * + *

Base class for both {@link TransactionTemplate} and + * {@link org.springframework.transaction.interceptor.DefaultTransactionAttribute}. + * + * @author Juergen Hoeller + * @since 08.05.2003 + */ +@SuppressWarnings("serial") +public class DefaultTransactionDefinition implements TransactionDefinition, Serializable { + + /** Prefix for the propagation constants defined in TransactionDefinition. */ + public static final String PREFIX_PROPAGATION = "PROPAGATION_"; + + /** Prefix for the isolation constants defined in TransactionDefinition. */ + public static final String PREFIX_ISOLATION = "ISOLATION_"; + + /** Prefix for transaction timeout values in description strings. */ + public static final String PREFIX_TIMEOUT = "timeout_"; + + /** Marker for read-only transactions in description strings. */ + public static final String READ_ONLY_MARKER = "readOnly"; + + + /** Constants instance for TransactionDefinition. */ + static final Constants constants = new Constants(TransactionDefinition.class); + + private int propagationBehavior = PROPAGATION_REQUIRED; + + private int isolationLevel = ISOLATION_DEFAULT; + + private int timeout = TIMEOUT_DEFAULT; + + private boolean readOnly = false; + + @Nullable + private String name; + + + /** + * Create a new DefaultTransactionDefinition, with default settings. + * Can be modified through bean property setters. + * @see #setPropagationBehavior + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + * @see #setName + */ + public DefaultTransactionDefinition() { + } + + /** + * Copy constructor. Definition can be modified through bean property setters. + * @see #setPropagationBehavior + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + * @see #setName + */ + public DefaultTransactionDefinition(TransactionDefinition other) { + this.propagationBehavior = other.getPropagationBehavior(); + this.isolationLevel = other.getIsolationLevel(); + this.timeout = other.getTimeout(); + this.readOnly = other.isReadOnly(); + this.name = other.getName(); + } + + /** + * Create a new DefaultTransactionDefinition with the given + * propagation behavior. Can be modified through bean property setters. + * @param propagationBehavior one of the propagation constants in the + * TransactionDefinition interface + * @see #setIsolationLevel + * @see #setTimeout + * @see #setReadOnly + */ + public DefaultTransactionDefinition(int propagationBehavior) { + this.propagationBehavior = propagationBehavior; + } + + + /** + * Set the propagation behavior by the name of the corresponding constant in + * TransactionDefinition, e.g. "PROPAGATION_REQUIRED". + * @param constantName name of the constant + * @throws IllegalArgumentException if the supplied value is not resolvable + * to one of the {@code PROPAGATION_} constants or is {@code null} + * @see #setPropagationBehavior + * @see #PROPAGATION_REQUIRED + */ + public final void setPropagationBehaviorName(String constantName) throws IllegalArgumentException { + if (!constantName.startsWith(PREFIX_PROPAGATION)) { + throw new IllegalArgumentException("Only propagation constants allowed"); + } + setPropagationBehavior(constants.asNumber(constantName).intValue()); + } + + /** + * Set the propagation behavior. Must be one of the propagation constants + * in the TransactionDefinition interface. Default is PROPAGATION_REQUIRED. + *

Exclusively designed for use with {@link #PROPAGATION_REQUIRED} or + * {@link #PROPAGATION_REQUIRES_NEW} since it only applies to newly started + * transactions. Consider switching the "validateExistingTransactions" flag to + * "true" on your transaction manager if you'd like isolation level declarations + * to get rejected when participating in an existing transaction with a different + * isolation level. + *

Note that a transaction manager that does not support custom isolation levels + * will throw an exception when given any other level than {@link #ISOLATION_DEFAULT}. + * @throws IllegalArgumentException if the supplied value is not one of the + * {@code PROPAGATION_} constants + * @see #PROPAGATION_REQUIRED + */ + public final void setPropagationBehavior(int propagationBehavior) { + if (!constants.getValues(PREFIX_PROPAGATION).contains(propagationBehavior)) { + throw new IllegalArgumentException("Only values of propagation constants allowed"); + } + this.propagationBehavior = propagationBehavior; + } + + @Override + public final int getPropagationBehavior() { + return this.propagationBehavior; + } + + /** + * Set the isolation level by the name of the corresponding constant in + * TransactionDefinition, e.g. "ISOLATION_DEFAULT". + * @param constantName name of the constant + * @throws IllegalArgumentException if the supplied value is not resolvable + * to one of the {@code ISOLATION_} constants or is {@code null} + * @see #setIsolationLevel + * @see #ISOLATION_DEFAULT + */ + public final void setIsolationLevelName(String constantName) throws IllegalArgumentException { + if (!constantName.startsWith(PREFIX_ISOLATION)) { + throw new IllegalArgumentException("Only isolation constants allowed"); + } + setIsolationLevel(constants.asNumber(constantName).intValue()); + } + + /** + * Set the isolation level. Must be one of the isolation constants + * in the TransactionDefinition interface. Default is ISOLATION_DEFAULT. + *

Exclusively designed for use with {@link #PROPAGATION_REQUIRED} or + * {@link #PROPAGATION_REQUIRES_NEW} since it only applies to newly started + * transactions. Consider switching the "validateExistingTransactions" flag to + * "true" on your transaction manager if you'd like isolation level declarations + * to get rejected when participating in an existing transaction with a different + * isolation level. + *

Note that a transaction manager that does not support custom isolation levels + * will throw an exception when given any other level than {@link #ISOLATION_DEFAULT}. + * @throws IllegalArgumentException if the supplied value is not one of the + * {@code ISOLATION_} constants + * @see #ISOLATION_DEFAULT + */ + public final void setIsolationLevel(int isolationLevel) { + if (!constants.getValues(PREFIX_ISOLATION).contains(isolationLevel)) { + throw new IllegalArgumentException("Only values of isolation constants allowed"); + } + this.isolationLevel = isolationLevel; + } + + @Override + public final int getIsolationLevel() { + return this.isolationLevel; + } + + /** + * Set the timeout to apply, as number of seconds. + * Default is TIMEOUT_DEFAULT (-1). + *

Exclusively designed for use with {@link #PROPAGATION_REQUIRED} or + * {@link #PROPAGATION_REQUIRES_NEW} since it only applies to newly started + * transactions. + *

Note that a transaction manager that does not support timeouts will throw + * an exception when given any other timeout than {@link #TIMEOUT_DEFAULT}. + * @see #TIMEOUT_DEFAULT + */ + public final void setTimeout(int timeout) { + if (timeout < TIMEOUT_DEFAULT) { + throw new IllegalArgumentException("Timeout must be a positive integer or TIMEOUT_DEFAULT"); + } + this.timeout = timeout; + } + + @Override + public final int getTimeout() { + return this.timeout; + } + + /** + * Set whether to optimize as read-only transaction. + * Default is "false". + *

The read-only flag applies to any transaction context, whether backed + * by an actual resource transaction ({@link #PROPAGATION_REQUIRED}/ + * {@link #PROPAGATION_REQUIRES_NEW}) or operating non-transactionally at + * the resource level ({@link #PROPAGATION_SUPPORTS}). In the latter case, + * the flag will only apply to managed resources within the application, + * such as a Hibernate {@code Session}. + *

This just serves as a hint for the actual transaction subsystem; + * it will not necessarily cause failure of write access attempts. + * A transaction manager which cannot interpret the read-only hint will + * not throw an exception when asked for a read-only transaction. + */ + public final void setReadOnly(boolean readOnly) { + this.readOnly = readOnly; + } + + @Override + public final boolean isReadOnly() { + return this.readOnly; + } + + /** + * Set the name of this transaction. Default is none. + *

This will be used as transaction name to be shown in a + * transaction monitor, if applicable (for example, WebLogic's). + */ + public final void setName(String name) { + this.name = name; + } + + @Override + @Nullable + public final String getName() { + return this.name; + } + + + /** + * This implementation compares the {@code toString()} results. + * @see #toString() + */ + @Override + public boolean equals(Object other) { + return (this == other || (other instanceof TransactionDefinition && toString().equals(other.toString()))); + } + + /** + * This implementation returns {@code toString()}'s hash code. + * @see #toString() + */ + @Override + public int hashCode() { + return toString().hashCode(); + } + + /** + * Return an identifying description for this transaction definition. + *

The format matches the one used by + * {@link org.springframework.transaction.interceptor.TransactionAttributeEditor}, + * to be able to feed {@code toString} results into bean properties of type + * {@link org.springframework.transaction.interceptor.TransactionAttribute}. + *

Has to be overridden in subclasses for correct {@code equals} + * and {@code hashCode} behavior. Alternatively, {@link #equals} + * and {@link #hashCode} can be overridden themselves. + * @see #getDefinitionDescription() + * @see org.springframework.transaction.interceptor.TransactionAttributeEditor + */ + @Override + public String toString() { + return getDefinitionDescription().toString(); + } + + /** + * Return an identifying description for this transaction definition. + *

Available to subclasses, for inclusion in their {@code toString()} result. + */ + protected final StringBuilder getDefinitionDescription() { + StringBuilder result = new StringBuilder(); + result.append(constants.toCode(this.propagationBehavior, PREFIX_PROPAGATION)); + result.append(','); + result.append(constants.toCode(this.isolationLevel, PREFIX_ISOLATION)); + if (this.timeout != TIMEOUT_DEFAULT) { + result.append(','); + result.append(PREFIX_TIMEOUT).append(this.timeout); + } + if (this.readOnly) { + result.append(','); + result.append(READ_ONLY_MARKER); + } + return result; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/DefaultTransactionStatus.java b/spring-tx/src/main/java/org/springframework/transaction/support/DefaultTransactionStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..574789a2e492f627664906b495da3b044ca251ed --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/DefaultTransactionStatus.java @@ -0,0 +1,206 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.NestedTransactionNotSupportedException; +import org.springframework.transaction.SavepointManager; +import org.springframework.util.Assert; + +/** + * Default implementation of the {@link org.springframework.transaction.TransactionStatus} + * interface, used by {@link AbstractPlatformTransactionManager}. Based on the concept + * of an underlying "transaction object". + * + *

Holds all status information that {@link AbstractPlatformTransactionManager} + * needs internally, including a generic transaction object determined by the + * concrete transaction manager implementation. + * + *

Supports delegating savepoint-related methods to a transaction object + * that implements the {@link SavepointManager} interface. + * + *

NOTE: This is not intended for use with other PlatformTransactionManager + * implementations, in particular not for mock transaction managers in testing environments. + * Use the alternative {@link SimpleTransactionStatus} class or a mock for the plain + * {@link org.springframework.transaction.TransactionStatus} interface instead. + * + * @author Juergen Hoeller + * @since 19.01.2004 + * @see AbstractPlatformTransactionManager + * @see org.springframework.transaction.SavepointManager + * @see #getTransaction + * @see #createSavepoint + * @see #rollbackToSavepoint + * @see #releaseSavepoint + * @see SimpleTransactionStatus + */ +public class DefaultTransactionStatus extends AbstractTransactionStatus { + + @Nullable + private final Object transaction; + + private final boolean newTransaction; + + private final boolean newSynchronization; + + private final boolean readOnly; + + private final boolean debug; + + @Nullable + private final Object suspendedResources; + + + /** + * Create a new {@code DefaultTransactionStatus} instance. + * @param transaction underlying transaction object that can hold state + * for the internal transaction implementation + * @param newTransaction if the transaction is new, otherwise participating + * in an existing transaction + * @param newSynchronization if a new transaction synchronization has been + * opened for the given transaction + * @param readOnly whether the transaction is marked as read-only + * @param debug should debug logging be enabled for the handling of this transaction? + * Caching it in here can prevent repeated calls to ask the logging system whether + * debug logging should be enabled. + * @param suspendedResources a holder for resources that have been suspended + * for this transaction, if any + */ + public DefaultTransactionStatus( + @Nullable Object transaction, boolean newTransaction, boolean newSynchronization, + boolean readOnly, boolean debug, @Nullable Object suspendedResources) { + + this.transaction = transaction; + this.newTransaction = newTransaction; + this.newSynchronization = newSynchronization; + this.readOnly = readOnly; + this.debug = debug; + this.suspendedResources = suspendedResources; + } + + + /** + * Return the underlying transaction object. + * @throws IllegalStateException if no transaction is active + */ + public Object getTransaction() { + Assert.state(this.transaction != null, "No transaction active"); + return this.transaction; + } + + /** + * Return whether there is an actual transaction active. + */ + public boolean hasTransaction() { + return (this.transaction != null); + } + + @Override + public boolean isNewTransaction() { + return (hasTransaction() && this.newTransaction); + } + + /** + * Return if a new transaction synchronization has been opened + * for this transaction. + */ + public boolean isNewSynchronization() { + return this.newSynchronization; + } + + /** + * Return if this transaction is defined as read-only transaction. + */ + public boolean isReadOnly() { + return this.readOnly; + } + + /** + * Return whether the progress of this transaction is debugged. This is used by + * {@link AbstractPlatformTransactionManager} as an optimization, to prevent repeated + * calls to {@code logger.isDebugEnabled()}. Not really intended for client code. + */ + public boolean isDebug() { + return this.debug; + } + + /** + * Return the holder for resources that have been suspended for this transaction, + * if any. + */ + @Nullable + public Object getSuspendedResources() { + return this.suspendedResources; + } + + + //--------------------------------------------------------------------- + // Enable functionality through underlying transaction object + //--------------------------------------------------------------------- + + /** + * Determine the rollback-only flag via checking the transaction object, provided + * that the latter implements the {@link SmartTransactionObject} interface. + *

Will return {@code true} if the global transaction itself has been marked + * rollback-only by the transaction coordinator, for example in case of a timeout. + * @see SmartTransactionObject#isRollbackOnly() + */ + @Override + public boolean isGlobalRollbackOnly() { + return ((this.transaction instanceof SmartTransactionObject) && + ((SmartTransactionObject) this.transaction).isRollbackOnly()); + } + + /** + * Delegate the flushing to the transaction object, provided that the latter + * implements the {@link SmartTransactionObject} interface. + * @see SmartTransactionObject#flush() + */ + @Override + public void flush() { + if (this.transaction instanceof SmartTransactionObject) { + ((SmartTransactionObject) this.transaction).flush(); + } + } + + /** + * This implementation exposes the {@link SavepointManager} interface + * of the underlying transaction object, if any. + * @throws NestedTransactionNotSupportedException if savepoints are not supported + * @see #isTransactionSavepointManager() + */ + @Override + protected SavepointManager getSavepointManager() { + Object transaction = this.transaction; + if (!(transaction instanceof SavepointManager)) { + throw new NestedTransactionNotSupportedException( + "Transaction object [" + this.transaction + "] does not support savepoints"); + } + return (SavepointManager) transaction; + } + + /** + * Return whether the underlying transaction implements the {@link SavepointManager} + * interface and therefore supports savepoints. + * @see #getTransaction() + * @see #getSavepointManager() + */ + public boolean isTransactionSavepointManager() { + return (this.transaction instanceof SavepointManager); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/DelegatingTransactionDefinition.java b/spring-tx/src/main/java/org/springframework/transaction/support/DelegatingTransactionDefinition.java new file mode 100644 index 0000000000000000000000000000000000000000..fbedd01b2611a06e7ead18d8ca7c52b6da85f2d2 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/DelegatingTransactionDefinition.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.io.Serializable; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.util.Assert; + +/** + * {@link TransactionDefinition} implementation that delegates all calls to a given target + * {@link TransactionDefinition} instance. Abstract because it is meant to be subclassed, + * with subclasses overriding specific methods that are not supposed to simply delegate + * to the target instance. + * + * @author Juergen Hoeller + * @since 3.0 + */ +@SuppressWarnings("serial") +public abstract class DelegatingTransactionDefinition implements TransactionDefinition, Serializable { + + private final TransactionDefinition targetDefinition; + + + /** + * Create a DelegatingTransactionAttribute for the given target attribute. + * @param targetDefinition the target TransactionAttribute to delegate to + */ + public DelegatingTransactionDefinition(TransactionDefinition targetDefinition) { + Assert.notNull(targetDefinition, "Target definition must not be null"); + this.targetDefinition = targetDefinition; + } + + + @Override + public int getPropagationBehavior() { + return this.targetDefinition.getPropagationBehavior(); + } + + @Override + public int getIsolationLevel() { + return this.targetDefinition.getIsolationLevel(); + } + + @Override + public int getTimeout() { + return this.targetDefinition.getTimeout(); + } + + @Override + public boolean isReadOnly() { + return this.targetDefinition.isReadOnly(); + } + + @Override + @Nullable + public String getName() { + return this.targetDefinition.getName(); + } + + + @Override + public boolean equals(Object other) { + return this.targetDefinition.equals(other); + } + + @Override + public int hashCode() { + return this.targetDefinition.hashCode(); + } + + @Override + public String toString() { + return this.targetDefinition.toString(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolder.java b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolder.java new file mode 100644 index 0000000000000000000000000000000000000000..ed9aaed037b5dbee8e61ea1a0f31779f9703e924 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolder.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +/** + * Generic interface to be implemented by resource holders. + * Allows Spring's transaction infrastructure to introspect + * and reset the holder when necessary. + * + * @author Juergen Hoeller + * @since 2.5.5 + * @see ResourceHolderSupport + * @see ResourceHolderSynchronization + */ +public interface ResourceHolder { + + /** + * Reset the transactional state of this holder. + */ + void reset(); + + /** + * Notify this holder that it has been unbound from transaction synchronization. + */ + void unbound(); + + /** + * Determine whether this holder is considered as 'void', + * i.e. as a leftover from a previous thread. + */ + boolean isVoid(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolderSupport.java b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolderSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..c58256223c09d9fdb0296ba9d3c0c5d6e94a1ee6 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolderSupport.java @@ -0,0 +1,210 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.util.Date; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionTimedOutException; + +/** + * Convenient base class for resource holders. + * + *

Features rollback-only support for participating transactions. + * Can expire after a certain number of seconds or milliseconds + * in order to determine a transactional timeout. + * + * @author Juergen Hoeller + * @since 02.02.2004 + * @see org.springframework.jdbc.datasource.DataSourceTransactionManager#doBegin + * @see org.springframework.jdbc.datasource.DataSourceUtils#applyTransactionTimeout + */ +public abstract class ResourceHolderSupport implements ResourceHolder { + + private boolean synchronizedWithTransaction = false; + + private boolean rollbackOnly = false; + + @Nullable + private Date deadline; + + private int referenceCount = 0; + + private boolean isVoid = false; + + + /** + * Mark the resource as synchronized with a transaction. + */ + public void setSynchronizedWithTransaction(boolean synchronizedWithTransaction) { + this.synchronizedWithTransaction = synchronizedWithTransaction; + } + + /** + * Return whether the resource is synchronized with a transaction. + */ + public boolean isSynchronizedWithTransaction() { + return this.synchronizedWithTransaction; + } + + /** + * Mark the resource transaction as rollback-only. + */ + public void setRollbackOnly() { + this.rollbackOnly = true; + } + + /** + * Reset the rollback-only status for this resource transaction. + *

Only really intended to be called after custom rollback steps which + * keep the original resource in action, e.g. in case of a savepoint. + * @since 5.0 + * @see org.springframework.transaction.SavepointManager#rollbackToSavepoint + */ + public void resetRollbackOnly() { + this.rollbackOnly = false; + } + + /** + * Return whether the resource transaction is marked as rollback-only. + */ + public boolean isRollbackOnly() { + return this.rollbackOnly; + } + + /** + * Set the timeout for this object in seconds. + * @param seconds number of seconds until expiration + */ + public void setTimeoutInSeconds(int seconds) { + setTimeoutInMillis(seconds * 1000L); + } + + /** + * Set the timeout for this object in milliseconds. + * @param millis number of milliseconds until expiration + */ + public void setTimeoutInMillis(long millis) { + this.deadline = new Date(System.currentTimeMillis() + millis); + } + + /** + * Return whether this object has an associated timeout. + */ + public boolean hasTimeout() { + return (this.deadline != null); + } + + /** + * Return the expiration deadline of this object. + * @return the deadline as Date object + */ + @Nullable + public Date getDeadline() { + return this.deadline; + } + + /** + * Return the time to live for this object in seconds. + * Rounds up eagerly, e.g. 9.00001 still to 10. + * @return number of seconds until expiration + * @throws TransactionTimedOutException if the deadline has already been reached + */ + public int getTimeToLiveInSeconds() { + double diff = ((double) getTimeToLiveInMillis()) / 1000; + int secs = (int) Math.ceil(diff); + checkTransactionTimeout(secs <= 0); + return secs; + } + + /** + * Return the time to live for this object in milliseconds. + * @return number of milliseconds until expiration + * @throws TransactionTimedOutException if the deadline has already been reached + */ + public long getTimeToLiveInMillis() throws TransactionTimedOutException{ + if (this.deadline == null) { + throw new IllegalStateException("No timeout specified for this resource holder"); + } + long timeToLive = this.deadline.getTime() - System.currentTimeMillis(); + checkTransactionTimeout(timeToLive <= 0); + return timeToLive; + } + + /** + * Set the transaction rollback-only if the deadline has been reached, + * and throw a TransactionTimedOutException. + */ + private void checkTransactionTimeout(boolean deadlineReached) throws TransactionTimedOutException { + if (deadlineReached) { + setRollbackOnly(); + throw new TransactionTimedOutException("Transaction timed out: deadline was " + this.deadline); + } + } + + /** + * Increase the reference count by one because the holder has been requested + * (i.e. someone requested the resource held by it). + */ + public void requested() { + this.referenceCount++; + } + + /** + * Decrease the reference count by one because the holder has been released + * (i.e. someone released the resource held by it). + */ + public void released() { + this.referenceCount--; + } + + /** + * Return whether there are still open references to this holder. + */ + public boolean isOpen() { + return (this.referenceCount > 0); + } + + /** + * Clear the transactional state of this resource holder. + */ + public void clear() { + this.synchronizedWithTransaction = false; + this.rollbackOnly = false; + this.deadline = null; + } + + /** + * Reset this resource holder - transactional state as well as reference count. + */ + @Override + public void reset() { + clear(); + this.referenceCount = 0; + } + + @Override + public void unbound() { + this.isVoid = true; + } + + @Override + public boolean isVoid() { + return this.isVoid; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolderSynchronization.java b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolderSynchronization.java new file mode 100644 index 0000000000000000000000000000000000000000..39b8c347fcfa7d685f73fb37372da5b5d346cf78 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceHolderSynchronization.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +/** + * {@link TransactionSynchronization} implementation that manages a + * {@link ResourceHolder} bound through {@link TransactionSynchronizationManager}. + * + * @author Juergen Hoeller + * @since 2.5.5 + * @param the resource holder type + * @param the resource key type + */ +public abstract class ResourceHolderSynchronization + implements TransactionSynchronization { + + private final H resourceHolder; + + private final K resourceKey; + + private volatile boolean holderActive = true; + + + /** + * Create a new ResourceHolderSynchronization for the given holder. + * @param resourceHolder the ResourceHolder to manage + * @param resourceKey the key to bind the ResourceHolder for + * @see TransactionSynchronizationManager#bindResource + */ + public ResourceHolderSynchronization(H resourceHolder, K resourceKey) { + this.resourceHolder = resourceHolder; + this.resourceKey = resourceKey; + } + + + @Override + public void suspend() { + if (this.holderActive) { + TransactionSynchronizationManager.unbindResource(this.resourceKey); + } + } + + @Override + public void resume() { + if (this.holderActive) { + TransactionSynchronizationManager.bindResource(this.resourceKey, this.resourceHolder); + } + } + + @Override + public void flush() { + flushResource(this.resourceHolder); + } + + @Override + public void beforeCommit(boolean readOnly) { + } + + @Override + public void beforeCompletion() { + if (shouldUnbindAtCompletion()) { + TransactionSynchronizationManager.unbindResource(this.resourceKey); + this.holderActive = false; + if (shouldReleaseBeforeCompletion()) { + releaseResource(this.resourceHolder, this.resourceKey); + } + } + } + + @Override + public void afterCommit() { + if (!shouldReleaseBeforeCompletion()) { + processResourceAfterCommit(this.resourceHolder); + } + } + + @Override + public void afterCompletion(int status) { + if (shouldUnbindAtCompletion()) { + boolean releaseNecessary = false; + if (this.holderActive) { + // The thread-bound resource holder might not be available anymore, + // since afterCompletion might get called from a different thread. + this.holderActive = false; + TransactionSynchronizationManager.unbindResourceIfPossible(this.resourceKey); + this.resourceHolder.unbound(); + releaseNecessary = true; + } + else { + releaseNecessary = shouldReleaseAfterCompletion(this.resourceHolder); + } + if (releaseNecessary) { + releaseResource(this.resourceHolder, this.resourceKey); + } + } + else { + // Probably a pre-bound resource... + cleanupResource(this.resourceHolder, this.resourceKey, (status == STATUS_COMMITTED)); + } + this.resourceHolder.reset(); + } + + + /** + * Return whether this holder should be unbound at completion + * (or should rather be left bound to the thread after the transaction). + *

The default implementation returns {@code true}. + */ + protected boolean shouldUnbindAtCompletion() { + return true; + } + + /** + * Return whether this holder's resource should be released before + * transaction completion ({@code true}) or rather after + * transaction completion ({@code false}). + *

Note that resources will only be released when they are + * unbound from the thread ({@link #shouldUnbindAtCompletion()}). + *

The default implementation returns {@code true}. + * @see #releaseResource + */ + protected boolean shouldReleaseBeforeCompletion() { + return true; + } + + /** + * Return whether this holder's resource should be released after + * transaction completion ({@code true}). + *

The default implementation returns {@code !shouldReleaseBeforeCompletion()}, + * releasing after completion if no attempt was made before completion. + * @see #releaseResource + */ + protected boolean shouldReleaseAfterCompletion(H resourceHolder) { + return !shouldReleaseBeforeCompletion(); + } + + /** + * Flush callback for the given resource holder. + * @param resourceHolder the resource holder to flush + */ + protected void flushResource(H resourceHolder) { + } + + /** + * After-commit callback for the given resource holder. + * Only called when the resource hasn't been released yet + * ({@link #shouldReleaseBeforeCompletion()}). + * @param resourceHolder the resource holder to process + */ + protected void processResourceAfterCommit(H resourceHolder) { + } + + /** + * Release the given resource (after it has been unbound from the thread). + * @param resourceHolder the resource holder to process + * @param resourceKey the key that the ResourceHolder was bound for + */ + protected void releaseResource(H resourceHolder, K resourceKey) { + } + + /** + * Perform a cleanup on the given resource (which is left bound to the thread). + * @param resourceHolder the resource holder to process + * @param resourceKey the key that the ResourceHolder was bound for + * @param committed whether the transaction has committed ({@code true}) + * or rolled back ({@code false}) + */ + protected void cleanupResource(H resourceHolder, K resourceKey, boolean committed) { + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/ResourceTransactionDefinition.java b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceTransactionDefinition.java new file mode 100644 index 0000000000000000000000000000000000000000..ae295384b6f8e65f41dc0328ff2c35f9517e30df --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceTransactionDefinition.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.transaction.TransactionDefinition; + +/** + * Extended variant of {@link TransactionDefinition}, indicating a resource transaction + * and in particular whether the transactional resource is ready for local optimizations. + * + * @author Juergen Hoeller + * @since 5.1 + * @see ResourceTransactionManager + */ +public interface ResourceTransactionDefinition extends TransactionDefinition { + + /** + * Determine whether the transactional resource is ready for local optimizations. + * @return {@code true} if the resource is known to be entirely transaction-local, + * not affecting any operations outside of the scope of the current transaction + * @see #isReadOnly() + */ + boolean isLocalResource(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/ResourceTransactionManager.java b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..f6ee501bba463617c211fb615f3c3d75bea9242b --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/ResourceTransactionManager.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.transaction.PlatformTransactionManager; + +/** + * Extension of the {@link org.springframework.transaction.PlatformTransactionManager} + * interface, indicating a native resource transaction manager, operating on a single + * target resource. Such transaction managers differ from JTA transaction managers in + * that they do not use XA transaction enlistment for an open number of resources but + * rather focus on leveraging the native power and simplicity of a single target resource. + * + *

This interface is mainly used for abstract introspection of a transaction manager, + * giving clients a hint on what kind of transaction manager they have been given + * and on what concrete resource the transaction manager is operating on. + * + * @author Juergen Hoeller + * @since 2.0.4 + * @see TransactionSynchronizationManager + */ +public interface ResourceTransactionManager extends PlatformTransactionManager { + + /** + * Return the resource factory that this transaction manager operates on, + * e.g. a JDBC DataSource or a JMS ConnectionFactory. + *

This target resource factory is usually used as resource key for + * {@link TransactionSynchronizationManager}'s resource bindings per thread. + * @return the target resource factory (never {@code null}) + * @see TransactionSynchronizationManager#bindResource + * @see TransactionSynchronizationManager#getResource + */ + Object getResourceFactory(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/SimpleTransactionScope.java b/spring-tx/src/main/java/org/springframework/transaction/support/SimpleTransactionScope.java new file mode 100644 index 0000000000000000000000000000000000000000..58effc4329f1dbd225a61df6a403a7593f3de99a --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/SimpleTransactionScope.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.beans.factory.config.Scope; +import org.springframework.lang.Nullable; + +/** + * A simple transaction-backed {@link Scope} implementation, delegating to + * {@link TransactionSynchronizationManager}'s resource binding mechanism. + * + *

NOTE: Like {@link org.springframework.context.support.SimpleThreadScope}, + * this transaction scope is not registered by default in common contexts. Instead, + * you need to explicitly assign it to a scope key in your setup, either through + * {@link org.springframework.beans.factory.config.ConfigurableBeanFactory#registerScope} + * or through a {@link org.springframework.beans.factory.config.CustomScopeConfigurer} bean. + * + * @author Juergen Hoeller + * @since 4.2 + * @see org.springframework.context.support.SimpleThreadScope + * @see org.springframework.beans.factory.config.ConfigurableBeanFactory#registerScope + * @see org.springframework.beans.factory.config.CustomScopeConfigurer + */ +public class SimpleTransactionScope implements Scope { + + @Override + public Object get(String name, ObjectFactory objectFactory) { + ScopedObjectsHolder scopedObjects = (ScopedObjectsHolder) TransactionSynchronizationManager.getResource(this); + if (scopedObjects == null) { + scopedObjects = new ScopedObjectsHolder(); + TransactionSynchronizationManager.registerSynchronization(new CleanupSynchronization(scopedObjects)); + TransactionSynchronizationManager.bindResource(this, scopedObjects); + } + Object scopedObject = scopedObjects.scopedInstances.get(name); + if (scopedObject == null) { + scopedObject = objectFactory.getObject(); + scopedObjects.scopedInstances.put(name, scopedObject); + } + return scopedObject; + } + + @Override + @Nullable + public Object remove(String name) { + ScopedObjectsHolder scopedObjects = (ScopedObjectsHolder) TransactionSynchronizationManager.getResource(this); + if (scopedObjects != null) { + scopedObjects.destructionCallbacks.remove(name); + return scopedObjects.scopedInstances.remove(name); + } + else { + return null; + } + } + + @Override + public void registerDestructionCallback(String name, Runnable callback) { + ScopedObjectsHolder scopedObjects = (ScopedObjectsHolder) TransactionSynchronizationManager.getResource(this); + if (scopedObjects != null) { + scopedObjects.destructionCallbacks.put(name, callback); + } + } + + @Override + @Nullable + public Object resolveContextualObject(String key) { + return null; + } + + @Override + @Nullable + public String getConversationId() { + return TransactionSynchronizationManager.getCurrentTransactionName(); + } + + + /** + * Holder for scoped objects. + */ + static class ScopedObjectsHolder { + + final Map scopedInstances = new HashMap<>(); + + final Map destructionCallbacks = new LinkedHashMap<>(); + } + + + private class CleanupSynchronization extends TransactionSynchronizationAdapter { + + private final ScopedObjectsHolder scopedObjects; + + public CleanupSynchronization(ScopedObjectsHolder scopedObjects) { + this.scopedObjects = scopedObjects; + } + + @Override + public void suspend() { + TransactionSynchronizationManager.unbindResource(SimpleTransactionScope.this); + } + + @Override + public void resume() { + TransactionSynchronizationManager.bindResource(SimpleTransactionScope.this, this.scopedObjects); + } + + @Override + public void afterCompletion(int status) { + TransactionSynchronizationManager.unbindResourceIfPossible(SimpleTransactionScope.this); + for (Runnable callback : this.scopedObjects.destructionCallbacks.values()) { + callback.run(); + } + this.scopedObjects.destructionCallbacks.clear(); + this.scopedObjects.scopedInstances.clear(); + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/SimpleTransactionStatus.java b/spring-tx/src/main/java/org/springframework/transaction/support/SimpleTransactionStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..2524941b0169d74c439a760cfe876139a87a7ac7 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/SimpleTransactionStatus.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +/** + * A simple {@link org.springframework.transaction.TransactionStatus} + * implementation. Derives from {@link AbstractTransactionStatus} and + * adds an explicit {@link #isNewTransaction() "newTransaction"} flag. + * + *

This class is not used by any of Spring's pre-built + * {@link org.springframework.transaction.PlatformTransactionManager} + * implementations. It is mainly provided as a start for custom transaction + * manager implementations and as a static mock for testing transactional + * code (either as part of a mock {@code PlatformTransactionManager} or + * as argument passed into a {@link TransactionCallback} to be tested). + * + * @author Juergen Hoeller + * @since 1.2.3 + * @see TransactionCallback#doInTransaction + */ +public class SimpleTransactionStatus extends AbstractTransactionStatus { + + private final boolean newTransaction; + + + /** + * Create a new {@code SimpleTransactionStatus} instance, + * indicating a new transaction. + */ + public SimpleTransactionStatus() { + this(true); + } + + /** + * Create a new {@code SimpleTransactionStatus} instance. + * @param newTransaction whether to indicate a new transaction + */ + public SimpleTransactionStatus(boolean newTransaction) { + this.newTransaction = newTransaction; + } + + + @Override + public boolean isNewTransaction() { + return this.newTransaction; + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/SmartTransactionObject.java b/spring-tx/src/main/java/org/springframework/transaction/support/SmartTransactionObject.java new file mode 100644 index 0000000000000000000000000000000000000000..923e00480456c564927632dc92eb40faddf3790c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/SmartTransactionObject.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.io.Flushable; + +/** + * Interface to be implemented by transaction objects that are able to + * return an internal rollback-only marker, typically from a another + * transaction that has participated and marked it as rollback-only. + * + *

Autodetected by DefaultTransactionStatus, to always return a + * current rollbackOnly flag even if not resulting from the current + * TransactionStatus. + * + * @author Juergen Hoeller + * @since 1.1 + * @see DefaultTransactionStatus#isRollbackOnly + */ +public interface SmartTransactionObject extends Flushable { + + /** + * Return whether the transaction is internally marked as rollback-only. + * Can, for example, check the JTA UserTransaction. + * @see javax.transaction.UserTransaction#getStatus + * @see javax.transaction.Status#STATUS_MARKED_ROLLBACK + */ + boolean isRollbackOnly(); + + /** + * Flush the underlying sessions to the datastore, if applicable: + * for example, all affected Hibernate/JPA sessions. + */ + @Override + void flush(); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionCallback.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionCallback.java new file mode 100644 index 0000000000000000000000000000000000000000..08bacac4ec43678f890c6322e1350ef5681387a1 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionCallback.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionStatus; + +/** + * Callback interface for transactional code. Used with {@link TransactionTemplate}'s + * {@code execute} method, often as anonymous class within a method implementation. + * + *

Typically used to assemble various calls to transaction-unaware data access + * services into a higher-level service method with transaction demarcation. As an + * alternative, consider the use of declarative transaction demarcation (e.g. through + * Spring's {@link org.springframework.transaction.annotation.Transactional} annotation). + * + * @author Juergen Hoeller + * @since 17.03.2003 + * @see TransactionTemplate + * @see CallbackPreferringPlatformTransactionManager + * @param the result type + */ +@FunctionalInterface +public interface TransactionCallback { + + /** + * Gets called by {@link TransactionTemplate#execute} within a transactional context. + * Does not need to care about transactions itself, although it can retrieve and + * influence the status of the current transaction via the given status object, + * e.g. setting rollback-only. + *

Allows for returning a result object created within the transaction, i.e. a + * domain object or a collection of domain objects. A RuntimeException thrown by the + * callback is treated as application exception that enforces a rollback. Any such + * exception will be propagated to the caller of the template, unless there is a + * problem rolling back, in which case a TransactionException will be thrown. + * @param status associated transaction status + * @return a result object, or {@code null} + * @see TransactionTemplate#execute + * @see CallbackPreferringPlatformTransactionManager#execute + */ + @Nullable + T doInTransaction(TransactionStatus status); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionCallbackWithoutResult.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionCallbackWithoutResult.java new file mode 100644 index 0000000000000000000000000000000000000000..84eb4d890088fd578f2f8daf639d403510cc4e42 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionCallbackWithoutResult.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionStatus; + +/** + * Simple convenience class for TransactionCallback implementation. + * Allows for implementing a doInTransaction version without result, + * i.e. without the need for a return statement. + * + * @author Juergen Hoeller + * @since 28.03.2003 + * @see TransactionTemplate + */ +public abstract class TransactionCallbackWithoutResult implements TransactionCallback { + + @Override + @Nullable + public final Object doInTransaction(TransactionStatus status) { + doInTransactionWithoutResult(status); + return null; + } + + /** + * Gets called by {@code TransactionTemplate.execute} within a transactional + * context. Does not need to care about transactions itself, although it can retrieve + * and influence the status of the current transaction via the given status object, + * e.g. setting rollback-only. + *

A RuntimeException thrown by the callback is treated as application + * exception that enforces a rollback. An exception gets propagated to the + * caller of the template. + *

Note when using JTA: JTA transactions only work with transactional + * JNDI resources, so implementations need to use such resources if they + * want transaction support. + * @param status associated transaction status + * @see TransactionTemplate#execute + */ + protected abstract void doInTransactionWithoutResult(TransactionStatus status); + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionOperations.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionOperations.java new file mode 100644 index 0000000000000000000000000000000000000000..bf70e1b6a0956595260e7541e3fe9490fc246114 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionOperations.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.TransactionException; + +/** + * Interface specifying basic transaction execution operations. + * Implemented by {@link TransactionTemplate}. Not often used directly, + * but a useful option to enhance testability, as it can easily be + * mocked or stubbed. + * + * @author Juergen Hoeller + * @since 2.0.4 + */ +public interface TransactionOperations { + + /** + * Execute the action specified by the given callback object within a transaction. + *

Allows for returning a result object created within the transaction, that is, + * a domain object or a collection of domain objects. A RuntimeException thrown + * by the callback is treated as a fatal exception that enforces a rollback. + * Such an exception gets propagated to the caller of the template. + * @param action the callback object that specifies the transactional action + * @return a result object returned by the callback, or {@code null} if none + * @throws TransactionException in case of initialization, rollback, or system errors + * @throws RuntimeException if thrown by the TransactionCallback + */ + @Nullable + T execute(TransactionCallback action) throws TransactionException; + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronization.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronization.java new file mode 100644 index 0000000000000000000000000000000000000000..9d21ddb01f3aae8574771cb51725ee940b5b44e3 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronization.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.io.Flushable; + +/** + * Interface for transaction synchronization callbacks. + * Supported by AbstractPlatformTransactionManager. + * + *

TransactionSynchronization implementations can implement the Ordered interface + * to influence their execution order. A synchronization that does not implement the + * Ordered interface is appended to the end of the synchronization chain. + * + *

System synchronizations performed by Spring itself use specific order values, + * allowing for fine-grained interaction with their execution order (if necessary). + * + * @author Juergen Hoeller + * @since 02.06.2003 + * @see TransactionSynchronizationManager + * @see AbstractPlatformTransactionManager + * @see org.springframework.jdbc.datasource.DataSourceUtils#CONNECTION_SYNCHRONIZATION_ORDER + */ +public interface TransactionSynchronization extends Flushable { + + /** Completion status in case of proper commit. */ + int STATUS_COMMITTED = 0; + + /** Completion status in case of proper rollback. */ + int STATUS_ROLLED_BACK = 1; + + /** Completion status in case of heuristic mixed completion or system errors. */ + int STATUS_UNKNOWN = 2; + + + /** + * Suspend this synchronization. + * Supposed to unbind resources from TransactionSynchronizationManager if managing any. + * @see TransactionSynchronizationManager#unbindResource + */ + default void suspend() { + } + + /** + * Resume this synchronization. + * Supposed to rebind resources to TransactionSynchronizationManager if managing any. + * @see TransactionSynchronizationManager#bindResource + */ + default void resume() { + } + + /** + * Flush the underlying session to the datastore, if applicable: + * for example, a Hibernate/JPA session. + * @see org.springframework.transaction.TransactionStatus#flush() + */ + @Override + default void flush() { + } + + /** + * Invoked before transaction commit (before "beforeCompletion"). + * Can e.g. flush transactional O/R Mapping sessions to the database. + *

This callback does not mean that the transaction will actually be committed. + * A rollback decision can still occur after this method has been called. This callback + * is rather meant to perform work that's only relevant if a commit still has a chance + * to happen, such as flushing SQL statements to the database. + *

Note that exceptions will get propagated to the commit caller and cause a + * rollback of the transaction. + * @param readOnly whether the transaction is defined as read-only transaction + * @throws RuntimeException in case of errors; will be propagated to the caller + * (note: do not throw TransactionException subclasses here!) + * @see #beforeCompletion + */ + default void beforeCommit(boolean readOnly) { + } + + /** + * Invoked before transaction commit/rollback. + * Can perform resource cleanup before transaction completion. + *

This method will be invoked after {@code beforeCommit}, even when + * {@code beforeCommit} threw an exception. This callback allows for + * closing resources before transaction completion, for any outcome. + * @throws RuntimeException in case of errors; will be logged but not propagated + * (note: do not throw TransactionException subclasses here!) + * @see #beforeCommit + * @see #afterCompletion + */ + default void beforeCompletion() { + } + + /** + * Invoked after transaction commit. Can perform further operations right + * after the main transaction has successfully committed. + *

Can e.g. commit further operations that are supposed to follow on a successful + * commit of the main transaction, like confirmation messages or emails. + *

NOTE: The transaction will have been committed already, but the + * transactional resources might still be active and accessible. As a consequence, + * any data access code triggered at this point will still "participate" in the + * original transaction, allowing to perform some cleanup (with no commit following + * anymore!), unless it explicitly declares that it needs to run in a separate + * transaction. Hence: Use {@code PROPAGATION_REQUIRES_NEW} for any + * transactional operation that is called from here. + * @throws RuntimeException in case of errors; will be propagated to the caller + * (note: do not throw TransactionException subclasses here!) + */ + default void afterCommit() { + } + + /** + * Invoked after transaction commit/rollback. + * Can perform resource cleanup after transaction completion. + *

NOTE: The transaction will have been committed or rolled back already, + * but the transactional resources might still be active and accessible. As a + * consequence, any data access code triggered at this point will still "participate" + * in the original transaction, allowing to perform some cleanup (with no commit + * following anymore!), unless it explicitly declares that it needs to run in a + * separate transaction. Hence: Use {@code PROPAGATION_REQUIRES_NEW} + * for any transactional operation that is called from here. + * @param status completion status according to the {@code STATUS_*} constants + * @throws RuntimeException in case of errors; will be logged but not propagated + * (note: do not throw TransactionException subclasses here!) + * @see #STATUS_COMMITTED + * @see #STATUS_ROLLED_BACK + * @see #STATUS_UNKNOWN + * @see #beforeCompletion + */ + default void afterCompletion(int status) { + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationAdapter.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..49cbe02ed442acecee6311a08b283903c1d8020c --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationAdapter.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import org.springframework.core.Ordered; + +/** + * Simple {@link TransactionSynchronization} adapter containing empty + * method implementations, for easier overriding of single methods. + * + *

Also implements the {@link Ordered} interface to enable the execution + * order of synchronizations to be controlled declaratively. The default + * {@link #getOrder() order} is {@link Ordered#LOWEST_PRECEDENCE}, indicating + * late execution; return a lower value for earlier execution. + * + * @author Juergen Hoeller + * @since 22.01.2004 + */ +public abstract class TransactionSynchronizationAdapter implements TransactionSynchronization, Ordered { + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + @Override + public void suspend() { + } + + @Override + public void resume() { + } + + @Override + public void flush() { + } + + @Override + public void beforeCommit(boolean readOnly) { + } + + @Override + public void beforeCompletion() { + } + + @Override + public void afterCommit() { + } + + @Override + public void afterCompletion(int status) { + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java new file mode 100644 index 0000000000000000000000000000000000000000..df9132d13d514ee66c702f72432e781e36bd3d98 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -0,0 +1,479 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.NamedThreadLocal; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Central delegate that manages resources and transaction synchronizations per thread. + * To be used by resource management code but not by typical application code. + * + *

Supports one resource per key without overwriting, that is, a resource needs + * to be removed before a new one can be set for the same key. + * Supports a list of transaction synchronizations if synchronization is active. + * + *

Resource management code should check for thread-bound resources, e.g. JDBC + * Connections or Hibernate Sessions, via {@code getResource}. Such code is + * normally not supposed to bind resources to threads, as this is the responsibility + * of transaction managers. A further option is to lazily bind on first use if + * transaction synchronization is active, for performing transactions that span + * an arbitrary number of resources. + * + *

Transaction synchronization must be activated and deactivated by a transaction + * manager via {@link #initSynchronization()} and {@link #clearSynchronization()}. + * This is automatically supported by {@link AbstractPlatformTransactionManager}, + * and thus by all standard Spring transaction managers, such as + * {@link org.springframework.transaction.jta.JtaTransactionManager} and + * {@link org.springframework.jdbc.datasource.DataSourceTransactionManager}. + * + *

Resource management code should only register synchronizations when this + * manager is active, which can be checked via {@link #isSynchronizationActive}; + * it should perform immediate resource cleanup else. If transaction synchronization + * isn't active, there is either no current transaction, or the transaction manager + * doesn't support transaction synchronization. + * + *

Synchronization is for example used to always return the same resources + * within a JTA transaction, e.g. a JDBC Connection or a Hibernate Session for + * any given DataSource or SessionFactory, respectively. + * + * @author Juergen Hoeller + * @since 02.06.2003 + * @see #isSynchronizationActive + * @see #registerSynchronization + * @see TransactionSynchronization + * @see AbstractPlatformTransactionManager#setTransactionSynchronization + * @see org.springframework.transaction.jta.JtaTransactionManager + * @see org.springframework.jdbc.datasource.DataSourceTransactionManager + * @see org.springframework.jdbc.datasource.DataSourceUtils#getConnection + */ +public abstract class TransactionSynchronizationManager { + + private static final Log logger = LogFactory.getLog(TransactionSynchronizationManager.class); + + private static final ThreadLocal> resources = + new NamedThreadLocal<>("Transactional resources"); + + private static final ThreadLocal> synchronizations = + new NamedThreadLocal<>("Transaction synchronizations"); + + private static final ThreadLocal currentTransactionName = + new NamedThreadLocal<>("Current transaction name"); + + private static final ThreadLocal currentTransactionReadOnly = + new NamedThreadLocal<>("Current transaction read-only status"); + + private static final ThreadLocal currentTransactionIsolationLevel = + new NamedThreadLocal<>("Current transaction isolation level"); + + private static final ThreadLocal actualTransactionActive = + new NamedThreadLocal<>("Actual transaction active"); + + + //------------------------------------------------------------------------- + // Management of transaction-associated resource handles + //------------------------------------------------------------------------- + + /** + * Return all resources that are bound to the current thread. + *

Mainly for debugging purposes. Resource managers should always invoke + * {@code hasResource} for a specific resource key that they are interested in. + * @return a Map with resource keys (usually the resource factory) and resource + * values (usually the active resource object), or an empty Map if there are + * currently no resources bound + * @see #hasResource + */ + public static Map getResourceMap() { + Map map = resources.get(); + return (map != null ? Collections.unmodifiableMap(map) : Collections.emptyMap()); + } + + /** + * Check if there is a resource for the given key bound to the current thread. + * @param key the key to check (usually the resource factory) + * @return if there is a value bound to the current thread + * @see ResourceTransactionManager#getResourceFactory() + */ + public static boolean hasResource(Object key) { + Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); + Object value = doGetResource(actualKey); + return (value != null); + } + + /** + * Retrieve a resource for the given key that is bound to the current thread. + * @param key the key to check (usually the resource factory) + * @return a value bound to the current thread (usually the active + * resource object), or {@code null} if none + * @see ResourceTransactionManager#getResourceFactory() + */ + @Nullable + public static Object getResource(Object key) { + Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); + Object value = doGetResource(actualKey); + if (value != null && logger.isTraceEnabled()) { + logger.trace("Retrieved value [" + value + "] for key [" + actualKey + "] bound to thread [" + + Thread.currentThread().getName() + "]"); + } + return value; + } + + /** + * Actually check the value of the resource that is bound for the given key. + */ + @Nullable + private static Object doGetResource(Object actualKey) { + Map map = resources.get(); + if (map == null) { + return null; + } + Object value = map.get(actualKey); + // Transparently remove ResourceHolder that was marked as void... + if (value instanceof ResourceHolder && ((ResourceHolder) value).isVoid()) { + map.remove(actualKey); + // Remove entire ThreadLocal if empty... + if (map.isEmpty()) { + resources.remove(); + } + value = null; + } + return value; + } + + /** + * Bind the given resource for the given key to the current thread. + * @param key the key to bind the value to (usually the resource factory) + * @param value the value to bind (usually the active resource object) + * @throws IllegalStateException if there is already a value bound to the thread + * @see ResourceTransactionManager#getResourceFactory() + */ + public static void bindResource(Object key, Object value) throws IllegalStateException { + Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); + Assert.notNull(value, "Value must not be null"); + Map map = resources.get(); + // set ThreadLocal Map if none found + if (map == null) { + map = new HashMap<>(); + resources.set(map); + } + Object oldValue = map.put(actualKey, value); + // Transparently suppress a ResourceHolder that was marked as void... + if (oldValue instanceof ResourceHolder && ((ResourceHolder) oldValue).isVoid()) { + oldValue = null; + } + if (oldValue != null) { + throw new IllegalStateException("Already value [" + oldValue + "] for key [" + + actualKey + "] bound to thread [" + Thread.currentThread().getName() + "]"); + } + if (logger.isTraceEnabled()) { + logger.trace("Bound value [" + value + "] for key [" + actualKey + "] to thread [" + + Thread.currentThread().getName() + "]"); + } + } + + /** + * Unbind a resource for the given key from the current thread. + * @param key the key to unbind (usually the resource factory) + * @return the previously bound value (usually the active resource object) + * @throws IllegalStateException if there is no value bound to the thread + * @see ResourceTransactionManager#getResourceFactory() + */ + public static Object unbindResource(Object key) throws IllegalStateException { + Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); + Object value = doUnbindResource(actualKey); + if (value == null) { + throw new IllegalStateException( + "No value for key [" + actualKey + "] bound to thread [" + Thread.currentThread().getName() + "]"); + } + return value; + } + + /** + * Unbind a resource for the given key from the current thread. + * @param key the key to unbind (usually the resource factory) + * @return the previously bound value, or {@code null} if none bound + */ + @Nullable + public static Object unbindResourceIfPossible(Object key) { + Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); + return doUnbindResource(actualKey); + } + + /** + * Actually remove the value of the resource that is bound for the given key. + */ + @Nullable + private static Object doUnbindResource(Object actualKey) { + Map map = resources.get(); + if (map == null) { + return null; + } + Object value = map.remove(actualKey); + // Remove entire ThreadLocal if empty... + if (map.isEmpty()) { + resources.remove(); + } + // Transparently suppress a ResourceHolder that was marked as void... + if (value instanceof ResourceHolder && ((ResourceHolder) value).isVoid()) { + value = null; + } + if (value != null && logger.isTraceEnabled()) { + logger.trace("Removed value [" + value + "] for key [" + actualKey + "] from thread [" + + Thread.currentThread().getName() + "]"); + } + return value; + } + + + //------------------------------------------------------------------------- + // Management of transaction synchronizations + //------------------------------------------------------------------------- + + /** + * Return if transaction synchronization is active for the current thread. + * Can be called before register to avoid unnecessary instance creation. + * @see #registerSynchronization + */ + public static boolean isSynchronizationActive() { + return (synchronizations.get() != null); + } + + /** + * Activate transaction synchronization for the current thread. + * Called by a transaction manager on transaction begin. + * @throws IllegalStateException if synchronization is already active + */ + public static void initSynchronization() throws IllegalStateException { + if (isSynchronizationActive()) { + throw new IllegalStateException("Cannot activate transaction synchronization - already active"); + } + logger.trace("Initializing transaction synchronization"); + synchronizations.set(new LinkedHashSet<>()); + } + + /** + * Register a new transaction synchronization for the current thread. + * Typically called by resource management code. + *

Note that synchronizations can implement the + * {@link org.springframework.core.Ordered} interface. + * They will be executed in an order according to their order value (if any). + * @param synchronization the synchronization object to register + * @throws IllegalStateException if transaction synchronization is not active + * @see org.springframework.core.Ordered + */ + public static void registerSynchronization(TransactionSynchronization synchronization) + throws IllegalStateException { + + Assert.notNull(synchronization, "TransactionSynchronization must not be null"); + Set synchs = synchronizations.get(); + if (synchs == null) { + throw new IllegalStateException("Transaction synchronization is not active"); + } + synchs.add(synchronization); + } + + /** + * Return an unmodifiable snapshot list of all registered synchronizations + * for the current thread. + * @return unmodifiable List of TransactionSynchronization instances + * @throws IllegalStateException if synchronization is not active + * @see TransactionSynchronization + */ + public static List getSynchronizations() throws IllegalStateException { + Set synchs = synchronizations.get(); + if (synchs == null) { + throw new IllegalStateException("Transaction synchronization is not active"); + } + // Return unmodifiable snapshot, to avoid ConcurrentModificationExceptions + // while iterating and invoking synchronization callbacks that in turn + // might register further synchronizations. + if (synchs.isEmpty()) { + return Collections.emptyList(); + } + else { + // Sort lazily here, not in registerSynchronization. + List sortedSynchs = new ArrayList<>(synchs); + AnnotationAwareOrderComparator.sort(sortedSynchs); + return Collections.unmodifiableList(sortedSynchs); + } + } + + /** + * Deactivate transaction synchronization for the current thread. + * Called by the transaction manager on transaction cleanup. + * @throws IllegalStateException if synchronization is not active + */ + public static void clearSynchronization() throws IllegalStateException { + if (!isSynchronizationActive()) { + throw new IllegalStateException("Cannot deactivate transaction synchronization - not active"); + } + logger.trace("Clearing transaction synchronization"); + synchronizations.remove(); + } + + + //------------------------------------------------------------------------- + // Exposure of transaction characteristics + //------------------------------------------------------------------------- + + /** + * Expose the name of the current transaction, if any. + * Called by the transaction manager on transaction begin and on cleanup. + * @param name the name of the transaction, or {@code null} to reset it + * @see org.springframework.transaction.TransactionDefinition#getName() + */ + public static void setCurrentTransactionName(@Nullable String name) { + currentTransactionName.set(name); + } + + /** + * Return the name of the current transaction, or {@code null} if none set. + * To be called by resource management code for optimizations per use case, + * for example to optimize fetch strategies for specific named transactions. + * @see org.springframework.transaction.TransactionDefinition#getName() + */ + @Nullable + public static String getCurrentTransactionName() { + return currentTransactionName.get(); + } + + /** + * Expose a read-only flag for the current transaction. + * Called by the transaction manager on transaction begin and on cleanup. + * @param readOnly {@code true} to mark the current transaction + * as read-only; {@code false} to reset such a read-only marker + * @see org.springframework.transaction.TransactionDefinition#isReadOnly() + */ + public static void setCurrentTransactionReadOnly(boolean readOnly) { + currentTransactionReadOnly.set(readOnly ? Boolean.TRUE : null); + } + + /** + * Return whether the current transaction is marked as read-only. + * To be called by resource management code when preparing a newly + * created resource (for example, a Hibernate Session). + *

Note that transaction synchronizations receive the read-only flag + * as argument for the {@code beforeCommit} callback, to be able + * to suppress change detection on commit. The present method is meant + * to be used for earlier read-only checks, for example to set the + * flush mode of a Hibernate Session to "FlushMode.MANUAL" upfront. + * @see org.springframework.transaction.TransactionDefinition#isReadOnly() + * @see TransactionSynchronization#beforeCommit(boolean) + */ + public static boolean isCurrentTransactionReadOnly() { + return (currentTransactionReadOnly.get() != null); + } + + /** + * Expose an isolation level for the current transaction. + * Called by the transaction manager on transaction begin and on cleanup. + * @param isolationLevel the isolation level to expose, according to the + * JDBC Connection constants (equivalent to the corresponding Spring + * TransactionDefinition constants), or {@code null} to reset it + * @see java.sql.Connection#TRANSACTION_READ_UNCOMMITTED + * @see java.sql.Connection#TRANSACTION_READ_COMMITTED + * @see java.sql.Connection#TRANSACTION_REPEATABLE_READ + * @see java.sql.Connection#TRANSACTION_SERIALIZABLE + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_READ_UNCOMMITTED + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_READ_COMMITTED + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_REPEATABLE_READ + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_SERIALIZABLE + * @see org.springframework.transaction.TransactionDefinition#getIsolationLevel() + */ + public static void setCurrentTransactionIsolationLevel(@Nullable Integer isolationLevel) { + currentTransactionIsolationLevel.set(isolationLevel); + } + + /** + * Return the isolation level for the current transaction, if any. + * To be called by resource management code when preparing a newly + * created resource (for example, a JDBC Connection). + * @return the currently exposed isolation level, according to the + * JDBC Connection constants (equivalent to the corresponding Spring + * TransactionDefinition constants), or {@code null} if none + * @see java.sql.Connection#TRANSACTION_READ_UNCOMMITTED + * @see java.sql.Connection#TRANSACTION_READ_COMMITTED + * @see java.sql.Connection#TRANSACTION_REPEATABLE_READ + * @see java.sql.Connection#TRANSACTION_SERIALIZABLE + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_READ_UNCOMMITTED + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_READ_COMMITTED + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_REPEATABLE_READ + * @see org.springframework.transaction.TransactionDefinition#ISOLATION_SERIALIZABLE + * @see org.springframework.transaction.TransactionDefinition#getIsolationLevel() + */ + @Nullable + public static Integer getCurrentTransactionIsolationLevel() { + return currentTransactionIsolationLevel.get(); + } + + /** + * Expose whether there currently is an actual transaction active. + * Called by the transaction manager on transaction begin and on cleanup. + * @param active {@code true} to mark the current thread as being associated + * with an actual transaction; {@code false} to reset that marker + */ + public static void setActualTransactionActive(boolean active) { + actualTransactionActive.set(active ? Boolean.TRUE : null); + } + + /** + * Return whether there currently is an actual transaction active. + * This indicates whether the current thread is associated with an actual + * transaction rather than just with active transaction synchronization. + *

To be called by resource management code that wants to discriminate + * between active transaction synchronization (with or without backing + * resource transaction; also on PROPAGATION_SUPPORTS) and an actual + * transaction being active (with backing resource transaction; + * on PROPAGATION_REQUIRED, PROPAGATION_REQUIRES_NEW, etc). + * @see #isSynchronizationActive() + */ + public static boolean isActualTransactionActive() { + return (actualTransactionActive.get() != null); + } + + + /** + * Clear the entire transaction synchronization state for the current thread: + * registered synchronizations as well as the various transaction characteristics. + * @see #clearSynchronization() + * @see #setCurrentTransactionName + * @see #setCurrentTransactionReadOnly + * @see #setCurrentTransactionIsolationLevel + * @see #setActualTransactionActive + */ + public static void clear() { + synchronizations.remove(); + currentTransactionName.remove(); + currentTransactionReadOnly.remove(); + currentTransactionIsolationLevel.remove(); + actualTransactionActive.remove(); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationUtils.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..e91d1bc2784d0f88d4651620392557cdde355f3e --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationUtils.java @@ -0,0 +1,196 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.aop.scope.ScopedObject; +import org.springframework.core.InfrastructureProxy; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + +/** + * Utility methods for triggering specific {@link TransactionSynchronization} + * callback methods on all currently registered synchronizations. + * + * @author Juergen Hoeller + * @since 2.0 + * @see TransactionSynchronization + * @see TransactionSynchronizationManager#getSynchronizations() + */ +public abstract class TransactionSynchronizationUtils { + + private static final Log logger = LogFactory.getLog(TransactionSynchronizationUtils.class); + + private static final boolean aopAvailable = ClassUtils.isPresent( + "org.springframework.aop.scope.ScopedObject", TransactionSynchronizationUtils.class.getClassLoader()); + + + /** + * Check whether the given resource transaction managers refers to the given + * (underlying) resource factory. + * @see ResourceTransactionManager#getResourceFactory() + * @see org.springframework.core.InfrastructureProxy#getWrappedObject() + */ + public static boolean sameResourceFactory(ResourceTransactionManager tm, Object resourceFactory) { + return unwrapResourceIfNecessary(tm.getResourceFactory()).equals(unwrapResourceIfNecessary(resourceFactory)); + } + + /** + * Unwrap the given resource handle if necessary; otherwise return + * the given handle as-is. + * @see org.springframework.core.InfrastructureProxy#getWrappedObject() + */ + static Object unwrapResourceIfNecessary(Object resource) { + Assert.notNull(resource, "Resource must not be null"); + Object resourceRef = resource; + // unwrap infrastructure proxy + if (resourceRef instanceof InfrastructureProxy) { + resourceRef = ((InfrastructureProxy) resourceRef).getWrappedObject(); + } + if (aopAvailable) { + // now unwrap scoped proxy + resourceRef = ScopedProxyUnwrapper.unwrapIfNecessary(resourceRef); + } + return resourceRef; + } + + + /** + * Trigger {@code flush} callbacks on all currently registered synchronizations. + * @throws RuntimeException if thrown by a {@code flush} callback + * @see TransactionSynchronization#flush() + */ + public static void triggerFlush() { + for (TransactionSynchronization synchronization : TransactionSynchronizationManager.getSynchronizations()) { + synchronization.flush(); + } + } + + /** + * Trigger {@code beforeCommit} callbacks on all currently registered synchronizations. + * @param readOnly whether the transaction is defined as read-only transaction + * @throws RuntimeException if thrown by a {@code beforeCommit} callback + * @see TransactionSynchronization#beforeCommit(boolean) + */ + public static void triggerBeforeCommit(boolean readOnly) { + for (TransactionSynchronization synchronization : TransactionSynchronizationManager.getSynchronizations()) { + synchronization.beforeCommit(readOnly); + } + } + + /** + * Trigger {@code beforeCompletion} callbacks on all currently registered synchronizations. + * @see TransactionSynchronization#beforeCompletion() + */ + public static void triggerBeforeCompletion() { + for (TransactionSynchronization synchronization : TransactionSynchronizationManager.getSynchronizations()) { + try { + synchronization.beforeCompletion(); + } + catch (Throwable tsex) { + logger.error("TransactionSynchronization.beforeCompletion threw exception", tsex); + } + } + } + + /** + * Trigger {@code afterCommit} callbacks on all currently registered synchronizations. + * @throws RuntimeException if thrown by a {@code afterCommit} callback + * @see TransactionSynchronizationManager#getSynchronizations() + * @see TransactionSynchronization#afterCommit() + */ + public static void triggerAfterCommit() { + invokeAfterCommit(TransactionSynchronizationManager.getSynchronizations()); + } + + /** + * Actually invoke the {@code afterCommit} methods of the + * given Spring TransactionSynchronization objects. + * @param synchronizations a List of TransactionSynchronization objects + * @see TransactionSynchronization#afterCommit() + */ + public static void invokeAfterCommit(@Nullable List synchronizations) { + if (synchronizations != null) { + for (TransactionSynchronization synchronization : synchronizations) { + synchronization.afterCommit(); + } + } + } + + /** + * Trigger {@code afterCompletion} callbacks on all currently registered synchronizations. + * @param completionStatus the completion status according to the + * constants in the TransactionSynchronization interface + * @see TransactionSynchronizationManager#getSynchronizations() + * @see TransactionSynchronization#afterCompletion(int) + * @see TransactionSynchronization#STATUS_COMMITTED + * @see TransactionSynchronization#STATUS_ROLLED_BACK + * @see TransactionSynchronization#STATUS_UNKNOWN + */ + public static void triggerAfterCompletion(int completionStatus) { + List synchronizations = TransactionSynchronizationManager.getSynchronizations(); + invokeAfterCompletion(synchronizations, completionStatus); + } + + /** + * Actually invoke the {@code afterCompletion} methods of the + * given Spring TransactionSynchronization objects. + * @param synchronizations a List of TransactionSynchronization objects + * @param completionStatus the completion status according to the + * constants in the TransactionSynchronization interface + * @see TransactionSynchronization#afterCompletion(int) + * @see TransactionSynchronization#STATUS_COMMITTED + * @see TransactionSynchronization#STATUS_ROLLED_BACK + * @see TransactionSynchronization#STATUS_UNKNOWN + */ + public static void invokeAfterCompletion(@Nullable List synchronizations, + int completionStatus) { + + if (synchronizations != null) { + for (TransactionSynchronization synchronization : synchronizations) { + try { + synchronization.afterCompletion(completionStatus); + } + catch (Throwable tsex) { + logger.error("TransactionSynchronization.afterCompletion threw exception", tsex); + } + } + } + } + + + /** + * Inner class to avoid hard-coded dependency on AOP module. + */ + private static class ScopedProxyUnwrapper { + + public static Object unwrapIfNecessary(Object resource) { + if (resource instanceof ScopedObject) { + return ((ScopedObject) resource).getTargetObject(); + } + else { + return resource; + } + } + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionTemplate.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionTemplate.java new file mode 100644 index 0000000000000000000000000000000000000000..e6d1c4aff3d997fa8e35a570afc78a36b2b75971 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionTemplate.java @@ -0,0 +1,188 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.lang.reflect.UndeclaredThrowableException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.util.Assert; + +/** + * Template class that simplifies programmatic transaction demarcation and + * transaction exception handling. + * + *

The central method is {@link #execute}, supporting transactional code that + * implements the {@link TransactionCallback} interface. This template handles + * the transaction lifecycle and possible exceptions such that neither the + * TransactionCallback implementation nor the calling code needs to explicitly + * handle transactions. + * + *

Typical usage: Allows for writing low-level data access objects that use + * resources such as JDBC DataSources but are not transaction-aware themselves. + * Instead, they can implicitly participate in transactions handled by higher-level + * application services utilizing this class, making calls to the low-level + * services via an inner-class callback object. + * + *

Can be used within a service implementation via direct instantiation with + * a transaction manager reference, or get prepared in an application context + * and passed to services as bean reference. Note: The transaction manager should + * always be configured as bean in the application context: in the first case given + * to the service directly, in the second case given to the prepared template. + * + *

Supports setting the propagation behavior and the isolation level by name, + * for convenient configuration in context definitions. + * + * @author Juergen Hoeller + * @since 17.03.2003 + * @see #execute + * @see #setTransactionManager + * @see org.springframework.transaction.PlatformTransactionManager + */ +@SuppressWarnings("serial") +public class TransactionTemplate extends DefaultTransactionDefinition + implements TransactionOperations, InitializingBean { + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private PlatformTransactionManager transactionManager; + + + /** + * Construct a new TransactionTemplate for bean usage. + *

Note: The PlatformTransactionManager needs to be set before + * any {@code execute} calls. + * @see #setTransactionManager + */ + public TransactionTemplate() { + } + + /** + * Construct a new TransactionTemplate using the given transaction manager. + * @param transactionManager the transaction management strategy to be used + */ + public TransactionTemplate(PlatformTransactionManager transactionManager) { + this.transactionManager = transactionManager; + } + + /** + * Construct a new TransactionTemplate using the given transaction manager, + * taking its default settings from the given transaction definition. + * @param transactionManager the transaction management strategy to be used + * @param transactionDefinition the transaction definition to copy the + * default settings from. Local properties can still be set to change values. + */ + public TransactionTemplate(PlatformTransactionManager transactionManager, TransactionDefinition transactionDefinition) { + super(transactionDefinition); + this.transactionManager = transactionManager; + } + + + /** + * Set the transaction management strategy to be used. + */ + public void setTransactionManager(@Nullable PlatformTransactionManager transactionManager) { + this.transactionManager = transactionManager; + } + + /** + * Return the transaction management strategy to be used. + */ + @Nullable + public PlatformTransactionManager getTransactionManager() { + return this.transactionManager; + } + + @Override + public void afterPropertiesSet() { + if (this.transactionManager == null) { + throw new IllegalArgumentException("Property 'transactionManager' is required"); + } + } + + + @Override + @Nullable + public T execute(TransactionCallback action) throws TransactionException { + Assert.state(this.transactionManager != null, "No PlatformTransactionManager set"); + + if (this.transactionManager instanceof CallbackPreferringPlatformTransactionManager) { + return ((CallbackPreferringPlatformTransactionManager) this.transactionManager).execute(this, action); + } + else { + TransactionStatus status = this.transactionManager.getTransaction(this); + T result; + try { + result = action.doInTransaction(status); + } + catch (RuntimeException | Error ex) { + // Transactional code threw application exception -> rollback + rollbackOnException(status, ex); + throw ex; + } + catch (Throwable ex) { + // Transactional code threw unexpected exception -> rollback + rollbackOnException(status, ex); + throw new UndeclaredThrowableException(ex, "TransactionCallback threw undeclared checked exception"); + } + this.transactionManager.commit(status); + return result; + } + } + + /** + * Perform a rollback, handling rollback exceptions properly. + * @param status object representing the transaction + * @param ex the thrown application exception or error + * @throws TransactionException in case of a rollback error + */ + private void rollbackOnException(TransactionStatus status, Throwable ex) throws TransactionException { + Assert.state(this.transactionManager != null, "No PlatformTransactionManager set"); + + logger.debug("Initiating transaction rollback on application exception", ex); + try { + this.transactionManager.rollback(status); + } + catch (TransactionSystemException ex2) { + logger.error("Application exception overridden by rollback exception", ex); + ex2.initApplicationException(ex); + throw ex2; + } + catch (RuntimeException | Error ex2) { + logger.error("Application exception overridden by rollback exception", ex); + throw ex2; + } + } + + + @Override + public boolean equals(Object other) { + return (this == other || (super.equals(other) && (!(other instanceof TransactionTemplate) || + getTransactionManager() == ((TransactionTemplate) other).getTransactionManager()))); + } + +} diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/package-info.java b/spring-tx/src/main/java/org/springframework/transaction/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..59a310edf0a640013319c751ce7ff334919e5e64 --- /dev/null +++ b/spring-tx/src/main/java/org/springframework/transaction/support/package-info.java @@ -0,0 +1,11 @@ +/** + * Support classes for the org.springframework.transaction package. + * Provides an abstract base class for transaction manager implementations, + * and a template plus callback for transaction demarcation. + */ +@NonNullApi +@NonNullFields +package org.springframework.transaction.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-tx/src/main/java/overview.html b/spring-tx/src/main/java/overview.html new file mode 100644 index 0000000000000000000000000000000000000000..3f2136f8129de06f803b5618abad4dc83ebacf58 --- /dev/null +++ b/spring-tx/src/main/java/overview.html @@ -0,0 +1,7 @@ + + +

+Spring's transaction infrastructure. Also includes DAO support and JCA integration. +

+ + \ No newline at end of file diff --git a/spring-tx/src/main/resources/META-INF/spring.handlers b/spring-tx/src/main/resources/META-INF/spring.handlers new file mode 100644 index 0000000000000000000000000000000000000000..4c07197ea4e5cb9349f956201c35323eb9fb3daf --- /dev/null +++ b/spring-tx/src/main/resources/META-INF/spring.handlers @@ -0,0 +1 @@ +http\://www.springframework.org/schema/tx=org.springframework.transaction.config.TxNamespaceHandler diff --git a/spring-tx/src/main/resources/META-INF/spring.schemas b/spring-tx/src/main/resources/META-INF/spring.schemas new file mode 100644 index 0000000000000000000000000000000000000000..998c0557a1342aed5b3365e72f427740db1b784c --- /dev/null +++ b/spring-tx/src/main/resources/META-INF/spring.schemas @@ -0,0 +1,20 @@ +http\://www.springframework.org/schema/tx/spring-tx-2.0.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-2.5.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-3.0.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-3.1.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-3.2.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-4.0.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-4.1.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-4.2.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx-4.3.xsd=org/springframework/transaction/config/spring-tx.xsd +http\://www.springframework.org/schema/tx/spring-tx.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-2.0.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-2.5.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-3.0.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-3.1.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-3.2.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-4.0.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-4.1.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-4.2.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx-4.3.xsd=org/springframework/transaction/config/spring-tx.xsd +https\://www.springframework.org/schema/tx/spring-tx.xsd=org/springframework/transaction/config/spring-tx.xsd diff --git a/spring-tx/src/main/resources/META-INF/spring.tooling b/spring-tx/src/main/resources/META-INF/spring.tooling new file mode 100644 index 0000000000000000000000000000000000000000..8a83a4b6b12556b0ddbf79dd79c9bce9ba6f6896 --- /dev/null +++ b/spring-tx/src/main/resources/META-INF/spring.tooling @@ -0,0 +1,4 @@ +# Tooling related information for the tx namespace +http\://www.springframework.org/schema/tx@name=tx Namespace +http\://www.springframework.org/schema/tx@prefix=tx +http\://www.springframework.org/schema/tx@icon=org/springframework/transaction/config/spring-tx.gif diff --git a/spring-tx/src/main/resources/org/springframework/jca/context/ra.xml b/spring-tx/src/main/resources/org/springframework/jca/context/ra.xml new file mode 100644 index 0000000000000000000000000000000000000000..4309c9a4f214a96b580f7a3fc6515283f7a1f0ca --- /dev/null +++ b/spring-tx/src/main/resources/org/springframework/jca/context/ra.xml @@ -0,0 +1,17 @@ + + + Spring Framework + Spring Connector + 1.0 + + org.springframework.jca.context.SpringContextResourceAdapter + + ContextConfigLocation + java.lang.String + META-INF/applicationContext.xml + + + diff --git a/spring-tx/src/main/resources/org/springframework/transaction/config/spring-tx.gif b/spring-tx/src/main/resources/org/springframework/transaction/config/spring-tx.gif new file mode 100644 index 0000000000000000000000000000000000000000..20ed1f9a4438054835c3bd7231c59dcc36d9f24e Binary files /dev/null and b/spring-tx/src/main/resources/org/springframework/transaction/config/spring-tx.gif differ diff --git a/spring-tx/src/main/resources/org/springframework/transaction/config/spring-tx.xsd b/spring-tx/src/main/resources/org/springframework/transaction/config/spring-tx.xsd new file mode 100644 index 0000000000000000000000000000000000000000..461d13a43f89a461c44765f9c66a9fe477ffecd8 --- /dev/null +++ b/spring-tx/src/main/resources/org/springframework/transaction/config/spring-tx.xsd @@ -0,0 +1,247 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationAdvisorTests.java b/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationAdvisorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a72105a0622334bc62b9177bd59c2846783204e5 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationAdvisorTests.java @@ -0,0 +1,227 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import javax.persistence.PersistenceException; + +import org.junit.Test; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.dao.DataAccessException; +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.dao.support.DataAccessUtilsTests.MapPersistenceExceptionTranslator; +import org.springframework.dao.support.PersistenceExceptionTranslator; +import org.springframework.stereotype.Repository; + +import static org.junit.Assert.*; + +/** + * Tests for PersistenceExceptionTranslationAdvisor's exception translation, as applied by + * PersistenceExceptionTranslationPostProcessor. + * + * @author Rod Johnson + * @author Juergen Hoeller + */ +public class PersistenceExceptionTranslationAdvisorTests { + + private RuntimeException doNotTranslate = new RuntimeException(); + + private PersistenceException persistenceException1 = new PersistenceException(); + + protected RepositoryInterface createProxy(RepositoryInterfaceImpl target) { + MapPersistenceExceptionTranslator mpet = new MapPersistenceExceptionTranslator(); + mpet.addTranslation(persistenceException1, new InvalidDataAccessApiUsageException("", persistenceException1)); + ProxyFactory pf = new ProxyFactory(target); + pf.addInterface(RepositoryInterface.class); + addPersistenceExceptionTranslation(pf, mpet); + return (RepositoryInterface) pf.getProxy(); + } + + protected void addPersistenceExceptionTranslation(ProxyFactory pf, PersistenceExceptionTranslator pet) { + pf.addAdvisor(new PersistenceExceptionTranslationAdvisor(pet, Repository.class)); + } + + @Test + public void noTranslationNeeded() { + RepositoryInterfaceImpl target = new RepositoryInterfaceImpl(); + RepositoryInterface ri = createProxy(target); + + ri.noThrowsClause(); + ri.throwsPersistenceException(); + + target.setBehavior(persistenceException1); + try { + ri.noThrowsClause(); + fail(); + } + catch (RuntimeException ex) { + assertSame(persistenceException1, ex); + } + try { + ri.throwsPersistenceException(); + fail(); + } + catch (RuntimeException ex) { + assertSame(persistenceException1, ex); + } + } + + @Test + public void translationNotNeededForTheseExceptions() { + RepositoryInterfaceImpl target = new StereotypedRepositoryInterfaceImpl(); + RepositoryInterface ri = createProxy(target); + + ri.noThrowsClause(); + ri.throwsPersistenceException(); + + target.setBehavior(doNotTranslate); + try { + ri.noThrowsClause(); + fail(); + } + catch (RuntimeException ex) { + assertSame(doNotTranslate, ex); + } + try { + ri.throwsPersistenceException(); + fail(); + } + catch (RuntimeException ex) { + assertSame(doNotTranslate, ex); + } + } + + @Test + public void translationNeededForTheseExceptions() { + doTestTranslationNeededForTheseExceptions(new StereotypedRepositoryInterfaceImpl()); + } + + @Test + public void translationNeededForTheseExceptionsOnSuperclass() { + doTestTranslationNeededForTheseExceptions(new MyStereotypedRepositoryInterfaceImpl()); + } + + @Test + public void translationNeededForTheseExceptionsWithCustomStereotype() { + doTestTranslationNeededForTheseExceptions(new CustomStereotypedRepositoryInterfaceImpl()); + } + + @Test + public void translationNeededForTheseExceptionsOnInterface() { + doTestTranslationNeededForTheseExceptions(new MyInterfaceStereotypedRepositoryInterfaceImpl()); + } + + @Test + public void translationNeededForTheseExceptionsOnInheritedInterface() { + doTestTranslationNeededForTheseExceptions(new MyInterfaceInheritedStereotypedRepositoryInterfaceImpl()); + } + + private void doTestTranslationNeededForTheseExceptions(RepositoryInterfaceImpl target) { + RepositoryInterface ri = createProxy(target); + + target.setBehavior(persistenceException1); + try { + ri.noThrowsClause(); + fail(); + } + catch (DataAccessException ex) { + // Expected + assertSame(persistenceException1, ex.getCause()); + } + catch (PersistenceException ex) { + fail("Should have been translated"); + } + + try { + ri.throwsPersistenceException(); + fail(); + } + catch (PersistenceException ex) { + assertSame(persistenceException1, ex); + } + } + + + public interface RepositoryInterface { + + void noThrowsClause(); + + void throwsPersistenceException() throws PersistenceException; + } + + public static class RepositoryInterfaceImpl implements RepositoryInterface { + + private RuntimeException runtimeException; + + public void setBehavior(RuntimeException rex) { + this.runtimeException = rex; + } + + @Override + public void noThrowsClause() { + if (runtimeException != null) { + throw runtimeException; + } + } + + @Override + public void throwsPersistenceException() throws PersistenceException { + if (runtimeException != null) { + throw runtimeException; + } + } + } + + @Repository + public static class StereotypedRepositoryInterfaceImpl extends RepositoryInterfaceImpl { + // Extends above class just to add repository annotation + } + + public static class MyStereotypedRepositoryInterfaceImpl extends StereotypedRepositoryInterfaceImpl { + } + + @MyRepository + public static class CustomStereotypedRepositoryInterfaceImpl extends RepositoryInterfaceImpl { + } + + @Target({ElementType.TYPE}) + @Retention(RetentionPolicy.RUNTIME) + @Repository + public @interface MyRepository { + } + + @Repository + public interface StereotypedInterface { + } + + public static class MyInterfaceStereotypedRepositoryInterfaceImpl extends RepositoryInterfaceImpl + implements StereotypedInterface { + } + + public interface StereotypedInheritingInterface extends StereotypedInterface { + } + + public static class MyInterfaceInheritedStereotypedRepositoryInterfaceImpl extends RepositoryInterfaceImpl + implements StereotypedInheritingInterface { + } + +} diff --git a/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationInterceptorTests.java b/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationInterceptorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..351e971b3ea8cb19614c004167f9dcde9c7311b0 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationInterceptorTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2007 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.annotation; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.dao.support.PersistenceExceptionTranslationInterceptor; +import org.springframework.dao.support.PersistenceExceptionTranslator; +import org.springframework.stereotype.Repository; + +/** + * Tests for standalone usage of a PersistenceExceptionTranslationInterceptor, as explicit advice bean in a BeanFactory + * rather than applied as part of a PersistenceExceptionTranslationAdvisor. + * + * @author Juergen Hoeller + */ +public class PersistenceExceptionTranslationInterceptorTests extends PersistenceExceptionTranslationAdvisorTests { + + @Override + protected void addPersistenceExceptionTranslation(ProxyFactory pf, PersistenceExceptionTranslator pet) { + if (AnnotationUtils.findAnnotation(pf.getTargetClass(), Repository.class) != null) { + DefaultListableBeanFactory bf = new DefaultListableBeanFactory(); + bf.registerBeanDefinition("peti", new RootBeanDefinition(PersistenceExceptionTranslationInterceptor.class)); + bf.registerSingleton("pet", pet); + pf.addAdvice((PersistenceExceptionTranslationInterceptor) bf.getBean("peti")); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationPostProcessorTests.java b/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationPostProcessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..95ebe3fb71eb642a937dec91ffc7b8037c61660d --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/dao/annotation/PersistenceExceptionTranslationPostProcessorTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.annotation; + +import javax.persistence.PersistenceException; + +import org.aspectj.lang.JoinPoint; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Before; +import org.junit.Test; + +import org.springframework.aop.Advisor; +import org.springframework.aop.aspectj.annotation.AnnotationAwareAspectJAutoProxyCreator; +import org.springframework.aop.framework.Advised; +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.dao.DataAccessException; +import org.springframework.dao.DataAccessResourceFailureException; +import org.springframework.dao.annotation.PersistenceExceptionTranslationAdvisorTests.RepositoryInterface; +import org.springframework.dao.annotation.PersistenceExceptionTranslationAdvisorTests.RepositoryInterfaceImpl; +import org.springframework.dao.annotation.PersistenceExceptionTranslationAdvisorTests.StereotypedRepositoryInterfaceImpl; +import org.springframework.dao.support.PersistenceExceptionTranslator; +import org.springframework.stereotype.Repository; + +import static org.junit.Assert.*; + +/** + * @author Rod Johnson + * @author Juergen Hoeller + */ +public class PersistenceExceptionTranslationPostProcessorTests { + + @Test + @SuppressWarnings("resource") + public void proxiesCorrectly() { + GenericApplicationContext gac = new GenericApplicationContext(); + gac.registerBeanDefinition("translator", + new RootBeanDefinition(PersistenceExceptionTranslationPostProcessor.class)); + gac.registerBeanDefinition("notProxied", new RootBeanDefinition(RepositoryInterfaceImpl.class)); + gac.registerBeanDefinition("proxied", new RootBeanDefinition(StereotypedRepositoryInterfaceImpl.class)); + gac.registerBeanDefinition("classProxied", new RootBeanDefinition(RepositoryWithoutInterface.class)); + gac.registerBeanDefinition("classProxiedAndAdvised", + new RootBeanDefinition(RepositoryWithoutInterfaceAndOtherwiseAdvised.class)); + gac.registerBeanDefinition("myTranslator", + new RootBeanDefinition(MyPersistenceExceptionTranslator.class)); + gac.registerBeanDefinition("proxyCreator", + BeanDefinitionBuilder.rootBeanDefinition(AnnotationAwareAspectJAutoProxyCreator.class). + addPropertyValue("order", 50).getBeanDefinition()); + gac.registerBeanDefinition("logger", new RootBeanDefinition(LogAllAspect.class)); + gac.refresh(); + + RepositoryInterface shouldNotBeProxied = (RepositoryInterface) gac.getBean("notProxied"); + assertFalse(AopUtils.isAopProxy(shouldNotBeProxied)); + RepositoryInterface shouldBeProxied = (RepositoryInterface) gac.getBean("proxied"); + assertTrue(AopUtils.isAopProxy(shouldBeProxied)); + RepositoryWithoutInterface rwi = (RepositoryWithoutInterface) gac.getBean("classProxied"); + assertTrue(AopUtils.isAopProxy(rwi)); + checkWillTranslateExceptions(rwi); + + Additional rwi2 = (Additional) gac.getBean("classProxiedAndAdvised"); + assertTrue(AopUtils.isAopProxy(rwi2)); + rwi2.additionalMethod(false); + checkWillTranslateExceptions(rwi2); + try { + rwi2.additionalMethod(true); + fail("Should have thrown DataAccessResourceFailureException"); + } + catch (DataAccessResourceFailureException ex) { + assertEquals("my failure", ex.getMessage()); + } + } + + protected void checkWillTranslateExceptions(Object o) { + assertTrue(o instanceof Advised); + Advised a = (Advised) o; + for (Advisor advisor : a.getAdvisors()) { + if (advisor instanceof PersistenceExceptionTranslationAdvisor) { + return; + } + } + fail("No translation"); + } + + + @Repository + public static class RepositoryWithoutInterface { + + public void nameDoesntMatter() { + } + } + + + public interface Additional { + + void additionalMethod(boolean fail); + } + + + public static class RepositoryWithoutInterfaceAndOtherwiseAdvised extends StereotypedRepositoryInterfaceImpl + implements Additional { + + @Override + public void additionalMethod(boolean fail) { + if (fail) { + throw new PersistenceException("my failure"); + } + } + } + + + public static class MyPersistenceExceptionTranslator implements PersistenceExceptionTranslator { + + @Override + public DataAccessException translateExceptionIfPossible(RuntimeException ex) { + if (ex instanceof PersistenceException) { + return new DataAccessResourceFailureException(ex.getMessage()); + } + return null; + } + } + + + @Aspect + public static class LogAllAspect { + + @Before("execution(void *.additionalMethod(*))") + public void log(JoinPoint jp) { + System.out.println("Before " + jp.getSignature().getName()); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/dao/support/ChainedPersistenceExceptionTranslatorTests.java b/spring-tx/src/test/java/org/springframework/dao/support/ChainedPersistenceExceptionTranslatorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7ba7cc365ee43430f95bd7e2989c687c0bdc77c6 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/dao/support/ChainedPersistenceExceptionTranslatorTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.support; + +import org.junit.Test; + +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.dao.OptimisticLockingFailureException; +import org.springframework.dao.support.DataAccessUtilsTests.MapPersistenceExceptionTranslator; + +import static org.junit.Assert.*; + +/** + * @author Rod Johnson + * @since 2.0 + */ +public class ChainedPersistenceExceptionTranslatorTests { + + @Test + public void empty() { + ChainedPersistenceExceptionTranslator pet = new ChainedPersistenceExceptionTranslator(); + //MapPersistenceExceptionTranslator mpet = new MapPersistenceExceptionTranslator(); + RuntimeException in = new RuntimeException("in"); + assertSame(in, DataAccessUtils.translateIfNecessary(in, pet)); + } + + @Test + public void exceptionTranslationWithTranslation() { + MapPersistenceExceptionTranslator mpet1 = new MapPersistenceExceptionTranslator(); + RuntimeException in1 = new RuntimeException("in"); + InvalidDataAccessApiUsageException out1 = new InvalidDataAccessApiUsageException("out"); + InvalidDataAccessApiUsageException out2 = new InvalidDataAccessApiUsageException("out"); + mpet1.addTranslation(in1, out1); + + ChainedPersistenceExceptionTranslator chainedPet1 = new ChainedPersistenceExceptionTranslator(); + assertSame("Should not translate yet", in1, DataAccessUtils.translateIfNecessary(in1, chainedPet1)); + chainedPet1.addDelegate(mpet1); + assertSame("Should now translate", out1, DataAccessUtils.translateIfNecessary(in1, chainedPet1)); + + // Now add a new translator and verify it wins + MapPersistenceExceptionTranslator mpet2 = new MapPersistenceExceptionTranslator(); + mpet2.addTranslation(in1, out2); + chainedPet1.addDelegate(mpet2); + assertSame("Should still translate the same due to ordering", + out1, DataAccessUtils.translateIfNecessary(in1, chainedPet1)); + + ChainedPersistenceExceptionTranslator chainedPet2 = new ChainedPersistenceExceptionTranslator(); + chainedPet2.addDelegate(mpet2); + chainedPet2.addDelegate(mpet1); + assertSame("Should translate differently due to ordering", + out2, DataAccessUtils.translateIfNecessary(in1, chainedPet2)); + + RuntimeException in2 = new RuntimeException("in2"); + OptimisticLockingFailureException out3 = new OptimisticLockingFailureException("out2"); + assertNull(chainedPet2.translateExceptionIfPossible(in2)); + MapPersistenceExceptionTranslator mpet3 = new MapPersistenceExceptionTranslator(); + mpet3.addTranslation(in2, out3); + chainedPet2.addDelegate(mpet3); + assertSame(out3, chainedPet2.translateExceptionIfPossible(in2)); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/dao/support/DataAccessUtilsTests.java b/spring-tx/src/test/java/org/springframework/dao/support/DataAccessUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5489b4b350fd830e32ef8fe940b5853bb8fa70c4 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/dao/support/DataAccessUtilsTests.java @@ -0,0 +1,290 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.dao.support; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.dao.DataAccessException; +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.dao.TypeMismatchDataAccessException; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @since 20.10.2004 + */ +public class DataAccessUtilsTests { + + @Test + public void withEmptyCollection() { + Collection col = new HashSet<>(); + + assertNull(DataAccessUtils.uniqueResult(col)); + + try { + DataAccessUtils.requiredUniqueResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(0, ex.getActualSize()); + } + + try { + DataAccessUtils.objectResult(col, String.class); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(0, ex.getActualSize()); + } + + try { + DataAccessUtils.intResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(0, ex.getActualSize()); + } + + try { + DataAccessUtils.longResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(0, ex.getActualSize()); + } + } + + @Test + public void withTooLargeCollection() { + Collection col = new HashSet<>(2); + col.add("test1"); + col.add("test2"); + + try { + DataAccessUtils.uniqueResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(2, ex.getActualSize()); + } + + try { + DataAccessUtils.requiredUniqueResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(2, ex.getActualSize()); + } + + try { + DataAccessUtils.objectResult(col, String.class); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(2, ex.getActualSize()); + } + + try { + DataAccessUtils.intResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(2, ex.getActualSize()); + } + + try { + DataAccessUtils.longResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(2, ex.getActualSize()); + } + } + + @Test + public void withInteger() { + Collection col = new HashSet<>(1); + col.add(5); + + assertEquals(Integer.valueOf(5), DataAccessUtils.uniqueResult(col)); + assertEquals(Integer.valueOf(5), DataAccessUtils.requiredUniqueResult(col)); + assertEquals(Integer.valueOf(5), DataAccessUtils.objectResult(col, Integer.class)); + assertEquals("5", DataAccessUtils.objectResult(col, String.class)); + assertEquals(5, DataAccessUtils.intResult(col)); + assertEquals(5, DataAccessUtils.longResult(col)); + } + + @Test + public void withSameIntegerInstanceTwice() { + Integer i = 5; + Collection col = new ArrayList<>(1); + col.add(i); + col.add(i); + + assertEquals(Integer.valueOf(5), DataAccessUtils.uniqueResult(col)); + assertEquals(Integer.valueOf(5), DataAccessUtils.requiredUniqueResult(col)); + assertEquals(Integer.valueOf(5), DataAccessUtils.objectResult(col, Integer.class)); + assertEquals("5", DataAccessUtils.objectResult(col, String.class)); + assertEquals(5, DataAccessUtils.intResult(col)); + assertEquals(5, DataAccessUtils.longResult(col)); + } + + @Test + @SuppressWarnings("deprecation") // on JDK 9 + public void withEquivalentIntegerInstanceTwice() { + Collection col = new ArrayList<>(2); + col.add(new Integer(5)); + col.add(new Integer(5)); + + try { + DataAccessUtils.uniqueResult(col); + fail("Should have thrown IncorrectResultSizeDataAccessException"); + } + catch (IncorrectResultSizeDataAccessException ex) { + // expected + assertEquals(1, ex.getExpectedSize()); + assertEquals(2, ex.getActualSize()); + } + } + + @Test + public void withLong() { + Collection col = new HashSet<>(1); + col.add(5L); + + assertEquals(Long.valueOf(5L), DataAccessUtils.uniqueResult(col)); + assertEquals(Long.valueOf(5L), DataAccessUtils.requiredUniqueResult(col)); + assertEquals(Long.valueOf(5L), DataAccessUtils.objectResult(col, Long.class)); + assertEquals("5", DataAccessUtils.objectResult(col, String.class)); + assertEquals(5, DataAccessUtils.intResult(col)); + assertEquals(5, DataAccessUtils.longResult(col)); + } + + @Test + public void withString() { + Collection col = new HashSet<>(1); + col.add("test1"); + + assertEquals("test1", DataAccessUtils.uniqueResult(col)); + assertEquals("test1", DataAccessUtils.requiredUniqueResult(col)); + assertEquals("test1", DataAccessUtils.objectResult(col, String.class)); + + try { + DataAccessUtils.intResult(col); + fail("Should have thrown TypeMismatchDataAccessException"); + } + catch (TypeMismatchDataAccessException ex) { + // expected + } + + try { + DataAccessUtils.longResult(col); + fail("Should have thrown TypeMismatchDataAccessException"); + } + catch (TypeMismatchDataAccessException ex) { + // expected + } + } + + @Test + public void withDate() { + Date date = new Date(); + Collection col = new HashSet<>(1); + col.add(date); + + assertEquals(date, DataAccessUtils.uniqueResult(col)); + assertEquals(date, DataAccessUtils.requiredUniqueResult(col)); + assertEquals(date, DataAccessUtils.objectResult(col, Date.class)); + assertEquals(date.toString(), DataAccessUtils.objectResult(col, String.class)); + + try { + DataAccessUtils.intResult(col); + fail("Should have thrown TypeMismatchDataAccessException"); + } + catch (TypeMismatchDataAccessException ex) { + // expected + } + + try { + DataAccessUtils.longResult(col); + fail("Should have thrown TypeMismatchDataAccessException"); + } + catch (TypeMismatchDataAccessException ex) { + // expected + } + } + + @Test + public void exceptionTranslationWithNoTranslation() { + MapPersistenceExceptionTranslator mpet = new MapPersistenceExceptionTranslator(); + RuntimeException in = new RuntimeException(); + assertSame(in, DataAccessUtils.translateIfNecessary(in, mpet)); + } + + @Test + public void exceptionTranslationWithTranslation() { + MapPersistenceExceptionTranslator mpet = new MapPersistenceExceptionTranslator(); + RuntimeException in = new RuntimeException("in"); + InvalidDataAccessApiUsageException out = new InvalidDataAccessApiUsageException("out"); + mpet.addTranslation(in, out); + assertSame(out, DataAccessUtils.translateIfNecessary(in, mpet)); + } + + + public static class MapPersistenceExceptionTranslator implements PersistenceExceptionTranslator { + + // in to out + private final Map translations = new HashMap<>(); + + public void addTranslation(RuntimeException in, RuntimeException out) { + this.translations.put(in, out); + } + + @Override + public DataAccessException translateExceptionIfPossible(RuntimeException ex) { + return (DataAccessException) translations.get(ex); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/jca/cci/CciLocalTransactionTests.java b/spring-tx/src/test/java/org/springframework/jca/cci/CciLocalTransactionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..767a184ffc2753878f7c293caa578f4fc88ab116 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/jca/cci/CciLocalTransactionTests.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.Interaction; +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.LocalTransaction; +import javax.resource.cci.Record; + +import org.junit.Test; + +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.jca.cci.connection.CciLocalTransactionManager; +import org.springframework.jca.cci.core.CciTemplate; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionCallbackWithoutResult; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.transaction.support.TransactionTemplate; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Thierry Templier + * @author Chris Beams + */ +public class CciLocalTransactionTests { + + /** + * Test if a transaction ( begin / commit ) is executed on the + * LocalTransaction when CciLocalTransactionManager is specified as + * transaction manager. + */ + @Test + public void testLocalTransactionCommit() throws ResourceException { + final ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + LocalTransaction localTransaction = mock(LocalTransaction.class); + final Record record = mock(Record.class); + final InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.getLocalTransaction()).willReturn(localTransaction); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, record, record)).willReturn(true); + given(connection.getLocalTransaction()).willReturn(localTransaction); + + CciLocalTransactionManager tm = new CciLocalTransactionManager(); + tm.setConnectionFactory(connectionFactory); + TransactionTemplate tt = new TransactionTemplate(tm); + + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue("Has thread connection", TransactionSynchronizationManager.hasResource(connectionFactory)); + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, record, record); + } + }); + + verify(localTransaction).begin(); + verify(interaction).close(); + verify(localTransaction).commit(); + verify(connection).close(); + } + + /** + * Test if a transaction ( begin / rollback ) is executed on the + * LocalTransaction when CciLocalTransactionManager is specified as + * transaction manager and a non-checked exception is thrown. + */ + @Test + public void testLocalTransactionRollback() throws ResourceException { + final ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + LocalTransaction localTransaction = mock(LocalTransaction.class); + final Record record = mock(Record.class); + final InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.getLocalTransaction()).willReturn(localTransaction); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, record, record)).willReturn(true); + given(connection.getLocalTransaction()).willReturn(localTransaction); + + CciLocalTransactionManager tm = new CciLocalTransactionManager(); + tm.setConnectionFactory(connectionFactory); + TransactionTemplate tt = new TransactionTemplate(tm); + + try { + tt.execute(new TransactionCallback() { + @Override + public Void doInTransaction(TransactionStatus status) { + assertTrue("Has thread connection", TransactionSynchronizationManager.hasResource(connectionFactory)); + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, record, record); + throw new DataRetrievalFailureException("error"); + } + }); + } + catch (Exception ex) { + } + + verify(localTransaction).begin(); + verify(interaction).close(); + verify(localTransaction).rollback(); + verify(connection).close(); + } +} diff --git a/spring-tx/src/test/java/org/springframework/jca/cci/CciTemplateTests.java b/spring-tx/src/test/java/org/springframework/jca/cci/CciTemplateTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fee931e31ee81a0030ecd92f670603b840b9d98f --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/jca/cci/CciTemplateTests.java @@ -0,0 +1,541 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import java.sql.SQLException; + +import javax.resource.NotSupportedException; +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.ConnectionSpec; +import javax.resource.cci.IndexedRecord; +import javax.resource.cci.Interaction; +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.MappedRecord; +import javax.resource.cci.Record; +import javax.resource.cci.RecordFactory; +import javax.resource.cci.ResultSet; + +import org.junit.Test; + +import org.springframework.jca.cci.connection.ConnectionSpecConnectionFactoryAdapter; +import org.springframework.jca.cci.connection.NotSupportedRecordFactory; +import org.springframework.jca.cci.core.CciTemplate; +import org.springframework.jca.cci.core.ConnectionCallback; +import org.springframework.jca.cci.core.InteractionCallback; +import org.springframework.jca.cci.core.RecordCreator; +import org.springframework.jca.cci.core.RecordExtractor; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Thierry Templier + * @author Juergen Hoeller + * @author Chris Beams + */ +public class CciTemplateTests { + + @Test + public void testCreateIndexedRecord() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + IndexedRecord indexedRecord = mock(IndexedRecord.class); + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(recordFactory.createIndexedRecord("name")).willReturn(indexedRecord); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.createIndexedRecord("name"); + + verify(recordFactory).createIndexedRecord("name"); + } + + @Test + public void testCreateMappedRecord() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + MappedRecord mappedRecord = mock(MappedRecord.class); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(recordFactory.createMappedRecord("name")).willReturn(mappedRecord); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.createMappedRecord("name"); + + verify(recordFactory).createMappedRecord("name"); + } + + @Test + public void testTemplateExecuteInputOutput() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, inputRecord, outputRecord); + + verify(interaction).execute(interactionSpec, inputRecord, outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteWithCreatorAndRecordFactoryNotSupported() + throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + + Record inputRecord = mock(Record.class); + final Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connectionFactory.getRecordFactory()).willThrow(new NotSupportedException("not supported")); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.setOutputRecordCreator(new RecordCreator() { + @Override + public Record createRecord(RecordFactory recordFactory) { + assertTrue(recordFactory instanceof NotSupportedRecordFactory); + return outputRecord; + } + }); + ct.execute(interactionSpec, inputRecord); + + verify(interaction).execute(interactionSpec, inputRecord, outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputTrueWithCreator2() + throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordCreator creator = mock(RecordCreator.class); + + Record inputRecord = mock(Record.class); + final Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(connection.createInteraction()).willReturn(interaction); + given(creator.createRecord(recordFactory)).willReturn(outputRecord); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.setOutputRecordCreator(creator); + ct.execute(interactionSpec, inputRecord); + + verify(interaction).execute(interactionSpec, inputRecord, outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputFalse() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord)).willReturn(outputRecord); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, inputRecord); + + verify(interaction).execute(interactionSpec, inputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void testTemplateExecuteInputExtractorTrueWithCreator() + throws ResourceException, SQLException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordExtractor extractor = mock(RecordExtractor.class); + RecordCreator creator = mock(RecordCreator.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(creator.createRecord(recordFactory)).willReturn(outputRecord); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + given(extractor.extractData(outputRecord)).willReturn(new Object()); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.setOutputRecordCreator(creator); + ct.execute(interactionSpec, inputRecord, extractor); + + verify(extractor).extractData(outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void testTemplateExecuteInputExtractorFalse() + throws ResourceException, SQLException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordExtractor extractor = mock(RecordExtractor.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord)).willReturn(outputRecord); + given(extractor.extractData(outputRecord)).willReturn(new Object()); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, inputRecord, extractor); + + verify(extractor).extractData(outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputGeneratorTrueWithCreator() + throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordCreator generator = mock(RecordCreator.class); + RecordCreator creator = mock(RecordCreator.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(generator.createRecord(recordFactory)).willReturn(inputRecord); + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(creator.createRecord(recordFactory)).willReturn(outputRecord); + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.setOutputRecordCreator(creator); + ct.execute(interactionSpec, generator); + + verify(interaction).execute(interactionSpec, inputRecord, outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputGeneratorFalse() + throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordCreator generator = mock(RecordCreator.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(generator.createRecord(recordFactory)).willReturn(inputRecord); + given(interaction.execute(interactionSpec, inputRecord)).willReturn(outputRecord); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, generator); + + verify(interaction).execute(interactionSpec, inputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void testTemplateExecuteInputGeneratorExtractorTrueWithCreator() + throws ResourceException, SQLException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordCreator generator = mock(RecordCreator.class); + RecordExtractor extractor = mock(RecordExtractor.class); + RecordCreator creator = mock(RecordCreator.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + Object obj = new Object(); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(creator.createRecord(recordFactory)).willReturn(outputRecord); + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(generator.createRecord(recordFactory)).willReturn(inputRecord); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + given(extractor.extractData(outputRecord)).willReturn(obj); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.setOutputRecordCreator(creator); + assertEquals(obj, ct.execute(interactionSpec, generator, extractor)); + + verify(interaction).close(); + verify(connection).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void testTemplateExecuteInputGeneratorExtractorFalse() + throws ResourceException, SQLException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordCreator generator = mock(RecordCreator.class); + RecordExtractor extractor = mock(RecordExtractor.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(generator.createRecord(recordFactory)).willReturn(inputRecord); + given(interaction.execute(interactionSpec, inputRecord)).willReturn(outputRecord); + given(extractor.extractData(outputRecord)).willReturn(new Object()); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, generator, extractor); + + verify(extractor).extractData(outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputOutputConnectionSpec() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + ConnectionSpec connectionSpec = mock(ConnectionSpec.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection(connectionSpec)).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + + ConnectionSpecConnectionFactoryAdapter adapter = new ConnectionSpecConnectionFactoryAdapter(); + adapter.setTargetConnectionFactory(connectionFactory); + adapter.setConnectionSpec(connectionSpec); + CciTemplate ct = new CciTemplate(adapter); + ct.execute(interactionSpec, inputRecord, outputRecord); + + verify(interaction).execute(interactionSpec, inputRecord, outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void testTemplateExecuteInputOutputResultsSetFalse() + throws ResourceException, SQLException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + RecordFactory recordFactory = mock(RecordFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + Record record = mock(Record.class); + ResultSet resultset = mock(ResultSet.class); + RecordCreator generator = mock(RecordCreator.class); + RecordExtractor extractor = mock(RecordExtractor.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(generator.createRecord(recordFactory)).willReturn(record); + given(interaction.execute(interactionSpec, record)).willReturn(resultset); + given(extractor.extractData(resultset)).willReturn(new Object()); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, generator, extractor); + + verify(extractor).extractData(resultset); + verify(resultset).close(); + verify(interaction).close(); + verify(connection).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void testTemplateExecuteConnectionCallback() + throws ResourceException, SQLException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + ConnectionCallback connectionCallback = mock(ConnectionCallback.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connectionCallback.doInConnection(connection, connectionFactory)).willReturn(new Object()); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(connectionCallback); + + verify(connectionCallback).doInConnection(connection, connectionFactory); + verify(connection).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void testTemplateExecuteInteractionCallback() + throws ResourceException, SQLException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + InteractionCallback interactionCallback = mock(InteractionCallback.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interactionCallback.doInInteraction(interaction,connectionFactory)).willReturn(new Object()); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionCallback); + + verify(interactionCallback).doInInteraction(interaction,connectionFactory); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputTrueTrueWithCreator() + throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordCreator creator = mock(RecordCreator.class); + + Record inputOutputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputOutputRecord, inputOutputRecord)).willReturn(true); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.setOutputRecordCreator(creator); + ct.execute(interactionSpec, inputOutputRecord, inputOutputRecord); + + verify(interaction).execute(interactionSpec, inputOutputRecord, inputOutputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputTrueTrue() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + Record inputOutputRecord = mock(Record.class); + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputOutputRecord, inputOutputRecord)).willReturn(true); + + CciTemplate ct = new CciTemplate(connectionFactory); + ct.execute(interactionSpec, inputOutputRecord, inputOutputRecord); + + verify(interaction).execute(interactionSpec, inputOutputRecord, inputOutputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testTemplateExecuteInputFalseTrue() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + Record inputOutputRecord = mock(Record.class); + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputOutputRecord)).willReturn(null); + + CciTemplate ct = new CciTemplate(connectionFactory); + Record tmpOutputRecord = ct.execute(interactionSpec, inputOutputRecord); + assertNull(tmpOutputRecord); + + verify(interaction).execute(interactionSpec, inputOutputRecord); + verify(interaction).close(); + verify(connection).close(); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/jca/cci/EisOperationTests.java b/spring-tx/src/test/java/org/springframework/jca/cci/EisOperationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e834a75190b126cf63b60d3c7f720b49957ac922 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/jca/cci/EisOperationTests.java @@ -0,0 +1,214 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.cci; + +import javax.resource.ResourceException; +import javax.resource.cci.Connection; +import javax.resource.cci.ConnectionFactory; +import javax.resource.cci.Interaction; +import javax.resource.cci.InteractionSpec; +import javax.resource.cci.Record; +import javax.resource.cci.RecordFactory; + +import org.junit.Test; + +import org.springframework.jca.cci.core.RecordCreator; +import org.springframework.jca.cci.object.MappingRecordOperation; +import org.springframework.jca.cci.object.SimpleRecordOperation; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Thierry Templier + * @author Chris Beams + */ +public class EisOperationTests { + + @Test + public void testSimpleRecordOperation() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + SimpleRecordOperation query = new SimpleRecordOperation(connectionFactory, interactionSpec); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord)).willReturn(outputRecord); + + query.execute(inputRecord); + + verify(interaction).execute(interactionSpec, inputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testSimpleRecordOperationWithExplicitOutputRecord() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + SimpleRecordOperation operation = new SimpleRecordOperation(connectionFactory, interactionSpec); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + + operation.execute(inputRecord, outputRecord); + + verify(interaction).execute(interactionSpec, inputRecord, outputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testSimpleRecordOperationWithInputOutputRecord() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + + Record inputOutputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + SimpleRecordOperation query = new SimpleRecordOperation(connectionFactory, interactionSpec); + + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputOutputRecord, inputOutputRecord)).willReturn(true); + + query.execute(inputOutputRecord, inputOutputRecord); + + verify(interaction).execute(interactionSpec, inputOutputRecord, inputOutputRecord); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testMappingRecordOperation() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordFactory recordFactory = mock(RecordFactory.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + QueryCallDetector callDetector = mock(QueryCallDetector.class); + + MappingRecordOperationImpl query = new MappingRecordOperationImpl(connectionFactory, interactionSpec); + query.setCallDetector(callDetector); + + Object inObj = new Object(); + Object outObj = new Object(); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(callDetector.callCreateInputRecord(recordFactory, inObj)).willReturn(inputRecord); + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(interaction.execute(interactionSpec, inputRecord)).willReturn(outputRecord); + given(callDetector.callExtractOutputData(outputRecord)).willReturn(outObj); + + assertSame(outObj, query.execute(inObj)); + verify(interaction).close(); + verify(connection).close(); + } + + @Test + public void testMappingRecordOperationWithOutputRecordCreator() throws ResourceException { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + Connection connection = mock(Connection.class); + Interaction interaction = mock(Interaction.class); + RecordFactory recordFactory = mock(RecordFactory.class); + + Record inputRecord = mock(Record.class); + Record outputRecord = mock(Record.class); + + RecordCreator outputCreator = mock(RecordCreator.class); + + InteractionSpec interactionSpec = mock(InteractionSpec.class); + + QueryCallDetector callDetector = mock(QueryCallDetector.class); + + MappingRecordOperationImpl query = new MappingRecordOperationImpl(connectionFactory, interactionSpec); + query.setOutputRecordCreator(outputCreator); + query.setCallDetector(callDetector); + + Object inObj = new Object(); + Object outObj = new Object(); + + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(callDetector.callCreateInputRecord(recordFactory, inObj)).willReturn(inputRecord); + given(connectionFactory.getConnection()).willReturn(connection); + given(connection.createInteraction()).willReturn(interaction); + given(connectionFactory.getRecordFactory()).willReturn(recordFactory); + given(outputCreator.createRecord(recordFactory)).willReturn(outputRecord); + given(interaction.execute(interactionSpec, inputRecord, outputRecord)).willReturn(true); + given(callDetector.callExtractOutputData(outputRecord)).willReturn(outObj); + + assertSame(outObj, query.execute(inObj)); + verify(interaction).close(); + verify(connection).close(); + } + + + private class MappingRecordOperationImpl extends MappingRecordOperation { + + private QueryCallDetector callDetector; + + public MappingRecordOperationImpl(ConnectionFactory connectionFactory, InteractionSpec interactionSpec) { + super(connectionFactory, interactionSpec); + } + + public void setCallDetector(QueryCallDetector callDetector) { + this.callDetector = callDetector; + } + + @Override + protected Record createInputRecord(RecordFactory recordFactory, Object inputObject) { + return this.callDetector.callCreateInputRecord(recordFactory, inputObject); + } + + @Override + protected Object extractOutputData(Record outputRecord) throws ResourceException { + return this.callDetector.callExtractOutputData(outputRecord); + } + } + + + private interface QueryCallDetector { + + Record callCreateInputRecord(RecordFactory recordFactory, Object inputObject); + + Object callExtractOutputData(Record outputRecord); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/jca/support/LocalConnectionFactoryBeanTests.java b/spring-tx/src/test/java/org/springframework/jca/support/LocalConnectionFactoryBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a4f05cbbc98d6197af031b73823718b80aa22d7f --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/jca/support/LocalConnectionFactoryBeanTests.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.jca.support; + +import javax.resource.spi.ConnectionManager; +import javax.resource.spi.ManagedConnectionFactory; + +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Unit tests for the {@link LocalConnectionFactoryBean} class. + * + * @author Rick Evans + * @author Chris Beams + */ +public class LocalConnectionFactoryBeanTests { + + @Test(expected = IllegalArgumentException.class) + public void testManagedConnectionFactoryIsRequired() throws Exception { + new LocalConnectionFactoryBean().afterPropertiesSet(); + } + + @Test + public void testIsSingleton() throws Exception { + LocalConnectionFactoryBean factory = new LocalConnectionFactoryBean(); + assertTrue(factory.isSingleton()); + } + + @Test + public void testGetObjectTypeIsNullIfConnectionFactoryHasNotBeenConfigured() throws Exception { + LocalConnectionFactoryBean factory = new LocalConnectionFactoryBean(); + assertNull(factory.getObjectType()); + } + + @Test + public void testCreatesVanillaConnectionFactoryIfNoConnectionManagerHasBeenConfigured() throws Exception { + final Object CONNECTION_FACTORY = new Object(); + ManagedConnectionFactory managedConnectionFactory = mock(ManagedConnectionFactory.class); + given(managedConnectionFactory.createConnectionFactory()).willReturn(CONNECTION_FACTORY); + LocalConnectionFactoryBean factory = new LocalConnectionFactoryBean(); + factory.setManagedConnectionFactory(managedConnectionFactory); + factory.afterPropertiesSet(); + assertEquals(CONNECTION_FACTORY, factory.getObject()); + } + + @Test + public void testCreatesManagedConnectionFactoryIfAConnectionManagerHasBeenConfigured() throws Exception { + ManagedConnectionFactory managedConnectionFactory = mock(ManagedConnectionFactory.class); + ConnectionManager connectionManager = mock(ConnectionManager.class); + LocalConnectionFactoryBean factory = new LocalConnectionFactoryBean(); + factory.setManagedConnectionFactory(managedConnectionFactory); + factory.setConnectionManager(connectionManager); + factory.afterPropertiesSet(); + verify(managedConnectionFactory).createConnectionFactory(connectionManager); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/tests/transaction/CallCountingTransactionManager.java b/spring-tx/src/test/java/org/springframework/tests/transaction/CallCountingTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..fd6acbad77f3045dff66da2914bc1830e63715db --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/tests/transaction/CallCountingTransactionManager.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.tests.transaction; + +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.support.AbstractPlatformTransactionManager; +import org.springframework.transaction.support.DefaultTransactionStatus; + +/** + * @author Rod Johnson + * @author Juergen Hoeller + */ +@SuppressWarnings("serial") +public class CallCountingTransactionManager extends AbstractPlatformTransactionManager { + + public TransactionDefinition lastDefinition; + public int begun; + public int commits; + public int rollbacks; + public int inflight; + + @Override + protected Object doGetTransaction() { + return new Object(); + } + + @Override + protected void doBegin(Object transaction, TransactionDefinition definition) { + this.lastDefinition = definition; + ++begun; + ++inflight; + } + + @Override + protected void doCommit(DefaultTransactionStatus status) { + ++commits; + --inflight; + } + + @Override + protected void doRollback(DefaultTransactionStatus status) { + ++rollbacks; + --inflight; + } + + public void clear() { + begun = commits = rollbacks = inflight = 0; + } + +} diff --git a/spring-tx/src/test/java/org/springframework/tests/transaction/MockJtaTransaction.java b/spring-tx/src/test/java/org/springframework/tests/transaction/MockJtaTransaction.java new file mode 100644 index 0000000000000000000000000000000000000000..8569ef31896ce4ae6213dde3690f31030e857b1b --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/tests/transaction/MockJtaTransaction.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.tests.transaction; + +import javax.transaction.Status; +import javax.transaction.Synchronization; +import javax.transaction.xa.XAResource; + +/** + * @author Juergen Hoeller + * @since 31.08.2004 + */ +public class MockJtaTransaction implements javax.transaction.Transaction { + + private Synchronization synchronization; + + @Override + public int getStatus() { + return Status.STATUS_ACTIVE; + } + + @Override + public void registerSynchronization(Synchronization synchronization) { + this.synchronization = synchronization; + } + + public Synchronization getSynchronization() { + return synchronization; + } + + @Override + public boolean enlistResource(XAResource xaResource) { + return false; + } + + @Override + public boolean delistResource(XAResource xaResource, int i) { + return false; + } + + @Override + public void commit() { + } + + @Override + public void rollback() { + } + + @Override + public void setRollbackOnly() { + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/JndiJtaTransactionManagerTests.java b/spring-tx/src/test/java/org/springframework/transaction/JndiJtaTransactionManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..25cfa725f2a927a292a2d2b4bbd8fa44aa2991e5 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/JndiJtaTransactionManagerTests.java @@ -0,0 +1,221 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import javax.transaction.Status; +import javax.transaction.TransactionManager; +import javax.transaction.UserTransaction; + +import org.junit.After; +import org.junit.Test; + +import org.springframework.tests.mock.jndi.ExpectedLookupTemplate; +import org.springframework.transaction.jta.JtaTransactionManager; +import org.springframework.transaction.jta.UserTransactionAdapter; +import org.springframework.transaction.support.TransactionCallbackWithoutResult; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.transaction.support.TransactionTemplate; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Juergen Hoeller + * @since 05.08.2005 + */ +public class JndiJtaTransactionManagerTests { + + @Test + public void jtaTransactionManagerWithDefaultJndiLookups1() throws Exception { + doTestJtaTransactionManagerWithDefaultJndiLookups("java:comp/TransactionManager", true, true); + } + + @Test + public void jtaTransactionManagerWithDefaultJndiLookups2() throws Exception { + doTestJtaTransactionManagerWithDefaultJndiLookups("java:/TransactionManager", true, true); + } + + @Test + public void jtaTransactionManagerWithDefaultJndiLookupsAndNoTmFound() throws Exception { + doTestJtaTransactionManagerWithDefaultJndiLookups("java:/tm", false, true); + } + + @Test + public void jtaTransactionManagerWithDefaultJndiLookupsAndNoUtFound() throws Exception { + doTestJtaTransactionManagerWithDefaultJndiLookups("java:/TransactionManager", true, false); + } + + private void doTestJtaTransactionManagerWithDefaultJndiLookups(String tmName, boolean tmFound, boolean defaultUt) + throws Exception { + + UserTransaction ut = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + if (defaultUt) { + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + } + else { + given(tm.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + } + + JtaTransactionManager ptm = new JtaTransactionManager(); + ExpectedLookupTemplate jndiTemplate = new ExpectedLookupTemplate(); + if (defaultUt) { + jndiTemplate.addObject("java:comp/UserTransaction", ut); + } + jndiTemplate.addObject(tmName, tm); + ptm.setJndiTemplate(jndiTemplate); + ptm.afterPropertiesSet(); + + if (tmFound) { + assertEquals(tm, ptm.getTransactionManager()); + } + else { + assertNull(ptm.getTransactionManager()); + } + + if (defaultUt) { + assertEquals(ut, ptm.getUserTransaction()); + } + else { + assertTrue(ptm.getUserTransaction() instanceof UserTransactionAdapter); + UserTransactionAdapter uta = (UserTransactionAdapter) ptm.getUserTransaction(); + assertEquals(tm, uta.getTransactionManager()); + } + + TransactionTemplate tt = new TransactionTemplate(ptm); + assertTrue(!TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + }); + assertTrue(!TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + + if (defaultUt) { + verify(ut).begin(); + verify(ut).commit(); + } + else { + verify(tm).begin(); + verify(tm).commit(); + } + + } + + @Test + public void jtaTransactionManagerWithCustomJndiLookups() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + TransactionManager tm = mock(TransactionManager.class); + + JtaTransactionManager ptm = new JtaTransactionManager(); + ptm.setUserTransactionName("jndi-ut"); + ptm.setTransactionManagerName("jndi-tm"); + ExpectedLookupTemplate jndiTemplate = new ExpectedLookupTemplate(); + jndiTemplate.addObject("jndi-ut", ut); + jndiTemplate.addObject("jndi-tm", tm); + ptm.setJndiTemplate(jndiTemplate); + ptm.afterPropertiesSet(); + + assertEquals(ut, ptm.getUserTransaction()); + assertEquals(tm, ptm.getTransactionManager()); + + TransactionTemplate tt = new TransactionTemplate(ptm); + assertTrue(!TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + }); + assertTrue(!TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + verify(ut).begin(); + verify(ut).commit(); + } + + @Test + public void jtaTransactionManagerWithNotCacheUserTransaction() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + UserTransaction ut2 = mock(UserTransaction.class); + given(ut2.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = new JtaTransactionManager(); + ptm.setJndiTemplate(new ExpectedLookupTemplate("java:comp/UserTransaction", ut)); + ptm.setCacheUserTransaction(false); + ptm.afterPropertiesSet(); + + assertEquals(ut, ptm.getUserTransaction()); + + TransactionTemplate tt = new TransactionTemplate(ptm); + assertEquals(JtaTransactionManager.SYNCHRONIZATION_ALWAYS, ptm.getTransactionSynchronization()); + assertTrue(!TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + }); + + ptm.setJndiTemplate(new ExpectedLookupTemplate("java:comp/UserTransaction", ut2)); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + }); + assertTrue(!TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + verify(ut).begin(); + verify(ut).commit(); + verify(ut2).begin(); + verify(ut2).commit(); + } + + /** + * Prevent any side-effects due to this test modifying ThreadLocals that might + * affect subsequent tests when all tests are run in the same JVM, as with Eclipse. + */ + @After + public void tearDown() { + assertTrue(TransactionSynchronizationManager.getResourceMap().isEmpty()); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertNull(TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/JtaTransactionManagerTests.java b/spring-tx/src/test/java/org/springframework/transaction/JtaTransactionManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c69f25a21ced71d4779bf55f7d3a1878ed2fbb2f --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/JtaTransactionManagerTests.java @@ -0,0 +1,1296 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import javax.transaction.HeuristicMixedException; +import javax.transaction.HeuristicRollbackException; +import javax.transaction.NotSupportedException; +import javax.transaction.RollbackException; +import javax.transaction.Status; +import javax.transaction.SystemException; +import javax.transaction.Transaction; +import javax.transaction.TransactionManager; +import javax.transaction.UserTransaction; + +import org.junit.After; +import org.junit.Test; + +import org.springframework.dao.OptimisticLockingFailureException; +import org.springframework.tests.transaction.MockJtaTransaction; +import org.springframework.transaction.jta.JtaTransactionManager; +import org.springframework.transaction.support.DefaultTransactionDefinition; +import org.springframework.transaction.support.TransactionCallbackWithoutResult; +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationAdapter; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.transaction.support.TransactionTemplate; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Juergen Hoeller + * @since 12.05.2003 + */ +public class JtaTransactionManagerTests { + + @Test + public void jtaTransactionManagerWithCommit() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setName("txName"); + + assertEquals(JtaTransactionManager.SYNCHRONIZATION_ALWAYS, ptm.getTransactionSynchronization()); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertNull(TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + assertEquals("txName", TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertNull(TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + verify(ut).begin(); + verify(ut).commit(); + verify(synch).beforeCommit(false); + verify(synch).beforeCompletion(); + verify(synch).afterCommit(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_COMMITTED); + } + + @Test + public void jtaTransactionManagerWithCommitAndSynchronizationOnActual() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronization(JtaTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).begin(); + verify(ut).commit(); + verify(synch).beforeCommit(false); + verify(synch).beforeCompletion(); + verify(synch).afterCommit(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_COMMITTED); + } + + @Test + public void jtaTransactionManagerWithCommitAndSynchronizationNever() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn( + Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronization(JtaTransactionManager.SYNCHRONIZATION_NEVER); + ptm.afterPropertiesSet(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).begin(); + verify(ut).commit(); + } + + @Test + public void jtaTransactionManagerWithRollback() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE); + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setTimeout(10); + tt.setName("txName"); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertNull(TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + assertEquals("txName", TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertNull(TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + verify(ut).setTransactionTimeout(10); + verify(ut).begin(); + verify(ut).rollback(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_ROLLED_BACK); + } + + @Test + public void jtaTransactionManagerWithRollbackAndSynchronizationOnActual() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE); + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronization(JtaTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + tt.setTimeout(10); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setTransactionTimeout(10); + verify(ut).begin(); + verify(ut).rollback(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_ROLLED_BACK); + } + + @Test + public void jtaTransactionManagerWithRollbackAndSynchronizationNever() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronizationName("SYNCHRONIZATION_NEVER"); + tt.setTimeout(10); + ptm.afterPropertiesSet(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setTransactionTimeout(10); + verify(ut).begin(); + verify(ut, atLeastOnce()).getStatus(); + verify(ut).rollback(); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndRollbackOnly() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setRollbackOnly(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndException() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + throw new IllegalStateException("I want a rollback"); + } + }); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + // expected + } + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setRollbackOnly(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndCommitException() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + willThrow(new OptimisticLockingFailureException("")).given(synch).beforeCommit(false); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + } + }); + fail("Should have thrown OptimisticLockingFailureException"); + } + catch (OptimisticLockingFailureException ex) { + // expected + } + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setRollbackOnly(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndRollbackOnlyAndNoGlobalRollback() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + ptm.setGlobalRollbackOnParticipationFailure(false); + TransactionTemplate tt = new TransactionTemplate(ptm); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setRollbackOnly(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndExceptionAndNoGlobalRollback() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + ptm.setGlobalRollbackOnParticipationFailure(false); + TransactionTemplate tt = new TransactionTemplate(ptm); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + throw new IllegalStateException("I want a rollback"); + } + }); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + // expected + } + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndJtaSynchronization() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + MockJtaTransaction tx = new MockJtaTransaction(); + + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + given(tm.getTransaction()).willReturn(tx); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut, tm); + TransactionTemplate tt = new TransactionTemplate(ptm); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertNotNull(tx.getSynchronization()); + tx.getSynchronization().beforeCompletion(); + tx.getSynchronization().afterCompletion(Status.STATUS_ROLLEDBACK); + + verify(ut).setRollbackOnly(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_ROLLED_BACK); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndSynchronizationOnActual() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronization(JtaTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setRollbackOnly(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + + @Test + public void jtaTransactionManagerWithExistingTransactionAndSynchronizationNever() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronization(JtaTransactionManager.SYNCHRONIZATION_NEVER); + ptm.afterPropertiesSet(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setRollbackOnly(); + } + + @Test + public void jtaTransactionManagerWithExistingAndPropagationSupports() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).setRollbackOnly(); + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_UNKNOWN); + } + + @Test + public void jtaTransactionManagerWithPropagationSupports() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION); + + final TransactionSynchronization synch = mock(TransactionSynchronization.class); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionSynchronizationManager.registerSynchronization(synch); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(synch).beforeCompletion(); + verify(synch).afterCompletion(TransactionSynchronization.STATUS_ROLLED_BACK); + } + + @Test + public void jtaTransactionManagerWithPropagationSupportsAndSynchronizationOnActual() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronization(JtaTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + ptm.afterPropertiesSet(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + } + + @Test + public void jtaTransactionManagerWithPropagationSupportsAndSynchronizationNever() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + ptm.setTransactionSynchronization(JtaTransactionManager.SYNCHRONIZATION_NEVER); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + ptm.afterPropertiesSet(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + } + + @Test + public void jtaTransactionManagerWithPropagationNotSupported() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + Transaction tx = mock(Transaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + given(tm.suspend()).willReturn(tx); + + JtaTransactionManager ptm = newJtaTransactionManager(ut, tm); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_NOT_SUPPORTED); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + status.setRollbackOnly(); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(tm).resume(tx); + } + + @Test + public void jtaTransactionManagerWithPropagationRequiresNew() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + Transaction tx = mock(Transaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_ACTIVE, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE, + Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + given(tm.suspend()).willReturn(tx); + + final JtaTransactionManager ptm = newJtaTransactionManager(ut, tm); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + tt.setName("txName"); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertEquals("txName", TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + TransactionTemplate tt2 = new TransactionTemplate(ptm); + tt2.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + tt2.setReadOnly(true); + tt2.setName("txName2"); + tt2.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertEquals("txName2", TransactionSynchronizationManager.getCurrentTransactionName()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + }); + + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertEquals("txName", TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut, times(2)).begin(); + verify(ut, times(2)).commit(); + verify(tm).resume(tx); + } + + @Test + public void jtaTransactionManagerWithPropagationRequiresNewWithinSupports() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + final JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + + TransactionTemplate tt2 = new TransactionTemplate(ptm); + tt2.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + tt2.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + } + }); + + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).begin(); + verify(ut).commit(); + } + + @Test + public void jtaTransactionManagerWithPropagationRequiresNewAndExisting() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + Transaction tx = mock(Transaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + given(tm.suspend()).willReturn(tx); + + JtaTransactionManager ptm = newJtaTransactionManager(ut, tm); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(ut).begin(); + verify(ut).commit(); + verify(tm).resume(tx); + } + + @Test + public void jtaTransactionManagerWithPropagationRequiresNewAndExistingWithSuspendException() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + willThrow(new SystemException()).given(tm).suspend(); + + JtaTransactionManager ptm = newJtaTransactionManager(ut, tm); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + } + }); + fail("Should have thrown TransactionSystemException"); + } + catch (TransactionSystemException ex) { + // expected + } + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + } + + @Test + public void jtaTransactionManagerWithPropagationRequiresNewAndExistingWithBeginException() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + Transaction tx = mock(Transaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + given(tm.suspend()).willReturn(tx); + willThrow(new SystemException()).given(ut).begin(); + + JtaTransactionManager ptm = newJtaTransactionManager(ut, tm); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + } + }); + fail("Should have thrown CannotCreateTransactionException"); + } + catch (CannotCreateTransactionException ex) { + // expected + } + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + verify(tm).resume(tx); + } + + @Test + public void jtaTransactionManagerWithPropagationRequiresNewAndAdapter() throws Exception { + TransactionManager tm = mock(TransactionManager.class); + Transaction tx = mock(Transaction.class); + given(tm.getStatus()).willReturn(Status.STATUS_ACTIVE); + given(tm.suspend()).willReturn(tx); + + JtaTransactionManager ptm = newJtaTransactionManager(tm); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + } + }); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + + verify(tm).begin(); + verify(tm).commit(); + verify(tm).resume(tx); + } + + @Test + public void jtaTransactionManagerWithPropagationRequiresNewAndSuspensionNotSupported() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + } + }); + fail("Should have thrown TransactionSuspensionNotSupportedException"); + } + catch (TransactionSuspensionNotSupportedException ex) { + // expected + } + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + } + + @Test + public void jtaTransactionManagerWithIsolationLevel() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setIsolationLevel(TransactionDefinition.ISOLATION_SERIALIZABLE); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + } + }); + fail("Should have thrown InvalidIsolationLevelException"); + } + catch (InvalidIsolationLevelException ex) { + // expected + } + } + + @Test + public void jtaTransactionManagerWithSystemExceptionOnIsExisting() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willThrow(new SystemException("system exception")); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + } + }); + fail("Should have thrown TransactionSystemException"); + } + catch (TransactionSystemException ex) { + // expected + } + } + + @Test + public void jtaTransactionManagerWithNestedBegin() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + } + }); + + verify(ut).begin(); + verify(ut).commit(); + } + + @Test + public void jtaTransactionManagerWithNotSupportedExceptionOnNestedBegin() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + willThrow(new NotSupportedException("not supported")).given(ut).begin(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + } + }); + fail("Should have thrown NestedTransactionNotSupportedException"); + } + catch (NestedTransactionNotSupportedException ex) { + // expected + } + } + + @Test + public void jtaTransactionManagerWithUnsupportedOperationExceptionOnNestedBegin() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + willThrow(new UnsupportedOperationException("not supported")).given(ut).begin(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + } + }); + fail("Should have thrown NestedTransactionNotSupportedException"); + } + catch (NestedTransactionNotSupportedException ex) { + // expected + } + } + + @Test + public void jtaTransactionManagerWithSystemExceptionOnBegin() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION); + willThrow(new SystemException("system exception")).given(ut).begin(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + } + }); + fail("Should have thrown CannotCreateTransactionException"); + } + catch (CannotCreateTransactionException ex) { + // expected + } + } + + @Test + public void jtaTransactionManagerWithRollbackExceptionOnCommit() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + willThrow(new RollbackException("unexpected rollback")).given(ut).commit(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { + @Override + public void afterCompletion(int status) { + assertTrue("Correct completion status", status == TransactionSynchronization.STATUS_ROLLED_BACK); + } + }); + } + }); + fail("Should have thrown UnexpectedRollbackException"); + } + catch (UnexpectedRollbackException ex) { + // expected + } + + verify(ut).begin(); + } + + @Test + public void jtaTransactionManagerWithNoExceptionOnGlobalRollbackOnly() throws Exception { + doTestJtaTransactionManagerWithNoExceptionOnGlobalRollbackOnly(false); + } + + @Test + public void jtaTransactionManagerWithNoExceptionOnGlobalRollbackOnlyAndFailEarly() throws Exception { + doTestJtaTransactionManagerWithNoExceptionOnGlobalRollbackOnly(true); + } + + private void doTestJtaTransactionManagerWithNoExceptionOnGlobalRollbackOnly(boolean failEarly) throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_MARKED_ROLLBACK, Status.STATUS_MARKED_ROLLBACK, + Status.STATUS_MARKED_ROLLBACK); + + JtaTransactionManager tm = newJtaTransactionManager(ut); + if (failEarly) { + tm.setFailEarlyOnGlobalRollbackOnly(true); + } + + TransactionStatus ts = tm.getTransaction(new DefaultTransactionDefinition()); + boolean outerTransactionBoundaryReached = false; + try { + assertTrue("Is new transaction", ts.isNewTransaction()); + + TransactionTemplate tt = new TransactionTemplate(tm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { + @Override + public void afterCompletion(int status) { + assertTrue("Correct completion status", status == TransactionSynchronization.STATUS_ROLLED_BACK); + } + }); + } + }); + + outerTransactionBoundaryReached = true; + tm.commit(ts); + + fail("Should have thrown UnexpectedRollbackException"); + } + catch (UnexpectedRollbackException ex) { + // expected + if (!outerTransactionBoundaryReached) { + tm.rollback(ts); + } + if (failEarly) { + assertFalse(outerTransactionBoundaryReached); + } + else { + assertTrue(outerTransactionBoundaryReached); + } + } + + verify(ut).begin(); + if (failEarly) { + verify(ut).rollback(); + } + else { + verify(ut).commit(); + } + } + + @Test + public void jtaTransactionManagerWithHeuristicMixedExceptionOnCommit() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + willThrow(new HeuristicMixedException("heuristic exception")).given(ut).commit(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { + @Override + public void afterCompletion(int status) { + assertTrue("Correct completion status", status == TransactionSynchronization.STATUS_UNKNOWN); + } + }); + } + }); + fail("Should have thrown HeuristicCompletionException"); + } + catch (HeuristicCompletionException ex) { + // expected + assertTrue(ex.getOutcomeState() == HeuristicCompletionException.STATE_MIXED); + } + + verify(ut).begin(); + } + + @Test + public void jtaTransactionManagerWithHeuristicRollbackExceptionOnCommit() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + willThrow(new HeuristicRollbackException("heuristic exception")).given(ut).commit(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { + @Override + public void afterCompletion(int status) { + assertTrue("Correct completion status", status == TransactionSynchronization.STATUS_UNKNOWN); + } + }); + } + }); + fail("Should have thrown HeuristicCompletionException"); + } + catch (HeuristicCompletionException ex) { + // expected + assertTrue(ex.getOutcomeState() == HeuristicCompletionException.STATE_ROLLED_BACK); + } + + verify(ut).begin(); + } + + @Test + public void jtaTransactionManagerWithSystemExceptionOnCommit() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + willThrow(new SystemException("system exception")).given(ut).commit(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + // something transactional + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { + @Override + public void afterCompletion(int status) { + assertTrue("Correct completion status", status == TransactionSynchronization.STATUS_UNKNOWN); + } + }); + } + }); + fail("Should have thrown TransactionSystemException"); + } + catch (TransactionSystemException ex) { + // expected + } + + verify(ut).begin(); + } + + @Test + public void jtaTransactionManagerWithSystemExceptionOnRollback() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE); + willThrow(new SystemException("system exception")).given(ut).rollback(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { + @Override + public void afterCompletion(int status) { + assertTrue("Correct completion status", status == TransactionSynchronization.STATUS_UNKNOWN); + } + }); + status.setRollbackOnly(); + } + }); + fail("Should have thrown TransactionSystemException"); + } + catch (TransactionSystemException ex) { + // expected + } + + verify(ut).begin(); + } + + @Test + public void jtaTransactionManagerWithIllegalStateExceptionOnRollbackOnly() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + willThrow(new IllegalStateException("no existing transaction")).given(ut).setRollbackOnly(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + status.setRollbackOnly(); + } + }); + fail("Should have thrown TransactionSystemException"); + } + catch (TransactionSystemException ex) { + // expected + } + } + + @Test + public void jtaTransactionManagerWithSystemExceptionOnRollbackOnly() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_ACTIVE); + willThrow(new SystemException("system exception")).given(ut).setRollbackOnly(); + + try { + JtaTransactionManager ptm = newJtaTransactionManager(ut); + TransactionTemplate tt = new TransactionTemplate(ptm); + tt.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + status.setRollbackOnly(); + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { + @Override + public void afterCompletion(int status) { + assertTrue("Correct completion status", status == TransactionSynchronization.STATUS_UNKNOWN); + } + }); + } + }); + fail("Should have thrown TransactionSystemException"); + } + catch (TransactionSystemException ex) { + // expected + } + } + + @Test + public void jtaTransactionManagerWithDoubleCommit() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, + Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionStatus status = ptm.getTransaction(new DefaultTransactionDefinition()); + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + // first commit + ptm.commit(status); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + // second commit attempt + ptm.commit(status); + fail("Should have thrown IllegalTransactionStateException"); + } + catch (IllegalTransactionStateException ex) { + // expected + } + + verify(ut).begin(); + verify(ut).commit(); + } + + @Test + public void jtaTransactionManagerWithDoubleRollback() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionStatus status = ptm.getTransaction(new DefaultTransactionDefinition()); + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + // first rollback + ptm.rollback(status); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + // second rollback attempt + ptm.rollback(status); + fail("Should have thrown IllegalTransactionStateException"); + } + catch (IllegalTransactionStateException ex) { + // expected + } + + verify(ut).begin(); + verify(ut).rollback(); + } + + @Test + public void jtaTransactionManagerWithRollbackAndCommit() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn(Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE); + + JtaTransactionManager ptm = newJtaTransactionManager(ut); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + TransactionStatus status = ptm.getTransaction(new DefaultTransactionDefinition()); + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + // first: rollback + ptm.rollback(status); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + try { + // second: commit attempt + ptm.commit(status); + fail("Should have thrown IllegalTransactionStateException"); + } + catch (IllegalTransactionStateException ex) { + // expected + } + + verify(ut).begin(); + verify(ut).rollback(); + } + + + protected JtaTransactionManager newJtaTransactionManager(UserTransaction ut) { + return new JtaTransactionManager(ut); + } + + protected JtaTransactionManager newJtaTransactionManager(TransactionManager tm) { + return new JtaTransactionManager(tm); + } + + protected JtaTransactionManager newJtaTransactionManager(UserTransaction ut, TransactionManager tm) { + return new JtaTransactionManager(ut, tm); + } + + + /** + * Prevent any side-effects due to this test modifying ThreadLocals that might + * affect subsequent tests when all tests are run in the same JVM, as with Eclipse. + */ + @After + public void tearDown() { + assertTrue(TransactionSynchronizationManager.getResourceMap().isEmpty()); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertNull(TransactionSynchronizationManager.getCurrentTransactionName()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertNull(TransactionSynchronizationManager.getCurrentTransactionIsolationLevel()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/MockCallbackPreferringTransactionManager.java b/spring-tx/src/test/java/org/springframework/transaction/MockCallbackPreferringTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..6ed75056954cf7c42a557ee680d2dbca03979a5c --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/MockCallbackPreferringTransactionManager.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.support.CallbackPreferringPlatformTransactionManager; +import org.springframework.transaction.support.SimpleTransactionStatus; +import org.springframework.transaction.support.TransactionCallback; + +/** + * @author Juergen Hoeller + */ +public class MockCallbackPreferringTransactionManager implements CallbackPreferringPlatformTransactionManager { + + private TransactionDefinition definition; + + private TransactionStatus status; + + + @Override + public T execute(TransactionDefinition definition, TransactionCallback callback) throws TransactionException { + this.definition = definition; + this.status = new SimpleTransactionStatus(); + return callback.doInTransaction(this.status); + } + + public TransactionDefinition getDefinition() { + return this.definition; + } + + public TransactionStatus getStatus() { + return this.status; + } + + + @Override + public TransactionStatus getTransaction(@Nullable TransactionDefinition definition) throws TransactionException { + throw new UnsupportedOperationException(); + } + + @Override + public void commit(TransactionStatus status) throws TransactionException { + throw new UnsupportedOperationException(); + } + + @Override + public void rollback(TransactionStatus status) throws TransactionException { + throw new UnsupportedOperationException(); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/TestTransactionManager.java b/spring-tx/src/test/java/org/springframework/transaction/TestTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..3087387681d7617e55a5a308b6c47f04be316b2d --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/TestTransactionManager.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import org.springframework.transaction.support.AbstractPlatformTransactionManager; +import org.springframework.transaction.support.DefaultTransactionStatus; + +/** + * @author Juergen Hoeller + * @since 29.04.2003 + */ +@SuppressWarnings("serial") +class TestTransactionManager extends AbstractPlatformTransactionManager { + + private static final Object TRANSACTION = "transaction"; + + private final boolean existingTransaction; + + private final boolean canCreateTransaction; + + protected boolean begin = false; + + protected boolean commit = false; + + protected boolean rollback = false; + + protected boolean rollbackOnly = false; + + protected TestTransactionManager(boolean existingTransaction, boolean canCreateTransaction) { + this.existingTransaction = existingTransaction; + this.canCreateTransaction = canCreateTransaction; + setTransactionSynchronization(SYNCHRONIZATION_NEVER); + } + + @Override + protected Object doGetTransaction() { + return TRANSACTION; + } + + @Override + protected boolean isExistingTransaction(Object transaction) { + return existingTransaction; + } + + @Override + protected void doBegin(Object transaction, TransactionDefinition definition) { + if (!TRANSACTION.equals(transaction)) { + throw new IllegalArgumentException("Not the same transaction object"); + } + if (!this.canCreateTransaction) { + throw new CannotCreateTransactionException("Cannot create transaction"); + } + this.begin = true; + } + + @Override + protected void doCommit(DefaultTransactionStatus status) { + if (!TRANSACTION.equals(status.getTransaction())) { + throw new IllegalArgumentException("Not the same transaction object"); + } + this.commit = true; + } + + @Override + protected void doRollback(DefaultTransactionStatus status) { + if (!TRANSACTION.equals(status.getTransaction())) { + throw new IllegalArgumentException("Not the same transaction object"); + } + this.rollback = true; + } + + @Override + protected void doSetRollbackOnly(DefaultTransactionStatus status) { + if (!TRANSACTION.equals(status.getTransaction())) { + throw new IllegalArgumentException("Not the same transaction object"); + } + this.rollbackOnly = true; + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/TransactionSupportTests.java b/spring-tx/src/test/java/org/springframework/transaction/TransactionSupportTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b8c08537d98358882c512d7a4d0f2511238dbb14 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/TransactionSupportTests.java @@ -0,0 +1,331 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import org.junit.After; +import org.junit.Test; + +import org.springframework.transaction.support.DefaultTransactionDefinition; +import org.springframework.transaction.support.DefaultTransactionStatus; +import org.springframework.transaction.support.TransactionCallbackWithoutResult; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.transaction.support.TransactionTemplate; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @since 29.04.2003 + */ +public class TransactionSupportTests { + + @Test + public void noExistingTransaction() { + PlatformTransactionManager tm = new TestTransactionManager(false, true); + DefaultTransactionStatus status1 = (DefaultTransactionStatus) + tm.getTransaction(new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_SUPPORTS)); + assertFalse("Must not have transaction", status1.hasTransaction()); + + DefaultTransactionStatus status2 = (DefaultTransactionStatus) + tm.getTransaction(new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_REQUIRED)); + assertTrue("Must have transaction", status2.hasTransaction()); + assertTrue("Must be new transaction", status2.isNewTransaction()); + + try { + tm.getTransaction(new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_MANDATORY)); + fail("Should not have thrown NoTransactionException"); + } + catch (IllegalTransactionStateException ex) { + // expected + } + } + + @Test + public void existingTransaction() { + PlatformTransactionManager tm = new TestTransactionManager(true, true); + DefaultTransactionStatus status1 = (DefaultTransactionStatus) + tm.getTransaction(new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_SUPPORTS)); + assertTrue("Must have transaction", status1.getTransaction() != null); + assertTrue("Must not be new transaction", !status1.isNewTransaction()); + + DefaultTransactionStatus status2 = (DefaultTransactionStatus) + tm.getTransaction(new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_REQUIRED)); + assertTrue("Must have transaction", status2.getTransaction() != null); + assertTrue("Must not be new transaction", !status2.isNewTransaction()); + + try { + DefaultTransactionStatus status3 = (DefaultTransactionStatus) + tm.getTransaction(new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_MANDATORY)); + assertTrue("Must have transaction", status3.getTransaction() != null); + assertTrue("Must not be new transaction", !status3.isNewTransaction()); + } + catch (NoTransactionException ex) { + fail("Should not have thrown NoTransactionException"); + } + } + + @Test + public void commitWithoutExistingTransaction() { + TestTransactionManager tm = new TestTransactionManager(false, true); + TransactionStatus status = tm.getTransaction(null); + tm.commit(status); + + assertTrue("triggered begin", tm.begin); + assertTrue("triggered commit", tm.commit); + assertTrue("no rollback", !tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + + @Test + public void rollbackWithoutExistingTransaction() { + TestTransactionManager tm = new TestTransactionManager(false, true); + TransactionStatus status = tm.getTransaction(null); + tm.rollback(status); + + assertTrue("triggered begin", tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("triggered rollback", tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + + @Test + public void rollbackOnlyWithoutExistingTransaction() { + TestTransactionManager tm = new TestTransactionManager(false, true); + TransactionStatus status = tm.getTransaction(null); + status.setRollbackOnly(); + tm.commit(status); + + assertTrue("triggered begin", tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("triggered rollback", tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + + @Test + public void commitWithExistingTransaction() { + TestTransactionManager tm = new TestTransactionManager(true, true); + TransactionStatus status = tm.getTransaction(null); + tm.commit(status); + + assertTrue("no begin", !tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("no rollback", !tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + + @Test + public void rollbackWithExistingTransaction() { + TestTransactionManager tm = new TestTransactionManager(true, true); + TransactionStatus status = tm.getTransaction(null); + tm.rollback(status); + + assertTrue("no begin", !tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("no rollback", !tm.rollback); + assertTrue("triggered rollbackOnly", tm.rollbackOnly); + } + + @Test + public void rollbackOnlyWithExistingTransaction() { + TestTransactionManager tm = new TestTransactionManager(true, true); + TransactionStatus status = tm.getTransaction(null); + status.setRollbackOnly(); + tm.commit(status); + + assertTrue("no begin", !tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("no rollback", !tm.rollback); + assertTrue("triggered rollbackOnly", tm.rollbackOnly); + } + + @Test + public void transactionTemplate() { + TestTransactionManager tm = new TestTransactionManager(false, true); + TransactionTemplate template = new TransactionTemplate(tm); + template.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + } + }); + + assertTrue("triggered begin", tm.begin); + assertTrue("triggered commit", tm.commit); + assertTrue("no rollback", !tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + + @Test + public void transactionTemplateWithCallbackPreference() { + MockCallbackPreferringTransactionManager ptm = new MockCallbackPreferringTransactionManager(); + TransactionTemplate template = new TransactionTemplate(ptm); + template.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + } + }); + + assertSame(template, ptm.getDefinition()); + assertFalse(ptm.getStatus().isRollbackOnly()); + } + + @Test + public void transactionTemplateWithException() { + TestTransactionManager tm = new TestTransactionManager(false, true); + TransactionTemplate template = new TransactionTemplate(tm); + final RuntimeException ex = new RuntimeException("Some application exception"); + try { + template.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + throw ex; + } + }); + fail("Should have propagated RuntimeException"); + } + catch (RuntimeException caught) { + // expected + assertTrue("Correct exception", caught == ex); + assertTrue("triggered begin", tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("triggered rollback", tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + } + + @SuppressWarnings("serial") + @Test + public void transactionTemplateWithRollbackException() { + final TransactionSystemException tex = new TransactionSystemException("system exception"); + TestTransactionManager tm = new TestTransactionManager(false, true) { + @Override + protected void doRollback(DefaultTransactionStatus status) { + super.doRollback(status); + throw tex; + } + }; + TransactionTemplate template = new TransactionTemplate(tm); + final RuntimeException ex = new RuntimeException("Some application exception"); + try { + template.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + throw ex; + } + }); + fail("Should have propagated RuntimeException"); + } + catch (RuntimeException caught) { + // expected + assertTrue("Correct exception", caught == tex); + assertTrue("triggered begin", tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("triggered rollback", tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + } + + @Test + public void transactionTemplateWithError() { + TestTransactionManager tm = new TestTransactionManager(false, true); + TransactionTemplate template = new TransactionTemplate(tm); + try { + template.execute(new TransactionCallbackWithoutResult() { + @Override + protected void doInTransactionWithoutResult(TransactionStatus status) { + throw new Error("Some application error"); + } + }); + fail("Should have propagated Error"); + } + catch (Error err) { + // expected + assertTrue("triggered begin", tm.begin); + assertTrue("no commit", !tm.commit); + assertTrue("triggered rollback", tm.rollback); + assertTrue("no rollbackOnly", !tm.rollbackOnly); + } + } + + @Test + public void transactionTemplateInitialization() { + TestTransactionManager tm = new TestTransactionManager(false, true); + TransactionTemplate template = new TransactionTemplate(); + template.setTransactionManager(tm); + assertTrue("correct transaction manager set", template.getTransactionManager() == tm); + + try { + template.setPropagationBehaviorName("TIMEOUT_DEFAULT"); + fail("Should have thrown IllegalArgumentException"); + } + catch (IllegalArgumentException ex) { + // expected + } + template.setPropagationBehaviorName("PROPAGATION_SUPPORTS"); + assertTrue("Correct propagation behavior set", template.getPropagationBehavior() == TransactionDefinition.PROPAGATION_SUPPORTS); + + try { + template.setPropagationBehavior(999); + fail("Should have thrown IllegalArgumentException"); + } + catch (IllegalArgumentException ex) { + // expected + } + template.setPropagationBehavior(TransactionDefinition.PROPAGATION_MANDATORY); + assertTrue("Correct propagation behavior set", template.getPropagationBehavior() == TransactionDefinition.PROPAGATION_MANDATORY); + + try { + template.setIsolationLevelName("TIMEOUT_DEFAULT"); + fail("Should have thrown IllegalArgumentException"); + } + catch (IllegalArgumentException ex) { + // expected + } + template.setIsolationLevelName("ISOLATION_SERIALIZABLE"); + assertTrue("Correct isolation level set", template.getIsolationLevel() == TransactionDefinition.ISOLATION_SERIALIZABLE); + + try { + template.setIsolationLevel(999); + fail("Should have thrown IllegalArgumentException"); + } + catch (IllegalArgumentException ex) { + // expected + } + template.setIsolationLevel(TransactionDefinition.ISOLATION_REPEATABLE_READ); + assertTrue("Correct isolation level set", template.getIsolationLevel() == TransactionDefinition.ISOLATION_REPEATABLE_READ); + } + + @Test + public void transactionTemplateEquality() { + TestTransactionManager tm1 = new TestTransactionManager(false, true); + TestTransactionManager tm2 = new TestTransactionManager(false, true); + TransactionTemplate template1 = new TransactionTemplate(tm1); + TransactionTemplate template2 = new TransactionTemplate(tm2); + TransactionTemplate template3 = new TransactionTemplate(tm2); + + assertNotEquals(template1, template2); + assertNotEquals(template1, template3); + assertEquals(template2, template3); + } + + + @After + public void clear() { + assertTrue(TransactionSynchronizationManager.getResourceMap().isEmpty()); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/TxNamespaceHandlerEventTests.java b/spring-tx/src/test/java/org/springframework/transaction/TxNamespaceHandlerEventTests.java new file mode 100644 index 0000000000000000000000000000000000000000..40dd4d57854257e450e343832e9ce39f151c15fc --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/TxNamespaceHandlerEventTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.parsing.BeanComponentDefinition; +import org.springframework.beans.factory.parsing.ComponentDefinition; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.core.io.ClassPathResource; +import org.springframework.tests.beans.CollectingReaderEventListener; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * @author Torsten Juergeleit + * @author Juergen Hoeller + */ +public class TxNamespaceHandlerEventTests { + + private DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + private CollectingReaderEventListener eventListener = new CollectingReaderEventListener(); + + + @Before + public void setUp() throws Exception { + XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.beanFactory); + reader.setEventListener(this.eventListener); + reader.loadBeanDefinitions(new ClassPathResource("txNamespaceHandlerTests.xml", getClass())); + } + + @Test + public void componentEventReceived() { + ComponentDefinition component = this.eventListener.getComponentDefinition("txAdvice"); + assertThat(component, instanceOf(BeanComponentDefinition.class)); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/TxNamespaceHandlerTests.java b/spring-tx/src/test/java/org/springframework/transaction/TxNamespaceHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f0c0a0b37042db5e9807440b982e27c7edc632c4 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/TxNamespaceHandlerTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction; + +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.aop.support.AopUtils; +import org.springframework.context.ApplicationContext; +import org.springframework.context.support.ClassPathXmlApplicationContext; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.interceptor.TransactionAttribute; +import org.springframework.transaction.interceptor.TransactionAttributeSource; +import org.springframework.transaction.interceptor.TransactionInterceptor; + +import static org.junit.Assert.*; + +/** + * @author Rob Harrop + * @author Adrian Colyer + */ +public class TxNamespaceHandlerTests { + + private ApplicationContext context; + + private Method getAgeMethod; + + private Method setAgeMethod; + + + @Before + public void setup() throws Exception { + this.context = new ClassPathXmlApplicationContext("txNamespaceHandlerTests.xml", getClass()); + this.getAgeMethod = ITestBean.class.getMethod("getAge"); + this.setAgeMethod = ITestBean.class.getMethod("setAge", int.class); + } + + + @Test + public void isProxy() { + ITestBean bean = getTestBean(); + assertTrue("testBean is not a proxy", AopUtils.isAopProxy(bean)); + } + + @Test + public void invokeTransactional() { + ITestBean testBean = getTestBean(); + CallCountingTransactionManager ptm = (CallCountingTransactionManager) context.getBean("transactionManager"); + + // try with transactional + assertEquals("Should not have any started transactions", 0, ptm.begun); + testBean.getName(); + assertTrue(ptm.lastDefinition.isReadOnly()); + assertEquals("Should have 1 started transaction", 1, ptm.begun); + assertEquals("Should have 1 committed transaction", 1, ptm.commits); + + // try with non-transaction + testBean.haveBirthday(); + assertEquals("Should not have started another transaction", 1, ptm.begun); + + // try with exceptional + try { + testBean.exceptional(new IllegalArgumentException("foo")); + fail("Should NEVER get here"); + } + catch (Throwable throwable) { + assertEquals("Should have another started transaction", 2, ptm.begun); + assertEquals("Should have 1 rolled back transaction", 1, ptm.rollbacks); + } + } + + @Test + public void rollbackRules() { + TransactionInterceptor txInterceptor = (TransactionInterceptor) context.getBean("txRollbackAdvice"); + TransactionAttributeSource txAttrSource = txInterceptor.getTransactionAttributeSource(); + TransactionAttribute txAttr = txAttrSource.getTransactionAttribute(getAgeMethod,ITestBean.class); + assertTrue("should be configured to rollback on Exception",txAttr.rollbackOn(new Exception())); + + txAttr = txAttrSource.getTransactionAttribute(setAgeMethod, ITestBean.class); + assertFalse("should not rollback on RuntimeException",txAttr.rollbackOn(new RuntimeException())); + } + + private ITestBean getTestBean() { + return (ITestBean) context.getBean("testBean"); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionAttributeSourceTests.java b/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionAttributeSourceTests.java new file mode 100644 index 0000000000000000000000000000000000000000..541965649f85d5738f53d6ec33124e999c52f497 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionAttributeSourceTests.java @@ -0,0 +1,953 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.io.IOException; +import java.io.Serializable; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.reflect.Method; + +import javax.ejb.TransactionAttributeType; + +import groovy.lang.GroovyObject; +import groovy.lang.MetaClass; +import org.junit.Test; + +import org.springframework.aop.framework.Advised; +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.interceptor.NoRollbackRuleAttribute; +import org.springframework.transaction.interceptor.RollbackRuleAttribute; +import org.springframework.transaction.interceptor.RuleBasedTransactionAttribute; +import org.springframework.transaction.interceptor.TransactionAttribute; +import org.springframework.transaction.interceptor.TransactionInterceptor; +import org.springframework.util.SerializationTestUtils; + +import static org.junit.Assert.*; + +/** + * @author Colin Sampaleanu + * @author Juergen Hoeller + * @author Sam Brannen + */ +public class AnnotationTransactionAttributeSourceTests { + + @Test + public void serializable() throws Exception { + TestBean1 tb = new TestBean1(); + CallCountingTransactionManager ptm = new CallCountingTransactionManager(); + AnnotationTransactionAttributeSource tas = new AnnotationTransactionAttributeSource(); + TransactionInterceptor ti = new TransactionInterceptor(ptm, tas); + + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setInterfaces(ITestBean1.class); + proxyFactory.addAdvice(ti); + proxyFactory.setTarget(tb); + ITestBean1 proxy = (ITestBean1) proxyFactory.getProxy(); + proxy.getAge(); + assertEquals(1, ptm.commits); + + ITestBean1 serializedProxy = (ITestBean1) SerializationTestUtils.serializeAndDeserialize(proxy); + serializedProxy.getAge(); + Advised advised = (Advised) serializedProxy; + TransactionInterceptor serializedTi = (TransactionInterceptor) advised.getAdvisors()[0].getAdvice(); + CallCountingTransactionManager serializedPtm = + (CallCountingTransactionManager) serializedTi.getTransactionManager(); + assertEquals(2, serializedPtm.commits); + } + + @Test + public void nullOrEmpty() throws Exception { + Method method = Empty.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + assertNull(atas.getTransactionAttribute(method, null)); + + // Try again in case of caching + assertNull(atas.getTransactionAttribute(method, null)); + } + + /** + * Test the important case where the invocation is on a proxied interface method + * but the attribute is defined on the target class. + */ + @Test + public void transactionAttributeDeclaredOnClassMethod() throws Exception { + Method classMethod = ITestBean1.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(classMethod, TestBean1.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + } + + /** + * Test the important case where the invocation is on a proxied interface method + * but the attribute is defined on the target class. + */ + @Test + public void transactionAttributeDeclaredOnCglibClassMethod() throws Exception { + Method classMethod = ITestBean1.class.getMethod("getAge"); + TestBean1 tb = new TestBean1(); + ProxyFactory pf = new ProxyFactory(tb); + pf.setProxyTargetClass(true); + Object proxy = pf.getProxy(); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(classMethod, proxy.getClass()); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + } + + /** + * Test case where attribute is on the interface method. + */ + @Test + public void transactionAttributeDeclaredOnInterfaceMethodOnly() throws Exception { + Method interfaceMethod = ITestBean2.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(interfaceMethod, TestBean2.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + } + + /** + * Test that when an attribute exists on both class and interface, class takes precedence. + */ + @Test + public void transactionAttributeOnTargetClassMethodOverridesAttributeOnInterfaceMethod() throws Exception { + Method interfaceMethod = ITestBean3.class.getMethod("getAge"); + Method interfaceMethod2 = ITestBean3.class.getMethod("getName"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(interfaceMethod, TestBean3.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRES_NEW, actual.getPropagationBehavior()); + assertEquals(TransactionAttribute.ISOLATION_REPEATABLE_READ, actual.getIsolationLevel()); + assertEquals(5, actual.getTimeout()); + assertTrue(actual.isReadOnly()); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + + TransactionAttribute actual2 = atas.getTransactionAttribute(interfaceMethod2, TestBean3.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, actual2.getPropagationBehavior()); + } + + @Test + public void rollbackRulesAreApplied() throws Exception { + Method method = TestBean3.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean3.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute("java.lang.Exception")); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + assertTrue(actual.rollbackOn(new Exception())); + assertFalse(actual.rollbackOn(new IOException())); + + actual = atas.getTransactionAttribute(method, method.getDeclaringClass()); + + rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute("java.lang.Exception")); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + assertTrue(actual.rollbackOn(new Exception())); + assertFalse(actual.rollbackOn(new IOException())); + } + + /** + * Test that transaction attribute is inherited from class + * if not specified on method. + */ + @Test + public void defaultsToClassTransactionAttribute() throws Exception { + Method method = TestBean4.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean4.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + } + + @Test + public void customClassAttributeDetected() throws Exception { + Method method = TestBean5.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean5.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + } + + @Test + public void customMethodAttributeDetected() throws Exception { + Method method = TestBean6.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean6.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + } + + @Test + public void customClassAttributeWithReadOnlyOverrideDetected() throws Exception { + Method method = TestBean7.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean7.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + + assertTrue(actual.isReadOnly()); + } + + @Test + public void customMethodAttributeWithReadOnlyOverrideDetected() throws Exception { + Method method = TestBean8.class.getMethod("getAge"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean8.class); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + + assertTrue(actual.isReadOnly()); + } + + @Test + public void customClassAttributeWithReadOnlyOverrideOnInterface() throws Exception { + Method method = TestInterface9.class.getMethod("getAge"); + + Transactional annotation = AnnotationUtils.findAnnotation(method, Transactional.class); + assertNull("AnnotationUtils.findAnnotation should not find @Transactional for TestBean9.getAge()", annotation); + annotation = AnnotationUtils.findAnnotation(TestBean9.class, Transactional.class); + assertNotNull("AnnotationUtils.findAnnotation failed to find @Transactional for TestBean9", annotation); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean9.class); + assertNotNull("Failed to retrieve TransactionAttribute for TestBean9.getAge()", actual); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + + assertTrue(actual.isReadOnly()); + } + + @Test + public void customMethodAttributeWithReadOnlyOverrideOnInterface() throws Exception { + Method method = TestInterface10.class.getMethod("getAge"); + + Transactional annotation = AnnotationUtils.findAnnotation(method, Transactional.class); + assertNotNull("AnnotationUtils.findAnnotation failed to find @Transactional for TestBean10.getAge()", + annotation); + annotation = AnnotationUtils.findAnnotation(TestBean10.class, Transactional.class); + assertNull("AnnotationUtils.findAnnotation should not find @Transactional for TestBean10", annotation); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute actual = atas.getTransactionAttribute(method, TestBean10.class); + assertNotNull("Failed to retrieve TransactionAttribute for TestBean10.getAge()", actual); + + RuleBasedTransactionAttribute rbta = new RuleBasedTransactionAttribute(); + rbta.getRollbackRules().add(new RollbackRuleAttribute(Exception.class)); + rbta.getRollbackRules().add(new NoRollbackRuleAttribute(IOException.class)); + assertEquals(rbta.getRollbackRules(), ((RuleBasedTransactionAttribute) actual).getRollbackRules()); + + assertTrue(actual.isReadOnly()); + } + + @Test + public void transactionAttributeDeclaredOnClassMethodWithEjb3() throws Exception { + Method getAgeMethod = ITestBean1.class.getMethod("getAge"); + Method getNameMethod = ITestBean1.class.getMethod("getName"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute getAgeAttr = atas.getTransactionAttribute(getAgeMethod, Ejb3AnnotatedBean1.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getAgeAttr.getPropagationBehavior()); + TransactionAttribute getNameAttr = atas.getTransactionAttribute(getNameMethod, Ejb3AnnotatedBean1.class); + assertEquals(TransactionAttribute.PROPAGATION_SUPPORTS, getNameAttr.getPropagationBehavior()); + } + + @Test + public void transactionAttributeDeclaredOnClassWithEjb3() throws Exception { + Method getAgeMethod = ITestBean1.class.getMethod("getAge"); + Method getNameMethod = ITestBean1.class.getMethod("getName"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute getAgeAttr = atas.getTransactionAttribute(getAgeMethod, Ejb3AnnotatedBean2.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getAgeAttr.getPropagationBehavior()); + TransactionAttribute getNameAttr = atas.getTransactionAttribute(getNameMethod, Ejb3AnnotatedBean2.class); + assertEquals(TransactionAttribute.PROPAGATION_SUPPORTS, getNameAttr.getPropagationBehavior()); + } + + @Test + public void transactionAttributeDeclaredOnInterfaceWithEjb3() throws Exception { + Method getAgeMethod = ITestEjb.class.getMethod("getAge"); + Method getNameMethod = ITestEjb.class.getMethod("getName"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute getAgeAttr = atas.getTransactionAttribute(getAgeMethod, Ejb3AnnotatedBean3.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getAgeAttr.getPropagationBehavior()); + TransactionAttribute getNameAttr = atas.getTransactionAttribute(getNameMethod, Ejb3AnnotatedBean3.class); + assertEquals(TransactionAttribute.PROPAGATION_SUPPORTS, getNameAttr.getPropagationBehavior()); + } + + @Test + public void transactionAttributeDeclaredOnClassMethodWithJta() throws Exception { + Method getAgeMethod = ITestBean1.class.getMethod("getAge"); + Method getNameMethod = ITestBean1.class.getMethod("getName"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute getAgeAttr = atas.getTransactionAttribute(getAgeMethod, JtaAnnotatedBean1.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getAgeAttr.getPropagationBehavior()); + TransactionAttribute getNameAttr = atas.getTransactionAttribute(getNameMethod, JtaAnnotatedBean1.class); + assertEquals(TransactionAttribute.PROPAGATION_SUPPORTS, getNameAttr.getPropagationBehavior()); + } + + @Test + public void transactionAttributeDeclaredOnClassWithJta() throws Exception { + Method getAgeMethod = ITestBean1.class.getMethod("getAge"); + Method getNameMethod = ITestBean1.class.getMethod("getName"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute getAgeAttr = atas.getTransactionAttribute(getAgeMethod, JtaAnnotatedBean2.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getAgeAttr.getPropagationBehavior()); + TransactionAttribute getNameAttr = atas.getTransactionAttribute(getNameMethod, JtaAnnotatedBean2.class); + assertEquals(TransactionAttribute.PROPAGATION_SUPPORTS, getNameAttr.getPropagationBehavior()); + } + + @Test + public void transactionAttributeDeclaredOnInterfaceWithJta() throws Exception { + Method getAgeMethod = ITestEjb.class.getMethod("getAge"); + Method getNameMethod = ITestEjb.class.getMethod("getName"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute getAgeAttr = atas.getTransactionAttribute(getAgeMethod, JtaAnnotatedBean3.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getAgeAttr.getPropagationBehavior()); + TransactionAttribute getNameAttr = atas.getTransactionAttribute(getNameMethod, JtaAnnotatedBean3.class); + assertEquals(TransactionAttribute.PROPAGATION_SUPPORTS, getNameAttr.getPropagationBehavior()); + } + + @Test + public void transactionAttributeDeclaredOnGroovyClass() throws Exception { + Method getAgeMethod = ITestBean1.class.getMethod("getAge"); + Method getNameMethod = ITestBean1.class.getMethod("getName"); + Method getMetaClassMethod = GroovyObject.class.getMethod("getMetaClass"); + + AnnotationTransactionAttributeSource atas = new AnnotationTransactionAttributeSource(); + TransactionAttribute getAgeAttr = atas.getTransactionAttribute(getAgeMethod, GroovyTestBean.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getAgeAttr.getPropagationBehavior()); + TransactionAttribute getNameAttr = atas.getTransactionAttribute(getNameMethod, GroovyTestBean.class); + assertEquals(TransactionAttribute.PROPAGATION_REQUIRED, getNameAttr.getPropagationBehavior()); + assertNull(atas.getTransactionAttribute(getMetaClassMethod, GroovyTestBean.class)); + } + + + interface ITestBean1 { + + int getAge(); + + void setAge(int age); + + String getName(); + + void setName(String name); + } + + + interface ITestBean2 { + + @Transactional + int getAge(); + + void setAge(int age); + } + + + interface ITestBean2X extends ITestBean2 { + + String getName(); + + void setName(String name); + } + + + @Transactional + interface ITestBean3 { + + int getAge(); + + void setAge(int age); + + String getName(); + + void setName(String name); + } + + + static class Empty implements ITestBean1 { + + private String name; + + private int age; + + public Empty() { + } + + public Empty(String name, int age) { + this.name = name; + this.age = age; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @SuppressWarnings("serial") + static class TestBean1 implements ITestBean1, Serializable { + + private String name; + + private int age; + + public TestBean1() { + } + + public TestBean1(String name, int age) { + this.name = name; + this.age = age; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + @Transactional(rollbackFor = Exception.class) + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + static class TestBean2 implements ITestBean2X { + + private String name; + + private int age; + + public TestBean2() { + } + + public TestBean2(String name, int age) { + this.name = name; + this.age = age; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + static class TestBean3 implements ITestBean3 { + + private String name; + + private int age; + + public TestBean3() { + } + + public TestBean3(String name, int age) { + this.name = name; + this.age = age; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + @Transactional(propagation = Propagation.REQUIRES_NEW, isolation=Isolation.REPEATABLE_READ, + timeout = 5, readOnly = true, rollbackFor = Exception.class, noRollbackFor = IOException.class) + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @Transactional(rollbackFor = Exception.class, noRollbackFor = IOException.class) + static class TestBean4 implements ITestBean3 { + + private String name; + + private int age; + + public TestBean4() { + } + + public TestBean4(String name, int age) { + this.name = name; + this.age = age; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @Retention(RetentionPolicy.RUNTIME) + @Transactional(rollbackFor = Exception.class, noRollbackFor = IOException.class) + @interface Tx { + } + + + @Tx + static class TestBean5 { + + public int getAge() { + return 10; + } + } + + + static class TestBean6 { + + @Tx + public int getAge() { + return 10; + } + } + + + @Retention(RetentionPolicy.RUNTIME) + @Transactional(rollbackFor = Exception.class, noRollbackFor = IOException.class) + @interface TxWithAttribute { + + boolean readOnly(); + } + + + @TxWithAttribute(readOnly = true) + static class TestBean7 { + + public int getAge() { + return 10; + } + } + + + static class TestBean8 { + + @TxWithAttribute(readOnly = true) + public int getAge() { + return 10; + } + } + + + @TxWithAttribute(readOnly = true) + interface TestInterface9 { + + int getAge(); + } + + + static class TestBean9 implements TestInterface9 { + + @Override + public int getAge() { + return 10; + } + } + + + interface TestInterface10 { + + @TxWithAttribute(readOnly = true) + int getAge(); + } + + + static class TestBean10 implements TestInterface10 { + + @Override + public int getAge() { + return 10; + } + } + + + static class Ejb3AnnotatedBean1 implements ITestBean1 { + + private String name; + + private int age; + + @Override + @javax.ejb.TransactionAttribute(TransactionAttributeType.SUPPORTS) + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + @javax.ejb.TransactionAttribute + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @javax.ejb.TransactionAttribute(TransactionAttributeType.SUPPORTS) + static class Ejb3AnnotatedBean2 implements ITestBean1 { + + private String name; + + private int age; + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + @javax.ejb.TransactionAttribute + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @javax.ejb.TransactionAttribute(TransactionAttributeType.SUPPORTS) + interface ITestEjb { + + @javax.ejb.TransactionAttribute + int getAge(); + + void setAge(int age); + + String getName(); + + void setName(String name); + } + + + static class Ejb3AnnotatedBean3 implements ITestEjb { + + private String name; + + private int age; + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + static class JtaAnnotatedBean1 implements ITestBean1 { + + private String name; + + private int age; + + @Override + @javax.transaction.Transactional(javax.transaction.Transactional.TxType.SUPPORTS) + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + @javax.transaction.Transactional + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @javax.transaction.Transactional(javax.transaction.Transactional.TxType.SUPPORTS) + static class JtaAnnotatedBean2 implements ITestBean1 { + + private String name; + + private int age; + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + @javax.transaction.Transactional + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @javax.transaction.Transactional(javax.transaction.Transactional.TxType.SUPPORTS) + interface ITestJta { + + @javax.transaction.Transactional + int getAge(); + + void setAge(int age); + + String getName(); + + void setName(String name); + } + + + static class JtaAnnotatedBean3 implements ITestEjb { + + private String name; + + private int age; + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + } + + + @Transactional + static class GroovyTestBean implements ITestBean1, GroovyObject { + + private String name; + + private int age; + + @Override + public String getName() { + return name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public int getAge() { + return age; + } + + @Override + public void setAge(int age) { + this.age = age; + } + + @Override + public Object invokeMethod(String name, Object args) { + return null; + } + + @Override + public Object getProperty(String propertyName) { + return null; + } + + @Override + public void setProperty(String propertyName, Object newValue) { + } + + @Override + public MetaClass getMetaClass() { + return null; + } + + @Override + public void setMetaClass(MetaClass metaClass) { + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionInterceptorTests.java b/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionInterceptorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ce3ce2ae17eac819b31055e11a0169d8cddb4ad1 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionInterceptorTests.java @@ -0,0 +1,452 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import org.junit.Test; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.interceptor.TransactionInterceptor; +import org.springframework.transaction.support.TransactionSynchronizationManager; + +import static org.junit.Assert.*; + +/** + * @author Rob Harrop + * @author Juergen Hoeller + */ +public class AnnotationTransactionInterceptorTests { + + private final CallCountingTransactionManager ptm = new CallCountingTransactionManager(); + + private final AnnotationTransactionAttributeSource source = new AnnotationTransactionAttributeSource(); + + private final TransactionInterceptor ti = new TransactionInterceptor(this.ptm, this.source); + + + @Test + public void classLevelOnly() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new TestClassLevelOnly()); + proxyFactory.addAdvice(this.ti); + + TestClassLevelOnly proxy = (TestClassLevelOnly) proxyFactory.getProxy(); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(1); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(2); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(3); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(4); + } + + @Test + public void withSingleMethodOverride() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new TestWithSingleMethodOverride()); + proxyFactory.addAdvice(this.ti); + + TestWithSingleMethodOverride proxy = (TestWithSingleMethodOverride) proxyFactory.getProxy(); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(1); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(2); + + proxy.doSomethingCompletelyElse(); + assertGetTransactionAndCommitCount(3); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(4); + } + + @Test + public void withSingleMethodOverrideInverted() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new TestWithSingleMethodOverrideInverted()); + proxyFactory.addAdvice(this.ti); + + TestWithSingleMethodOverrideInverted proxy = (TestWithSingleMethodOverrideInverted) proxyFactory.getProxy(); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(1); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(2); + + proxy.doSomethingCompletelyElse(); + assertGetTransactionAndCommitCount(3); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(4); + } + + @Test + public void withMultiMethodOverride() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new TestWithMultiMethodOverride()); + proxyFactory.addAdvice(this.ti); + + TestWithMultiMethodOverride proxy = (TestWithMultiMethodOverride) proxyFactory.getProxy(); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(1); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(2); + + proxy.doSomethingCompletelyElse(); + assertGetTransactionAndCommitCount(3); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(4); + } + + @Test + public void withRollback() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new TestWithRollback()); + proxyFactory.addAdvice(this.ti); + + TestWithRollback proxy = (TestWithRollback) proxyFactory.getProxy(); + + try { + proxy.doSomethingErroneous(); + fail("Should throw IllegalStateException"); + } + catch (IllegalStateException ex) { + assertGetTransactionAndRollbackCount(1); + } + + try { + proxy.doSomethingElseErroneous(); + fail("Should throw IllegalArgumentException"); + } + catch (IllegalArgumentException ex) { + assertGetTransactionAndRollbackCount(2); + } + } + + @Test + public void withInterface() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new TestWithInterfaceImpl()); + proxyFactory.addInterface(TestWithInterface.class); + proxyFactory.addAdvice(this.ti); + + TestWithInterface proxy = (TestWithInterface) proxyFactory.getProxy(); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(1); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(2); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(3); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(4); + + proxy.doSomethingDefault(); + assertGetTransactionAndCommitCount(5); + } + + @Test + public void crossClassInterfaceMethodLevelOnJdkProxy() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new SomeServiceImpl()); + proxyFactory.addInterface(SomeService.class); + proxyFactory.addAdvice(this.ti); + + SomeService someService = (SomeService) proxyFactory.getProxy(); + + someService.bar(); + assertGetTransactionAndCommitCount(1); + + someService.foo(); + assertGetTransactionAndCommitCount(2); + + someService.fooBar(); + assertGetTransactionAndCommitCount(3); + } + + @Test + public void crossClassInterfaceOnJdkProxy() { + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(new OtherServiceImpl()); + proxyFactory.addInterface(OtherService.class); + proxyFactory.addAdvice(this.ti); + + OtherService otherService = (OtherService) proxyFactory.getProxy(); + + otherService.foo(); + assertGetTransactionAndCommitCount(1); + } + + @Test + public void withInterfaceOnTargetJdkProxy() { + ProxyFactory targetFactory = new ProxyFactory(); + targetFactory.setTarget(new TestWithInterfaceImpl()); + targetFactory.addInterface(TestWithInterface.class); + + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(targetFactory.getProxy()); + proxyFactory.addInterface(TestWithInterface.class); + proxyFactory.addAdvice(this.ti); + + TestWithInterface proxy = (TestWithInterface) proxyFactory.getProxy(); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(1); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(2); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(3); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(4); + + proxy.doSomethingDefault(); + assertGetTransactionAndCommitCount(5); + } + + @Test + public void withInterfaceOnTargetCglibProxy() { + ProxyFactory targetFactory = new ProxyFactory(); + targetFactory.setTarget(new TestWithInterfaceImpl()); + targetFactory.setProxyTargetClass(true); + + ProxyFactory proxyFactory = new ProxyFactory(); + proxyFactory.setTarget(targetFactory.getProxy()); + proxyFactory.addInterface(TestWithInterface.class); + proxyFactory.addAdvice(this.ti); + + TestWithInterface proxy = (TestWithInterface) proxyFactory.getProxy(); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(1); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(2); + + proxy.doSomethingElse(); + assertGetTransactionAndCommitCount(3); + + proxy.doSomething(); + assertGetTransactionAndCommitCount(4); + + proxy.doSomethingDefault(); + assertGetTransactionAndCommitCount(5); + } + + private void assertGetTransactionAndCommitCount(int expectedCount) { + assertEquals(expectedCount, this.ptm.begun); + assertEquals(expectedCount, this.ptm.commits); + } + + private void assertGetTransactionAndRollbackCount(int expectedCount) { + assertEquals(expectedCount, this.ptm.begun); + assertEquals(expectedCount, this.ptm.rollbacks); + } + + + @Transactional + public static class TestClassLevelOnly { + + public void doSomething() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + public void doSomethingElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + } + + + @Transactional + public static class TestWithSingleMethodOverride { + + public void doSomething() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + @Transactional(readOnly = true) + public void doSomethingElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + public void doSomethingCompletelyElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + } + + + @Transactional(readOnly = true) + public static class TestWithSingleMethodOverrideInverted { + + @Transactional + public void doSomething() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + public void doSomethingElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + public void doSomethingCompletelyElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + } + + + @Transactional + public static class TestWithMultiMethodOverride { + + @Transactional(readOnly = true) + public void doSomething() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + @Transactional(readOnly = true) + public void doSomethingElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + public void doSomethingCompletelyElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + } + + + @Transactional(rollbackFor = IllegalStateException.class) + public static class TestWithRollback { + + public void doSomethingErroneous() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + throw new IllegalStateException(); + } + + @Transactional(rollbackFor = IllegalArgumentException.class) + public void doSomethingElseErroneous() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + throw new IllegalArgumentException(); + } + } + + + public interface BaseInterface { + + void doSomething(); + } + + + @Transactional + public interface TestWithInterface extends BaseInterface { + + @Transactional(readOnly = true) + void doSomethingElse(); + + default void doSomethingDefault() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + } + + + public static class TestWithInterfaceImpl implements TestWithInterface { + + @Override + public void doSomething() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + + @Override + public void doSomethingElse() { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + } + + + public interface SomeService { + + void foo(); + + @Transactional + void bar(); + + @Transactional(readOnly = true) + void fooBar(); + } + + + public static class SomeServiceImpl implements SomeService { + + @Override + public void bar() { + } + + @Override + @Transactional + public void foo() { + } + + @Override + @Transactional(readOnly = false) + public void fooBar() { + } + } + + + public interface OtherService { + + void foo(); + } + + + @Transactional + public static class OtherServiceImpl implements OtherService { + + @Override + public void foo() { + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionNamespaceHandlerTests.java b/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionNamespaceHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6c7ce330eed2c0d836839c5d6efc498fe0d98f3f --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/annotation/AnnotationTransactionNamespaceHandlerTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.lang.management.ManagementFactory; +import java.util.Collection; +import java.util.Map; + +import javax.management.MBeanServer; +import javax.management.ObjectName; + +import org.junit.After; +import org.junit.Test; + +import org.springframework.aop.support.AopUtils; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.support.ClassPathXmlApplicationContext; +import org.springframework.jmx.export.annotation.ManagedOperation; +import org.springframework.jmx.export.annotation.ManagedResource; +import org.springframework.stereotype.Service; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.config.TransactionManagementConfigUtils; +import org.springframework.transaction.event.TransactionalEventListenerFactory; + +import static org.junit.Assert.*; + +/** + * @author Rob Harrop + * @author Juergen Hoeller + * @author Sam Brannen + */ +public class AnnotationTransactionNamespaceHandlerTests { + + private final ConfigurableApplicationContext context = new ClassPathXmlApplicationContext( + "org/springframework/transaction/annotation/annotationTransactionNamespaceHandlerTests.xml"); + + @After + public void tearDown() { + this.context.close(); + } + + @Test + public void isProxy() throws Exception { + TransactionalTestBean bean = getTestBean(); + assertTrue("testBean is not a proxy", AopUtils.isAopProxy(bean)); + Map services = this.context.getBeansWithAnnotation(Service.class); + assertTrue("Stereotype annotation not visible", services.containsKey("testBean")); + } + + @Test + public void invokeTransactional() throws Exception { + TransactionalTestBean testBean = getTestBean(); + CallCountingTransactionManager ptm = (CallCountingTransactionManager) context.getBean("transactionManager"); + + // try with transactional + assertEquals("Should not have any started transactions", 0, ptm.begun); + testBean.findAllFoos(); + assertEquals("Should have 1 started transaction", 1, ptm.begun); + assertEquals("Should have 1 committed transaction", 1, ptm.commits); + + // try with non-transaction + testBean.doSomething(); + assertEquals("Should not have started another transaction", 1, ptm.begun); + + // try with exceptional + try { + testBean.exceptional(new IllegalArgumentException("foo")); + fail("Should NEVER get here"); + } + catch (Throwable throwable) { + assertEquals("Should have another started transaction", 2, ptm.begun); + assertEquals("Should have 1 rolled back transaction", 1, ptm.rollbacks); + } + } + + @Test + public void nonPublicMethodsNotAdvised() { + TransactionalTestBean testBean = getTestBean(); + CallCountingTransactionManager ptm = (CallCountingTransactionManager) context.getBean("transactionManager"); + + assertEquals("Should not have any started transactions", 0, ptm.begun); + testBean.annotationsOnProtectedAreIgnored(); + assertEquals("Should not have any started transactions", 0, ptm.begun); + } + + @Test + public void mBeanExportAlsoWorks() throws Exception { + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + assertEquals("done", + server.invoke(ObjectName.getInstance("test:type=TestBean"), "doSomething", new Object[0], new String[0])); + } + + @Test + public void transactionalEventListenerRegisteredProperly() { + assertTrue(this.context.containsBean(TransactionManagementConfigUtils + .TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME)); + assertEquals(1, this.context.getBeansOfType(TransactionalEventListenerFactory.class).size()); + } + + private TransactionalTestBean getTestBean() { + return (TransactionalTestBean) context.getBean("testBean"); + } + + + @Service + @ManagedResource("test:type=TestBean") + public static class TransactionalTestBean { + + @Transactional(readOnly = true) + public Collection findAllFoos() { + return null; + } + + @Transactional + public void saveFoo() { + } + + @Transactional("qualifiedTransactionManager") + public void saveQualifiedFoo() { + } + + @Transactional(transactionManager = "qualifiedTransactionManager") + public void saveQualifiedFooWithAttributeAlias() { + } + + @Transactional + public void exceptional(Throwable t) throws Throwable { + throw t; + } + + @ManagedOperation + public String doSomething() { + return "done"; + } + + @Transactional + protected void annotationsOnProtectedAreIgnored() { + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/annotation/EnableTransactionManagementTests.java b/spring-tx/src/test/java/org/springframework/transaction/annotation/EnableTransactionManagementTests.java new file mode 100644 index 0000000000000000000000000000000000000000..93311ba359c39f69340e3cd19c744f1d49819a46 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/annotation/EnableTransactionManagementTests.java @@ -0,0 +1,358 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.annotation; + +import java.util.Collection; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.AdviceMode; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ConditionContext; +import org.springframework.context.annotation.Conditional; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ConfigurationCondition; +import org.springframework.core.type.AnnotatedTypeMetadata; +import org.springframework.stereotype.Service; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.config.TransactionManagementConfigUtils; +import org.springframework.transaction.event.TransactionalEventListenerFactory; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Tests demonstrating use of @EnableTransactionManagement @Configuration classes. + * + * @author Chris Beams + * @author Juergen Hoeller + * @author Stephane Nicoll + * @author Sam Brannen + * @since 3.1 + */ +public class EnableTransactionManagementTests { + + @Test + public void transactionProxyIsCreated() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext( + EnableTxConfig.class, TxManagerConfig.class); + TransactionalTestBean bean = ctx.getBean(TransactionalTestBean.class); + assertTrue("testBean is not a proxy", AopUtils.isAopProxy(bean)); + Map services = ctx.getBeansWithAnnotation(Service.class); + assertTrue("Stereotype annotation not visible", services.containsKey("testBean")); + ctx.close(); + } + + @Test + public void transactionProxyIsCreatedWithEnableOnSuperclass() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext( + InheritedEnableTxConfig.class, TxManagerConfig.class); + TransactionalTestBean bean = ctx.getBean(TransactionalTestBean.class); + assertTrue("testBean is not a proxy", AopUtils.isAopProxy(bean)); + Map services = ctx.getBeansWithAnnotation(Service.class); + assertTrue("Stereotype annotation not visible", services.containsKey("testBean")); + ctx.close(); + } + + @Test + public void transactionProxyIsCreatedWithEnableOnExcludedSuperclass() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext( + ParentEnableTxConfig.class, ChildEnableTxConfig.class, TxManagerConfig.class); + TransactionalTestBean bean = ctx.getBean(TransactionalTestBean.class); + assertTrue("testBean is not a proxy", AopUtils.isAopProxy(bean)); + Map services = ctx.getBeansWithAnnotation(Service.class); + assertTrue("Stereotype annotation not visible", services.containsKey("testBean")); + ctx.close(); + } + + @Test + public void txManagerIsResolvedOnInvocationOfTransactionalMethod() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext( + EnableTxConfig.class, TxManagerConfig.class); + TransactionalTestBean bean = ctx.getBean(TransactionalTestBean.class); + + // invoke a transactional method, causing the PlatformTransactionManager bean to be resolved. + bean.findAllFoos(); + ctx.close(); + } + + @Test + public void txManagerIsResolvedCorrectlyWhenMultipleManagersArePresent() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext( + EnableTxConfig.class, MultiTxManagerConfig.class); + TransactionalTestBean bean = ctx.getBean(TransactionalTestBean.class); + + // invoke a transactional method, causing the PlatformTransactionManager bean to be resolved. + bean.findAllFoos(); + ctx.close(); + } + + /** + * A cheap test just to prove that in ASPECTJ mode, the AnnotationTransactionAspect does indeed + * get loaded -- or in this case, attempted to be loaded at which point the test fails. + */ + @Test + @SuppressWarnings("resource") + public void proxyTypeAspectJCausesRegistrationOfAnnotationTransactionAspect() { + try { + new AnnotationConfigApplicationContext(EnableAspectjTxConfig.class, TxManagerConfig.class); + fail("should have thrown CNFE when trying to load AnnotationTransactionAspect. " + + "Do you actually have org.springframework.aspects on the classpath?"); + } + catch (Exception ex) { + assertThat(ex.getMessage(), containsString("AspectJJtaTransactionManagementConfiguration")); + } + } + + @Test + public void transactionalEventListenerRegisteredProperly() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext(EnableTxConfig.class); + assertTrue(ctx.containsBean(TransactionManagementConfigUtils.TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME)); + assertEquals(1, ctx.getBeansOfType(TransactionalEventListenerFactory.class).size()); + ctx.close(); + } + + @Test + public void spr11915TransactionManagerAsManualSingleton() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext(Spr11915Config.class); + TransactionalTestBean bean = ctx.getBean(TransactionalTestBean.class); + CallCountingTransactionManager txManager = ctx.getBean("qualifiedTransactionManager", CallCountingTransactionManager.class); + + bean.saveQualifiedFoo(); + assertThat(txManager.begun, equalTo(1)); + assertThat(txManager.commits, equalTo(1)); + assertThat(txManager.rollbacks, equalTo(0)); + + bean.saveQualifiedFooWithAttributeAlias(); + assertThat(txManager.begun, equalTo(2)); + assertThat(txManager.commits, equalTo(2)); + assertThat(txManager.rollbacks, equalTo(0)); + + ctx.close(); + } + + @Test + public void spr14322FindsOnInterfaceWithInterfaceProxy() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext(Spr14322ConfigA.class); + TransactionalTestInterface bean = ctx.getBean(TransactionalTestInterface.class); + CallCountingTransactionManager txManager = ctx.getBean(CallCountingTransactionManager.class); + + bean.saveFoo(); + bean.saveBar(); + assertThat(txManager.begun, equalTo(2)); + assertThat(txManager.commits, equalTo(2)); + assertThat(txManager.rollbacks, equalTo(0)); + + ctx.close(); + } + + @Test + public void spr14322FindsOnInterfaceWithCglibProxy() { + AnnotationConfigApplicationContext ctx = new AnnotationConfigApplicationContext(Spr14322ConfigB.class); + TransactionalTestInterface bean = ctx.getBean(TransactionalTestInterface.class); + CallCountingTransactionManager txManager = ctx.getBean(CallCountingTransactionManager.class); + + bean.saveFoo(); + bean.saveBar(); + assertThat(txManager.begun, equalTo(2)); + assertThat(txManager.commits, equalTo(2)); + assertThat(txManager.rollbacks, equalTo(0)); + + ctx.close(); + } + + + @Service + public static class TransactionalTestBean { + + @Transactional(readOnly = true) + public Collection findAllFoos() { + return null; + } + + @Transactional("qualifiedTransactionManager") + public void saveQualifiedFoo() { + } + + @Transactional(transactionManager = "qualifiedTransactionManager") + public void saveQualifiedFooWithAttributeAlias() { + } + } + + + @Configuration + @EnableTransactionManagement + static class EnableTxConfig { + } + + + @Configuration + static class InheritedEnableTxConfig extends EnableTxConfig { + } + + + @Configuration + @EnableTransactionManagement + @Conditional(NeverCondition.class) + static class ParentEnableTxConfig { + + @Bean + Object someBean() { + return new Object(); + } + } + + + @Configuration + static class ChildEnableTxConfig extends ParentEnableTxConfig { + + @Override + Object someBean() { + return "X"; + } + } + + + private static class NeverCondition implements ConfigurationCondition { + + @Override + public boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata) { + return false; + } + + @Override + public ConfigurationPhase getConfigurationPhase() { + return ConfigurationPhase.REGISTER_BEAN; + } + } + + + @Configuration + @EnableTransactionManagement(mode = AdviceMode.ASPECTJ) + static class EnableAspectjTxConfig { + } + + + @Configuration + static class TxManagerConfig { + + @Bean + public TransactionalTestBean testBean() { + return new TransactionalTestBean(); + } + + @Bean + public PlatformTransactionManager txManager() { + return new CallCountingTransactionManager(); + } + } + + + @Configuration + static class MultiTxManagerConfig extends TxManagerConfig implements TransactionManagementConfigurer { + + @Bean + public PlatformTransactionManager txManager2() { + return new CallCountingTransactionManager(); + } + + @Override + public PlatformTransactionManager annotationDrivenTransactionManager() { + return txManager2(); + } + } + + + @Configuration + @EnableTransactionManagement + static class Spr11915Config { + + @Autowired + public void initializeApp(ConfigurableApplicationContext applicationContext) { + applicationContext.getBeanFactory().registerSingleton( + "qualifiedTransactionManager", new CallCountingTransactionManager()); + } + + @Bean + public TransactionalTestBean testBean() { + return new TransactionalTestBean(); + } + } + + + public interface BaseTransactionalInterface { + + @Transactional + default void saveBar() { + } + } + + + public interface TransactionalTestInterface extends BaseTransactionalInterface { + + @Transactional + void saveFoo(); + } + + + @Service + public static class TransactionalTestService implements TransactionalTestInterface { + + @Override + public void saveFoo() { + } + } + + + @Configuration + @EnableTransactionManagement + static class Spr14322ConfigA { + + @Bean + public TransactionalTestInterface testBean() { + return new TransactionalTestService(); + } + + @Bean + public PlatformTransactionManager txManager() { + return new CallCountingTransactionManager(); + } + } + + + @Configuration + @EnableTransactionManagement(proxyTargetClass = true) + static class Spr14322ConfigB { + + @Bean + public TransactionalTestInterface testBean() { + return new TransactionalTestService(); + } + + @Bean + public PlatformTransactionManager txManager() { + return new CallCountingTransactionManager(); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/config/AnnotationDrivenTests.java b/spring-tx/src/test/java/org/springframework/transaction/config/AnnotationDrivenTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a63a9e0905db36d579b916858e8949550415f276 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/config/AnnotationDrivenTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import java.io.Serializable; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; +import org.junit.Test; + +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.support.ClassPathXmlApplicationContext; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.util.SerializationTestUtils; + +import static org.junit.Assert.*; + +/** + * @author Rob Harrop + * @author Juergen Hoeller + */ +public class AnnotationDrivenTests { + + @Test + public void withProxyTargetClass() throws Exception { + ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext("annotationDrivenProxyTargetClassTests.xml", getClass()); + doTestWithMultipleTransactionManagers(context); + } + + @Test + public void withConfigurationClass() throws Exception { + ApplicationContext parent = new AnnotationConfigApplicationContext(TransactionManagerConfiguration.class); + ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext(new String[] {"annotationDrivenConfigurationClassTests.xml"}, getClass(), parent); + doTestWithMultipleTransactionManagers(context); + } + + @Test + public void withAnnotatedTransactionManagers() throws Exception { + AnnotationConfigApplicationContext parent = new AnnotationConfigApplicationContext(); + parent.registerBeanDefinition("transactionManager1", new RootBeanDefinition(SynchTransactionManager.class)); + parent.registerBeanDefinition("transactionManager2", new RootBeanDefinition(NoSynchTransactionManager.class)); + parent.refresh(); + ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext(new String[] {"annotationDrivenConfigurationClassTests.xml"}, getClass(), parent); + doTestWithMultipleTransactionManagers(context); + } + + private void doTestWithMultipleTransactionManagers(ApplicationContext context) { + CallCountingTransactionManager tm1 = context.getBean("transactionManager1", CallCountingTransactionManager.class); + CallCountingTransactionManager tm2 = context.getBean("transactionManager2", CallCountingTransactionManager.class); + TransactionalService service = context.getBean("service", TransactionalService.class); + assertTrue(AopUtils.isCglibProxy(service)); + service.setSomething("someName"); + assertEquals(1, tm1.commits); + assertEquals(0, tm2.commits); + service.doSomething(); + assertEquals(1, tm1.commits); + assertEquals(1, tm2.commits); + service.setSomething("someName"); + assertEquals(2, tm1.commits); + assertEquals(1, tm2.commits); + service.doSomething(); + assertEquals(2, tm1.commits); + assertEquals(2, tm2.commits); + } + + @Test + @SuppressWarnings("resource") + public void serializableWithPreviousUsage() throws Exception { + ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext("annotationDrivenProxyTargetClassTests.xml", getClass()); + TransactionalService service = context.getBean("service", TransactionalService.class); + service.setSomething("someName"); + service = (TransactionalService) SerializationTestUtils.serializeAndDeserialize(service); + service.setSomething("someName"); + } + + @Test + @SuppressWarnings("resource") + public void serializableWithoutPreviousUsage() throws Exception { + ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext("annotationDrivenProxyTargetClassTests.xml", getClass()); + TransactionalService service = context.getBean("service", TransactionalService.class); + service = (TransactionalService) SerializationTestUtils.serializeAndDeserialize(service); + service.setSomething("someName"); + } + + + @SuppressWarnings("serial") + public static class TransactionCheckingInterceptor implements MethodInterceptor, Serializable { + + @Override + public Object invoke(MethodInvocation methodInvocation) throws Throwable { + if (methodInvocation.getMethod().getName().equals("setSomething")) { + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + } + else { + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + } + return methodInvocation.proceed(); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/config/NoSynch.java b/spring-tx/src/test/java/org/springframework/transaction/config/NoSynch.java new file mode 100644 index 0000000000000000000000000000000000000000..5abeae194c3a30f906769fedead3ca243c01a5d3 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/config/NoSynch.java @@ -0,0 +1,31 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.springframework.beans.factory.annotation.Qualifier; + +/** + * @author Juergen Hoeller + */ +@Qualifier("noSynch") +@Retention(RetentionPolicy.RUNTIME) +public @interface NoSynch { + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/config/NoSynchTransactionManager.java b/spring-tx/src/test/java/org/springframework/transaction/config/NoSynchTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..a4e6ec1b548ced1bd751a80f1f6bf6d0da4a2f17 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/config/NoSynchTransactionManager.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import org.springframework.tests.transaction.CallCountingTransactionManager; + +/** + * @author Juergen Hoeller + */ +@NoSynch +@SuppressWarnings("serial") +public class NoSynchTransactionManager extends CallCountingTransactionManager { + + public NoSynchTransactionManager() { + setTransactionSynchronization(CallCountingTransactionManager.SYNCHRONIZATION_NEVER); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/config/SynchTransactionManager.java b/spring-tx/src/test/java/org/springframework/transaction/config/SynchTransactionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..a09e1582df3e2683921a1eb63d0882af155c3aca --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/config/SynchTransactionManager.java @@ -0,0 +1,29 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.tests.transaction.CallCountingTransactionManager; + +/** + * @author Juergen Hoeller + */ +@Qualifier("synch") +@SuppressWarnings("serial") +public class SynchTransactionManager extends CallCountingTransactionManager { + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/config/TransactionManagerConfiguration.java b/spring-tx/src/test/java/org/springframework/transaction/config/TransactionManagerConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..af0e5a351cbcfbc85325f7abe1e098eaae19155f --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/config/TransactionManagerConfiguration.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; + +/** + * @author Juergen Hoeller + */ +@Configuration +public class TransactionManagerConfiguration { + + @Bean + @Qualifier("synch") + public PlatformTransactionManager transactionManager1() { + return new CallCountingTransactionManager(); + } + + @Bean + @NoSynch + public PlatformTransactionManager transactionManager2() { + CallCountingTransactionManager tm = new CallCountingTransactionManager(); + tm.setTransactionSynchronization(CallCountingTransactionManager.SYNCHRONIZATION_NEVER); + return tm; + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/config/TransactionalService.java b/spring-tx/src/test/java/org/springframework/transaction/config/TransactionalService.java new file mode 100644 index 0000000000000000000000000000000000000000..cef9bcfb7448ad32b6e3cdbdaac58cf4431038c6 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/config/TransactionalService.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.config; + +import java.io.Serializable; + +import org.springframework.transaction.annotation.Transactional; + +/** + * @author Rob Harrop + * @author Juergen Hoeller + */ +@SuppressWarnings("serial") +public class TransactionalService implements Serializable { + + @Transactional("synch") + public void setSomething(String name) { + } + + @Transactional("noSynch") + public void doSomething() { + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/event/ApplicationListenerMethodTransactionalAdapterTests.java b/spring-tx/src/test/java/org/springframework/transaction/event/ApplicationListenerMethodTransactionalAdapterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..bd22c4cd248e888ee1bc24546c949a289aa04618 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/event/ApplicationListenerMethodTransactionalAdapterTests.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.event; + +import java.lang.reflect.Method; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.context.PayloadApplicationEvent; +import org.springframework.context.event.ApplicationListenerMethodAdapter; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.util.ReflectionUtils; + +import static org.junit.Assert.*; + +/** + * @author Stephane Nicoll + */ +public class ApplicationListenerMethodTransactionalAdapterTests { + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @Test + public void defaultPhase() { + Method m = ReflectionUtils.findMethod(SampleEvents.class, "defaultPhase", String.class); + assertPhase(m, TransactionPhase.AFTER_COMMIT); + } + + @Test + public void phaseSet() { + Method m = ReflectionUtils.findMethod(SampleEvents.class, "phaseSet", String.class); + assertPhase(m, TransactionPhase.AFTER_ROLLBACK); + } + + @Test + public void phaseAndClassesSet() { + Method m = ReflectionUtils.findMethod(SampleEvents.class, "phaseAndClassesSet"); + assertPhase(m, TransactionPhase.AFTER_COMPLETION); + supportsEventType(true, m, createGenericEventType(String.class)); + supportsEventType(true, m, createGenericEventType(Integer.class)); + supportsEventType(false, m, createGenericEventType(Double.class)); + } + + @Test + public void valueSet() { + Method m = ReflectionUtils.findMethod(SampleEvents.class, "valueSet"); + assertPhase(m, TransactionPhase.AFTER_COMMIT); + supportsEventType(true, m, createGenericEventType(String.class)); + supportsEventType(false, m, createGenericEventType(Double.class)); + } + + private void assertPhase(Method method, TransactionPhase expected) { + assertNotNull("Method must not be null", method); + TransactionalEventListener annotation = + AnnotatedElementUtils.findMergedAnnotation(method, TransactionalEventListener.class); + assertEquals("Wrong phase for '" + method + "'", expected, annotation.phase()); + } + + private void supportsEventType(boolean match, Method method, ResolvableType eventType) { + ApplicationListenerMethodAdapter adapter = createTestInstance(method); + assertEquals("Wrong match for event '" + eventType + "' on " + method, + match, adapter.supportsEventType(eventType)); + } + + private ApplicationListenerMethodTransactionalAdapter createTestInstance(Method m) { + return new ApplicationListenerMethodTransactionalAdapter("test", SampleEvents.class, m); + } + + private ResolvableType createGenericEventType(Class payloadType) { + return ResolvableType.forClassWithGenerics(PayloadApplicationEvent.class, payloadType); + } + + + static class SampleEvents { + + @TransactionalEventListener + public void defaultPhase(String data) { + } + + @TransactionalEventListener(phase = TransactionPhase.AFTER_ROLLBACK) + public void phaseSet(String data) { + } + + @TransactionalEventListener(classes = {String.class, Integer.class}, + phase = TransactionPhase.AFTER_COMPLETION) + public void phaseAndClassesSet() { + } + + @TransactionalEventListener(String.class) + public void valueSet() { + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalEventListenerTests.java b/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalEventListenerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1529c7ebbd00353a1f4a6333ccee1f871e62f890 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalEventListenerTests.java @@ -0,0 +1,549 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.event; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.event.EventListener; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.annotation.EnableTransactionManagement; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.TransactionSynchronizationAdapter; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.transaction.support.TransactionTemplate; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; +import static org.springframework.transaction.event.TransactionPhase.*; + +/** + * Integration tests for {@link TransactionalEventListener} support + * + * @author Stephane Nicoll + * @author Sam Brannen + * @since 4.2 + */ +public class TransactionalEventListenerTests { + + private ConfigurableApplicationContext context; + + private EventCollector eventCollector; + + private TransactionTemplate transactionTemplate = new TransactionTemplate(new CallCountingTransactionManager()); + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @After + public void closeContext() { + if (this.context != null) { + this.context.close(); + } + } + + + @Test + public void immediately() { + load(ImmediateTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertEvents(EventCollector.IMMEDIATELY, "test"); + getEventCollector().assertTotalEventsCount(1); + return null; + + }); + getEventCollector().assertEvents(EventCollector.IMMEDIATELY, "test"); + getEventCollector().assertTotalEventsCount(1); + } + + @Test + public void immediatelyImpactsCurrentTransaction() { + load(ImmediateTestListener.class, BeforeCommitTestListener.class); + try { + this.transactionTemplate.execute(status -> { + getContext().publishEvent("FAIL"); + fail("Should have thrown an exception at this point"); + return null; + }); + } + catch (IllegalStateException e) { + assertTrue(e.getMessage().contains("Test exception")); + assertTrue(e.getMessage().contains(EventCollector.IMMEDIATELY)); + } + getEventCollector().assertEvents(EventCollector.IMMEDIATELY, "FAIL"); + getEventCollector().assertTotalEventsCount(1); + } + + @Test + public void afterCompletionCommit() { + load(AfterCompletionTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + return null; + + }); + getEventCollector().assertEvents(EventCollector.AFTER_COMPLETION, "test"); + getEventCollector().assertTotalEventsCount(1); // After rollback not invoked + } + + @Test + public void afterCompletionRollback() { + load(AfterCompletionTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + status.setRollbackOnly(); + return null; + + }); + getEventCollector().assertEvents(EventCollector.AFTER_COMPLETION, "test"); + getEventCollector().assertTotalEventsCount(1); // After rollback not invoked + } + + @Test + public void afterCommit() { + load(AfterCompletionExplicitTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + return null; + + }); + getEventCollector().assertEvents(EventCollector.AFTER_COMMIT, "test"); + getEventCollector().assertTotalEventsCount(1); // After rollback not invoked + } + + @Test + public void afterCommitWithTransactionalComponentListenerProxiedViaDynamicProxy() { + load(TransactionalConfiguration.class, TransactionalComponentTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("SKIP"); + getEventCollector().assertNoEventReceived(); + return null; + }); + getEventCollector().assertNoEventReceived(); + } + + @Test + public void afterRollback() { + load(AfterCompletionExplicitTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + status.setRollbackOnly(); + return null; + }); + getEventCollector().assertEvents(EventCollector.AFTER_ROLLBACK, "test"); + getEventCollector().assertTotalEventsCount(1); // After commit not invoked + } + + @Test + public void beforeCommit() { + load(BeforeCommitTestListener.class); + this.transactionTemplate.execute(status -> { + TransactionSynchronizationManager.registerSynchronization(new EventTransactionSynchronization(10) { + @Override + public void beforeCommit(boolean readOnly) { + getEventCollector().assertNoEventReceived(); // Not seen yet + } + }); + TransactionSynchronizationManager.registerSynchronization(new EventTransactionSynchronization(20) { + @Override + public void beforeCommit(boolean readOnly) { + getEventCollector().assertEvents(EventCollector.BEFORE_COMMIT, "test"); + getEventCollector().assertTotalEventsCount(1); + } + }); + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + return null; + + }); + getEventCollector().assertEvents(EventCollector.BEFORE_COMMIT, "test"); + getEventCollector().assertTotalEventsCount(1); + } + + @Test + public void beforeCommitWithException() { // Validates the custom synchronization is invoked + load(BeforeCommitTestListener.class); + try { + this.transactionTemplate.execute(status -> { + TransactionSynchronizationManager.registerSynchronization(new EventTransactionSynchronization(10) { + @Override + public void beforeCommit(boolean readOnly) { + throw new IllegalStateException("test"); + } + }); + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + return null; + + }); + fail("Should have thrown an exception"); + } + catch (IllegalStateException e) { + // Test exception - ignore + } + getEventCollector().assertNoEventReceived(); // Before commit not invoked + } + + @Test + public void regularTransaction() { + load(ImmediateTestListener.class, BeforeCommitTestListener.class, AfterCompletionExplicitTestListener.class); + this.transactionTemplate.execute(status -> { + TransactionSynchronizationManager.registerSynchronization(new EventTransactionSynchronization(10) { + @Override + public void beforeCommit(boolean readOnly) { + getEventCollector().assertTotalEventsCount(1); // Immediate event + getEventCollector().assertEvents(EventCollector.IMMEDIATELY, "test"); + } + }); + TransactionSynchronizationManager.registerSynchronization(new EventTransactionSynchronization(20) { + @Override + public void beforeCommit(boolean readOnly) { + getEventCollector().assertEvents(EventCollector.BEFORE_COMMIT, "test"); + getEventCollector().assertTotalEventsCount(2); + } + }); + getContext().publishEvent("test"); + getEventCollector().assertTotalEventsCount(1); + return null; + + }); + getEventCollector().assertEvents(EventCollector.AFTER_COMMIT, "test"); + getEventCollector().assertTotalEventsCount(3); // Immediate, before commit, after commit + } + + @Test + public void noTransaction() { + load(BeforeCommitTestListener.class, AfterCompletionTestListener.class, + AfterCompletionExplicitTestListener.class); + this.context.publishEvent("test"); + getEventCollector().assertTotalEventsCount(0); + } + + @Test + public void noTransactionWithFallbackExecution() { + load(FallbackExecutionTestListener.class); + this.context.publishEvent("test"); + this.eventCollector.assertEvents(EventCollector.BEFORE_COMMIT, "test"); + this.eventCollector.assertEvents(EventCollector.AFTER_COMMIT, "test"); + this.eventCollector.assertEvents(EventCollector.AFTER_ROLLBACK, "test"); + this.eventCollector.assertEvents(EventCollector.AFTER_COMPLETION, "test"); + getEventCollector().assertTotalEventsCount(4); + } + + @Test + public void conditionFoundOnTransactionalEventListener() { + load(ImmediateTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("SKIP"); + getEventCollector().assertNoEventReceived(); + return null; + }); + getEventCollector().assertNoEventReceived(); + } + + @Test + public void afterCommitMetaAnnotation() throws Exception { + load(AfterCommitMetaAnnotationTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("test"); + getEventCollector().assertNoEventReceived(); + return null; + + }); + getEventCollector().assertEvents(EventCollector.AFTER_COMMIT, "test"); + getEventCollector().assertTotalEventsCount(1); + } + + @Test + public void conditionFoundOnMetaAnnotation() { + load(AfterCommitMetaAnnotationTestListener.class); + this.transactionTemplate.execute(status -> { + getContext().publishEvent("SKIP"); + getEventCollector().assertNoEventReceived(); + return null; + + }); + getEventCollector().assertNoEventReceived(); + } + + + protected EventCollector getEventCollector() { + return eventCollector; + } + + protected ConfigurableApplicationContext getContext() { + return context; + } + + private void load(Class... classes) { + List> allClasses = new ArrayList<>(); + allClasses.add(BasicConfiguration.class); + allClasses.addAll(Arrays.asList(classes)); + doLoad(allClasses.toArray(new Class[allClasses.size()])); + } + + private void doLoad(Class... classes) { + this.context = new AnnotationConfigApplicationContext(classes); + this.eventCollector = this.context.getBean(EventCollector.class); + } + + + @Configuration + static class BasicConfiguration { + + @Bean // set automatically with tx management + public TransactionalEventListenerFactory transactionalEventListenerFactory() { + return new TransactionalEventListenerFactory(); + } + + @Bean + public EventCollector eventCollector() { + return new EventCollector(); + } + } + + + @EnableTransactionManagement + @Configuration + static class TransactionalConfiguration { + + @Bean + public CallCountingTransactionManager transactionManager() { + return new CallCountingTransactionManager(); + } + } + + + static class EventCollector { + + public static final String IMMEDIATELY = "IMMEDIATELY"; + + public static final String BEFORE_COMMIT = "BEFORE_COMMIT"; + + public static final String AFTER_COMPLETION = "AFTER_COMPLETION"; + + public static final String AFTER_COMMIT = "AFTER_COMMIT"; + + public static final String AFTER_ROLLBACK = "AFTER_ROLLBACK"; + + public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK}; + + private final MultiValueMap events = new LinkedMultiValueMap<>(); + + public void addEvent(String phase, Object event) { + this.events.add(phase, event); + } + + public List getEvents(String phase) { + return this.events.getOrDefault(phase, Collections.emptyList()); + } + + public void assertNoEventReceived(String... phases) { + if (phases.length == 0) { // All values if none set + phases = ALL_PHASES; + } + for (String phase : phases) { + List eventsForPhase = getEvents(phase); + assertEquals("Expected no events for phase '" + phase + "' " + + "but got " + eventsForPhase + ":", 0, eventsForPhase.size()); + } + } + + public void assertEvents(String phase, Object... expected) { + List actual = getEvents(phase); + assertEquals("wrong number of events for phase '" + phase + "'", expected.length, actual.size()); + for (int i = 0; i < expected.length; i++) { + assertEquals("Wrong event for phase '" + phase + "' at index " + i, expected[i], actual.get(i)); + } + } + + public void assertTotalEventsCount(int number) { + int size = 0; + for (Map.Entry> entry : this.events.entrySet()) { + size += entry.getValue().size(); + } + assertEquals("Wrong number of total events (" + this.events.size() + ") " + + "registered phase(s)", number, size); + } + } + + + static abstract class BaseTransactionalTestListener { + + static final String FAIL_MSG = "FAIL"; + + @Autowired + private EventCollector eventCollector; + + public void handleEvent(String phase, String data) { + this.eventCollector.addEvent(phase, data); + if (FAIL_MSG.equals(data)) { + throw new IllegalStateException("Test exception on phase '" + phase + "'"); + } + } + } + + + @Component + static class ImmediateTestListener extends BaseTransactionalTestListener { + + @EventListener(condition = "!'SKIP'.equals(#data)") + public void handleImmediately(String data) { + handleEvent(EventCollector.IMMEDIATELY, data); + } + } + + + @Component + static class AfterCompletionTestListener extends BaseTransactionalTestListener { + + @TransactionalEventListener(phase = AFTER_COMPLETION) + public void handleAfterCompletion(String data) { + handleEvent(EventCollector.AFTER_COMPLETION, data); + } + } + + + @Component + static class AfterCompletionExplicitTestListener extends BaseTransactionalTestListener { + + @TransactionalEventListener(phase = AFTER_COMMIT) + public void handleAfterCommit(String data) { + handleEvent(EventCollector.AFTER_COMMIT, data); + } + + @TransactionalEventListener(phase = AFTER_ROLLBACK) + public void handleAfterRollback(String data) { + handleEvent(EventCollector.AFTER_ROLLBACK, data); + } + } + + + @Transactional + @Component + static interface TransactionalComponentTestListenerInterface { + + // Cannot use #data in condition due to dynamic proxy. + @TransactionalEventListener(condition = "!'SKIP'.equals(#p0)") + void handleAfterCommit(String data); + } + + + static class TransactionalComponentTestListener extends BaseTransactionalTestListener implements + TransactionalComponentTestListenerInterface { + + @Override + public void handleAfterCommit(String data) { + handleEvent(EventCollector.AFTER_COMMIT, data); + } + } + + + @Component + static class BeforeCommitTestListener extends BaseTransactionalTestListener { + + @TransactionalEventListener(phase = BEFORE_COMMIT) + @Order(15) + public void handleBeforeCommit(String data) { + handleEvent(EventCollector.BEFORE_COMMIT, data); + } + } + + + @Component + static class FallbackExecutionTestListener extends BaseTransactionalTestListener { + + @TransactionalEventListener(phase = BEFORE_COMMIT, fallbackExecution = true) + public void handleBeforeCommit(String data) { + handleEvent(EventCollector.BEFORE_COMMIT, data); + } + + @TransactionalEventListener(phase = AFTER_COMMIT, fallbackExecution = true) + public void handleAfterCommit(String data) { + handleEvent(EventCollector.AFTER_COMMIT, data); + } + + @TransactionalEventListener(phase = AFTER_ROLLBACK, fallbackExecution = true) + public void handleAfterRollback(String data) { + handleEvent(EventCollector.AFTER_ROLLBACK, data); + } + + @TransactionalEventListener(phase = AFTER_COMPLETION, fallbackExecution = true) + public void handleAfterCompletion(String data) { + handleEvent(EventCollector.AFTER_COMPLETION, data); + } + } + + + @TransactionalEventListener(phase = AFTER_COMMIT, condition = "!'SKIP'.equals(#p0)") + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @interface AfterCommitEventListener { + } + + + @Component + static class AfterCommitMetaAnnotationTestListener extends BaseTransactionalTestListener { + + @AfterCommitEventListener + public void handleAfterCommit(String data) { + handleEvent(EventCollector.AFTER_COMMIT, data); + } + } + + + static class EventTransactionSynchronization extends TransactionSynchronizationAdapter { + + private final int order; + + EventTransactionSynchronization(int order) { + this.order = order; + } + + @Override + public int getOrder() { + return order; + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/AbstractTransactionAspectTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/AbstractTransactionAspectTests.java new file mode 100644 index 0000000000000000000000000000000000000000..50ef7655ceebc05d39633a35dad8e9dd09e0ef51 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/AbstractTransactionAspectTests.java @@ -0,0 +1,567 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.dao.OptimisticLockingFailureException; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.transaction.CannotCreateTransactionException; +import org.springframework.transaction.MockCallbackPreferringTransactionManager; +import org.springframework.transaction.NoTransactionException; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.transaction.UnexpectedRollbackException; +import org.springframework.transaction.interceptor.TransactionAspectSupport.TransactionInfo; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Mock object based tests for transaction aspects. + * True unit test in that it tests how the transaction aspect uses + * the PlatformTransactionManager helper, rather than indirectly + * testing the helper implementation. + * + * This is a superclass to allow testing both the AOP Alliance MethodInterceptor + * and the AspectJ aspect. + * + * @author Rod Johnson + * @since 16.03.2003 + */ +public abstract class AbstractTransactionAspectTests { + + protected Method exceptionalMethod; + + protected Method getNameMethod; + + protected Method setNameMethod; + + + @Before + public void setup() throws Exception { + exceptionalMethod = ITestBean.class.getMethod("exceptional", Throwable.class); + getNameMethod = ITestBean.class.getMethod("getName"); + setNameMethod = ITestBean.class.getMethod("setName", String.class); + } + + + @Test + public void noTransaction() throws Exception { + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + + TestBean tb = new TestBean(); + TransactionAttributeSource tas = new MapTransactionAttributeSource(); + + // All the methods in this class use the advised() template method + // to obtain a transaction object, configured with the given PlatformTransactionManager + // and transaction attribute source + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + checkTransactionStatus(false); + itb.getName(); + checkTransactionStatus(false); + + // expect no calls + verifyZeroInteractions(ptm); + } + + /** + * Check that a transaction is created and committed. + */ + @Test + public void transactionShouldSucceed() throws Exception { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(getNameMethod, txatt); + + TransactionStatus status = mock(TransactionStatus.class); + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + // expect a transaction + given(ptm.getTransaction(txatt)).willReturn(status); + + TestBean tb = new TestBean(); + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + checkTransactionStatus(false); + itb.getName(); + checkTransactionStatus(false); + + verify(ptm).commit(status); + } + + /** + * Check that a transaction is created and committed using + * CallbackPreferringPlatformTransactionManager. + */ + @Test + public void transactionShouldSucceedWithCallbackPreference() throws Exception { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(getNameMethod, txatt); + + MockCallbackPreferringTransactionManager ptm = new MockCallbackPreferringTransactionManager(); + + TestBean tb = new TestBean(); + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + checkTransactionStatus(false); + itb.getName(); + checkTransactionStatus(false); + + assertSame(txatt, ptm.getDefinition()); + assertFalse(ptm.getStatus().isRollbackOnly()); + } + + @Test + public void transactionExceptionPropagatedWithCallbackPreference() throws Throwable { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(exceptionalMethod, txatt); + + MockCallbackPreferringTransactionManager ptm = new MockCallbackPreferringTransactionManager(); + + TestBean tb = new TestBean(); + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + checkTransactionStatus(false); + try { + itb.exceptional(new OptimisticLockingFailureException("")); + fail("Should have thrown OptimisticLockingFailureException"); + } + catch (OptimisticLockingFailureException ex) { + // expected + } + checkTransactionStatus(false); + + assertSame(txatt, ptm.getDefinition()); + assertFalse(ptm.getStatus().isRollbackOnly()); + } + + /** + * Check that two transactions are created and committed. + */ + @Test + public void twoTransactionsShouldSucceed() throws Exception { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + MapTransactionAttributeSource tas1 = new MapTransactionAttributeSource(); + tas1.register(getNameMethod, txatt); + MapTransactionAttributeSource tas2 = new MapTransactionAttributeSource(); + tas2.register(setNameMethod, txatt); + + TransactionStatus status = mock(TransactionStatus.class); + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + // expect a transaction + given(ptm.getTransaction(txatt)).willReturn(status); + + TestBean tb = new TestBean(); + ITestBean itb = (ITestBean) advised(tb, ptm, new TransactionAttributeSource[] {tas1, tas2}); + + checkTransactionStatus(false); + itb.getName(); + checkTransactionStatus(false); + itb.setName("myName"); + checkTransactionStatus(false); + + verify(ptm, times(2)).commit(status); + } + + /** + * Check that a transaction is created and committed. + */ + @Test + public void transactionShouldSucceedWithNotNew() throws Exception { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(getNameMethod, txatt); + + TransactionStatus status = mock(TransactionStatus.class); + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + // expect a transaction + given(ptm.getTransaction(txatt)).willReturn(status); + + TestBean tb = new TestBean(); + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + checkTransactionStatus(false); + // verification!? + itb.getName(); + checkTransactionStatus(false); + + verify(ptm).commit(status); + } + + @Test + public void enclosingTransactionWithNonTransactionMethodOnAdvisedInside() throws Throwable { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(exceptionalMethod, txatt); + + TransactionStatus status = mock(TransactionStatus.class); + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + // Expect a transaction + given(ptm.getTransaction(txatt)).willReturn(status); + + final String spouseName = "innerName"; + + TestBean outer = new TestBean() { + @Override + public void exceptional(Throwable t) throws Throwable { + TransactionInfo ti = TransactionAspectSupport.currentTransactionInfo(); + assertTrue(ti.hasTransaction()); + assertEquals(spouseName, getSpouse().getName()); + } + }; + TestBean inner = new TestBean() { + @Override + public String getName() { + // Assert that we're in the inner proxy + TransactionInfo ti = TransactionAspectSupport.currentTransactionInfo(); + assertFalse(ti.hasTransaction()); + return spouseName; + } + }; + + ITestBean outerProxy = (ITestBean) advised(outer, ptm, tas); + ITestBean innerProxy = (ITestBean) advised(inner, ptm, tas); + outer.setSpouse(innerProxy); + + checkTransactionStatus(false); + + // Will invoke inner.getName, which is non-transactional + outerProxy.exceptional(null); + + checkTransactionStatus(false); + + verify(ptm).commit(status); + } + + @Test + public void enclosingTransactionWithNestedTransactionOnAdvisedInside() throws Throwable { + final TransactionAttribute outerTxatt = new DefaultTransactionAttribute(); + final TransactionAttribute innerTxatt = new DefaultTransactionAttribute(TransactionDefinition.PROPAGATION_NESTED); + + Method outerMethod = exceptionalMethod; + Method innerMethod = getNameMethod; + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(outerMethod, outerTxatt); + tas.register(innerMethod, innerTxatt); + + TransactionStatus outerStatus = mock(TransactionStatus.class); + TransactionStatus innerStatus = mock(TransactionStatus.class); + + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + // Expect a transaction + given(ptm.getTransaction(outerTxatt)).willReturn(outerStatus); + given(ptm.getTransaction(innerTxatt)).willReturn(innerStatus); + + final String spouseName = "innerName"; + + TestBean outer = new TestBean() { + @Override + public void exceptional(Throwable t) throws Throwable { + TransactionInfo ti = TransactionAspectSupport.currentTransactionInfo(); + assertTrue(ti.hasTransaction()); + assertEquals(outerTxatt, ti.getTransactionAttribute()); + assertEquals(spouseName, getSpouse().getName()); + } + }; + TestBean inner = new TestBean() { + @Override + public String getName() { + // Assert that we're in the inner proxy + TransactionInfo ti = TransactionAspectSupport.currentTransactionInfo(); + // Has nested transaction + assertTrue(ti.hasTransaction()); + assertEquals(innerTxatt, ti.getTransactionAttribute()); + return spouseName; + } + }; + + ITestBean outerProxy = (ITestBean) advised(outer, ptm, tas); + ITestBean innerProxy = (ITestBean) advised(inner, ptm, tas); + outer.setSpouse(innerProxy); + + checkTransactionStatus(false); + + // Will invoke inner.getName, which is non-transactional + outerProxy.exceptional(null); + + checkTransactionStatus(false); + + verify(ptm).commit(innerStatus); + verify(ptm).commit(outerStatus); + } + + @Test + public void rollbackOnCheckedException() throws Throwable { + doTestRollbackOnException(new Exception(), true, false); + } + + @Test + public void noRollbackOnCheckedException() throws Throwable { + doTestRollbackOnException(new Exception(), false, false); + } + + @Test + public void rollbackOnUncheckedException() throws Throwable { + doTestRollbackOnException(new RuntimeException(), true, false); + } + + @Test + public void noRollbackOnUncheckedException() throws Throwable { + doTestRollbackOnException(new RuntimeException(), false, false); + } + + @Test + public void rollbackOnCheckedExceptionWithRollbackException() throws Throwable { + doTestRollbackOnException(new Exception(), true, true); + } + + @Test + public void noRollbackOnCheckedExceptionWithRollbackException() throws Throwable { + doTestRollbackOnException(new Exception(), false, true); + } + + @Test + public void rollbackOnUncheckedExceptionWithRollbackException() throws Throwable { + doTestRollbackOnException(new RuntimeException(), true, true); + } + + @Test + public void noRollbackOnUncheckedExceptionWithRollbackException() throws Throwable { + doTestRollbackOnException(new RuntimeException(), false, true); + } + + /** + * Check that the given exception thrown by the target can produce the + * desired behavior with the appropriate transaction attribute. + * @param ex exception to be thrown by the target + * @param shouldRollback whether this should cause a transaction rollback + */ + @SuppressWarnings("serial") + protected void doTestRollbackOnException( + final Exception ex, final boolean shouldRollback, boolean rollbackException) throws Exception { + + TransactionAttribute txatt = new DefaultTransactionAttribute() { + @Override + public boolean rollbackOn(Throwable t) { + assertTrue(t == ex); + return shouldRollback; + } + }; + + Method m = exceptionalMethod; + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(m, txatt); + + TransactionStatus status = mock(TransactionStatus.class); + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + // Gets additional call(s) from TransactionControl + + given(ptm.getTransaction(txatt)).willReturn(status); + + TransactionSystemException tex = new TransactionSystemException("system exception"); + if (rollbackException) { + if (shouldRollback) { + willThrow(tex).given(ptm).rollback(status); + } + else { + willThrow(tex).given(ptm).commit(status); + } + } + + TestBean tb = new TestBean(); + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + try { + itb.exceptional(ex); + fail("Should have thrown exception"); + } + catch (Throwable t) { + if (rollbackException) { + assertEquals("Caught wrong exception", tex, t); + } + else { + assertEquals("Caught wrong exception", ex, t); + } + } + + if (!rollbackException) { + if (shouldRollback) { + verify(ptm).rollback(status); + } + else { + verify(ptm).commit(status); + } + } + } + + /** + * Test that TransactionStatus.setRollbackOnly works. + */ + @Test + public void programmaticRollback() throws Exception { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + Method m = getNameMethod; + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(m, txatt); + + TransactionStatus status = mock(TransactionStatus.class); + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + + given(ptm.getTransaction(txatt)).willReturn(status); + + final String name = "jenny"; + TestBean tb = new TestBean() { + @Override + public String getName() { + TransactionStatus txStatus = TransactionInterceptor.currentTransactionStatus(); + txStatus.setRollbackOnly(); + return name; + } + }; + + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + // verification!? + assertTrue(name.equals(itb.getName())); + + verify(ptm).commit(status); + } + + /** + * Simulate a transaction infrastructure failure. + * Shouldn't invoke target method. + */ + @Test + public void cannotCreateTransaction() throws Exception { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + Method m = getNameMethod; + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(m, txatt); + + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + // Expect a transaction + CannotCreateTransactionException ex = new CannotCreateTransactionException("foobar", null); + given(ptm.getTransaction(txatt)).willThrow(ex); + + TestBean tb = new TestBean() { + @Override + public String getName() { + throw new UnsupportedOperationException( + "Shouldn't have invoked target method when couldn't create transaction for transactional method"); + } + }; + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + try { + itb.getName(); + fail("Shouldn't have invoked method"); + } + catch (CannotCreateTransactionException thrown) { + assertTrue(thrown == ex); + } + } + + /** + * Simulate failure of the underlying transaction infrastructure to commit. + * Check that the target method was invoked, but that the transaction + * infrastructure exception was thrown to the client + */ + @Test + public void cannotCommitTransaction() throws Exception { + TransactionAttribute txatt = new DefaultTransactionAttribute(); + + Method m = setNameMethod; + MapTransactionAttributeSource tas = new MapTransactionAttributeSource(); + tas.register(m, txatt); + // Method m2 = getNameMethod; + // No attributes for m2 + + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + + TransactionStatus status = mock(TransactionStatus.class); + given(ptm.getTransaction(txatt)).willReturn(status); + UnexpectedRollbackException ex = new UnexpectedRollbackException("foobar", null); + willThrow(ex).given(ptm).commit(status); + + TestBean tb = new TestBean(); + ITestBean itb = (ITestBean) advised(tb, ptm, tas); + + String name = "new name"; + try { + itb.setName(name); + fail("Shouldn't have succeeded"); + } + catch (UnexpectedRollbackException thrown) { + assertTrue(thrown == ex); + } + + // Should have invoked target and changed name + assertTrue(itb.getName() == name); + } + + protected void checkTransactionStatus(boolean expected) { + try { + TransactionInterceptor.currentTransactionStatus(); + if (!expected) { + fail("Should have thrown NoTransactionException"); + } + } + catch (NoTransactionException ex) { + if (expected) { + fail("Should have current TransactionStatus"); + } + } + } + + + protected Object advised( + Object target, PlatformTransactionManager ptm, TransactionAttributeSource[] tas) throws Exception { + + return advised(target, ptm, new CompositeTransactionAttributeSource(tas)); + } + + /** + * Subclasses must implement this to create an advised object based on the + * given target. In the case of AspectJ, the advised object will already + * have been created, as there's no distinction between target and proxy. + * In the case of Spring's own AOP framework, a proxy must be created + * using a suitably configured transaction interceptor + * @param target target if there's a distinct target. If not (AspectJ), + * return target. + * @return transactional advised object + */ + protected abstract Object advised( + Object target, PlatformTransactionManager ptm, TransactionAttributeSource tas) throws Exception; + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/BeanFactoryTransactionTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/BeanFactoryTransactionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..805bb3c0155af9c29dfc9a62bb72d3c9b571e4d2 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/BeanFactoryTransactionTests.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.Map; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.aop.support.AopUtils; +import org.springframework.aop.support.StaticMethodMatcherPointcut; +import org.springframework.aop.target.HotSwappableTargetSource; +import org.springframework.beans.FatalBeanException; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.core.io.ClassPathResource; +import org.springframework.lang.Nullable; +import org.springframework.tests.sample.beans.DerivedTestBean; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.tests.transaction.CallCountingTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.TransactionStatus; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Test cases for AOP transaction management. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 23.04.2003 + */ +public class BeanFactoryTransactionTests { + + private DefaultListableBeanFactory factory; + + + @Before + public void setUp() { + this.factory = new DefaultListableBeanFactory(); + new XmlBeanDefinitionReader(this.factory).loadBeanDefinitions( + new ClassPathResource("transactionalBeanFactory.xml", getClass())); + } + + + @Test + public void testGetsAreNotTransactionalWithProxyFactory1() { + ITestBean testBean = (ITestBean) factory.getBean("proxyFactory1"); + assertTrue("testBean is a dynamic proxy", Proxy.isProxyClass(testBean.getClass())); + assertFalse(testBean instanceof TransactionalProxy); + doTestGetsAreNotTransactional(testBean); + } + + @Test + public void testGetsAreNotTransactionalWithProxyFactory2DynamicProxy() { + this.factory.preInstantiateSingletons(); + ITestBean testBean = (ITestBean) factory.getBean("proxyFactory2DynamicProxy"); + assertTrue("testBean is a dynamic proxy", Proxy.isProxyClass(testBean.getClass())); + assertTrue(testBean instanceof TransactionalProxy); + doTestGetsAreNotTransactional(testBean); + } + + @Test + public void testGetsAreNotTransactionalWithProxyFactory2Cglib() { + ITestBean testBean = (ITestBean) factory.getBean("proxyFactory2Cglib"); + assertTrue("testBean is CGLIB advised", AopUtils.isCglibProxy(testBean)); + assertTrue(testBean instanceof TransactionalProxy); + doTestGetsAreNotTransactional(testBean); + } + + @Test + public void testProxyFactory2Lazy() { + ITestBean testBean = (ITestBean) factory.getBean("proxyFactory2Lazy"); + assertFalse(factory.containsSingleton("target")); + assertEquals(666, testBean.getAge()); + assertTrue(factory.containsSingleton("target")); + } + + @Test + public void testCglibTransactionProxyImplementsNoInterfaces() { + ImplementsNoInterfaces ini = (ImplementsNoInterfaces) factory.getBean("cglibNoInterfaces"); + assertTrue("testBean is CGLIB advised", AopUtils.isCglibProxy(ini)); + assertTrue(ini instanceof TransactionalProxy); + String newName = "Gordon"; + + // Install facade + CallCountingTransactionManager ptm = new CallCountingTransactionManager(); + PlatformTransactionManagerFacade.delegate = ptm; + + ini.setName(newName); + assertEquals(newName, ini.getName()); + assertEquals(2, ptm.commits); + } + + @Test + public void testGetsAreNotTransactionalWithProxyFactory3() { + ITestBean testBean = (ITestBean) factory.getBean("proxyFactory3"); + assertTrue("testBean is a full proxy", testBean instanceof DerivedTestBean); + assertTrue(testBean instanceof TransactionalProxy); + InvocationCounterPointcut txnCounter = (InvocationCounterPointcut) factory.getBean("txnInvocationCounterPointcut"); + InvocationCounterInterceptor preCounter = (InvocationCounterInterceptor) factory.getBean("preInvocationCounterInterceptor"); + InvocationCounterInterceptor postCounter = (InvocationCounterInterceptor) factory.getBean("postInvocationCounterInterceptor"); + txnCounter.counter = 0; + preCounter.counter = 0; + postCounter.counter = 0; + doTestGetsAreNotTransactional(testBean); + // Can't assert it's equal to 4 as the pointcut may be optimized and only invoked once + assertTrue(0 < txnCounter.counter && txnCounter.counter <= 4); + assertEquals(4, preCounter.counter); + assertEquals(4, postCounter.counter); + } + + private void doTestGetsAreNotTransactional(final ITestBean testBean) { + // Install facade + PlatformTransactionManager ptm = mock(PlatformTransactionManager.class); + PlatformTransactionManagerFacade.delegate = ptm; + + assertTrue("Age should not be " + testBean.getAge(), testBean.getAge() == 666); + + // Expect no methods + verifyZeroInteractions(ptm); + + // Install facade expecting a call + final TransactionStatus ts = mock(TransactionStatus.class); + ptm = new PlatformTransactionManager() { + private boolean invoked; + @Override + public TransactionStatus getTransaction(@Nullable TransactionDefinition def) throws TransactionException { + if (invoked) { + throw new IllegalStateException("getTransaction should not get invoked more than once"); + } + invoked = true; + if (!(def.getName().contains(DerivedTestBean.class.getName()) && def.getName().contains("setAge"))) { + throw new IllegalStateException( + "transaction name should contain class and method name: " + def.getName()); + } + return ts; + } + @Override + public void commit(TransactionStatus status) throws TransactionException { + assertTrue(status == ts); + } + @Override + public void rollback(TransactionStatus status) throws TransactionException { + throw new IllegalStateException("rollback should not get invoked"); + } + }; + PlatformTransactionManagerFacade.delegate = ptm; + + // TODO same as old age to avoid ordering effect for now + int age = 666; + testBean.setAge(age); + assertTrue(testBean.getAge() == age); + } + + @Test + public void testGetBeansOfTypeWithAbstract() { + Map beansOfType = factory.getBeansOfType(ITestBean.class, true, true); + assertNotNull(beansOfType); + } + + /** + * Check that we fail gracefully if the user doesn't set any transaction attributes. + */ + @Test + public void testNoTransactionAttributeSource() { + try { + DefaultListableBeanFactory bf = new DefaultListableBeanFactory(); + new XmlBeanDefinitionReader(bf).loadBeanDefinitions(new ClassPathResource("noTransactionAttributeSource.xml", getClass())); + bf.getBean("noTransactionAttributeSource"); + fail("Should require TransactionAttributeSource to be set"); + } + catch (FatalBeanException ex) { + // Ok + } + } + + /** + * Test that we can set the target to a dynamic TargetSource. + */ + @Test + public void testDynamicTargetSource() { + // Install facade + CallCountingTransactionManager txMan = new CallCountingTransactionManager(); + PlatformTransactionManagerFacade.delegate = txMan; + + TestBean tb = (TestBean) factory.getBean("hotSwapped"); + assertEquals(666, tb.getAge()); + int newAge = 557; + tb.setAge(newAge); + assertEquals(newAge, tb.getAge()); + + TestBean target2 = new TestBean(); + target2.setAge(65); + HotSwappableTargetSource ts = (HotSwappableTargetSource) factory.getBean("swapper"); + ts.swap(target2); + assertEquals(target2.getAge(), tb.getAge()); + tb.setAge(newAge); + assertEquals(newAge, target2.getAge()); + + assertEquals(0, txMan.inflight); + assertEquals(2, txMan.commits); + assertEquals(0, txMan.rollbacks); + } + + + public static class InvocationCounterPointcut extends StaticMethodMatcherPointcut { + + int counter = 0; + + @Override + public boolean matches(Method method, @Nullable Class clazz) { + counter++; + return true; + } + } + + + public static class InvocationCounterInterceptor implements MethodInterceptor { + + int counter = 0; + + @Override + public Object invoke(MethodInvocation methodInvocation) throws Throwable { + counter++; + return methodInvocation.proceed(); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/ImplementsNoInterfaces.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/ImplementsNoInterfaces.java new file mode 100644 index 0000000000000000000000000000000000000000..fa6ddb98a1561d62fed938ad793f17aee5e69ee8 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/ImplementsNoInterfaces.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.springframework.tests.sample.beans.TestBean; + +/** + * Test for CGLIB proxying that implements no interfaces + * and has one dependency. + * + * @author Rod Johnson + */ +public class ImplementsNoInterfaces { + + private TestBean testBean; + + public void setDependency(TestBean testBean) { + this.testBean = testBean; + } + + public String getName() { + return testBean.getName(); + } + + public void setName(String name) { + testBean.setName(name); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/MapTransactionAttributeSource.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/MapTransactionAttributeSource.java new file mode 100644 index 0000000000000000000000000000000000000000..b4c7d382932b44e3af443662128456323077be2c --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/MapTransactionAttributeSource.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; + +/** + * Inherits fallback behavior from AbstractFallbackTransactionAttributeSource. + * + * @author Rod Johnson + * @author Juergen Hoeller + */ +public class MapTransactionAttributeSource extends AbstractFallbackTransactionAttributeSource { + + private final Map attributeMap = new HashMap<>(); + + + public void register(Class clazz, TransactionAttribute txAttr) { + this.attributeMap.put(clazz, txAttr); + } + + public void register(Method method, TransactionAttribute txAttr) { + this.attributeMap.put(method, txAttr); + } + + + @Override + protected TransactionAttribute findTransactionAttribute(Class clazz) { + return this.attributeMap.get(clazz); + } + + @Override + protected TransactionAttribute findTransactionAttribute(Method method) { + return this.attributeMap.get(method); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/MyRuntimeException.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/MyRuntimeException.java new file mode 100644 index 0000000000000000000000000000000000000000..80affe6bc24f9e47f6cc05a38e893d02bc6ffc4a --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/MyRuntimeException.java @@ -0,0 +1,31 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.springframework.core.NestedRuntimeException; + +/** + * An example {@link RuntimeException} for use in testing rollback rules. + * + * @author Chris Beams + */ +@SuppressWarnings("serial") +class MyRuntimeException extends NestedRuntimeException { + public MyRuntimeException(String msg) { + super(msg); + } +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/PlatformTransactionManagerFacade.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/PlatformTransactionManagerFacade.java new file mode 100644 index 0000000000000000000000000000000000000000..c8beae777fdddf1c331d42d09fde630413d8e050 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/PlatformTransactionManagerFacade.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionStatus; + +/** + * Used for testing only (for example, when we must replace the + * behavior of a PlatformTransactionManager bean we don't have access to). + * + *

Allows behavior of an entire class to change with static delegate change. + * Not multi-threaded. + * + * @author Rod Johnson + * @since 26.04.2003 + */ +public class PlatformTransactionManagerFacade implements PlatformTransactionManager { + + /** + * This member can be changed to change behavior class-wide. + */ + public static PlatformTransactionManager delegate; + + @Override + public TransactionStatus getTransaction(@Nullable TransactionDefinition definition) { + return delegate.getTransaction(definition); + } + + @Override + public void commit(TransactionStatus status) { + delegate.commit(status); + } + + @Override + public void rollback(TransactionStatus status) { + delegate.rollback(status); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/RollbackRuleTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/RollbackRuleTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6d529d6fca14e5f7ac5e3d73ac310bd74ba7f340 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/RollbackRuleTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.IOException; + +import org.junit.Test; + +import org.springframework.beans.FatalBeanException; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for the {@link RollbackRuleAttribute} class. + * + * @author Rod Johnson + * @author Rick Evans + * @author Chris Beams + * @author Sam Brannen + * @since 09.04.2003 + */ +public class RollbackRuleTests { + + @Test + public void foundImmediatelyWithString() { + RollbackRuleAttribute rr = new RollbackRuleAttribute(java.lang.Exception.class.getName()); + assertEquals(0, rr.getDepth(new Exception())); + } + + @Test + public void foundImmediatelyWithClass() { + RollbackRuleAttribute rr = new RollbackRuleAttribute(Exception.class); + assertEquals(0, rr.getDepth(new Exception())); + } + + @Test + public void notFound() { + RollbackRuleAttribute rr = new RollbackRuleAttribute(java.io.IOException.class.getName()); + assertEquals(-1, rr.getDepth(new MyRuntimeException(""))); + } + + @Test + public void ancestry() { + RollbackRuleAttribute rr = new RollbackRuleAttribute(java.lang.Exception.class.getName()); + // Exception -> Runtime -> NestedRuntime -> MyRuntimeException + assertThat(rr.getDepth(new MyRuntimeException("")), equalTo(3)); + } + + @Test + public void alwaysTrueForThrowable() { + RollbackRuleAttribute rr = new RollbackRuleAttribute(java.lang.Throwable.class.getName()); + assertTrue(rr.getDepth(new MyRuntimeException("")) > 0); + assertTrue(rr.getDepth(new IOException()) > 0); + assertTrue(rr.getDepth(new FatalBeanException(null,null)) > 0); + assertTrue(rr.getDepth(new RuntimeException()) > 0); + } + + @Test(expected = IllegalArgumentException.class) + public void ctorArgMustBeAThrowableClassWithNonThrowableType() { + new RollbackRuleAttribute(StringBuffer.class); + } + + @Test(expected = IllegalArgumentException.class) + public void ctorArgMustBeAThrowableClassWithNullThrowableType() { + new RollbackRuleAttribute((Class) null); + } + + @Test(expected = IllegalArgumentException.class) + public void ctorArgExceptionStringNameVersionWithNull() { + new RollbackRuleAttribute((String) null); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/RuleBasedTransactionAttributeTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/RuleBasedTransactionAttributeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4a6f77ac304dd3016097209284cbbbab07dd680c --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/RuleBasedTransactionAttributeTests.java @@ -0,0 +1,174 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.IOException; +import java.rmi.RemoteException; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +import org.junit.Test; + +import org.springframework.transaction.TransactionDefinition; + +import static org.junit.Assert.*; + +/** + * @author Rod Johnson + * @author Juergen Hoeller + * @author Rick Evans + * @author Chris Beams + * @since 09.04.2003 + */ +public class RuleBasedTransactionAttributeTests { + + @Test + public void testDefaultRule() { + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(); + assertTrue(rta.rollbackOn(new RuntimeException())); + assertTrue(rta.rollbackOn(new MyRuntimeException(""))); + assertFalse(rta.rollbackOn(new Exception())); + assertFalse(rta.rollbackOn(new IOException())); + } + + /** + * Test one checked exception that should roll back. + */ + @Test + public void testRuleForRollbackOnChecked() { + List list = new LinkedList<>(); + list.add(new RollbackRuleAttribute(IOException.class.getName())); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, list); + + assertTrue(rta.rollbackOn(new RuntimeException())); + assertTrue(rta.rollbackOn(new MyRuntimeException(""))); + assertFalse(rta.rollbackOn(new Exception())); + // Check that default behaviour is overridden + assertTrue(rta.rollbackOn(new IOException())); + } + + @Test + public void testRuleForCommitOnUnchecked() { + List list = new LinkedList<>(); + list.add(new NoRollbackRuleAttribute(MyRuntimeException.class.getName())); + list.add(new RollbackRuleAttribute(IOException.class.getName())); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, list); + + assertTrue(rta.rollbackOn(new RuntimeException())); + // Check default behaviour is overridden + assertFalse(rta.rollbackOn(new MyRuntimeException(""))); + assertFalse(rta.rollbackOn(new Exception())); + // Check that default behaviour is overridden + assertTrue(rta.rollbackOn(new IOException())); + } + + @Test + public void testRuleForSelectiveRollbackOnCheckedWithString() { + List l = new LinkedList<>(); + l.add(new RollbackRuleAttribute(java.rmi.RemoteException.class.getName())); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, l); + doTestRuleForSelectiveRollbackOnChecked(rta); + } + + @Test + public void testRuleForSelectiveRollbackOnCheckedWithClass() { + List l = Collections.singletonList(new RollbackRuleAttribute(RemoteException.class)); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, l); + doTestRuleForSelectiveRollbackOnChecked(rta); + } + + private void doTestRuleForSelectiveRollbackOnChecked(RuleBasedTransactionAttribute rta) { + assertTrue(rta.rollbackOn(new RuntimeException())); + // Check default behaviour is overridden + assertFalse(rta.rollbackOn(new Exception())); + // Check that default behaviour is overridden + assertTrue(rta.rollbackOn(new RemoteException())); + } + + /** + * Check that a rule can cause commit on a IOException + * when Exception prompts a rollback. + */ + @Test + public void testRuleForCommitOnSubclassOfChecked() { + List list = new LinkedList<>(); + // Note that it's important to ensure that we have this as + // a FQN: otherwise it will match everything! + list.add(new RollbackRuleAttribute("java.lang.Exception")); + list.add(new NoRollbackRuleAttribute("IOException")); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, list); + + assertTrue(rta.rollbackOn(new RuntimeException())); + assertTrue(rta.rollbackOn(new Exception())); + // Check that default behaviour is overridden + assertFalse(rta.rollbackOn(new IOException())); + } + + @Test + public void testRollbackNever() { + List list = new LinkedList<>(); + list.add(new NoRollbackRuleAttribute("Throwable")); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, list); + + assertFalse(rta.rollbackOn(new Throwable())); + assertFalse(rta.rollbackOn(new RuntimeException())); + assertFalse(rta.rollbackOn(new MyRuntimeException(""))); + assertFalse(rta.rollbackOn(new Exception())); + assertFalse(rta.rollbackOn(new IOException())); + } + + @Test + public void testToStringMatchesEditor() { + List list = new LinkedList<>(); + list.add(new NoRollbackRuleAttribute("Throwable")); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, list); + + TransactionAttributeEditor tae = new TransactionAttributeEditor(); + tae.setAsText(rta.toString()); + rta = (RuleBasedTransactionAttribute) tae.getValue(); + + assertFalse(rta.rollbackOn(new Throwable())); + assertFalse(rta.rollbackOn(new RuntimeException())); + assertFalse(rta.rollbackOn(new MyRuntimeException(""))); + assertFalse(rta.rollbackOn(new Exception())); + assertFalse(rta.rollbackOn(new IOException())); + } + + /** + * See this forum post. + */ + @Test + public void testConflictingRulesToDetermineExactContract() { + List list = new LinkedList<>(); + list.add(new NoRollbackRuleAttribute(MyBusinessWarningException.class)); + list.add(new RollbackRuleAttribute(MyBusinessException.class)); + RuleBasedTransactionAttribute rta = new RuleBasedTransactionAttribute(TransactionDefinition.PROPAGATION_REQUIRED, list); + + assertTrue(rta.rollbackOn(new MyBusinessException())); + assertFalse(rta.rollbackOn(new MyBusinessWarningException())); + } + + + @SuppressWarnings("serial") + private static class MyBusinessException extends Exception {} + + + @SuppressWarnings("serial") + private static final class MyBusinessWarningException extends MyBusinessException {} + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeEditorTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeEditorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3fa70c3847a991aab995bb86fe987c0745367ea0 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeEditorTests.java @@ -0,0 +1,176 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + + +import java.io.IOException; + +import org.junit.Test; + +import org.springframework.transaction.TransactionDefinition; + +import static org.junit.Assert.*; + +/** + * Tests to check conversion from String to TransactionAttribute. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Chris Beams + * @since 26.04.2003 + */ +public class TransactionAttributeEditorTests { + + @Test + public void testNull() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText(null); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertTrue(ta == null); + } + + @Test + public void testEmptyString() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText(""); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertTrue(ta == null); + } + + @Test + public void testValidPropagationCodeOnly() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText("PROPAGATION_REQUIRED"); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertTrue(ta != null); + assertTrue(ta.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRED); + assertTrue(ta.getIsolationLevel() == TransactionDefinition.ISOLATION_DEFAULT); + assertTrue(!ta.isReadOnly()); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidPropagationCodeOnly() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + // should have failed with bogus propagation code + pe.setAsText("XXPROPAGATION_REQUIRED"); + } + + @Test + public void testValidPropagationCodeAndIsolationCode() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText("PROPAGATION_REQUIRED, ISOLATION_READ_UNCOMMITTED"); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertTrue(ta != null); + assertTrue(ta.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRED); + assertTrue(ta.getIsolationLevel() == TransactionDefinition.ISOLATION_READ_UNCOMMITTED); + } + + @Test(expected = IllegalArgumentException.class) + public void testValidPropagationAndIsolationCodesAndInvalidRollbackRule() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + // should fail with bogus rollback rule + pe.setAsText("PROPAGATION_REQUIRED,ISOLATION_READ_UNCOMMITTED,XXX"); + } + + @Test + public void testValidPropagationCodeAndIsolationCodeAndRollbackRules1() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText("PROPAGATION_MANDATORY,ISOLATION_REPEATABLE_READ,timeout_10,-IOException,+MyRuntimeException"); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertNotNull(ta); + assertEquals(TransactionDefinition.PROPAGATION_MANDATORY, ta.getPropagationBehavior()); + assertEquals(TransactionDefinition.ISOLATION_REPEATABLE_READ, ta.getIsolationLevel()); + assertEquals(10, ta.getTimeout()); + assertFalse(ta.isReadOnly()); + assertTrue(ta.rollbackOn(new RuntimeException())); + assertFalse(ta.rollbackOn(new Exception())); + // Check for our bizarre customized rollback rules + assertTrue(ta.rollbackOn(new IOException())); + assertTrue(!ta.rollbackOn(new MyRuntimeException(""))); + } + + @Test + public void testValidPropagationCodeAndIsolationCodeAndRollbackRules2() { + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText("+IOException,readOnly,ISOLATION_READ_COMMITTED,-MyRuntimeException,PROPAGATION_SUPPORTS"); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertNotNull(ta); + assertEquals(TransactionDefinition.PROPAGATION_SUPPORTS, ta.getPropagationBehavior()); + assertEquals(TransactionDefinition.ISOLATION_READ_COMMITTED, ta.getIsolationLevel()); + assertEquals(TransactionDefinition.TIMEOUT_DEFAULT, ta.getTimeout()); + assertTrue(ta.isReadOnly()); + assertTrue(ta.rollbackOn(new RuntimeException())); + assertFalse(ta.rollbackOn(new Exception())); + // Check for our bizarre customized rollback rules + assertFalse(ta.rollbackOn(new IOException())); + assertTrue(ta.rollbackOn(new MyRuntimeException(""))); + } + + @Test + public void testDefaultTransactionAttributeToString() { + DefaultTransactionAttribute source = new DefaultTransactionAttribute(); + source.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + source.setIsolationLevel(TransactionDefinition.ISOLATION_REPEATABLE_READ); + source.setTimeout(10); + source.setReadOnly(true); + + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText(source.toString()); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertEquals(ta, source); + assertEquals(TransactionDefinition.PROPAGATION_SUPPORTS, ta.getPropagationBehavior()); + assertEquals(TransactionDefinition.ISOLATION_REPEATABLE_READ, ta.getIsolationLevel()); + assertEquals(10, ta.getTimeout()); + assertTrue(ta.isReadOnly()); + assertTrue(ta.rollbackOn(new RuntimeException())); + assertFalse(ta.rollbackOn(new Exception())); + + source.setTimeout(9); + assertNotSame(ta, source); + source.setTimeout(10); + assertEquals(ta, source); + } + + @Test + public void testRuleBasedTransactionAttributeToString() { + RuleBasedTransactionAttribute source = new RuleBasedTransactionAttribute(); + source.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + source.setIsolationLevel(TransactionDefinition.ISOLATION_REPEATABLE_READ); + source.setTimeout(10); + source.setReadOnly(true); + source.getRollbackRules().add(new RollbackRuleAttribute("IllegalArgumentException")); + source.getRollbackRules().add(new NoRollbackRuleAttribute("IllegalStateException")); + + TransactionAttributeEditor pe = new TransactionAttributeEditor(); + pe.setAsText(source.toString()); + TransactionAttribute ta = (TransactionAttribute) pe.getValue(); + assertEquals(ta, source); + assertEquals(TransactionDefinition.PROPAGATION_SUPPORTS, ta.getPropagationBehavior()); + assertEquals(TransactionDefinition.ISOLATION_REPEATABLE_READ, ta.getIsolationLevel()); + assertEquals(10, ta.getTimeout()); + assertTrue(ta.isReadOnly()); + assertTrue(ta.rollbackOn(new IllegalArgumentException())); + assertFalse(ta.rollbackOn(new IllegalStateException())); + + source.getRollbackRules().clear(); + assertNotSame(ta, source); + source.getRollbackRules().add(new RollbackRuleAttribute("IllegalArgumentException")); + source.getRollbackRules().add(new NoRollbackRuleAttribute("IllegalStateException")); + assertEquals(ta, source); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceAdvisorTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceAdvisorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0314dea89c6241d42f836cd923e7296d1d35a128 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceAdvisorTests.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.util.Properties; + +import org.junit.Test; + +import org.springframework.util.SerializationTestUtils; + +/** + * @author Rod Johnson + */ +public class TransactionAttributeSourceAdvisorTests { + + @Test + public void serializability() throws Exception { + TransactionInterceptor ti = new TransactionInterceptor(); + ti.setTransactionAttributes(new Properties()); + TransactionAttributeSourceAdvisor tas = new TransactionAttributeSourceAdvisor(ti); + SerializationTestUtils.serializeAndDeserialize(tas); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceEditorTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceEditorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a58b956d3e10528171478084b2b7ade53872d3fe --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceEditorTests.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.lang.reflect.Method; + +import org.junit.Test; + +import org.springframework.transaction.TransactionDefinition; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link TransactionAttributeSourceEditor}. + * + *

Format is: {@code FQN.Method=tx attribute representation} + * + * @author Rod Johnson + * @author Sam Brannen + * @since 26.04.2003 + */ +public class TransactionAttributeSourceEditorTests { + + private final TransactionAttributeSourceEditor editor = new TransactionAttributeSourceEditor(); + + + @Test + public void nullValue() throws Exception { + editor.setAsText(null); + TransactionAttributeSource tas = (TransactionAttributeSource) editor.getValue(); + + Method m = Object.class.getMethod("hashCode"); + assertNull(tas.getTransactionAttribute(m, null)); + } + + @Test(expected = IllegalArgumentException.class) + public void invalidFormat() throws Exception { + editor.setAsText("foo=bar"); + } + + @Test + public void matchesSpecific() throws Exception { + editor.setAsText( + "java.lang.Object.hashCode=PROPAGATION_REQUIRED\n" + + "java.lang.Object.equals=PROPAGATION_MANDATORY\n" + + "java.lang.Object.*it=PROPAGATION_SUPPORTS\n" + + "java.lang.Object.notify=PROPAGATION_SUPPORTS\n" + + "java.lang.Object.not*=PROPAGATION_REQUIRED"); + TransactionAttributeSource tas = (TransactionAttributeSource) editor.getValue(); + + checkTransactionProperties(tas, Object.class.getMethod("hashCode"), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("equals", Object.class), + TransactionDefinition.PROPAGATION_MANDATORY); + checkTransactionProperties(tas, Object.class.getMethod("wait"), + TransactionDefinition.PROPAGATION_SUPPORTS); + checkTransactionProperties(tas, Object.class.getMethod("wait", long.class), + TransactionDefinition.PROPAGATION_SUPPORTS); + checkTransactionProperties(tas, Object.class.getMethod("wait", long.class, int.class), + TransactionDefinition.PROPAGATION_SUPPORTS); + checkTransactionProperties(tas, Object.class.getMethod("notify"), + TransactionDefinition.PROPAGATION_SUPPORTS); + checkTransactionProperties(tas, Object.class.getMethod("notifyAll"), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("toString"), -1); + } + + @Test + public void matchesAll() throws Exception { + editor.setAsText("java.lang.Object.*=PROPAGATION_REQUIRED"); + TransactionAttributeSource tas = (TransactionAttributeSource) editor.getValue(); + + checkTransactionProperties(tas, Object.class.getMethod("hashCode"), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("equals", Object.class), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("wait"), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("wait", long.class), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("wait", long.class, int.class), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("notify"), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("notifyAll"), + TransactionDefinition.PROPAGATION_REQUIRED); + checkTransactionProperties(tas, Object.class.getMethod("toString"), + TransactionDefinition.PROPAGATION_REQUIRED); + } + + private void checkTransactionProperties(TransactionAttributeSource tas, Method method, int propagationBehavior) { + TransactionAttribute ta = tas.getTransactionAttribute(method, null); + if (propagationBehavior >= 0) { + assertNotNull(ta); + assertEquals(TransactionDefinition.ISOLATION_DEFAULT, ta.getIsolationLevel()); + assertEquals(propagationBehavior, ta.getPropagationBehavior()); + } + else { + assertNull(ta); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceTests.java new file mode 100644 index 0000000000000000000000000000000000000000..23e2ce7d95fe3ac67308dcf1821e0d4e794b505e --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionAttributeSourceTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.IOException; +import java.util.Properties; + +import org.junit.Test; + +import org.springframework.transaction.TransactionDefinition; + +import static org.junit.Assert.*; + +/** + * Unit tests for the various {@link TransactionAttributeSource} implementations. + * + * @author Colin Sampaleanu + * @author Juergen Hoeller + * @author Rick Evans + * @author Chris Beams + * @since 15.10.2003 + * @see org.springframework.transaction.interceptor.TransactionProxyFactoryBean + */ +public class TransactionAttributeSourceTests { + + @Test + public void matchAlwaysTransactionAttributeSource() throws Exception { + MatchAlwaysTransactionAttributeSource tas = new MatchAlwaysTransactionAttributeSource(); + TransactionAttribute ta = tas.getTransactionAttribute(Object.class.getMethod("hashCode"), null); + assertNotNull(ta); + assertTrue(TransactionDefinition.PROPAGATION_REQUIRED == ta.getPropagationBehavior()); + + tas.setTransactionAttribute(new DefaultTransactionAttribute(TransactionDefinition.PROPAGATION_SUPPORTS)); + ta = tas.getTransactionAttribute(IOException.class.getMethod("getMessage"), IOException.class); + assertNotNull(ta); + assertTrue(TransactionDefinition.PROPAGATION_SUPPORTS == ta.getPropagationBehavior()); + } + + @Test + public void nameMatchTransactionAttributeSourceWithStarAtStartOfMethodName() throws Exception { + NameMatchTransactionAttributeSource tas = new NameMatchTransactionAttributeSource(); + Properties attributes = new Properties(); + attributes.put("*ashCode", "PROPAGATION_REQUIRED"); + tas.setProperties(attributes); + TransactionAttribute ta = tas.getTransactionAttribute(Object.class.getMethod("hashCode"), null); + assertNotNull(ta); + assertEquals(TransactionDefinition.PROPAGATION_REQUIRED, ta.getPropagationBehavior()); + } + + @Test + public void nameMatchTransactionAttributeSourceWithStarAtEndOfMethodName() throws Exception { + NameMatchTransactionAttributeSource tas = new NameMatchTransactionAttributeSource(); + Properties attributes = new Properties(); + attributes.put("hashCod*", "PROPAGATION_REQUIRED"); + tas.setProperties(attributes); + TransactionAttribute ta = tas.getTransactionAttribute(Object.class.getMethod("hashCode"), null); + assertNotNull(ta); + assertEquals(TransactionDefinition.PROPAGATION_REQUIRED, ta.getPropagationBehavior()); + } + + @Test + public void nameMatchTransactionAttributeSourceMostSpecificMethodNameIsDefinitelyMatched() throws Exception { + NameMatchTransactionAttributeSource tas = new NameMatchTransactionAttributeSource(); + Properties attributes = new Properties(); + attributes.put("*", "PROPAGATION_REQUIRED"); + attributes.put("hashCode", "PROPAGATION_MANDATORY"); + tas.setProperties(attributes); + TransactionAttribute ta = tas.getTransactionAttribute(Object.class.getMethod("hashCode"), null); + assertNotNull(ta); + assertEquals(TransactionDefinition.PROPAGATION_MANDATORY, ta.getPropagationBehavior()); + } + + @Test + public void nameMatchTransactionAttributeSourceWithEmptyMethodName() throws Exception { + NameMatchTransactionAttributeSource tas = new NameMatchTransactionAttributeSource(); + Properties attributes = new Properties(); + attributes.put("", "PROPAGATION_MANDATORY"); + tas.setProperties(attributes); + TransactionAttribute ta = tas.getTransactionAttribute(Object.class.getMethod("hashCode"), null); + assertNull(ta); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionInterceptorTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionInterceptorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..44fc6c641f98b4793764e40a8d6cace1e28e57aa --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/TransactionInterceptorTests.java @@ -0,0 +1,327 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.interceptor; + +import java.io.Serializable; +import java.util.Properties; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.TransactionStatus; +import org.springframework.util.SerializationTestUtils; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Mock object based tests for TransactionInterceptor. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 16.03.2003 + */ +public class TransactionInterceptorTests extends AbstractTransactionAspectTests { + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @Override + protected Object advised(Object target, PlatformTransactionManager ptm, TransactionAttributeSource[] tas) { + TransactionInterceptor ti = new TransactionInterceptor(); + ti.setTransactionManager(ptm); + ti.setTransactionAttributeSources(tas); + + ProxyFactory pf = new ProxyFactory(target); + pf.addAdvice(0, ti); + return pf.getProxy(); + } + + /** + * Template method to create an advised object given the + * target object and transaction setup. + * Creates a TransactionInterceptor and applies it. + */ + @Override + protected Object advised(Object target, PlatformTransactionManager ptm, TransactionAttributeSource tas) { + TransactionInterceptor ti = new TransactionInterceptor(); + ti.setTransactionManager(ptm); + assertEquals(ptm, ti.getTransactionManager()); + ti.setTransactionAttributeSource(tas); + assertEquals(tas, ti.getTransactionAttributeSource()); + + ProxyFactory pf = new ProxyFactory(target); + pf.addAdvice(0, ti); + return pf.getProxy(); + } + + + /** + * A TransactionInterceptor should be serializable if its + * PlatformTransactionManager is. + */ + @Test + public void serializableWithAttributeProperties() throws Exception { + TransactionInterceptor ti = new TransactionInterceptor(); + Properties props = new Properties(); + props.setProperty("methodName", "PROPAGATION_REQUIRED"); + ti.setTransactionAttributes(props); + PlatformTransactionManager ptm = new SerializableTransactionManager(); + ti.setTransactionManager(ptm); + ti = (TransactionInterceptor) SerializationTestUtils.serializeAndDeserialize(ti); + + // Check that logger survived deserialization + assertNotNull(ti.logger); + assertTrue(ti.getTransactionManager() instanceof SerializableTransactionManager); + assertNotNull(ti.getTransactionAttributeSource()); + } + + @Test + public void serializableWithCompositeSource() throws Exception { + NameMatchTransactionAttributeSource tas1 = new NameMatchTransactionAttributeSource(); + Properties props = new Properties(); + props.setProperty("methodName", "PROPAGATION_REQUIRED"); + tas1.setProperties(props); + + NameMatchTransactionAttributeSource tas2 = new NameMatchTransactionAttributeSource(); + props = new Properties(); + props.setProperty("otherMethodName", "PROPAGATION_REQUIRES_NEW"); + tas2.setProperties(props); + + TransactionInterceptor ti = new TransactionInterceptor(); + ti.setTransactionAttributeSources(tas1, tas2); + PlatformTransactionManager ptm = new SerializableTransactionManager(); + ti.setTransactionManager(ptm); + ti = (TransactionInterceptor) SerializationTestUtils.serializeAndDeserialize(ti); + + assertTrue(ti.getTransactionManager() instanceof SerializableTransactionManager); + assertTrue(ti.getTransactionAttributeSource() instanceof CompositeTransactionAttributeSource); + CompositeTransactionAttributeSource ctas = (CompositeTransactionAttributeSource) ti.getTransactionAttributeSource(); + assertTrue(ctas.getTransactionAttributeSources()[0] instanceof NameMatchTransactionAttributeSource); + assertTrue(ctas.getTransactionAttributeSources()[1] instanceof NameMatchTransactionAttributeSource); + } + + @Test + public void determineTransactionManagerWithNoBeanFactory() { + PlatformTransactionManager transactionManager = mock(PlatformTransactionManager.class); + TransactionInterceptor ti = transactionInterceptorWithTransactionManager(transactionManager, null); + + assertSame(transactionManager, ti.determineTransactionManager(new DefaultTransactionAttribute())); + } + + @Test + public void determineTransactionManagerWithNoBeanFactoryAndNoTransactionAttribute() { + PlatformTransactionManager transactionManager = mock(PlatformTransactionManager.class); + TransactionInterceptor ti = transactionInterceptorWithTransactionManager(transactionManager, null); + + assertSame(transactionManager, ti.determineTransactionManager(null)); + } + + @Test + public void determineTransactionManagerWithNoTransactionAttribute() { + BeanFactory beanFactory = mock(BeanFactory.class); + TransactionInterceptor ti = simpleTransactionInterceptor(beanFactory); + + assertNull(ti.determineTransactionManager(null)); + } + + @Test + public void determineTransactionManagerWithQualifierUnknown() { + BeanFactory beanFactory = mock(BeanFactory.class); + TransactionInterceptor ti = simpleTransactionInterceptor(beanFactory); + DefaultTransactionAttribute attribute = new DefaultTransactionAttribute(); + attribute.setQualifier("fooTransactionManager"); + + thrown.expect(NoSuchBeanDefinitionException.class); + thrown.expectMessage("'fooTransactionManager'"); + ti.determineTransactionManager(attribute); + } + + @Test + public void determineTransactionManagerWithQualifierAndDefault() { + BeanFactory beanFactory = mock(BeanFactory.class); + PlatformTransactionManager transactionManager = mock(PlatformTransactionManager.class); + TransactionInterceptor ti = transactionInterceptorWithTransactionManager(transactionManager, beanFactory); + PlatformTransactionManager fooTransactionManager = + associateTransactionManager(beanFactory, "fooTransactionManager"); + + DefaultTransactionAttribute attribute = new DefaultTransactionAttribute(); + attribute.setQualifier("fooTransactionManager"); + + assertSame(fooTransactionManager, ti.determineTransactionManager(attribute)); + } + + @Test + public void determineTransactionManagerWithQualifierAndDefaultName() { + BeanFactory beanFactory = mock(BeanFactory.class); + associateTransactionManager(beanFactory, "defaultTransactionManager"); + TransactionInterceptor ti = transactionInterceptorWithTransactionManagerName( + "defaultTransactionManager", beanFactory); + + PlatformTransactionManager fooTransactionManager = + associateTransactionManager(beanFactory, "fooTransactionManager"); + DefaultTransactionAttribute attribute = new DefaultTransactionAttribute(); + attribute.setQualifier("fooTransactionManager"); + + assertSame(fooTransactionManager, ti.determineTransactionManager(attribute)); + } + + @Test + public void determineTransactionManagerWithEmptyQualifierAndDefaultName() { + BeanFactory beanFactory = mock(BeanFactory.class); + PlatformTransactionManager defaultTransactionManager + = associateTransactionManager(beanFactory, "defaultTransactionManager"); + TransactionInterceptor ti = transactionInterceptorWithTransactionManagerName( + "defaultTransactionManager", beanFactory); + + DefaultTransactionAttribute attribute = new DefaultTransactionAttribute(); + attribute.setQualifier(""); + + assertSame(defaultTransactionManager, ti.determineTransactionManager(attribute)); + } + + @Test + public void determineTransactionManagerWithQualifierSeveralTimes() { + BeanFactory beanFactory = mock(BeanFactory.class); + TransactionInterceptor ti = simpleTransactionInterceptor(beanFactory); + + PlatformTransactionManager txManager = associateTransactionManager(beanFactory, "fooTransactionManager"); + + DefaultTransactionAttribute attribute = new DefaultTransactionAttribute(); + attribute.setQualifier("fooTransactionManager"); + PlatformTransactionManager actual = ti.determineTransactionManager(attribute); + assertSame(txManager, actual); + + // Call again, should be cached + PlatformTransactionManager actual2 = ti.determineTransactionManager(attribute); + assertSame(txManager, actual2); + verify(beanFactory, times(1)).containsBean("fooTransactionManager"); + verify(beanFactory, times(1)).getBean("fooTransactionManager", PlatformTransactionManager.class); + } + + @Test + public void determineTransactionManagerWithBeanNameSeveralTimes() { + BeanFactory beanFactory = mock(BeanFactory.class); + TransactionInterceptor ti = transactionInterceptorWithTransactionManagerName( + "fooTransactionManager", beanFactory); + + PlatformTransactionManager txManager = associateTransactionManager(beanFactory, "fooTransactionManager"); + + DefaultTransactionAttribute attribute = new DefaultTransactionAttribute(); + PlatformTransactionManager actual = ti.determineTransactionManager(attribute); + assertSame(txManager, actual); + + // Call again, should be cached + PlatformTransactionManager actual2 = ti.determineTransactionManager(attribute); + assertSame(txManager, actual2); + verify(beanFactory, times(1)).getBean("fooTransactionManager", PlatformTransactionManager.class); + } + + @Test + public void determineTransactionManagerDefaultSeveralTimes() { + BeanFactory beanFactory = mock(BeanFactory.class); + TransactionInterceptor ti = simpleTransactionInterceptor(beanFactory); + + PlatformTransactionManager txManager = mock(PlatformTransactionManager.class); + given(beanFactory.getBean(PlatformTransactionManager.class)).willReturn(txManager); + + DefaultTransactionAttribute attribute = new DefaultTransactionAttribute(); + PlatformTransactionManager actual = ti.determineTransactionManager(attribute); + assertSame(txManager, actual); + + // Call again, should be cached + PlatformTransactionManager actual2 = ti.determineTransactionManager(attribute); + assertSame(txManager, actual2); + verify(beanFactory, times(1)).getBean(PlatformTransactionManager.class); + } + + + private TransactionInterceptor createTransactionInterceptor(BeanFactory beanFactory, + String transactionManagerName, PlatformTransactionManager transactionManager) { + + TransactionInterceptor ti = new TransactionInterceptor(); + if (beanFactory != null) { + ti.setBeanFactory(beanFactory); + } + if (transactionManagerName != null) { + ti.setTransactionManagerBeanName(transactionManagerName); + + } + if (transactionManager != null) { + ti.setTransactionManager(transactionManager); + } + ti.setTransactionAttributeSource(new NameMatchTransactionAttributeSource()); + ti.afterPropertiesSet(); + return ti; + } + + private TransactionInterceptor transactionInterceptorWithTransactionManager( + PlatformTransactionManager transactionManager, BeanFactory beanFactory) { + + return createTransactionInterceptor(beanFactory, null, transactionManager); + } + + private TransactionInterceptor transactionInterceptorWithTransactionManagerName( + String transactionManagerName, BeanFactory beanFactory) { + + return createTransactionInterceptor(beanFactory, transactionManagerName, null); + } + + private TransactionInterceptor simpleTransactionInterceptor(BeanFactory beanFactory) { + return createTransactionInterceptor(beanFactory, null, null); + } + + private PlatformTransactionManager associateTransactionManager(BeanFactory beanFactory, String name) { + PlatformTransactionManager transactionManager = mock(PlatformTransactionManager.class); + given(beanFactory.containsBean(name)).willReturn(true); + given(beanFactory.getBean(name, PlatformTransactionManager.class)).willReturn(transactionManager); + return transactionManager; + } + + + /** + * We won't use this: we just want to know it's serializable. + */ + @SuppressWarnings("serial") + public static class SerializableTransactionManager implements PlatformTransactionManager, Serializable { + + @Override + public TransactionStatus getTransaction(@Nullable TransactionDefinition definition) throws TransactionException { + throw new UnsupportedOperationException(); + } + + @Override + public void commit(TransactionStatus status) throws TransactionException { + throw new UnsupportedOperationException(); + } + + @Override + public void rollback(TransactionStatus status) throws TransactionException { + throw new UnsupportedOperationException(); + } + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/jta/MockUOWManager.java b/spring-tx/src/test/java/org/springframework/transaction/jta/MockUOWManager.java new file mode 100644 index 0000000000000000000000000000000000000000..96036c801187d1c824e6566eb8aff834590a1263 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/jta/MockUOWManager.java @@ -0,0 +1,131 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import javax.transaction.Synchronization; + +import com.ibm.wsspi.uow.UOWAction; +import com.ibm.wsspi.uow.UOWActionException; +import com.ibm.wsspi.uow.UOWException; +import com.ibm.wsspi.uow.UOWManager; + +/** + * @author Juergen Hoeller + */ +public class MockUOWManager implements UOWManager { + + private int type = UOW_TYPE_GLOBAL_TRANSACTION; + + private boolean joined; + + private int timeout; + + private boolean rollbackOnly; + + private int status = UOW_STATUS_NONE; + + private final Map resources = new HashMap<>(); + + private final List synchronizations = new LinkedList<>(); + + + @Override + public void runUnderUOW(int type, boolean join, UOWAction action) throws UOWActionException, UOWException { + this.type = type; + this.joined = join; + try { + this.status = UOW_STATUS_ACTIVE; + action.run(); + this.status = (this.rollbackOnly ? UOW_STATUS_ROLLEDBACK : UOW_STATUS_COMMITTED); + } + catch (Error | RuntimeException ex) { + this.status = UOW_STATUS_ROLLEDBACK; + throw ex; + } catch (Exception ex) { + this.status = UOW_STATUS_ROLLEDBACK; + throw new UOWActionException(ex); + } + } + + @Override + public int getUOWType() { + return this.type; + } + + public boolean getJoined() { + return this.joined; + } + + @Override + public long getLocalUOWId() { + return 0; + } + + @Override + public void setUOWTimeout(int uowType, int timeout) { + this.timeout = timeout; + } + + @Override + public int getUOWTimeout() { + return this.timeout; + } + + @Override + public void setRollbackOnly() { + this.rollbackOnly = true; + } + + @Override + public boolean getRollbackOnly() { + return this.rollbackOnly; + } + + public void setUOWStatus(int status) { + this.status = status; + } + + @Override + public int getUOWStatus() { + return this.status; + } + + @Override + public void putResource(Object key, Object value) { + this.resources.put(key, value); + } + + @Override + public Object getResource(Object key) throws NullPointerException { + return this.resources.get(key); + } + + @Override + public void registerInterposedSynchronization(Synchronization sync) { + this.synchronizations.add(sync); + } + + public List getSynchronizations() { + return this.synchronizations; + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/jta/WebSphereUowTransactionManagerTests.java b/spring-tx/src/test/java/org/springframework/transaction/jta/WebSphereUowTransactionManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a28a9040988ec3071e6c05e8a018eacaa2f8c9b4 --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/jta/WebSphereUowTransactionManagerTests.java @@ -0,0 +1,678 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.jta; + +import javax.transaction.RollbackException; +import javax.transaction.Status; +import javax.transaction.UserTransaction; + +import org.junit.Test; + +import org.springframework.dao.OptimisticLockingFailureException; +import org.springframework.tests.mock.jndi.ExpectedLookupTemplate; +import org.springframework.transaction.IllegalTransactionStateException; +import org.springframework.transaction.NestedTransactionNotSupportedException; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.TransactionSystemException; +import org.springframework.transaction.support.DefaultTransactionDefinition; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionSynchronizationManager; + +import com.ibm.wsspi.uow.UOWAction; +import com.ibm.wsspi.uow.UOWException; +import com.ibm.wsspi.uow.UOWManager; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Juergen Hoeller + */ +public class WebSphereUowTransactionManagerTests { + + @Test + public void uowManagerFoundInJndi() { + MockUOWManager manager = new MockUOWManager(); + ExpectedLookupTemplate jndiTemplate = + new ExpectedLookupTemplate(WebSphereUowTransactionManager.DEFAULT_UOW_MANAGER_NAME, manager); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(); + ptm.setJndiTemplate(jndiTemplate); + ptm.afterPropertiesSet(); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + return "result"; + } + })); + + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + + @Test + public void uowManagerAndUserTransactionFoundInJndi() throws Exception { + UserTransaction ut = mock(UserTransaction.class); + given(ut.getStatus()).willReturn( Status.STATUS_NO_TRANSACTION, Status.STATUS_ACTIVE, Status.STATUS_ACTIVE); + + MockUOWManager manager = new MockUOWManager(); + ExpectedLookupTemplate jndiTemplate = new ExpectedLookupTemplate(); + jndiTemplate.addObject(WebSphereUowTransactionManager.DEFAULT_USER_TRANSACTION_NAME, ut); + jndiTemplate.addObject(WebSphereUowTransactionManager.DEFAULT_UOW_MANAGER_NAME, manager); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(); + ptm.setJndiTemplate(jndiTemplate); + ptm.afterPropertiesSet(); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + TransactionStatus ts = ptm.getTransaction(definition); + ptm.commit(ts); + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + return "result"; + } + })); + + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + verify(ut).begin(); + verify(ut).commit(); + } + + @Test + public void propagationMandatoryFailsInCaseOfNoExistingTransaction() { + MockUOWManager manager = new MockUOWManager(); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_MANDATORY); + + try { + ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + return "result"; + } + }); + fail("Should have thrown IllegalTransactionStateException"); + } + catch (IllegalTransactionStateException ex) { + // expected + } + } + + @Test + public void newTransactionSynchronizationUsingPropagationSupports() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_SUPPORTS, WebSphereUowTransactionManager.SYNCHRONIZATION_ALWAYS); + } + + @Test + public void newTransactionSynchronizationUsingPropagationNotSupported() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_NOT_SUPPORTED, WebSphereUowTransactionManager.SYNCHRONIZATION_ALWAYS); + } + + @Test + public void newTransactionSynchronizationUsingPropagationNever() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_NEVER, WebSphereUowTransactionManager.SYNCHRONIZATION_ALWAYS); + } + + @Test + public void newTransactionSynchronizationUsingPropagationSupportsAndSynchOnActual() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_SUPPORTS, WebSphereUowTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + } + + @Test + public void newTransactionSynchronizationUsingPropagationNotSupportedAndSynchOnActual() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_NOT_SUPPORTED, WebSphereUowTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + } + + @Test + public void newTransactionSynchronizationUsingPropagationNeverAndSynchOnActual() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_NEVER, WebSphereUowTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + } + + @Test + public void newTransactionSynchronizationUsingPropagationSupportsAndSynchNever() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_SUPPORTS, WebSphereUowTransactionManager.SYNCHRONIZATION_NEVER); + } + + @Test + public void newTransactionSynchronizationUsingPropagationNotSupportedAndSynchNever() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_NOT_SUPPORTED, WebSphereUowTransactionManager.SYNCHRONIZATION_NEVER); + } + + @Test + public void newTransactionSynchronizationUsingPropagationNeverAndSynchNever() { + doTestNewTransactionSynchronization( + TransactionDefinition.PROPAGATION_NEVER, WebSphereUowTransactionManager.SYNCHRONIZATION_NEVER); + } + + private void doTestNewTransactionSynchronization(int propagationBehavior, final int synchMode) { + MockUOWManager manager = new MockUOWManager(); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + ptm.setTransactionSynchronization(synchMode); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setPropagationBehavior(propagationBehavior); + definition.setReadOnly(true); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + if (synchMode == WebSphereUowTransactionManager.SYNCHRONIZATION_ALWAYS) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + else { + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_LOCAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + + @Test + public void newTransactionWithCommitUsingPropagationRequired() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_REQUIRED, WebSphereUowTransactionManager.SYNCHRONIZATION_ALWAYS); + } + + @Test + public void newTransactionWithCommitUsingPropagationRequiresNew() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_REQUIRES_NEW, WebSphereUowTransactionManager.SYNCHRONIZATION_ALWAYS); + } + + @Test + public void newTransactionWithCommitUsingPropagationNested() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_NESTED, WebSphereUowTransactionManager.SYNCHRONIZATION_ALWAYS); + } + + @Test + public void newTransactionWithCommitUsingPropagationRequiredAndSynchOnActual() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_REQUIRED, WebSphereUowTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + } + + @Test + public void newTransactionWithCommitUsingPropagationRequiresNewAndSynchOnActual() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_REQUIRES_NEW, WebSphereUowTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + } + + @Test + public void newTransactionWithCommitUsingPropagationNestedAndSynchOnActual() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_NESTED, WebSphereUowTransactionManager.SYNCHRONIZATION_ON_ACTUAL_TRANSACTION); + } + + @Test + public void newTransactionWithCommitUsingPropagationRequiredAndSynchNever() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_REQUIRED, WebSphereUowTransactionManager.SYNCHRONIZATION_NEVER); + } + + @Test + public void newTransactionWithCommitUsingPropagationRequiresNewAndSynchNever() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_REQUIRES_NEW, WebSphereUowTransactionManager.SYNCHRONIZATION_NEVER); + } + + @Test + public void newTransactionWithCommitUsingPropagationNestedAndSynchNever() { + doTestNewTransactionWithCommit( + TransactionDefinition.PROPAGATION_NESTED, WebSphereUowTransactionManager.SYNCHRONIZATION_NEVER); + } + + private void doTestNewTransactionWithCommit(int propagationBehavior, final int synchMode) { + MockUOWManager manager = new MockUOWManager(); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + ptm.setTransactionSynchronization(synchMode); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setPropagationBehavior(propagationBehavior); + definition.setReadOnly(true); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + if (synchMode != WebSphereUowTransactionManager.SYNCHRONIZATION_NEVER) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + else { + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + } + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + + @Test + public void newTransactionWithCommitAndTimeout() { + MockUOWManager manager = new MockUOWManager(); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setTimeout(10); + definition.setReadOnly(true); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(10, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + + @Test + public void newTransactionWithCommitException() { + final RollbackException rex = new RollbackException(); + MockUOWManager manager = new MockUOWManager() { + @Override + public void runUnderUOW(int type, boolean join, UOWAction action) throws UOWException { + throw new UOWException(rex); + } + }; + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + try { + ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + return "result"; + } + }); + fail("Should have thrown TransactionSystemException"); + } + catch (TransactionSystemException ex) { + // expected + assertTrue(ex.getCause() instanceof UOWException); + assertSame(rex, ex.getRootCause()); + assertSame(rex, ex.getMostSpecificCause()); + } + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + } + + @Test + public void newTransactionWithRollback() { + MockUOWManager manager = new MockUOWManager(); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + try { + ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + throw new OptimisticLockingFailureException(""); + } + }); + fail("Should have thrown OptimisticLockingFailureException"); + } + catch (OptimisticLockingFailureException ex) { + // expected + } + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertTrue(manager.getRollbackOnly()); + } + + @Test + public void newTransactionWithRollbackOnly() { + MockUOWManager manager = new MockUOWManager(); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + status.setRollbackOnly(); + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertTrue(manager.getRollbackOnly()); + } + + @Test + public void existingNonSpringTransaction() { + MockUOWManager manager = new MockUOWManager(); + manager.setUOWStatus(UOWManager.UOW_STATUS_ACTIVE); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertTrue(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + + @Test + public void propagationNeverFailsInCaseOfExistingTransaction() { + MockUOWManager manager = new MockUOWManager(); + manager.setUOWStatus(UOWManager.UOW_STATUS_ACTIVE); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NEVER); + + try { + ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + return "result"; + } + }); + fail("Should have thrown IllegalTransactionStateException"); + } + catch (IllegalTransactionStateException ex) { + // expected + } + } + + @Test + public void propagationNestedFailsInCaseOfExistingTransaction() { + MockUOWManager manager = new MockUOWManager(); + manager.setUOWStatus(UOWManager.UOW_STATUS_ACTIVE); + WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); + + try { + ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + return "result"; + } + }); + fail("Should have thrown NestedTransactionNotSupportedException"); + } + catch (NestedTransactionNotSupportedException ex) { + // expected + } + } + + @Test + public void existingTransactionWithParticipationUsingPropagationRequired() { + doTestExistingTransactionWithParticipation(TransactionDefinition.PROPAGATION_REQUIRED); + } + + @Test + public void existingTransactionWithParticipationUsingPropagationSupports() { + doTestExistingTransactionWithParticipation(TransactionDefinition.PROPAGATION_SUPPORTS); + } + + @Test + public void existingTransactionWithParticipationUsingPropagationMandatory() { + doTestExistingTransactionWithParticipation(TransactionDefinition.PROPAGATION_MANDATORY); + } + + private void doTestExistingTransactionWithParticipation(int propagationBehavior) { + MockUOWManager manager = new MockUOWManager(); + final WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + final DefaultTransactionDefinition definition2 = new DefaultTransactionDefinition(); + definition2.setPropagationBehavior(propagationBehavior); + definition2.setReadOnly(true); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertEquals("result2", ptm.execute(definition2, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + return "result2"; + } + })); + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + assertTrue(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + + @Test + public void existingTransactionWithSuspensionUsingPropagationRequiresNew() { + doTestExistingTransactionWithSuspension(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + } + + @Test + public void existingTransactionWithSuspensionUsingPropagationNotSupported() { + doTestExistingTransactionWithSuspension(TransactionDefinition.PROPAGATION_NOT_SUPPORTED); + } + + private void doTestExistingTransactionWithSuspension(final int propagationBehavior) { + MockUOWManager manager = new MockUOWManager(); + final WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + final DefaultTransactionDefinition definition2 = new DefaultTransactionDefinition(); + definition2.setPropagationBehavior(propagationBehavior); + definition2.setReadOnly(true); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertEquals("result2", ptm.execute(definition2, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertEquals(propagationBehavior == TransactionDefinition.PROPAGATION_REQUIRES_NEW, + TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + return "result2"; + } + })); + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + if (propagationBehavior == TransactionDefinition.PROPAGATION_REQUIRES_NEW) { + assertEquals(UOWManager.UOW_TYPE_GLOBAL_TRANSACTION, manager.getUOWType()); + } + else { + assertEquals(UOWManager.UOW_TYPE_LOCAL_TRANSACTION, manager.getUOWType()); + } + assertFalse(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + + @Test + public void existingTransactionUsingPropagationNotSupported() { + MockUOWManager manager = new MockUOWManager(); + final WebSphereUowTransactionManager ptm = new WebSphereUowTransactionManager(manager); + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + final DefaultTransactionDefinition definition2 = new DefaultTransactionDefinition(); + definition2.setPropagationBehavior(TransactionDefinition.PROPAGATION_NOT_SUPPORTED); + definition2.setReadOnly(true); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals("result", ptm.execute(definition, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertTrue(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + assertEquals("result2", ptm.execute(definition2, new TransactionCallback() { + @Override + public String doInTransaction(TransactionStatus status) { + assertTrue(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertTrue(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + return "result2"; + } + })); + return "result"; + } + })); + + assertFalse(TransactionSynchronizationManager.isSynchronizationActive()); + assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); + assertFalse(TransactionSynchronizationManager.isCurrentTransactionReadOnly()); + + assertEquals(0, manager.getUOWTimeout()); + assertEquals(UOWManager.UOW_TYPE_LOCAL_TRANSACTION, manager.getUOWType()); + assertFalse(manager.getJoined()); + assertFalse(manager.getRollbackOnly()); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/support/JtaTransactionManagerSerializationTests.java b/spring-tx/src/test/java/org/springframework/transaction/support/JtaTransactionManagerSerializationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..17b8657045d9186bb7d97adcb6eeb7914a12463d --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/support/JtaTransactionManagerSerializationTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import javax.transaction.TransactionManager; +import javax.transaction.UserTransaction; + +import org.junit.Test; + +import org.springframework.tests.mock.jndi.SimpleNamingContextBuilder; +import org.springframework.transaction.jta.JtaTransactionManager; +import org.springframework.util.SerializationTestUtils; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * @author Rod Johnson + */ +public class JtaTransactionManagerSerializationTests { + + @Test + public void serializable() throws Exception { + UserTransaction ut1 = mock(UserTransaction.class); + UserTransaction ut2 = mock(UserTransaction.class); + TransactionManager tm = mock(TransactionManager.class); + + JtaTransactionManager jtam = new JtaTransactionManager(); + jtam.setUserTransaction(ut1); + jtam.setTransactionManager(tm); + jtam.setRollbackOnCommitFailure(true); + jtam.afterPropertiesSet(); + + SimpleNamingContextBuilder jndiEnv = SimpleNamingContextBuilder + .emptyActivatedContextBuilder(); + jndiEnv.bind(JtaTransactionManager.DEFAULT_USER_TRANSACTION_NAME, ut2); + JtaTransactionManager serializedJtatm = (JtaTransactionManager) SerializationTestUtils + .serializeAndDeserialize(jtam); + + // should do client-side lookup + assertNotNull("Logger must survive serialization", + serializedJtatm.logger); + assertTrue("UserTransaction looked up on client", serializedJtatm + .getUserTransaction() == ut2); + assertNull("TransactionManager didn't survive", serializedJtatm + .getTransactionManager()); + assertEquals(true, serializedJtatm.isRollbackOnCommitFailure()); + } + +} diff --git a/spring-tx/src/test/java/org/springframework/transaction/support/SimpleTransactionScopeTests.java b/spring-tx/src/test/java/org/springframework/transaction/support/SimpleTransactionScopeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6770402fbc1c53875951272066724d461c9b5d8f --- /dev/null +++ b/spring-tx/src/test/java/org/springframework/transaction/support/SimpleTransactionScopeTests.java @@ -0,0 +1,198 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.transaction.support; + +import java.util.HashSet; +import java.util.Set; + +import org.junit.Test; + +import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.support.GenericBeanDefinition; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.tests.sample.beans.DerivedTestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.tests.transaction.CallCountingTransactionManager; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + */ +public class SimpleTransactionScopeTests { + + @Test + @SuppressWarnings("resource") + public void getFromScope() throws Exception { + GenericApplicationContext context = new GenericApplicationContext(); + context.getBeanFactory().registerScope("tx", new SimpleTransactionScope()); + + GenericBeanDefinition bd1 = new GenericBeanDefinition(); + bd1.setBeanClass(TestBean.class); + bd1.setScope("tx"); + bd1.setPrimary(true); + context.registerBeanDefinition("txScopedObject1", bd1); + + GenericBeanDefinition bd2 = new GenericBeanDefinition(); + bd2.setBeanClass(DerivedTestBean.class); + bd2.setScope("tx"); + context.registerBeanDefinition("txScopedObject2", bd2); + + context.refresh(); + + try { + context.getBean(TestBean.class); + fail("Should have thrown BeanCreationException"); + } + catch (BeanCreationException ex) { + // expected - no synchronization active + assertTrue(ex.getCause() instanceof IllegalStateException); + } + + try { + context.getBean(DerivedTestBean.class); + fail("Should have thrown BeanCreationException"); + } + catch (BeanCreationException ex) { + // expected - no synchronization active + assertTrue(ex.getCause() instanceof IllegalStateException); + } + + TestBean bean1 = null; + DerivedTestBean bean2 = null; + DerivedTestBean bean2a = null; + DerivedTestBean bean2b = null; + + TransactionSynchronizationManager.initSynchronization(); + try { + bean1 = context.getBean(TestBean.class); + assertSame(bean1, context.getBean(TestBean.class)); + + bean2 = context.getBean(DerivedTestBean.class); + assertSame(bean2, context.getBean(DerivedTestBean.class)); + context.getBeanFactory().destroyScopedBean("txScopedObject2"); + assertFalse(TransactionSynchronizationManager.hasResource("txScopedObject2")); + assertTrue(bean2.wasDestroyed()); + + bean2a = context.getBean(DerivedTestBean.class); + assertSame(bean2a, context.getBean(DerivedTestBean.class)); + assertNotSame(bean2, bean2a); + context.getBeanFactory().getRegisteredScope("tx").remove("txScopedObject2"); + assertFalse(TransactionSynchronizationManager.hasResource("txScopedObject2")); + assertFalse(bean2a.wasDestroyed()); + + bean2b = context.getBean(DerivedTestBean.class); + assertSame(bean2b, context.getBean(DerivedTestBean.class)); + assertNotSame(bean2, bean2b); + assertNotSame(bean2a, bean2b); + } + finally { + TransactionSynchronizationUtils.triggerAfterCompletion(TransactionSynchronization.STATUS_COMMITTED); + TransactionSynchronizationManager.clearSynchronization(); + } + + assertFalse(bean2a.wasDestroyed()); + assertTrue(bean2b.wasDestroyed()); + assertTrue(TransactionSynchronizationManager.getResourceMap().isEmpty()); + + try { + context.getBean(TestBean.class); + fail("Should have thrown IllegalStateException"); + } + catch (BeanCreationException ex) { + // expected - no synchronization active + assertTrue(ex.getCause() instanceof IllegalStateException); + } + + try { + context.getBean(DerivedTestBean.class); + fail("Should have thrown IllegalStateException"); + } + catch (BeanCreationException ex) { + // expected - no synchronization active + assertTrue(ex.getCause() instanceof IllegalStateException); + } + } + + @Test + public void getWithTransactionManager() throws Exception { + try (GenericApplicationContext context = new GenericApplicationContext()) { + context.getBeanFactory().registerScope("tx", new SimpleTransactionScope()); + + GenericBeanDefinition bd1 = new GenericBeanDefinition(); + bd1.setBeanClass(TestBean.class); + bd1.setScope("tx"); + bd1.setPrimary(true); + context.registerBeanDefinition("txScopedObject1", bd1); + + GenericBeanDefinition bd2 = new GenericBeanDefinition(); + bd2.setBeanClass(DerivedTestBean.class); + bd2.setScope("tx"); + context.registerBeanDefinition("txScopedObject2", bd2); + + context.refresh(); + + CallCountingTransactionManager tm = new CallCountingTransactionManager(); + TransactionTemplate tt = new TransactionTemplate(tm); + Set finallyDestroy = new HashSet<>(); + + tt.execute(status -> { + TestBean bean1 = context.getBean(TestBean.class); + assertSame(bean1, context.getBean(TestBean.class)); + + DerivedTestBean bean2 = context.getBean(DerivedTestBean.class); + assertSame(bean2, context.getBean(DerivedTestBean.class)); + context.getBeanFactory().destroyScopedBean("txScopedObject2"); + assertFalse(TransactionSynchronizationManager.hasResource("txScopedObject2")); + assertTrue(bean2.wasDestroyed()); + + DerivedTestBean bean2a = context.getBean(DerivedTestBean.class); + assertSame(bean2a, context.getBean(DerivedTestBean.class)); + assertNotSame(bean2, bean2a); + context.getBeanFactory().getRegisteredScope("tx").remove("txScopedObject2"); + assertFalse(TransactionSynchronizationManager.hasResource("txScopedObject2")); + assertFalse(bean2a.wasDestroyed()); + + DerivedTestBean bean2b = context.getBean(DerivedTestBean.class); + finallyDestroy.add(bean2b); + assertSame(bean2b, context.getBean(DerivedTestBean.class)); + assertNotSame(bean2, bean2b); + assertNotSame(bean2a, bean2b); + + Set immediatelyDestroy = new HashSet<>(); + TransactionTemplate tt2 = new TransactionTemplate(tm); + tt2.setPropagationBehavior(TransactionTemplate.PROPAGATION_REQUIRED); + tt2.execute(status2 -> { + DerivedTestBean bean2c = context.getBean(DerivedTestBean.class); + immediatelyDestroy.add(bean2c); + assertSame(bean2c, context.getBean(DerivedTestBean.class)); + assertNotSame(bean2, bean2c); + assertNotSame(bean2a, bean2c); + assertNotSame(bean2b, bean2c); + return null; + }); + assertTrue(immediatelyDestroy.iterator().next().wasDestroyed()); + assertFalse(bean2b.wasDestroyed()); + + return null; + }); + + assertTrue(finallyDestroy.iterator().next().wasDestroyed()); + } + } + +} diff --git a/spring-tx/src/test/resources/log4j2-test.xml b/spring-tx/src/test/resources/log4j2-test.xml new file mode 100644 index 0000000000000000000000000000000000000000..ff2f0402132b9e67532c2493cf89b9c5aa6ad2ce --- /dev/null +++ b/spring-tx/src/test/resources/log4j2-test.xml @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/spring-tx/src/test/resources/org/springframework/transaction/annotation/annotationTransactionNamespaceHandlerTests.xml b/spring-tx/src/test/resources/org/springframework/transaction/annotation/annotationTransactionNamespaceHandlerTests.xml new file mode 100644 index 0000000000000000000000000000000000000000..868ac0c6d9028f5e61cc7d65edc122251bc5a8f7 --- /dev/null +++ b/spring-tx/src/test/resources/org/springframework/transaction/annotation/annotationTransactionNamespaceHandlerTests.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + diff --git a/spring-tx/src/test/resources/org/springframework/transaction/config/annotationDrivenConfigurationClassTests.xml b/spring-tx/src/test/resources/org/springframework/transaction/config/annotationDrivenConfigurationClassTests.xml new file mode 100644 index 0000000000000000000000000000000000000000..6ead5633a805db5d18f5223c93c6cc59f1015db0 --- /dev/null +++ b/spring-tx/src/test/resources/org/springframework/transaction/config/annotationDrivenConfigurationClassTests.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + diff --git a/spring-tx/src/test/resources/org/springframework/transaction/config/annotationDrivenProxyTargetClassTests.xml b/spring-tx/src/test/resources/org/springframework/transaction/config/annotationDrivenProxyTargetClassTests.xml new file mode 100644 index 0000000000000000000000000000000000000000..fc59ae34b5bff96152543a714fd92b57f33c209f --- /dev/null +++ b/spring-tx/src/test/resources/org/springframework/transaction/config/annotationDrivenProxyTargetClassTests.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-tx/src/test/resources/org/springframework/transaction/interceptor/noTransactionAttributeSource.xml b/spring-tx/src/test/resources/org/springframework/transaction/interceptor/noTransactionAttributeSource.xml new file mode 100644 index 0000000000000000000000000000000000000000..7c10400577de50240b7c11414a9986b2272df6b2 --- /dev/null +++ b/spring-tx/src/test/resources/org/springframework/transaction/interceptor/noTransactionAttributeSource.xml @@ -0,0 +1,22 @@ + + + + + + + + custom + 666 + + + + + + + + + + + diff --git a/spring-tx/src/test/resources/org/springframework/transaction/interceptor/transactionalBeanFactory.xml b/spring-tx/src/test/resources/org/springframework/transaction/interceptor/transactionalBeanFactory.xml new file mode 100644 index 0000000000000000000000000000000000000000..336045ef7a444416bef880baace67d7b85c70bcb --- /dev/null +++ b/spring-tx/src/test/resources/org/springframework/transaction/interceptor/transactionalBeanFactory.xml @@ -0,0 +1,137 @@ + + + + + + + dependency + + + + + custom + 666 + + + + + + + + + + + + org.springframework.tests.sample.beans.ITestBean.s*=PROPAGATION_MANDATORY + org.springframework.tests.sample.beans.AgeHolder.setAg*=PROPAGATION_REQUIRED + org.springframework.tests.sample.beans.ITestBean.set*= PROPAGATION_SUPPORTS , readOnly + + + + + + + org.springframework.tests.sample.beans.ITestBean + + + + txInterceptor + target + + + + + + + + + PROPAGATION_MANDATORY + PROPAGATION_REQUIRED , readOnly + PROPAGATION_SUPPORTS + + + + + + + + + + + true + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + PROPAGATION_REQUIRED + + + + + + + + + + + + + + + + PROPAGATION_MANDATORY + PROPAGATION_REQUIRED + PROPAGATION_SUPPORTS + + + true + false + + + + + + + + + diff --git a/spring-tx/src/test/resources/org/springframework/transaction/txNamespaceHandlerTests.xml b/spring-tx/src/test/resources/org/springframework/transaction/txNamespaceHandlerTests.xml new file mode 100644 index 0000000000000000000000000000000000000000..2a2f922039af4cc06dab35855f5f24cbf71fe390 --- /dev/null +++ b/spring-tx/src/test/resources/org/springframework/transaction/txNamespaceHandlerTests.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-web/spring-web.gradle b/spring-web/spring-web.gradle new file mode 100644 index 0000000000000000000000000000000000000000..bf5a9d1a565eae7d87b0cd616af23270138162ac --- /dev/null +++ b/spring-web/spring-web.gradle @@ -0,0 +1,87 @@ +description = "Spring Web" + +dependencyManagement { + imports { + mavenBom "io.projectreactor:reactor-bom:${reactorVersion}" + mavenBom "io.netty:netty-bom:${nettyVersion}" + mavenBom "org.eclipse.jetty:jetty-bom:${jettyVersion}" + } +} + +dependencies { + compile(project(":spring-beans")) + compile(project(":spring-core")) + optional(project(":spring-aop")) + optional(project(":spring-context")) + optional(project(":spring-oxm")) + optional("javax.servlet:javax.servlet-api:3.1.0") + optional("javax.servlet.jsp:javax.servlet.jsp-api:2.3.2-b02") + optional("javax.el:javax.el-api:3.0.1-b04") + optional("javax.faces:javax.faces-api:2.2") + optional("javax.json.bind:javax.json.bind-api:1.0") + optional("javax.mail:javax.mail-api:1.6.2") + optional("javax.validation:validation-api:1.1.0.Final") + optional("javax.xml.bind:jaxb-api:2.3.1") + optional("javax.xml.ws:jaxws-api:2.3.1") + optional("org.glassfish.main:javax.jws:4.0-b33") + optional("io.reactivex:rxjava:${rxjavaVersion}") + optional("io.reactivex:rxjava-reactive-streams:${rxjavaAdapterVersion}") + optional("io.reactivex.rxjava2:rxjava:${rxjava2Version}") + optional("io.netty:netty-all") + optional("io.projectreactor.netty:reactor-netty") + optional("io.undertow:undertow-core:${undertowVersion}") + optional("org.apache.tomcat.embed:tomcat-embed-core:${tomcatVersion}") + optional("org.eclipse.jetty:jetty-server") { + exclude group: "javax.servlet", module: "javax.servlet-api" + } + optional("org.eclipse.jetty:jetty-servlet") { + exclude group: "javax.servlet", module: "javax.servlet-api" + } + optional("org.eclipse.jetty:jetty-reactive-httpclient:1.0.3") + optional("com.squareup.okhttp3:okhttp:3.14.7") + optional("org.apache.httpcomponents:httpclient:4.5.10") { + exclude group: "commons-logging", module: "commons-logging" + } + optional("org.apache.httpcomponents:httpasyncclient:4.1.4") { + exclude group: "commons-logging", module: "commons-logging" + } + optional("commons-fileupload:commons-fileupload:1.4") + optional("org.synchronoss.cloud:nio-multipart-parser:1.1.0") + optional("com.fasterxml.woodstox:woodstox-core:5.3.0") { // woodstox before aalto + exclude group: "stax", module: "stax-api" + } + optional("com.fasterxml:aalto-xml:1.1.1") + optional("com.fasterxml.jackson.core:jackson-databind:${jackson2Version}") + optional("com.fasterxml.jackson.dataformat:jackson-dataformat-xml:${jackson2Version}") + optional("com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${jackson2Version}") + optional("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${jackson2Version}") + optional("com.google.code.gson:gson:2.8.5") + optional("com.google.protobuf:protobuf-java-util:3.6.1") + optional("com.googlecode.protobuf-java-format:protobuf-java-format:1.4") + optional("com.rometools:rome:1.12.2") + optional("com.caucho:hessian:4.0.51") + optional("org.codehaus.groovy:groovy:${groovyVersion}") + optional("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}") + optional("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}") + testCompile("io.projectreactor:reactor-test") + testCompile("org.apache.taglibs:taglibs-standard-jstlel:1.2.5") { + exclude group: "org.apache.taglibs", module: "taglibs-standard-spec" + } + testCompile("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:${jackson2Version}") + testCompile("com.fasterxml.jackson.datatype:jackson-datatype-joda:${jackson2Version}") + testCompile("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:${jackson2Version}") + testCompile("com.fasterxml.jackson.module:jackson-module-kotlin:${jackson2Version}") + testCompile("org.apache.tomcat:tomcat-util:${tomcatVersion}") + testCompile("org.apache.tomcat.embed:tomcat-embed-core:${tomcatVersion}") + testCompile("org.eclipse.jetty:jetty-server") + testCompile("org.eclipse.jetty:jetty-servlet") + testCompile("com.squareup.okhttp3:mockwebserver:3.14.7") + testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}") + testCompile("org.skyscreamer:jsonassert:1.5.0") + testCompile("org.xmlunit:xmlunit-matchers:2.6.2") + testRuntime("com.sun.mail:javax.mail:1.6.2") + testRuntime("com.sun.xml.bind:jaxb-core:2.3.0.1") + testRuntime("com.sun.xml.bind:jaxb-impl:2.3.0.1") + testRuntime("javax.json:javax.json-api:1.1.4") + testRuntime("org.apache.johnzon:johnzon-jsonb:1.1.13") +} diff --git a/spring-web/src/main/java/org/springframework/http/CacheControl.java b/spring-web/src/main/java/org/springframework/http/CacheControl.java new file mode 100644 index 0000000000000000000000000000000000000000..0e72aa8fa3717dcee3b0e61ac44d500026bf0f12 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/CacheControl.java @@ -0,0 +1,320 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.concurrent.TimeUnit; + +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * A builder for creating "Cache-Control" HTTP response headers. + * + *

Adding Cache-Control directives to HTTP responses can significantly improve the client + * experience when interacting with a web application. This builder creates opinionated + * "Cache-Control" headers with response directives only, with several use cases in mind. + * + *

    + *
  • Caching HTTP responses with {@code CacheControl cc = CacheControl.maxAge(1, TimeUnit.HOURS)} + * will result in {@code Cache-Control: "max-age=3600"}
  • + *
  • Preventing cache with {@code CacheControl cc = CacheControl.noStore()} + * will result in {@code Cache-Control: "no-store"}
  • + *
  • Advanced cases like {@code CacheControl cc = CacheControl.maxAge(1, TimeUnit.HOURS).noTransform().cachePublic()} + * will result in {@code Cache-Control: "max-age=3600, no-transform, public"}
  • + *
+ * + *

Note that to be efficient, Cache-Control headers should be written along HTTP validators + * such as "Last-Modified" or "ETag" headers. + * + * @author Brian Clozel + * @author Juergen Hoeller + * @since 4.2 + * @see rfc7234 section 5.2.2 + * @see + * HTTP caching - Google developers reference + * @see Mark Nottingham's cache documentation + */ +public class CacheControl { + + private long maxAge = -1; + + private boolean noCache = false; + + private boolean noStore = false; + + private boolean mustRevalidate = false; + + private boolean noTransform = false; + + private boolean cachePublic = false; + + private boolean cachePrivate = false; + + private boolean proxyRevalidate = false; + + private long staleWhileRevalidate = -1; + + private long staleIfError = -1; + + private long sMaxAge = -1; + + + /** + * Create an empty CacheControl instance. + * @see #empty() + */ + protected CacheControl() { + } + + + /** + * Return an empty directive. + *

This is well suited for using other optional directives without "max-age", + * "no-cache" or "no-store". + * @return {@code this}, to facilitate method chaining + */ + public static CacheControl empty() { + return new CacheControl(); + } + + /** + * Add a "max-age=" directive. + *

This directive is well suited for publicly caching resources, knowing that + * they won't change within the configured amount of time. Additional directives + * can be also used, in case resources shouldn't be cached ({@link #cachePrivate()}) + * or transformed ({@link #noTransform()}) by shared caches. + *

In order to prevent caches to reuse the cached response even when it has + * become stale (i.e. the "max-age" delay is passed), the "must-revalidate" + * directive should be set ({@link #mustRevalidate()} + * @param maxAge the maximum time the response should be cached + * @param unit the time unit of the {@code maxAge} argument + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.8 + */ + public static CacheControl maxAge(long maxAge, TimeUnit unit) { + CacheControl cc = new CacheControl(); + cc.maxAge = unit.toSeconds(maxAge); + return cc; + } + + /** + * Add a "no-cache" directive. + *

This directive is well suited for telling caches that the response + * can be reused only if the client revalidates it with the server. + * This directive won't disable cache altogether and may result with clients + * sending conditional requests (with "ETag", "If-Modified-Since" headers) + * and the server responding with "304 - Not Modified" status. + *

In order to disable caching and minimize requests/responses exchanges, + * the {@link #noStore()} directive should be used instead of {@code #noCache()}. + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.2 + */ + public static CacheControl noCache() { + CacheControl cc = new CacheControl(); + cc.noCache = true; + return cc; + } + + /** + * Add a "no-store" directive. + *

This directive is well suited for preventing caches (browsers and proxies) + * to cache the content of responses. + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.3 + */ + public static CacheControl noStore() { + CacheControl cc = new CacheControl(); + cc.noStore = true; + return cc; + } + + + /** + * Add a "must-revalidate" directive. + *

This directive indicates that once it has become stale, a cache MUST NOT + * use the response to satisfy subsequent requests without successful validation + * on the origin server. + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.1 + */ + public CacheControl mustRevalidate() { + this.mustRevalidate = true; + return this; + } + + /** + * Add a "no-transform" directive. + *

This directive indicates that intermediaries (caches and others) should + * not transform the response content. This can be useful to force caches and + * CDNs not to automatically gzip or optimize the response content. + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.4 + */ + public CacheControl noTransform() { + this.noTransform = true; + return this; + } + + /** + * Add a "public" directive. + *

This directive indicates that any cache MAY store the response, + * even if the response would normally be non-cacheable or cacheable + * only within a private cache. + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.5 + */ + public CacheControl cachePublic() { + this.cachePublic = true; + return this; + } + + /** + * Add a "private" directive. + *

This directive indicates that the response message is intended + * for a single user and MUST NOT be stored by a shared cache. + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.6 + */ + public CacheControl cachePrivate() { + this.cachePrivate = true; + return this; + } + + /** + * Add a "proxy-revalidate" directive. + *

This directive has the same meaning as the "must-revalidate" directive, + * except that it does not apply to private caches (i.e. browsers, HTTP clients). + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.7 + */ + public CacheControl proxyRevalidate() { + this.proxyRevalidate = true; + return this; + } + + /** + * Add an "s-maxage" directive. + *

This directive indicates that, in shared caches, the maximum age specified + * by this directive overrides the maximum age specified by other directives. + * @param sMaxAge the maximum time the response should be cached + * @param unit the time unit of the {@code sMaxAge} argument + * @return {@code this}, to facilitate method chaining + * @see rfc7234 section 5.2.2.9 + */ + public CacheControl sMaxAge(long sMaxAge, TimeUnit unit) { + this.sMaxAge = unit.toSeconds(sMaxAge); + return this; + } + + /** + * Add a "stale-while-revalidate" directive. + *

This directive indicates that caches MAY serve the response in which it + * appears after it becomes stale, up to the indicated number of seconds. + * If a cached response is served stale due to the presence of this extension, + * the cache SHOULD attempt to revalidate it while still serving stale responses + * (i.e. without blocking). + * @param staleWhileRevalidate the maximum time the response should be used while being revalidated + * @param unit the time unit of the {@code staleWhileRevalidate} argument + * @return {@code this}, to facilitate method chaining + * @see rfc5861 section 3 + */ + public CacheControl staleWhileRevalidate(long staleWhileRevalidate, TimeUnit unit) { + this.staleWhileRevalidate = unit.toSeconds(staleWhileRevalidate); + return this; + } + + /** + * Add a "stale-if-error" directive. + *

This directive indicates that when an error is encountered, a cached stale response + * MAY be used to satisfy the request, regardless of other freshness information. + * @param staleIfError the maximum time the response should be used when errors are encountered + * @param unit the time unit of the {@code staleIfError} argument + * @return {@code this}, to facilitate method chaining + * @see rfc5861 section 4 + */ + public CacheControl staleIfError(long staleIfError, TimeUnit unit) { + this.staleIfError = unit.toSeconds(staleIfError); + return this; + } + + + /** + * Return the "Cache-Control" header value, if any. + * @return the header value, or {@code null} if no directive was added + */ + @Nullable + public String getHeaderValue() { + String headerValue = toHeaderValue(); + return (StringUtils.hasText(headerValue) ? headerValue : null); + } + + /** + * Return the "Cache-Control" header value. + * @return the header value (potentially empty) + */ + private String toHeaderValue() { + StringBuilder headerValue = new StringBuilder(); + if (this.maxAge != -1) { + appendDirective(headerValue, "max-age=" + this.maxAge); + } + if (this.noCache) { + appendDirective(headerValue, "no-cache"); + } + if (this.noStore) { + appendDirective(headerValue, "no-store"); + } + if (this.mustRevalidate) { + appendDirective(headerValue, "must-revalidate"); + } + if (this.noTransform) { + appendDirective(headerValue, "no-transform"); + } + if (this.cachePublic) { + appendDirective(headerValue, "public"); + } + if (this.cachePrivate) { + appendDirective(headerValue, "private"); + } + if (this.proxyRevalidate) { + appendDirective(headerValue, "proxy-revalidate"); + } + if (this.sMaxAge != -1) { + appendDirective(headerValue, "s-maxage=" + this.sMaxAge); + } + if (this.staleIfError != -1) { + appendDirective(headerValue, "stale-if-error=" + this.staleIfError); + } + if (this.staleWhileRevalidate != -1) { + appendDirective(headerValue, "stale-while-revalidate=" + this.staleWhileRevalidate); + } + return headerValue.toString(); + } + + private void appendDirective(StringBuilder builder, String value) { + if (builder.length() > 0) { + builder.append(", "); + } + builder.append(value); + } + + + @Override + public String toString() { + return "CacheControl [" + toHeaderValue() + "]"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/ContentDisposition.java b/spring-web/src/main/java/org/springframework/http/ContentDisposition.java new file mode 100644 index 0000000000000000000000000000000000000000..068ec0dd934694f67aeee5752eedecbd172aeecc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ContentDisposition.java @@ -0,0 +1,605 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.ByteArrayOutputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.ZonedDateTime; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.List; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + +import static java.nio.charset.StandardCharsets.ISO_8859_1; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.format.DateTimeFormatter.RFC_1123_DATE_TIME; + +/** + * Represent the Content-Disposition type and parameters as defined in RFC 2183. + * + * @author Sebastien Deleuze + * @author Juergen Hoeller + * @since 5.0 + * @see RFC 2183 + */ +public final class ContentDisposition { + + private static final String INVALID_HEADER_FIELD_PARAMETER_FORMAT = + "Invalid header field parameter format (as defined in RFC 5987)"; + + + @Nullable + private final String type; + + @Nullable + private final String name; + + @Nullable + private final String filename; + + @Nullable + private final Charset charset; + + @Nullable + private final Long size; + + @Nullable + private final ZonedDateTime creationDate; + + @Nullable + private final ZonedDateTime modificationDate; + + @Nullable + private final ZonedDateTime readDate; + + + /** + * Private constructor. See static factory methods in this class. + */ + private ContentDisposition(@Nullable String type, @Nullable String name, @Nullable String filename, + @Nullable Charset charset, @Nullable Long size, @Nullable ZonedDateTime creationDate, + @Nullable ZonedDateTime modificationDate, @Nullable ZonedDateTime readDate) { + + this.type = type; + this.name = name; + this.filename = filename; + this.charset = charset; + this.size = size; + this.creationDate = creationDate; + this.modificationDate = modificationDate; + this.readDate = readDate; + } + + + /** + * Return the disposition type, like for example {@literal inline}, {@literal attachment}, + * {@literal form-data}, or {@code null} if not defined. + */ + @Nullable + public String getType() { + return this.type; + } + + /** + * Return the value of the {@literal name} parameter, or {@code null} if not defined. + */ + @Nullable + public String getName() { + return this.name; + } + + /** + * Return the value of the {@literal filename} parameter (or the value of the + * {@literal filename*} one decoded as defined in the RFC 5987), or {@code null} if not defined. + */ + @Nullable + public String getFilename() { + return this.filename; + } + + /** + * Return the charset defined in {@literal filename*} parameter, or {@code null} if not defined. + */ + @Nullable + public Charset getCharset() { + return this.charset; + } + + /** + * Return the value of the {@literal size} parameter, or {@code null} if not defined. + */ + @Nullable + public Long getSize() { + return this.size; + } + + /** + * Return the value of the {@literal creation-date} parameter, or {@code null} if not defined. + */ + @Nullable + public ZonedDateTime getCreationDate() { + return this.creationDate; + } + + /** + * Return the value of the {@literal modification-date} parameter, or {@code null} if not defined. + */ + @Nullable + public ZonedDateTime getModificationDate() { + return this.modificationDate; + } + + /** + * Return the value of the {@literal read-date} parameter, or {@code null} if not defined. + */ + @Nullable + public ZonedDateTime getReadDate() { + return this.readDate; + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ContentDisposition)) { + return false; + } + ContentDisposition otherCd = (ContentDisposition) other; + return (ObjectUtils.nullSafeEquals(this.type, otherCd.type) && + ObjectUtils.nullSafeEquals(this.name, otherCd.name) && + ObjectUtils.nullSafeEquals(this.filename, otherCd.filename) && + ObjectUtils.nullSafeEquals(this.charset, otherCd.charset) && + ObjectUtils.nullSafeEquals(this.size, otherCd.size) && + ObjectUtils.nullSafeEquals(this.creationDate, otherCd.creationDate)&& + ObjectUtils.nullSafeEquals(this.modificationDate, otherCd.modificationDate)&& + ObjectUtils.nullSafeEquals(this.readDate, otherCd.readDate)); + } + + @Override + public int hashCode() { + int result = ObjectUtils.nullSafeHashCode(this.type); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.name); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.filename); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.charset); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.size); + result = 31 * result + (this.creationDate != null ? this.creationDate.hashCode() : 0); + result = 31 * result + (this.modificationDate != null ? this.modificationDate.hashCode() : 0); + result = 31 * result + (this.readDate != null ? this.readDate.hashCode() : 0); + return result; + } + + /** + * Return the header value for this content disposition as defined in RFC 2183. + * @see #parse(String) + */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + if (this.type != null) { + sb.append(this.type); + } + if (this.name != null) { + sb.append("; name=\""); + sb.append(this.name).append('\"'); + } + if (this.filename != null) { + if (this.charset == null || StandardCharsets.US_ASCII.equals(this.charset)) { + sb.append("; filename=\""); + sb.append(escapeQuotationsInFilename(this.filename)).append('\"'); + } + else { + sb.append("; filename*="); + sb.append(encodeFilename(this.filename, this.charset)); + } + } + if (this.size != null) { + sb.append("; size="); + sb.append(this.size); + } + if (this.creationDate != null) { + sb.append("; creation-date=\""); + sb.append(RFC_1123_DATE_TIME.format(this.creationDate)); + sb.append('\"'); + } + if (this.modificationDate != null) { + sb.append("; modification-date=\""); + sb.append(RFC_1123_DATE_TIME.format(this.modificationDate)); + sb.append('\"'); + } + if (this.readDate != null) { + sb.append("; read-date=\""); + sb.append(RFC_1123_DATE_TIME.format(this.readDate)); + sb.append('\"'); + } + return sb.toString(); + } + + + /** + * Return a builder for a {@code ContentDisposition}. + * @param type the disposition type like for example {@literal inline}, + * {@literal attachment}, or {@literal form-data} + * @return the builder + */ + public static Builder builder(String type) { + return new BuilderImpl(type); + } + + /** + * Return an empty content disposition. + */ + public static ContentDisposition empty() { + return new ContentDisposition("", null, null, null, null, null, null, null); + } + + /** + * Parse a {@literal Content-Disposition} header value as defined in RFC 2183. + * @param contentDisposition the {@literal Content-Disposition} header value + * @return the parsed content disposition + * @see #toString() + */ + public static ContentDisposition parse(String contentDisposition) { + List parts = tokenize(contentDisposition); + String type = parts.get(0); + String name = null; + String filename = null; + Charset charset = null; + Long size = null; + ZonedDateTime creationDate = null; + ZonedDateTime modificationDate = null; + ZonedDateTime readDate = null; + for (int i = 1; i < parts.size(); i++) { + String part = parts.get(i); + int eqIndex = part.indexOf('='); + if (eqIndex != -1) { + String attribute = part.substring(0, eqIndex); + String value = (part.startsWith("\"", eqIndex + 1) && part.endsWith("\"") ? + part.substring(eqIndex + 2, part.length() - 1) : + part.substring(eqIndex + 1)); + if (attribute.equals("name") ) { + name = value; + } + else if (attribute.equals("filename*") ) { + int idx1 = value.indexOf('\''); + int idx2 = value.indexOf('\'', idx1 + 1); + if (idx1 != -1 && idx2 != -1) { + charset = Charset.forName(value.substring(0, idx1).trim()); + Assert.isTrue(UTF_8.equals(charset) || ISO_8859_1.equals(charset), + "Charset should be UTF-8 or ISO-8859-1"); + filename = decodeFilename(value.substring(idx2 + 1), charset); + } + else { + // US ASCII + filename = decodeFilename(value, StandardCharsets.US_ASCII); + } + } + else if (attribute.equals("filename") && (filename == null)) { + filename = value; + } + else if (attribute.equals("size") ) { + size = Long.parseLong(value); + } + else if (attribute.equals("creation-date")) { + try { + creationDate = ZonedDateTime.parse(value, RFC_1123_DATE_TIME); + } + catch (DateTimeParseException ex) { + // ignore + } + } + else if (attribute.equals("modification-date")) { + try { + modificationDate = ZonedDateTime.parse(value, RFC_1123_DATE_TIME); + } + catch (DateTimeParseException ex) { + // ignore + } + } + else if (attribute.equals("read-date")) { + try { + readDate = ZonedDateTime.parse(value, RFC_1123_DATE_TIME); + } + catch (DateTimeParseException ex) { + // ignore + } + } + } + else { + throw new IllegalArgumentException("Invalid content disposition format"); + } + } + return new ContentDisposition(type, name, filename, charset, size, creationDate, modificationDate, readDate); + } + + private static List tokenize(String headerValue) { + int index = headerValue.indexOf(';'); + String type = (index >= 0 ? headerValue.substring(0, index) : headerValue).trim(); + if (type.isEmpty()) { + throw new IllegalArgumentException("Content-Disposition header must not be empty"); + } + List parts = new ArrayList<>(); + parts.add(type); + if (index >= 0) { + do { + int nextIndex = index + 1; + boolean quoted = false; + boolean escaped = false; + while (nextIndex < headerValue.length()) { + char ch = headerValue.charAt(nextIndex); + if (ch == ';') { + if (!quoted) { + break; + } + } + else if (!escaped && ch == '"') { + quoted = !quoted; + } + escaped = (!escaped && ch == '\\'); + nextIndex++; + } + String part = headerValue.substring(index + 1, nextIndex).trim(); + if (!part.isEmpty()) { + parts.add(part); + } + index = nextIndex; + } + while (index < headerValue.length()); + } + return parts; + } + + /** + * Decode the given header field param as described in RFC 5987. + *

Only the US-ASCII, UTF-8 and ISO-8859-1 charsets are supported. + * @param filename the filename + * @param charset the charset for the filename + * @return the encoded header field param + * @see RFC 5987 + */ + private static String decodeFilename(String filename, Charset charset) { + Assert.notNull(filename, "'input' String` should not be null"); + Assert.notNull(charset, "'charset' should not be null"); + byte[] value = filename.getBytes(charset); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + int index = 0; + while (index < value.length) { + byte b = value[index]; + if (isRFC5987AttrChar(b)) { + baos.write((char) b); + index++; + } + else if (b == '%' && index < value.length - 2) { + char[] array = new char[]{(char) value[index + 1], (char) value[index + 2]}; + try { + baos.write(Integer.parseInt(String.valueOf(array), 16)); + } + catch (NumberFormatException ex) { + throw new IllegalArgumentException(INVALID_HEADER_FIELD_PARAMETER_FORMAT, ex); + } + index+=3; + } + else { + throw new IllegalArgumentException(INVALID_HEADER_FIELD_PARAMETER_FORMAT); + } + } + return new String(baos.toByteArray(), charset); + } + + private static boolean isRFC5987AttrChar(byte c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + c == '!' || c == '#' || c == '$' || c == '&' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; + } + + private static String escapeQuotationsInFilename(String filename) { + if (filename.indexOf('"') == -1 && filename.indexOf('\\') == -1) { + return filename; + } + boolean escaped = false; + StringBuilder sb = new StringBuilder(); + for (char c : filename.toCharArray()) { + sb.append((c == '"' && !escaped) ? "\\\"" : c); + escaped = (!escaped && c == '\\'); + } + // Remove backslash at the end.. + if (escaped) { + sb.deleteCharAt(sb.length() - 1); + } + return sb.toString(); + } + + /** + * Encode the given header field param as describe in RFC 5987. + * @param input the header field param + * @param charset the charset of the header field param string, + * only the US-ASCII, UTF-8 and ISO-8859-1 charsets are supported + * @return the encoded header field param + * @see RFC 5987 + */ + private static String encodeFilename(String input, Charset charset) { + Assert.notNull(input, "`input` is required"); + Assert.notNull(charset, "`charset` is required"); + Assert.isTrue(!StandardCharsets.US_ASCII.equals(charset), "ASCII does not require encoding"); + Assert.isTrue(UTF_8.equals(charset) || ISO_8859_1.equals(charset), "Only UTF-8 and ISO-8859-1 supported."); + byte[] source = input.getBytes(charset); + int len = source.length; + StringBuilder sb = new StringBuilder(len << 1); + sb.append(charset.name()); + sb.append("''"); + for (byte b : source) { + if (isRFC5987AttrChar(b)) { + sb.append((char) b); + } + else { + sb.append('%'); + char hex1 = Character.toUpperCase(Character.forDigit((b >> 4) & 0xF, 16)); + char hex2 = Character.toUpperCase(Character.forDigit(b & 0xF, 16)); + sb.append(hex1); + sb.append(hex2); + } + } + return sb.toString(); + } + + + /** + * A mutable builder for {@code ContentDisposition}. + */ + public interface Builder { + + /** + * Set the value of the {@literal name} parameter. + */ + Builder name(String name); + + /** + * Set the value of the {@literal filename} parameter. The given + * filename will be formatted as quoted-string, as defined in RFC 2616, + * section 2.2, and any quote characters within the filename value will + * be escaped with a backslash, e.g. {@code "foo\"bar.txt"} becomes + * {@code "foo\\\"bar.txt"}. + */ + Builder filename(String filename); + + /** + * Set the value of the {@literal filename*} that will be encoded as + * defined in the RFC 5987. Only the US-ASCII, UTF-8 and ISO-8859-1 + * charsets are supported. + *

Note: Do not use this for a + * {@code "multipart/form-data"} requests as per + * RFC 7578, Section 4.2 + * and also RFC 5987 itself mentions it does not apply to multipart + * requests. + */ + Builder filename(String filename, Charset charset); + + /** + * Set the value of the {@literal size} parameter. + */ + Builder size(Long size); + + /** + * Set the value of the {@literal creation-date} parameter. + */ + Builder creationDate(ZonedDateTime creationDate); + + /** + * Set the value of the {@literal modification-date} parameter. + */ + Builder modificationDate(ZonedDateTime modificationDate); + + /** + * Set the value of the {@literal read-date} parameter. + */ + Builder readDate(ZonedDateTime readDate); + + /** + * Build the content disposition. + */ + ContentDisposition build(); + } + + + private static class BuilderImpl implements Builder { + + private final String type; + + @Nullable + private String name; + + @Nullable + private String filename; + + @Nullable + private Charset charset; + + @Nullable + private Long size; + + @Nullable + private ZonedDateTime creationDate; + + @Nullable + private ZonedDateTime modificationDate; + + @Nullable + private ZonedDateTime readDate; + + public BuilderImpl(String type) { + Assert.hasText(type, "'type' must not be not empty"); + this.type = type; + } + + @Override + public Builder name(String name) { + this.name = name; + return this; + } + + @Override + public Builder filename(String filename) { + Assert.hasText(filename, "No filename"); + this.filename = filename; + return this; + } + + @Override + public Builder filename(String filename, Charset charset) { + Assert.hasText(filename, "No filename"); + this.filename = filename; + this.charset = charset; + return this; + } + + @Override + public Builder size(Long size) { + this.size = size; + return this; + } + + @Override + public Builder creationDate(ZonedDateTime creationDate) { + this.creationDate = creationDate; + return this; + } + + @Override + public Builder modificationDate(ZonedDateTime modificationDate) { + this.modificationDate = modificationDate; + return this; + } + + @Override + public Builder readDate(ZonedDateTime readDate) { + this.readDate = readDate; + return this; + } + + @Override + public ContentDisposition build() { + return new ContentDisposition(this.type, this.name, this.filename, this.charset, + this.size, this.creationDate, this.modificationDate, this.readDate); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpCookie.java b/spring-web/src/main/java/org/springframework/http/HttpCookie.java new file mode 100644 index 0000000000000000000000000000000000000000..7bf20e7e4d6cfc85b65cb373b58d46b1a620f46f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpCookie.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Represents an HTTP cookie as a name-value pair consistent with the content of + * the "Cookie" request header. The {@link ResponseCookie} sub-class has the + * additional attributes expected in the "Set-Cookie" response header. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @see RFC 6265 + */ +public class HttpCookie { + + private final String name; + + private final String value; + + + public HttpCookie(String name, @Nullable String value) { + Assert.hasLength(name, "'name' is required and must not be empty."); + this.name = name; + this.value = (value != null ? value : ""); + } + + /** + * Return the cookie name. + */ + public String getName() { + return this.name; + } + + /** + * Return the cookie value or an empty string (never {@code null}). + */ + public String getValue() { + return this.value; + } + + + @Override + public int hashCode() { + return this.name.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof HttpCookie)) { + return false; + } + HttpCookie otherCookie = (HttpCookie) other; + return (this.name.equalsIgnoreCase(otherCookie.getName())); + } + + @Override + public String toString() { + return this.name + '=' + this.value; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpEntity.java b/spring-web/src/main/java/org/springframework/http/HttpEntity.java new file mode 100644 index 0000000000000000000000000000000000000000..35b2b97cafe4312c05329e0ba30d141b1159a5d8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpEntity.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; + +/** + * Represents an HTTP request or response entity, consisting of headers and body. + * + *

Typically used in combination with the {@link org.springframework.web.client.RestTemplate}, + * like so: + *

+ * HttpHeaders headers = new HttpHeaders();
+ * headers.setContentType(MediaType.TEXT_PLAIN);
+ * HttpEntity<String> entity = new HttpEntity<String>(helloWorld, headers);
+ * URI location = template.postForLocation("https://example.com", entity);
+ * 
+ * or + *
+ * HttpEntity<String> entity = template.getForEntity("https://example.com", String.class);
+ * String body = entity.getBody();
+ * MediaType contentType = entity.getHeaders().getContentType();
+ * 
+ * Can also be used in Spring MVC, as a return value from a @Controller method: + *
+ * @RequestMapping("/handle")
+ * public HttpEntity<String> handle() {
+ *   HttpHeaders responseHeaders = new HttpHeaders();
+ *   responseHeaders.set("MyResponseHeader", "MyValue");
+ *   return new HttpEntity<String>("Hello World", responseHeaders);
+ * }
+ * 
+ * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0.2 + * @param the body type + * @see org.springframework.web.client.RestTemplate + * @see #getBody() + * @see #getHeaders() + */ +public class HttpEntity { + + /** + * The empty {@code HttpEntity}, with no body or headers. + */ + public static final HttpEntity EMPTY = new HttpEntity<>(); + + + private final HttpHeaders headers; + + @Nullable + private final T body; + + + /** + * Create a new, empty {@code HttpEntity}. + */ + protected HttpEntity() { + this(null, null); + } + + /** + * Create a new {@code HttpEntity} with the given body and no headers. + * @param body the entity body + */ + public HttpEntity(T body) { + this(body, null); + } + + /** + * Create a new {@code HttpEntity} with the given headers and no body. + * @param headers the entity headers + */ + public HttpEntity(MultiValueMap headers) { + this(null, headers); + } + + /** + * Create a new {@code HttpEntity} with the given body and headers. + * @param body the entity body + * @param headers the entity headers + */ + public HttpEntity(@Nullable T body, @Nullable MultiValueMap headers) { + this.body = body; + HttpHeaders tempHeaders = new HttpHeaders(); + if (headers != null) { + tempHeaders.putAll(headers); + } + this.headers = HttpHeaders.readOnlyHttpHeaders(tempHeaders); + } + + + /** + * Returns the headers of this entity. + */ + public HttpHeaders getHeaders() { + return this.headers; + } + + /** + * Returns the body of this entity. + */ + @Nullable + public T getBody() { + return this.body; + } + + /** + * Indicates whether this entity has a body. + */ + public boolean hasBody() { + return (this.body != null); + } + + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (other == null || other.getClass() != getClass()) { + return false; + } + HttpEntity otherEntity = (HttpEntity) other; + return (ObjectUtils.nullSafeEquals(this.headers, otherEntity.headers) && + ObjectUtils.nullSafeEquals(this.body, otherEntity.body)); + } + + @Override + public int hashCode() { + return (ObjectUtils.nullSafeHashCode(this.headers) * 29 + ObjectUtils.nullSafeHashCode(this.body)); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("<"); + if (this.body != null) { + builder.append(this.body); + builder.append(','); + } + builder.append(this.headers); + builder.append('>'); + return builder.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java new file mode 100644 index 0000000000000000000000000000000000000000..53ecb5c15a9c8b05c3bd4a7656a3377ade4a548a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -0,0 +1,1756 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import java.text.DecimalFormat; +import java.text.DecimalFormatSymbols; +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumSet; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * A data structure representing HTTP request or response headers, mapping String header names + * to a list of String values, also offering accessors for common application-level data types. + * + *

In addition to the regular methods defined by {@link Map}, this class offers many common + * convenience methods, for example: + *

    + *
  • {@link #getFirst(String)} returns the first value associated with a given header name
  • + *
  • {@link #add(String, String)} adds a header value to the list of values for a header name
  • + *
  • {@link #set(String, String)} sets the header value to a single string value
  • + *
+ * + *

Note that {@code HttpHeaders} generally treats header names in a case-insensitive manner. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Brian Clozel + * @author Juergen Hoeller + * @author Josh Long + * @since 3.0 + */ +public class HttpHeaders implements MultiValueMap, Serializable { + + private static final long serialVersionUID = -8578554704772377436L; + + + /** + * The HTTP {@code Accept} header field name. + * @see Section 5.3.2 of RFC 7231 + */ + public static final String ACCEPT = "Accept"; + /** + * The HTTP {@code Accept-Charset} header field name. + * @see Section 5.3.3 of RFC 7231 + */ + public static final String ACCEPT_CHARSET = "Accept-Charset"; + /** + * The HTTP {@code Accept-Encoding} header field name. + * @see Section 5.3.4 of RFC 7231 + */ + public static final String ACCEPT_ENCODING = "Accept-Encoding"; + /** + * The HTTP {@code Accept-Language} header field name. + * @see Section 5.3.5 of RFC 7231 + */ + public static final String ACCEPT_LANGUAGE = "Accept-Language"; + /** + * The HTTP {@code Accept-Ranges} header field name. + * @see Section 5.3.5 of RFC 7233 + */ + public static final String ACCEPT_RANGES = "Accept-Ranges"; + /** + * The CORS {@code Access-Control-Allow-Credentials} response header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials"; + /** + * The CORS {@code Access-Control-Allow-Headers} response header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"; + /** + * The CORS {@code Access-Control-Allow-Methods} response header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods"; + /** + * The CORS {@code Access-Control-Allow-Origin} response header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"; + /** + * The CORS {@code Access-Control-Expose-Headers} response header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers"; + /** + * The CORS {@code Access-Control-Max-Age} response header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age"; + /** + * The CORS {@code Access-Control-Request-Headers} request header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_REQUEST_HEADERS = "Access-Control-Request-Headers"; + /** + * The CORS {@code Access-Control-Request-Method} request header field name. + * @see CORS W3C recommendation + */ + public static final String ACCESS_CONTROL_REQUEST_METHOD = "Access-Control-Request-Method"; + /** + * The HTTP {@code Age} header field name. + * @see Section 5.1 of RFC 7234 + */ + public static final String AGE = "Age"; + /** + * The HTTP {@code Allow} header field name. + * @see Section 7.4.1 of RFC 7231 + */ + public static final String ALLOW = "Allow"; + /** + * The HTTP {@code Authorization} header field name. + * @see Section 4.2 of RFC 7235 + */ + public static final String AUTHORIZATION = "Authorization"; + /** + * The HTTP {@code Cache-Control} header field name. + * @see Section 5.2 of RFC 7234 + */ + public static final String CACHE_CONTROL = "Cache-Control"; + /** + * The HTTP {@code Connection} header field name. + * @see Section 6.1 of RFC 7230 + */ + public static final String CONNECTION = "Connection"; + /** + * The HTTP {@code Content-Encoding} header field name. + * @see Section 3.1.2.2 of RFC 7231 + */ + public static final String CONTENT_ENCODING = "Content-Encoding"; + /** + * The HTTP {@code Content-Disposition} header field name. + * @see RFC 6266 + */ + public static final String CONTENT_DISPOSITION = "Content-Disposition"; + /** + * The HTTP {@code Content-Language} header field name. + * @see Section 3.1.3.2 of RFC 7231 + */ + public static final String CONTENT_LANGUAGE = "Content-Language"; + /** + * The HTTP {@code Content-Length} header field name. + * @see Section 3.3.2 of RFC 7230 + */ + public static final String CONTENT_LENGTH = "Content-Length"; + /** + * The HTTP {@code Content-Location} header field name. + * @see Section 3.1.4.2 of RFC 7231 + */ + public static final String CONTENT_LOCATION = "Content-Location"; + /** + * The HTTP {@code Content-Range} header field name. + * @see Section 4.2 of RFC 7233 + */ + public static final String CONTENT_RANGE = "Content-Range"; + /** + * The HTTP {@code Content-Type} header field name. + * @see Section 3.1.1.5 of RFC 7231 + */ + public static final String CONTENT_TYPE = "Content-Type"; + /** + * The HTTP {@code Cookie} header field name. + * @see Section 4.3.4 of RFC 2109 + */ + public static final String COOKIE = "Cookie"; + /** + * The HTTP {@code Date} header field name. + * @see Section 7.1.1.2 of RFC 7231 + */ + public static final String DATE = "Date"; + /** + * The HTTP {@code ETag} header field name. + * @see Section 2.3 of RFC 7232 + */ + public static final String ETAG = "ETag"; + /** + * The HTTP {@code Expect} header field name. + * @see Section 5.1.1 of RFC 7231 + */ + public static final String EXPECT = "Expect"; + /** + * The HTTP {@code Expires} header field name. + * @see Section 5.3 of RFC 7234 + */ + public static final String EXPIRES = "Expires"; + /** + * The HTTP {@code From} header field name. + * @see Section 5.5.1 of RFC 7231 + */ + public static final String FROM = "From"; + /** + * The HTTP {@code Host} header field name. + * @see Section 5.4 of RFC 7230 + */ + public static final String HOST = "Host"; + /** + * The HTTP {@code If-Match} header field name. + * @see Section 3.1 of RFC 7232 + */ + public static final String IF_MATCH = "If-Match"; + /** + * The HTTP {@code If-Modified-Since} header field name. + * @see Section 3.3 of RFC 7232 + */ + public static final String IF_MODIFIED_SINCE = "If-Modified-Since"; + /** + * The HTTP {@code If-None-Match} header field name. + * @see Section 3.2 of RFC 7232 + */ + public static final String IF_NONE_MATCH = "If-None-Match"; + /** + * The HTTP {@code If-Range} header field name. + * @see Section 3.2 of RFC 7233 + */ + public static final String IF_RANGE = "If-Range"; + /** + * The HTTP {@code If-Unmodified-Since} header field name. + * @see Section 3.4 of RFC 7232 + */ + public static final String IF_UNMODIFIED_SINCE = "If-Unmodified-Since"; + /** + * The HTTP {@code Last-Modified} header field name. + * @see Section 2.2 of RFC 7232 + */ + public static final String LAST_MODIFIED = "Last-Modified"; + /** + * The HTTP {@code Link} header field name. + * @see RFC 5988 + */ + public static final String LINK = "Link"; + /** + * The HTTP {@code Location} header field name. + * @see Section 7.1.2 of RFC 7231 + */ + public static final String LOCATION = "Location"; + /** + * The HTTP {@code Max-Forwards} header field name. + * @see Section 5.1.2 of RFC 7231 + */ + public static final String MAX_FORWARDS = "Max-Forwards"; + /** + * The HTTP {@code Origin} header field name. + * @see RFC 6454 + */ + public static final String ORIGIN = "Origin"; + /** + * The HTTP {@code Pragma} header field name. + * @see Section 5.4 of RFC 7234 + */ + public static final String PRAGMA = "Pragma"; + /** + * The HTTP {@code Proxy-Authenticate} header field name. + * @see Section 4.3 of RFC 7235 + */ + public static final String PROXY_AUTHENTICATE = "Proxy-Authenticate"; + /** + * The HTTP {@code Proxy-Authorization} header field name. + * @see Section 4.4 of RFC 7235 + */ + public static final String PROXY_AUTHORIZATION = "Proxy-Authorization"; + /** + * The HTTP {@code Range} header field name. + * @see Section 3.1 of RFC 7233 + */ + public static final String RANGE = "Range"; + /** + * The HTTP {@code Referer} header field name. + * @see Section 5.5.2 of RFC 7231 + */ + public static final String REFERER = "Referer"; + /** + * The HTTP {@code Retry-After} header field name. + * @see Section 7.1.3 of RFC 7231 + */ + public static final String RETRY_AFTER = "Retry-After"; + /** + * The HTTP {@code Server} header field name. + * @see Section 7.4.2 of RFC 7231 + */ + public static final String SERVER = "Server"; + /** + * The HTTP {@code Set-Cookie} header field name. + * @see Section 4.2.2 of RFC 2109 + */ + public static final String SET_COOKIE = "Set-Cookie"; + /** + * The HTTP {@code Set-Cookie2} header field name. + * @see RFC 2965 + */ + public static final String SET_COOKIE2 = "Set-Cookie2"; + /** + * The HTTP {@code TE} header field name. + * @see Section 4.3 of RFC 7230 + */ + public static final String TE = "TE"; + /** + * The HTTP {@code Trailer} header field name. + * @see Section 4.4 of RFC 7230 + */ + public static final String TRAILER = "Trailer"; + /** + * The HTTP {@code Transfer-Encoding} header field name. + * @see Section 3.3.1 of RFC 7230 + */ + public static final String TRANSFER_ENCODING = "Transfer-Encoding"; + /** + * The HTTP {@code Upgrade} header field name. + * @see Section 6.7 of RFC 7230 + */ + public static final String UPGRADE = "Upgrade"; + /** + * The HTTP {@code User-Agent} header field name. + * @see Section 5.5.3 of RFC 7231 + */ + public static final String USER_AGENT = "User-Agent"; + /** + * The HTTP {@code Vary} header field name. + * @see Section 7.1.4 of RFC 7231 + */ + public static final String VARY = "Vary"; + /** + * The HTTP {@code Via} header field name. + * @see Section 5.7.1 of RFC 7230 + */ + public static final String VIA = "Via"; + /** + * The HTTP {@code Warning} header field name. + * @see Section 5.5 of RFC 7234 + */ + public static final String WARNING = "Warning"; + /** + * The HTTP {@code WWW-Authenticate} header field name. + * @see Section 4.1 of RFC 7235 + */ + public static final String WWW_AUTHENTICATE = "WWW-Authenticate"; + + + /** + * An empty {@code HttpHeaders} instance (immutable). + * @since 5.0 + */ + public static final HttpHeaders EMPTY = new ReadOnlyHttpHeaders(new LinkedMultiValueMap<>()); + + /** + * Pattern matching ETag multiple field values in headers such as "If-Match", "If-None-Match". + * @see Section 2.3 of RFC 7232 + */ + private static final Pattern ETAG_HEADER_VALUE_PATTERN = Pattern.compile("\\*|\\s*((W\\/)?(\"[^\"]*\"))\\s*,?"); + + private static final DecimalFormatSymbols DECIMAL_FORMAT_SYMBOLS = new DecimalFormatSymbols(Locale.ENGLISH); + + private static final ZoneId GMT = ZoneId.of("GMT"); + + /** + * Date formats with time zone as specified in the HTTP RFC to use for formatting. + * @see Section 7.1.1.1 of RFC 7231 + */ + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US).withZone(GMT); + + /** + * Date formats with time zone as specified in the HTTP RFC to use for parsing. + * @see Section 7.1.1.1 of RFC 7231 + */ + private static final DateTimeFormatter[] DATE_PARSERS = new DateTimeFormatter[] { + DateTimeFormatter.RFC_1123_DATE_TIME, + DateTimeFormatter.ofPattern("EEEE, dd-MMM-yy HH:mm:ss zzz", Locale.US), + DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy", Locale.US).withZone(GMT) + }; + + + final MultiValueMap headers; + + + /** + * Construct a new, empty instance of the {@code HttpHeaders} object. + *

This is the common constructor, using a case-insensitive map structure. + */ + public HttpHeaders() { + this(CollectionUtils.toMultiValueMap(new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH))); + } + + /** + * Construct a new {@code HttpHeaders} instance backed by an existing map. + *

This constructor is available as an optimization for adapting to existing + * headers map structures, primarily for internal use within the framework. + * @param headers the headers map (expected to operate with case-insensitive keys) + * @since 5.1 + */ + public HttpHeaders(MultiValueMap headers) { + Assert.notNull(headers, "MultiValueMap must not be null"); + this.headers = headers; + } + + + /** + * Set the list of acceptable {@linkplain MediaType media types}, + * as specified by the {@code Accept} header. + */ + public void setAccept(List acceptableMediaTypes) { + set(ACCEPT, MediaType.toString(acceptableMediaTypes)); + } + + /** + * Return the list of acceptable {@linkplain MediaType media types}, + * as specified by the {@code Accept} header. + *

Returns an empty list when the acceptable media types are unspecified. + */ + public List getAccept() { + return MediaType.parseMediaTypes(get(ACCEPT)); + } + + /** + * Set the acceptable language ranges, as specified by the + * {@literal Accept-Language} header. + * @since 5.0 + */ + public void setAcceptLanguage(List languages) { + Assert.notNull(languages, "LanguageRange List must not be null"); + DecimalFormat decimal = new DecimalFormat("0.0", DECIMAL_FORMAT_SYMBOLS); + List values = languages.stream() + .map(range -> + range.getWeight() == Locale.LanguageRange.MAX_WEIGHT ? + range.getRange() : + range.getRange() + ";q=" + decimal.format(range.getWeight())) + .collect(Collectors.toList()); + set(ACCEPT_LANGUAGE, toCommaDelimitedString(values)); + } + + /** + * Return the language ranges from the {@literal "Accept-Language"} header. + *

If you only need sorted, preferred locales only use + * {@link #getAcceptLanguageAsLocales()} or if you need to filter based on + * a list of supported locales you can pass the returned list to + * {@link Locale#filter(List, Collection)}. + * @throws IllegalArgumentException if the value cannot be converted to a language range + * @since 5.0 + */ + public List getAcceptLanguage() { + String value = getFirst(ACCEPT_LANGUAGE); + return (StringUtils.hasText(value) ? Locale.LanguageRange.parse(value) : Collections.emptyList()); + } + + /** + * Variant of {@link #setAcceptLanguage(List)} using {@link Locale}'s. + * @since 5.0 + */ + public void setAcceptLanguageAsLocales(List locales) { + setAcceptLanguage(locales.stream() + .map(locale -> new Locale.LanguageRange(locale.toLanguageTag())) + .collect(Collectors.toList())); + } + + /** + * A variant of {@link #getAcceptLanguage()} that converts each + * {@link java.util.Locale.LanguageRange} to a {@link Locale}. + * @return the locales or an empty list + * @throws IllegalArgumentException if the value cannot be converted to a locale + * @since 5.0 + */ + public List getAcceptLanguageAsLocales() { + List ranges = getAcceptLanguage(); + if (ranges.isEmpty()) { + return Collections.emptyList(); + } + return ranges.stream() + .map(range -> Locale.forLanguageTag(range.getRange())) + .filter(locale -> StringUtils.hasText(locale.getDisplayName())) + .collect(Collectors.toList()); + } + + /** + * Set the (new) value of the {@code Access-Control-Allow-Credentials} response header. + */ + public void setAccessControlAllowCredentials(boolean allowCredentials) { + set(ACCESS_CONTROL_ALLOW_CREDENTIALS, Boolean.toString(allowCredentials)); + } + + /** + * Return the value of the {@code Access-Control-Allow-Credentials} response header. + */ + public boolean getAccessControlAllowCredentials() { + return Boolean.parseBoolean(getFirst(ACCESS_CONTROL_ALLOW_CREDENTIALS)); + } + + /** + * Set the (new) value of the {@code Access-Control-Allow-Headers} response header. + */ + public void setAccessControlAllowHeaders(List allowedHeaders) { + set(ACCESS_CONTROL_ALLOW_HEADERS, toCommaDelimitedString(allowedHeaders)); + } + + /** + * Return the value of the {@code Access-Control-Allow-Headers} response header. + */ + public List getAccessControlAllowHeaders() { + return getValuesAsList(ACCESS_CONTROL_ALLOW_HEADERS); + } + + /** + * Set the (new) value of the {@code Access-Control-Allow-Methods} response header. + */ + public void setAccessControlAllowMethods(List allowedMethods) { + set(ACCESS_CONTROL_ALLOW_METHODS, StringUtils.collectionToCommaDelimitedString(allowedMethods)); + } + + /** + * Return the value of the {@code Access-Control-Allow-Methods} response header. + */ + public List getAccessControlAllowMethods() { + List result = new ArrayList<>(); + String value = getFirst(ACCESS_CONTROL_ALLOW_METHODS); + if (value != null) { + String[] tokens = StringUtils.tokenizeToStringArray(value, ","); + for (String token : tokens) { + HttpMethod resolved = HttpMethod.resolve(token); + if (resolved != null) { + result.add(resolved); + } + } + } + return result; + } + + /** + * Set the (new) value of the {@code Access-Control-Allow-Origin} response header. + */ + public void setAccessControlAllowOrigin(@Nullable String allowedOrigin) { + setOrRemove(ACCESS_CONTROL_ALLOW_ORIGIN, allowedOrigin); + } + + /** + * Return the value of the {@code Access-Control-Allow-Origin} response header. + */ + @Nullable + public String getAccessControlAllowOrigin() { + return getFieldValues(ACCESS_CONTROL_ALLOW_ORIGIN); + } + + /** + * Set the (new) value of the {@code Access-Control-Expose-Headers} response header. + */ + public void setAccessControlExposeHeaders(List exposedHeaders) { + set(ACCESS_CONTROL_EXPOSE_HEADERS, toCommaDelimitedString(exposedHeaders)); + } + + /** + * Return the value of the {@code Access-Control-Expose-Headers} response header. + */ + public List getAccessControlExposeHeaders() { + return getValuesAsList(ACCESS_CONTROL_EXPOSE_HEADERS); + } + + /** + * Set the (new) value of the {@code Access-Control-Max-Age} response header. + */ + public void setAccessControlMaxAge(long maxAge) { + set(ACCESS_CONTROL_MAX_AGE, Long.toString(maxAge)); + } + + /** + * Return the value of the {@code Access-Control-Max-Age} response header. + *

Returns -1 when the max age is unknown. + */ + public long getAccessControlMaxAge() { + String value = getFirst(ACCESS_CONTROL_MAX_AGE); + return (value != null ? Long.parseLong(value) : -1); + } + + /** + * Set the (new) value of the {@code Access-Control-Request-Headers} request header. + */ + public void setAccessControlRequestHeaders(List requestHeaders) { + set(ACCESS_CONTROL_REQUEST_HEADERS, toCommaDelimitedString(requestHeaders)); + } + + /** + * Return the value of the {@code Access-Control-Request-Headers} request header. + */ + public List getAccessControlRequestHeaders() { + return getValuesAsList(ACCESS_CONTROL_REQUEST_HEADERS); + } + + /** + * Set the (new) value of the {@code Access-Control-Request-Method} request header. + */ + public void setAccessControlRequestMethod(@Nullable HttpMethod requestMethod) { + setOrRemove(ACCESS_CONTROL_REQUEST_METHOD, (requestMethod != null ? requestMethod.name() : null)); + } + + /** + * Return the value of the {@code Access-Control-Request-Method} request header. + */ + @Nullable + public HttpMethod getAccessControlRequestMethod() { + return HttpMethod.resolve(getFirst(ACCESS_CONTROL_REQUEST_METHOD)); + } + + /** + * Set the list of acceptable {@linkplain Charset charsets}, + * as specified by the {@code Accept-Charset} header. + */ + public void setAcceptCharset(List acceptableCharsets) { + StringBuilder builder = new StringBuilder(); + for (Iterator iterator = acceptableCharsets.iterator(); iterator.hasNext();) { + Charset charset = iterator.next(); + builder.append(charset.name().toLowerCase(Locale.ENGLISH)); + if (iterator.hasNext()) { + builder.append(", "); + } + } + set(ACCEPT_CHARSET, builder.toString()); + } + + /** + * Return the list of acceptable {@linkplain Charset charsets}, + * as specified by the {@code Accept-Charset} header. + */ + public List getAcceptCharset() { + String value = getFirst(ACCEPT_CHARSET); + if (value != null) { + String[] tokens = StringUtils.tokenizeToStringArray(value, ","); + List result = new ArrayList<>(tokens.length); + for (String token : tokens) { + int paramIdx = token.indexOf(';'); + String charsetName; + if (paramIdx == -1) { + charsetName = token; + } + else { + charsetName = token.substring(0, paramIdx); + } + if (!charsetName.equals("*")) { + result.add(Charset.forName(charsetName)); + } + } + return result; + } + else { + return Collections.emptyList(); + } + } + + /** + * Set the set of allowed {@link HttpMethod HTTP methods}, + * as specified by the {@code Allow} header. + */ + public void setAllow(Set allowedMethods) { + set(ALLOW, StringUtils.collectionToCommaDelimitedString(allowedMethods)); + } + + /** + * Return the set of allowed {@link HttpMethod HTTP methods}, + * as specified by the {@code Allow} header. + *

Returns an empty set when the allowed methods are unspecified. + */ + public Set getAllow() { + String value = getFirst(ALLOW); + if (StringUtils.hasLength(value)) { + String[] tokens = StringUtils.tokenizeToStringArray(value, ","); + List result = new ArrayList<>(tokens.length); + for (String token : tokens) { + HttpMethod resolved = HttpMethod.resolve(token); + if (resolved != null) { + result.add(resolved); + } + } + return EnumSet.copyOf(result); + } + else { + return EnumSet.noneOf(HttpMethod.class); + } + } + + /** + * Set the value of the {@linkplain #AUTHORIZATION Authorization} header to + * Basic Authentication based on the given username and password. + *

Note that this method only supports characters in the + * {@link StandardCharsets#ISO_8859_1 ISO-8859-1} character set. + * @param username the username + * @param password the password + * @throws IllegalArgumentException if either {@code user} or + * {@code password} contain characters that cannot be encoded to ISO-8859-1 + * @since 5.1 + * @see #setBasicAuth(String, String, Charset) + * @see RFC 7617 + */ + public void setBasicAuth(String username, String password) { + setBasicAuth(username, password, null); + } + + /** + * Set the value of the {@linkplain #AUTHORIZATION Authorization} header to + * Basic Authentication based on the given username and password. + * @param username the username + * @param password the password + * @param charset the charset to use to convert the credentials into an octet + * sequence. Defaults to {@linkplain StandardCharsets#ISO_8859_1 ISO-8859-1}. + * @throws IllegalArgumentException if {@code username} or {@code password} + * contains characters that cannot be encoded to the given charset + * @since 5.1 + * @see RFC 7617 + */ + public void setBasicAuth(String username, String password, @Nullable Charset charset) { + Assert.notNull(username, "Username must not be null"); + Assert.notNull(password, "Password must not be null"); + if (charset == null) { + charset = StandardCharsets.ISO_8859_1; + } + + CharsetEncoder encoder = charset.newEncoder(); + if (!encoder.canEncode(username) || !encoder.canEncode(password)) { + throw new IllegalArgumentException( + "Username or password contains characters that cannot be encoded to " + charset.displayName()); + } + + String credentialsString = username + ":" + password; + byte[] encodedBytes = Base64.getEncoder().encode(credentialsString.getBytes(charset)); + String encodedCredentials = new String(encodedBytes, charset); + set(AUTHORIZATION, "Basic " + encodedCredentials); + } + + /** + * Set the value of the {@linkplain #AUTHORIZATION Authorization} header to + * the given Bearer token. + * @param token the Base64 encoded token + * @since 5.1 + * @see RFC 6750 + */ + public void setBearerAuth(String token) { + set(AUTHORIZATION, "Bearer " + token); + } + + /** + * Set a configured {@link CacheControl} instance as the + * new value of the {@code Cache-Control} header. + * @since 5.0.5 + */ + public void setCacheControl(CacheControl cacheControl) { + setOrRemove(CACHE_CONTROL, cacheControl.getHeaderValue()); + } + + /** + * Set the (new) value of the {@code Cache-Control} header. + */ + public void setCacheControl(@Nullable String cacheControl) { + setOrRemove(CACHE_CONTROL, cacheControl); + } + + /** + * Return the value of the {@code Cache-Control} header. + */ + @Nullable + public String getCacheControl() { + return getFieldValues(CACHE_CONTROL); + } + + /** + * Set the (new) value of the {@code Connection} header. + */ + public void setConnection(String connection) { + set(CONNECTION, connection); + } + + /** + * Set the (new) value of the {@code Connection} header. + */ + public void setConnection(List connection) { + set(CONNECTION, toCommaDelimitedString(connection)); + } + + /** + * Return the value of the {@code Connection} header. + */ + public List getConnection() { + return getValuesAsList(CONNECTION); + } + + /** + * Set the {@code Content-Disposition} header when creating a + * {@code "multipart/form-data"} request. + *

Applications typically would not set this header directly but + * rather prepare a {@code MultiValueMap}, containing an + * Object or a {@link org.springframework.core.io.Resource} for each part, + * and then pass that to the {@code RestTemplate} or {@code WebClient}. + * @param name the control name + * @param filename the filename (may be {@code null}) + * @see #getContentDisposition() + */ + public void setContentDispositionFormData(String name, @Nullable String filename) { + Assert.notNull(name, "Name must not be null"); + ContentDisposition.Builder disposition = ContentDisposition.builder("form-data").name(name); + if (filename != null) { + disposition.filename(filename); + } + setContentDisposition(disposition.build()); + } + + /** + * Set the {@literal Content-Disposition} header. + *

This could be used on a response to indicate if the content is + * expected to be displayed inline in the browser or as an attachment to be + * saved locally. + *

It can also be used for a {@code "multipart/form-data"} request. + * For more details see notes on {@link #setContentDispositionFormData}. + * @since 5.0 + * @see #getContentDisposition() + */ + public void setContentDisposition(ContentDisposition contentDisposition) { + set(CONTENT_DISPOSITION, contentDisposition.toString()); + } + + /** + * Return a parsed representation of the {@literal Content-Disposition} header. + * @since 5.0 + * @see #setContentDisposition(ContentDisposition) + */ + public ContentDisposition getContentDisposition() { + String contentDisposition = getFirst(CONTENT_DISPOSITION); + if (contentDisposition != null) { + return ContentDisposition.parse(contentDisposition); + } + return ContentDisposition.empty(); + } + + /** + * Set the {@link Locale} of the content language, + * as specified by the {@literal Content-Language} header. + *

Use {@code put(CONTENT_LANGUAGE, list)} if you need + * to set multiple content languages.

+ * @since 5.0 + */ + public void setContentLanguage(@Nullable Locale locale) { + setOrRemove(CONTENT_LANGUAGE, (locale != null ? locale.toLanguageTag() : null)); + } + + /** + * Return the first {@link Locale} of the content languages, + * as specified by the {@literal Content-Language} header. + *

Returns {@code null} when the content language is unknown. + *

Use {@code getValuesAsList(CONTENT_LANGUAGE)} if you need + * to get multiple content languages.

+ * @since 5.0 + */ + @Nullable + public Locale getContentLanguage() { + return getValuesAsList(CONTENT_LANGUAGE) + .stream() + .findFirst() + .map(Locale::forLanguageTag) + .orElse(null); + } + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + */ + public void setContentLength(long contentLength) { + set(CONTENT_LENGTH, Long.toString(contentLength)); + } + + /** + * Return the length of the body in bytes, as specified by the + * {@code Content-Length} header. + *

Returns -1 when the content-length is unknown. + */ + public long getContentLength() { + String value = getFirst(CONTENT_LENGTH); + return (value != null ? Long.parseLong(value) : -1); + } + + /** + * Set the {@linkplain MediaType media type} of the body, + * as specified by the {@code Content-Type} header. + */ + public void setContentType(@Nullable MediaType mediaType) { + if (mediaType != null) { + Assert.isTrue(!mediaType.isWildcardType(), "Content-Type cannot contain wildcard type '*'"); + Assert.isTrue(!mediaType.isWildcardSubtype(), "Content-Type cannot contain wildcard subtype '*'"); + set(CONTENT_TYPE, mediaType.toString()); + } + else { + remove(CONTENT_TYPE); + } + } + + /** + * Return the {@linkplain MediaType media type} of the body, as specified + * by the {@code Content-Type} header. + *

Returns {@code null} when the content-type is unknown. + */ + @Nullable + public MediaType getContentType() { + String value = getFirst(CONTENT_TYPE); + return (StringUtils.hasLength(value) ? MediaType.parseMediaType(value) : null); + } + + /** + * Set the date and time at which the message was created, as specified + * by the {@code Date} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + */ + public void setDate(long date) { + setDate(DATE, date); + } + + /** + * Return the date and time at which the message was created, as specified + * by the {@code Date} header. + *

The date is returned as the number of milliseconds since + * January 1, 1970 GMT. Returns -1 when the date is unknown. + * @throws IllegalArgumentException if the value cannot be converted to a date + */ + public long getDate() { + return getFirstDate(DATE); + } + + /** + * Set the (new) entity tag of the body, as specified by the {@code ETag} header. + */ + public void setETag(@Nullable String etag) { + if (etag != null) { + Assert.isTrue(etag.startsWith("\"") || etag.startsWith("W/"), + "Invalid ETag: does not start with W/ or \""); + Assert.isTrue(etag.endsWith("\""), "Invalid ETag: does not end with \""); + set(ETAG, etag); + } + else { + remove(ETAG); + } + } + + /** + * Return the entity tag of the body, as specified by the {@code ETag} header. + */ + @Nullable + public String getETag() { + return getFirst(ETAG); + } + + /** + * Set the duration after which the message is no longer valid, + * as specified by the {@code Expires} header. + * @since 5.0.5 + */ + public void setExpires(ZonedDateTime expires) { + setZonedDateTime(EXPIRES, expires); + } + + /** + * Set the date and time at which the message is no longer valid, + * as specified by the {@code Expires} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + */ + public void setExpires(long expires) { + setDate(EXPIRES, expires); + } + + /** + * Return the date and time at which the message is no longer valid, + * as specified by the {@code Expires} header. + *

The date is returned as the number of milliseconds since + * January 1, 1970 GMT. Returns -1 when the date is unknown. + * @see #getFirstZonedDateTime(String) + */ + public long getExpires() { + return getFirstDate(EXPIRES, false); + } + + /** + * Set the (new) value of the {@code Host} header. + *

If the given {@linkplain InetSocketAddress#getPort() port} is {@code 0}, + * the host header will only contain the + * {@linkplain InetSocketAddress#getHostString() host name}. + * @since 5.0 + */ + public void setHost(@Nullable InetSocketAddress host) { + if (host != null) { + String value = host.getHostString(); + int port = host.getPort(); + if (port != 0) { + value = value + ":" + port; + } + set(HOST, value); + } + else { + remove(HOST, null); + } + } + + /** + * Return the value of the {@code Host} header, if available. + *

If the header value does not contain a port, the + * {@linkplain InetSocketAddress#getPort() port} in the returned address will + * be {@code 0}. + * @since 5.0 + */ + @Nullable + public InetSocketAddress getHost() { + String value = getFirst(HOST); + if (value == null) { + return null; + } + + String host = null; + int port = 0; + int separator = (value.startsWith("[") ? value.indexOf(':', value.indexOf(']')) : value.lastIndexOf(':')); + if (separator != -1) { + host = value.substring(0, separator); + String portString = value.substring(separator + 1); + try { + port = Integer.parseInt(portString); + } + catch (NumberFormatException ex) { + // ignore + } + } + + if (host == null) { + host = value; + } + return InetSocketAddress.createUnresolved(host, port); + } + + /** + * Set the (new) value of the {@code If-Match} header. + * @since 4.3 + */ + public void setIfMatch(String ifMatch) { + set(IF_MATCH, ifMatch); + } + + /** + * Set the (new) value of the {@code If-Match} header. + * @since 4.3 + */ + public void setIfMatch(List ifMatchList) { + set(IF_MATCH, toCommaDelimitedString(ifMatchList)); + } + + /** + * Return the value of the {@code If-Match} header. + * @throws IllegalArgumentException if parsing fails + * @since 4.3 + */ + public List getIfMatch() { + return getETagValuesAsList(IF_MATCH); + } + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @since 5.1.4 + */ + public void setIfModifiedSince(ZonedDateTime ifModifiedSince) { + setZonedDateTime(IF_MODIFIED_SINCE, ifModifiedSince.withZoneSameInstant(GMT)); + } + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @since 5.1.4 + */ + public void setIfModifiedSince(Instant ifModifiedSince) { + setInstant(IF_MODIFIED_SINCE, ifModifiedSince); + } + + /** + * Set the (new) value of the {@code If-Modified-Since} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + */ + public void setIfModifiedSince(long ifModifiedSince) { + setDate(IF_MODIFIED_SINCE, ifModifiedSince); + } + + /** + * Return the value of the {@code If-Modified-Since} header. + *

The date is returned as the number of milliseconds since + * January 1, 1970 GMT. Returns -1 when the date is unknown. + * @see #getFirstZonedDateTime(String) + */ + public long getIfModifiedSince() { + return getFirstDate(IF_MODIFIED_SINCE, false); + } + + /** + * Set the (new) value of the {@code If-None-Match} header. + */ + public void setIfNoneMatch(String ifNoneMatch) { + set(IF_NONE_MATCH, ifNoneMatch); + } + + /** + * Set the (new) values of the {@code If-None-Match} header. + */ + public void setIfNoneMatch(List ifNoneMatchList) { + set(IF_NONE_MATCH, toCommaDelimitedString(ifNoneMatchList)); + } + + /** + * Return the value of the {@code If-None-Match} header. + * @throws IllegalArgumentException if parsing fails + */ + public List getIfNoneMatch() { + return getETagValuesAsList(IF_NONE_MATCH); + } + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @since 5.1.4 + */ + public void setIfUnmodifiedSince(ZonedDateTime ifUnmodifiedSince) { + setZonedDateTime(IF_UNMODIFIED_SINCE, ifUnmodifiedSince.withZoneSameInstant(GMT)); + } + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @since 5.1.4 + */ + public void setIfUnmodifiedSince(Instant ifUnmodifiedSince) { + setInstant(IF_UNMODIFIED_SINCE, ifUnmodifiedSince); + } + + /** + * Set the (new) value of the {@code If-Unmodified-Since} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @since 4.3 + */ + public void setIfUnmodifiedSince(long ifUnmodifiedSince) { + setDate(IF_UNMODIFIED_SINCE, ifUnmodifiedSince); + } + + /** + * Return the value of the {@code If-Unmodified-Since} header. + *

The date is returned as the number of milliseconds since + * January 1, 1970 GMT. Returns -1 when the date is unknown. + * @since 4.3 + * @see #getFirstZonedDateTime(String) + */ + public long getIfUnmodifiedSince() { + return getFirstDate(IF_UNMODIFIED_SINCE, false); + } + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @since 5.1.4 + */ + public void setLastModified(ZonedDateTime lastModified) { + setZonedDateTime(LAST_MODIFIED, lastModified.withZoneSameInstant(GMT)); + } + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @since 5.1.4 + */ + public void setLastModified(Instant lastModified) { + setInstant(LAST_MODIFIED, lastModified); + } + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + */ + public void setLastModified(long lastModified) { + setDate(LAST_MODIFIED, lastModified); + } + + /** + * Return the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + *

The date is returned as the number of milliseconds since + * January 1, 1970 GMT. Returns -1 when the date is unknown. + * @see #getFirstZonedDateTime(String) + */ + public long getLastModified() { + return getFirstDate(LAST_MODIFIED, false); + } + + /** + * Set the (new) location of a resource, + * as specified by the {@code Location} header. + */ + public void setLocation(@Nullable URI location) { + setOrRemove(LOCATION, (location != null ? location.toASCIIString() : null)); + } + + /** + * Return the (new) location of a resource + * as specified by the {@code Location} header. + *

Returns {@code null} when the location is unknown. + */ + @Nullable + public URI getLocation() { + String value = getFirst(LOCATION); + return (value != null ? URI.create(value) : null); + } + + /** + * Set the (new) value of the {@code Origin} header. + */ + public void setOrigin(@Nullable String origin) { + setOrRemove(ORIGIN, origin); + } + + /** + * Return the value of the {@code Origin} header. + */ + @Nullable + public String getOrigin() { + return getFirst(ORIGIN); + } + + /** + * Set the (new) value of the {@code Pragma} header. + */ + public void setPragma(@Nullable String pragma) { + setOrRemove(PRAGMA, pragma); + } + + /** + * Return the value of the {@code Pragma} header. + */ + @Nullable + public String getPragma() { + return getFirst(PRAGMA); + } + + /** + * Sets the (new) value of the {@code Range} header. + */ + public void setRange(List ranges) { + String value = HttpRange.toString(ranges); + set(RANGE, value); + } + + /** + * Return the value of the {@code Range} header. + *

Returns an empty list when the range is unknown. + */ + public List getRange() { + String value = getFirst(RANGE); + return HttpRange.parseRanges(value); + } + + /** + * Set the (new) value of the {@code Upgrade} header. + */ + public void setUpgrade(@Nullable String upgrade) { + setOrRemove(UPGRADE, upgrade); + } + + /** + * Return the value of the {@code Upgrade} header. + */ + @Nullable + public String getUpgrade() { + return getFirst(UPGRADE); + } + + /** + * Set the request header names (e.g. "Accept-Language") for which the + * response is subject to content negotiation and variances based on the + * value of those request headers. + * @param requestHeaders the request header names + * @since 4.3 + */ + public void setVary(List requestHeaders) { + set(VARY, toCommaDelimitedString(requestHeaders)); + } + + /** + * Return the request header names subject to content negotiation. + * @since 4.3 + */ + public List getVary() { + return getValuesAsList(VARY); + } + + /** + * Set the given date under the given header name after formatting it as a string + * using the RFC-1123 date-time formatter. The equivalent of + * {@link #set(String, String)} but for date headers. + * @since 5.0 + */ + public void setZonedDateTime(String headerName, ZonedDateTime date) { + set(headerName, DATE_FORMATTER.format(date)); + } + + /** + * Set the given date under the given header name after formatting it as a string + * using the RFC-1123 date-time formatter. The equivalent of + * {@link #set(String, String)} but for date headers. + * @since 5.1.4 + */ + public void setInstant(String headerName, Instant date) { + setZonedDateTime(headerName, ZonedDateTime.ofInstant(date, GMT)); + } + + /** + * Set the given date under the given header name after formatting it as a string + * using the RFC-1123 date-time formatter. The equivalent of + * {@link #set(String, String)} but for date headers. + * @since 3.2.4 + * @see #setZonedDateTime(String, ZonedDateTime) + */ + public void setDate(String headerName, long date) { + setInstant(headerName, Instant.ofEpochMilli(date)); + } + + /** + * Parse the first header value for the given header name as a date, + * return -1 if there is no value, or raise {@link IllegalArgumentException} + * if the value cannot be parsed as a date. + * @param headerName the header name + * @return the parsed date header, or -1 if none + * @since 3.2.4 + * @see #getFirstZonedDateTime(String) + */ + public long getFirstDate(String headerName) { + return getFirstDate(headerName, true); + } + + /** + * Parse the first header value for the given header name as a date, + * return -1 if there is no value or also in case of an invalid value + * (if {@code rejectInvalid=false}), or raise {@link IllegalArgumentException} + * if the value cannot be parsed as a date. + * @param headerName the header name + * @param rejectInvalid whether to reject invalid values with an + * {@link IllegalArgumentException} ({@code true}) or rather return -1 + * in that case ({@code false}) + * @return the parsed date header, or -1 if none (or invalid) + * @see #getFirstZonedDateTime(String, boolean) + */ + private long getFirstDate(String headerName, boolean rejectInvalid) { + ZonedDateTime zonedDateTime = getFirstZonedDateTime(headerName, rejectInvalid); + return (zonedDateTime != null ? zonedDateTime.toInstant().toEpochMilli() : -1); + } + + /** + * Parse the first header value for the given header name as a date, + * return {@code null} if there is no value, or raise {@link IllegalArgumentException} + * if the value cannot be parsed as a date. + * @param headerName the header name + * @return the parsed date header, or {@code null} if none + * @since 5.0 + */ + @Nullable + public ZonedDateTime getFirstZonedDateTime(String headerName) { + return getFirstZonedDateTime(headerName, true); + } + + /** + * Parse the first header value for the given header name as a date, + * return {@code null} if there is no value or also in case of an invalid value + * (if {@code rejectInvalid=false}), or raise {@link IllegalArgumentException} + * if the value cannot be parsed as a date. + * @param headerName the header name + * @param rejectInvalid whether to reject invalid values with an + * {@link IllegalArgumentException} ({@code true}) or rather return {@code null} + * in that case ({@code false}) + * @return the parsed date header, or {@code null} if none (or invalid) + */ + @Nullable + private ZonedDateTime getFirstZonedDateTime(String headerName, boolean rejectInvalid) { + String headerValue = getFirst(headerName); + if (headerValue == null) { + // No header value sent at all + return null; + } + if (headerValue.length() >= 3) { + // Short "0" or "-1" like values are never valid HTTP date headers... + // Let's only bother with DateTimeFormatter parsing for long enough values. + + // See https://stackoverflow.com/questions/12626699/if-modified-since-http-header-passed-by-ie9-includes-length + int parametersIndex = headerValue.indexOf(';'); + if (parametersIndex != -1) { + headerValue = headerValue.substring(0, parametersIndex); + } + + for (DateTimeFormatter dateFormatter : DATE_PARSERS) { + try { + return ZonedDateTime.parse(headerValue, dateFormatter); + } + catch (DateTimeParseException ex) { + // ignore + } + } + + } + if (rejectInvalid) { + throw new IllegalArgumentException("Cannot parse date value \"" + headerValue + + "\" for \"" + headerName + "\" header"); + } + return null; + } + + /** + * Return all values of a given header name, + * even if this header is set multiple times. + * @param headerName the header name + * @return all associated values + * @since 4.3 + */ + public List getValuesAsList(String headerName) { + List values = get(headerName); + if (values != null) { + List result = new ArrayList<>(); + for (String value : values) { + if (value != null) { + Collections.addAll(result, StringUtils.tokenizeToStringArray(value, ",")); + } + } + return result; + } + return Collections.emptyList(); + } + + /** + * Retrieve a combined result from the field values of the ETag header. + * @param headerName the header name + * @return the combined result + * @throws IllegalArgumentException if parsing fails + * @since 4.3 + */ + protected List getETagValuesAsList(String headerName) { + List values = get(headerName); + if (values != null) { + List result = new ArrayList<>(); + for (String value : values) { + if (value != null) { + Matcher matcher = ETAG_HEADER_VALUE_PATTERN.matcher(value); + while (matcher.find()) { + if ("*".equals(matcher.group())) { + result.add(matcher.group()); + } + else { + result.add(matcher.group(1)); + } + } + if (result.isEmpty()) { + throw new IllegalArgumentException( + "Could not parse header '" + headerName + "' with value '" + value + "'"); + } + } + } + return result; + } + return Collections.emptyList(); + } + + /** + * Retrieve a combined result from the field values of multi-valued headers. + * @param headerName the header name + * @return the combined result + * @since 4.3 + */ + @Nullable + protected String getFieldValues(String headerName) { + List headerValues = get(headerName); + return (headerValues != null ? toCommaDelimitedString(headerValues) : null); + } + + /** + * Turn the given list of header values into a comma-delimited result. + * @param headerValues the list of header values + * @return a combined result with comma delimitation + */ + protected String toCommaDelimitedString(List headerValues) { + StringBuilder builder = new StringBuilder(); + for (Iterator it = headerValues.iterator(); it.hasNext();) { + String val = it.next(); + if (val != null) { + builder.append(val); + if (it.hasNext()) { + builder.append(", "); + } + } + } + return builder.toString(); + } + + /** + * Set the given header value, or remove the header if {@code null}. + * @param headerName the header name + * @param headerValue the header value, or {@code null} for none + */ + private void setOrRemove(String headerName, @Nullable String headerValue) { + if (headerValue != null) { + set(headerName, headerValue); + } + else { + remove(headerName); + } + } + + + // MultiValueMap implementation + + /** + * Return the first header value for the given header name, if any. + * @param headerName the header name + * @return the first header value, or {@code null} if none + */ + @Override + @Nullable + public String getFirst(String headerName) { + return this.headers.getFirst(headerName); + } + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValue the header value + * @throws UnsupportedOperationException if adding headers is not supported + * @see #put(String, List) + * @see #set(String, String) + */ + @Override + public void add(String headerName, @Nullable String headerValue) { + this.headers.add(headerName, headerValue); + } + + @Override + public void addAll(String key, List values) { + this.headers.addAll(key, values); + } + + @Override + public void addAll(MultiValueMap values) { + this.headers.addAll(values); + } + + /** + * Set the given, single header value under the given name. + * @param headerName the header name + * @param headerValue the header value + * @throws UnsupportedOperationException if adding headers is not supported + * @see #put(String, List) + * @see #add(String, String) + */ + @Override + public void set(String headerName, @Nullable String headerValue) { + this.headers.set(headerName, headerValue); + } + + @Override + public void setAll(Map values) { + this.headers.setAll(values); + } + + @Override + public Map toSingleValueMap() { + return this.headers.toSingleValueMap(); + } + + + // Map implementation + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return this.headers.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return this.headers.containsValue(value); + } + + @Override + @Nullable + public List get(Object key) { + return this.headers.get(key); + } + + @Override + public List put(String key, List value) { + return this.headers.put(key, value); + } + + @Override + public List remove(Object key) { + return this.headers.remove(key); + } + + @Override + public void putAll(Map> map) { + this.headers.putAll(map); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.keySet(); + } + + @Override + public Collection> values() { + return this.headers.values(); + } + + @Override + public Set>> entrySet() { + return this.headers.entrySet(); + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof HttpHeaders)) { + return false; + } + return unwrap(this).equals(unwrap((HttpHeaders) other)); + } + + private static MultiValueMap unwrap(HttpHeaders headers) { + while (headers.headers instanceof HttpHeaders) { + headers = (HttpHeaders) headers.headers; + } + return headers.headers; + } + + @Override + public int hashCode() { + return this.headers.hashCode(); + } + + @Override + public String toString() { + return formatHeaders(this.headers); + } + + + /** + * Apply a read-only {@code HttpHeaders} wrapper around the given headers, + * if necessary. + * @param headers the headers to expose + * @return a read-only variant of the headers, or the original headers as-is + */ + public static HttpHeaders readOnlyHttpHeaders(HttpHeaders headers) { + Assert.notNull(headers, "HttpHeaders must not be null"); + return (headers instanceof ReadOnlyHttpHeaders ? headers : new ReadOnlyHttpHeaders(headers.headers)); + } + + /** + * Remove any read-only wrapper that may have been previously applied around + * the given headers via {@link #readOnlyHttpHeaders(HttpHeaders)}. + * @param headers the headers to expose + * @return a writable variant of the headers, or the original headers as-is + * @since 5.1.1 + */ + public static HttpHeaders writableHttpHeaders(HttpHeaders headers) { + Assert.notNull(headers, "HttpHeaders must not be null"); + if (headers == EMPTY) { + return new HttpHeaders(); + } + return (headers instanceof ReadOnlyHttpHeaders ? new HttpHeaders(headers.headers) : headers); + } + + /** + * Helps to format HTTP header values, as HTTP header values themselves can + * contain comma-separated values, can become confusing with regular + * {@link Map} formatting that also uses commas between entries. + * @param headers the headers to format + * @return the headers to a String + * @since 5.1.4 + */ + public static String formatHeaders(MultiValueMap headers) { + return headers.entrySet().stream() + .map(entry -> { + List values = entry.getValue(); + return entry.getKey() + ":" + (values.size() == 1 ? + "\"" + values.get(0) + "\"" : + values.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(", "))); + }) + .collect(Collectors.joining(", ", "[", "]")); + } + + // Package-private: used in ResponseCookie + static String formatDate(long date) { + Instant instant = Instant.ofEpochMilli(date); + ZonedDateTime time = ZonedDateTime.ofInstant(instant, GMT); + return DATE_FORMATTER.format(time); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpInputMessage.java b/spring-web/src/main/java/org/springframework/http/HttpInputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..e499944e7902eb03e6045db215f658740e074113 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpInputMessage.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.IOException; +import java.io.InputStream; + +/** + * Represents an HTTP input message, consisting of {@linkplain #getHeaders() headers} + * and a readable {@linkplain #getBody() body}. + * + *

Typically implemented by an HTTP request handle on the server side, + * or an HTTP response handle on the client side. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public interface HttpInputMessage extends HttpMessage { + + /** + * Return the body of the message as an input stream. + * @return the input stream body (never {@code null}) + * @throws IOException in case of I/O errors + */ + InputStream getBody() throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpLogging.java b/spring-web/src/main/java/org/springframework/http/HttpLogging.java new file mode 100644 index 0000000000000000000000000000000000000000..b247492cbd07810f88689be2e4811e018627658a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpLogging.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogDelegateFactory; + +/** + * Holds the shared logger named "org.springframework.web.HttpLogging" for HTTP + * related logging when "org.springframework.http" is not enabled but + * "org.springframework.web" is. + * + *

That means "org.springframework.web" enables all web logging including + * from lower level packages such as "org.springframework.http" and modules + * such as codecs from {@literal "spring-core"} when those are wrapped with + * {@link org.springframework.http.codec.EncoderHttpMessageWriter EncoderHttpMessageWriter} or + * {@link org.springframework.http.codec.DecoderHttpMessageReader DecoderHttpMessageReader}. + * + *

To see logging from the primary class loggers simply enable logging for + * "org.springframework.http" and "org.springframework.codec". + * + * @author Rossen Stoyanchev + * @since 5.1 + * @see LogDelegateFactory + */ +public abstract class HttpLogging { + + private static final Log fallbackLogger = + LogFactory.getLog("org.springframework.web." + HttpLogging.class.getSimpleName()); + + + /** + * Create a primary logger for the given class and wrap it with a composite + * that delegates to it or to the fallback logger + * "org.springframework.web.HttpLogging", if the primary is not enabled. + * @param primaryLoggerClass the class for the name of the primary logger + * @return the resulting composite logger + */ + public static Log forLogName(Class primaryLoggerClass) { + Log primaryLogger = LogFactory.getLog(primaryLoggerClass); + return forLog(primaryLogger); + } + + /** + * Wrap the given primary logger with a composite logger that delegates to + * it or to the fallback logger "org.springframework.web.HttpLogging", + * if the primary is not enabled. + * @param primaryLogger the primary logger to use + * @return the resulting composite logger + */ + public static Log forLog(Log primaryLogger) { + return LogDelegateFactory.getCompositeLog(primaryLogger, fallbackLogger); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpMessage.java b/spring-web/src/main/java/org/springframework/http/HttpMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..336f666ccfd1e95c931d300d6429d3318558f73a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpMessage.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +/** + * Represents the base interface for HTTP request and response messages. + * Consists of {@link HttpHeaders}, retrievable via {@link #getHeaders()}. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public interface HttpMessage { + + /** + * Return the headers of this message. + * @return a corresponding HttpHeaders object (never {@code null}) + */ + HttpHeaders getHeaders(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpMethod.java b/spring-web/src/main/java/org/springframework/http/HttpMethod.java new file mode 100644 index 0000000000000000000000000000000000000000..b39b314c09b3efcda26261ec739d3d6cf486b3f6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpMethod.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.lang.Nullable; + +/** + * Java 5 enumeration of HTTP request methods. Intended for use + * with {@link org.springframework.http.client.ClientHttpRequest} + * and {@link org.springframework.web.client.RestTemplate}. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + */ +public enum HttpMethod { + + GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS, TRACE; + + + private static final Map mappings = new HashMap<>(16); + + static { + for (HttpMethod httpMethod : values()) { + mappings.put(httpMethod.name(), httpMethod); + } + } + + + /** + * Resolve the given method value to an {@code HttpMethod}. + * @param method the method value as a String + * @return the corresponding {@code HttpMethod}, or {@code null} if not found + * @since 4.2.4 + */ + @Nullable + public static HttpMethod resolve(@Nullable String method) { + return (method != null ? mappings.get(method) : null); + } + + + /** + * Determine whether this {@code HttpMethod} matches the given + * method value. + * @param method the method value as a String + * @return {@code true} if it matches, {@code false} otherwise + * @since 4.2.4 + */ + public boolean matches(String method) { + return (this == resolve(method)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpOutputMessage.java b/spring-web/src/main/java/org/springframework/http/HttpOutputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..b379afb435c9588c7bd8f17644da1df1f1a41f80 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpOutputMessage.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Represents an HTTP output message, consisting of {@linkplain #getHeaders() headers} + * and a writable {@linkplain #getBody() body}. + * + *

Typically implemented by an HTTP request handle on the client side, + * or an HTTP response handle on the server side. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public interface HttpOutputMessage extends HttpMessage { + + /** + * Return the body of the message as an output stream. + * @return the output stream body (never {@code null}) + * @throws IOException in case of I/O errors + */ + OutputStream getBody() throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpRange.java b/spring-web/src/main/java/org/springframework/http/HttpRange.java new file mode 100644 index 0000000000000000000000000000000000000000..3e0fbe3c123b80121204aceda6bdd778fbca51ba --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpRange.java @@ -0,0 +1,359 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourceRegion; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * Represents an HTTP (byte) range for use with the HTTP {@code "Range"} header. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 4.2 + * @see HTTP/1.1: Range Requests + * @see HttpHeaders#setRange(List) + * @see HttpHeaders#getRange() + */ +public abstract class HttpRange { + + /** Maximum ranges per request. */ + private static final int MAX_RANGES = 100; + + private static final String BYTE_RANGE_PREFIX = "bytes="; + + + /** + * Turn a {@code Resource} into a {@link ResourceRegion} using the range + * information contained in the current {@code HttpRange}. + * @param resource the {@code Resource} to select the region from + * @return the selected region of the given {@code Resource} + * @since 4.3 + */ + public ResourceRegion toResourceRegion(Resource resource) { + // Don't try to determine contentLength on InputStreamResource - cannot be read afterwards... + // Note: custom InputStreamResource subclasses could provide a pre-calculated content length! + Assert.isTrue(resource.getClass() != InputStreamResource.class, + "Cannot convert an InputStreamResource to a ResourceRegion"); + long contentLength = getLengthFor(resource); + long start = getRangeStart(contentLength); + long end = getRangeEnd(contentLength); + return new ResourceRegion(resource, start, end - start + 1); + } + + /** + * Return the start of the range given the total length of a representation. + * @param length the length of the representation + * @return the start of this range for the representation + */ + public abstract long getRangeStart(long length); + + /** + * Return the end of the range (inclusive) given the total length of a representation. + * @param length the length of the representation + * @return the end of the range for the representation + */ + public abstract long getRangeEnd(long length); + + + /** + * Create an {@code HttpRange} from the given position to the end. + * @param firstBytePos the first byte position + * @return a byte range that ranges from {@code firstPos} till the end + * @see Byte Ranges + */ + public static HttpRange createByteRange(long firstBytePos) { + return new ByteRange(firstBytePos, null); + } + + /** + * Create a {@code HttpRange} from the given fist to last position. + * @param firstBytePos the first byte position + * @param lastBytePos the last byte position + * @return a byte range that ranges from {@code firstPos} till {@code lastPos} + * @see Byte Ranges + */ + public static HttpRange createByteRange(long firstBytePos, long lastBytePos) { + return new ByteRange(firstBytePos, lastBytePos); + } + + /** + * Create an {@code HttpRange} that ranges over the last given number of bytes. + * @param suffixLength the number of bytes for the range + * @return a byte range that ranges over the last {@code suffixLength} number of bytes + * @see Byte Ranges + */ + public static HttpRange createSuffixRange(long suffixLength) { + return new SuffixByteRange(suffixLength); + } + + /** + * Parse the given, comma-separated string into a list of {@code HttpRange} objects. + *

This method can be used to parse an {@code Range} header. + * @param ranges the string to parse + * @return the list of ranges + * @throws IllegalArgumentException if the string cannot be parsed + * or if the number of ranges is greater than 100 + */ + public static List parseRanges(@Nullable String ranges) { + if (!StringUtils.hasLength(ranges)) { + return Collections.emptyList(); + } + if (!ranges.startsWith(BYTE_RANGE_PREFIX)) { + throw new IllegalArgumentException("Range '" + ranges + "' does not start with 'bytes='"); + } + ranges = ranges.substring(BYTE_RANGE_PREFIX.length()); + + String[] tokens = StringUtils.tokenizeToStringArray(ranges, ","); + if (tokens.length > MAX_RANGES) { + throw new IllegalArgumentException("Too many ranges: " + tokens.length); + } + List result = new ArrayList<>(tokens.length); + for (String token : tokens) { + result.add(parseRange(token)); + } + return result; + } + + private static HttpRange parseRange(String range) { + Assert.hasLength(range, "Range String must not be empty"); + int dashIdx = range.indexOf('-'); + if (dashIdx > 0) { + long firstPos = Long.parseLong(range.substring(0, dashIdx)); + if (dashIdx < range.length() - 1) { + Long lastPos = Long.parseLong(range.substring(dashIdx + 1)); + return new ByteRange(firstPos, lastPos); + } + else { + return new ByteRange(firstPos, null); + } + } + else if (dashIdx == 0) { + long suffixLength = Long.parseLong(range.substring(1)); + return new SuffixByteRange(suffixLength); + } + else { + throw new IllegalArgumentException("Range '" + range + "' does not contain \"-\""); + } + } + + /** + * Convert each {@code HttpRange} into a {@code ResourceRegion}, selecting the + * appropriate segment of the given {@code Resource} using HTTP Range information. + * @param ranges the list of ranges + * @param resource the resource to select the regions from + * @return the list of regions for the given resource + * @throws IllegalArgumentException if the sum of all ranges exceeds the resource length + * @since 4.3 + */ + public static List toResourceRegions(List ranges, Resource resource) { + if (CollectionUtils.isEmpty(ranges)) { + return Collections.emptyList(); + } + List regions = new ArrayList<>(ranges.size()); + for (HttpRange range : ranges) { + regions.add(range.toResourceRegion(resource)); + } + if (ranges.size() > 1) { + long length = getLengthFor(resource); + long total = 0; + for (ResourceRegion region : regions) { + total += region.getCount(); + } + if (total >= length) { + throw new IllegalArgumentException("The sum of all ranges (" + total + + ") should be less than the resource length (" + length + ")"); + } + } + return regions; + } + + private static long getLengthFor(Resource resource) { + try { + long contentLength = resource.contentLength(); + Assert.isTrue(contentLength > 0, "Resource content length should be > 0"); + return contentLength; + } + catch (IOException ex) { + throw new IllegalArgumentException("Failed to obtain Resource content length", ex); + } + } + + /** + * Return a string representation of the given list of {@code HttpRange} objects. + *

This method can be used to for an {@code Range} header. + * @param ranges the ranges to create a string of + * @return the string representation + */ + public static String toString(Collection ranges) { + Assert.notEmpty(ranges, "Ranges Collection must not be empty"); + StringBuilder builder = new StringBuilder(BYTE_RANGE_PREFIX); + for (Iterator iterator = ranges.iterator(); iterator.hasNext(); ) { + HttpRange range = iterator.next(); + builder.append(range); + if (iterator.hasNext()) { + builder.append(", "); + } + } + return builder.toString(); + } + + + /** + * Represents an HTTP/1.1 byte range, with a first and optional last position. + * @see Byte Ranges + * @see HttpRange#createByteRange(long) + * @see HttpRange#createByteRange(long, long) + */ + private static class ByteRange extends HttpRange { + + private final long firstPos; + + @Nullable + private final Long lastPos; + + public ByteRange(long firstPos, @Nullable Long lastPos) { + assertPositions(firstPos, lastPos); + this.firstPos = firstPos; + this.lastPos = lastPos; + } + + private void assertPositions(long firstBytePos, @Nullable Long lastBytePos) { + if (firstBytePos < 0) { + throw new IllegalArgumentException("Invalid first byte position: " + firstBytePos); + } + if (lastBytePos != null && lastBytePos < firstBytePos) { + throw new IllegalArgumentException("firstBytePosition=" + firstBytePos + + " should be less then or equal to lastBytePosition=" + lastBytePos); + } + } + + @Override + public long getRangeStart(long length) { + return this.firstPos; + } + + @Override + public long getRangeEnd(long length) { + if (this.lastPos != null && this.lastPos < length) { + return this.lastPos; + } + else { + return length - 1; + } + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ByteRange)) { + return false; + } + ByteRange otherRange = (ByteRange) other; + return (this.firstPos == otherRange.firstPos && + ObjectUtils.nullSafeEquals(this.lastPos, otherRange.lastPos)); + } + + @Override + public int hashCode() { + return (ObjectUtils.nullSafeHashCode(this.firstPos) * 31 + + ObjectUtils.nullSafeHashCode(this.lastPos)); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append(this.firstPos); + builder.append('-'); + if (this.lastPos != null) { + builder.append(this.lastPos); + } + return builder.toString(); + } + } + + + /** + * Represents an HTTP/1.1 suffix byte range, with a number of suffix bytes. + * @see Byte Ranges + * @see HttpRange#createSuffixRange(long) + */ + private static class SuffixByteRange extends HttpRange { + + private final long suffixLength; + + public SuffixByteRange(long suffixLength) { + if (suffixLength < 0) { + throw new IllegalArgumentException("Invalid suffix length: " + suffixLength); + } + this.suffixLength = suffixLength; + } + + @Override + public long getRangeStart(long length) { + if (this.suffixLength < length) { + return length - this.suffixLength; + } + else { + return 0; + } + } + + @Override + public long getRangeEnd(long length) { + return length - 1; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof SuffixByteRange)) { + return false; + } + SuffixByteRange otherRange = (SuffixByteRange) other; + return (this.suffixLength == otherRange.suffixLength); + } + + @Override + public int hashCode() { + return Long.hashCode(this.suffixLength); + } + + @Override + public String toString() { + return "-" + this.suffixLength; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpRequest.java b/spring-web/src/main/java/org/springframework/http/HttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..ed82012166e2ba06fa790fde192083e9924bdb2c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpRequest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.net.URI; + +import org.springframework.lang.Nullable; + +/** + * Represents an HTTP request message, consisting of + * {@linkplain #getMethod() method} and {@linkplain #getURI() uri}. + * + * @author Arjen Poutsma + * @since 3.1 + */ +public interface HttpRequest extends HttpMessage { + + /** + * Return the HTTP method of the request. + * @return the HTTP method as an HttpMethod enum value, or {@code null} + * if not resolvable (e.g. in case of a non-standard HTTP method) + * @see #getMethodValue() + * @see HttpMethod#resolve(String) + */ + @Nullable + default HttpMethod getMethod() { + return HttpMethod.resolve(getMethodValue()); + } + + /** + * Return the HTTP method of the request as a String value. + * @return the HTTP method as a plain String + * @since 5.0 + * @see #getMethod() + */ + String getMethodValue(); + + /** + * Return the URI of the request (including a query string if any, + * but only if it is well-formed for a URI representation). + * @return the URI of the request (never {@code null}) + */ + URI getURI(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/HttpStatus.java b/spring-web/src/main/java/org/springframework/http/HttpStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..13fc01036f265583d5392c8b509f1fecddbe1c86 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/HttpStatus.java @@ -0,0 +1,621 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.springframework.lang.Nullable; + +/** + * Enumeration of HTTP status codes. + * + *

The HTTP status code series can be retrieved via {@link #series()}. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Brian Clozel + * @since 3.0 + * @see HttpStatus.Series + * @see HTTP Status Code Registry + * @see List of HTTP status codes - Wikipedia + */ +public enum HttpStatus { + + // 1xx Informational + + /** + * {@code 100 Continue}. + * @see HTTP/1.1: Semantics and Content, section 6.2.1 + */ + CONTINUE(100, "Continue"), + /** + * {@code 101 Switching Protocols}. + * @see HTTP/1.1: Semantics and Content, section 6.2.2 + */ + SWITCHING_PROTOCOLS(101, "Switching Protocols"), + /** + * {@code 102 Processing}. + * @see WebDAV + */ + PROCESSING(102, "Processing"), + /** + * {@code 103 Checkpoint}. + * @see A proposal for supporting + * resumable POST/PUT HTTP requests in HTTP/1.0 + */ + CHECKPOINT(103, "Checkpoint"), + + // 2xx Success + + /** + * {@code 200 OK}. + * @see HTTP/1.1: Semantics and Content, section 6.3.1 + */ + OK(200, "OK"), + /** + * {@code 201 Created}. + * @see HTTP/1.1: Semantics and Content, section 6.3.2 + */ + CREATED(201, "Created"), + /** + * {@code 202 Accepted}. + * @see HTTP/1.1: Semantics and Content, section 6.3.3 + */ + ACCEPTED(202, "Accepted"), + /** + * {@code 203 Non-Authoritative Information}. + * @see HTTP/1.1: Semantics and Content, section 6.3.4 + */ + NON_AUTHORITATIVE_INFORMATION(203, "Non-Authoritative Information"), + /** + * {@code 204 No Content}. + * @see HTTP/1.1: Semantics and Content, section 6.3.5 + */ + NO_CONTENT(204, "No Content"), + /** + * {@code 205 Reset Content}. + * @see HTTP/1.1: Semantics and Content, section 6.3.6 + */ + RESET_CONTENT(205, "Reset Content"), + /** + * {@code 206 Partial Content}. + * @see HTTP/1.1: Range Requests, section 4.1 + */ + PARTIAL_CONTENT(206, "Partial Content"), + /** + * {@code 207 Multi-Status}. + * @see WebDAV + */ + MULTI_STATUS(207, "Multi-Status"), + /** + * {@code 208 Already Reported}. + * @see WebDAV Binding Extensions + */ + ALREADY_REPORTED(208, "Already Reported"), + /** + * {@code 226 IM Used}. + * @see Delta encoding in HTTP + */ + IM_USED(226, "IM Used"), + + // 3xx Redirection + + /** + * {@code 300 Multiple Choices}. + * @see HTTP/1.1: Semantics and Content, section 6.4.1 + */ + MULTIPLE_CHOICES(300, "Multiple Choices"), + /** + * {@code 301 Moved Permanently}. + * @see HTTP/1.1: Semantics and Content, section 6.4.2 + */ + MOVED_PERMANENTLY(301, "Moved Permanently"), + /** + * {@code 302 Found}. + * @see HTTP/1.1: Semantics and Content, section 6.4.3 + */ + FOUND(302, "Found"), + /** + * {@code 302 Moved Temporarily}. + * @see HTTP/1.0, section 9.3 + * @deprecated in favor of {@link #FOUND} which will be returned from {@code HttpStatus.valueOf(302)} + */ + @Deprecated + MOVED_TEMPORARILY(302, "Moved Temporarily"), + /** + * {@code 303 See Other}. + * @see HTTP/1.1: Semantics and Content, section 6.4.4 + */ + SEE_OTHER(303, "See Other"), + /** + * {@code 304 Not Modified}. + * @see HTTP/1.1: Conditional Requests, section 4.1 + */ + NOT_MODIFIED(304, "Not Modified"), + /** + * {@code 305 Use Proxy}. + * @see HTTP/1.1: Semantics and Content, section 6.4.5 + * @deprecated due to security concerns regarding in-band configuration of a proxy + */ + @Deprecated + USE_PROXY(305, "Use Proxy"), + /** + * {@code 307 Temporary Redirect}. + * @see HTTP/1.1: Semantics and Content, section 6.4.7 + */ + TEMPORARY_REDIRECT(307, "Temporary Redirect"), + /** + * {@code 308 Permanent Redirect}. + * @see RFC 7238 + */ + PERMANENT_REDIRECT(308, "Permanent Redirect"), + + // --- 4xx Client Error --- + + /** + * {@code 400 Bad Request}. + * @see HTTP/1.1: Semantics and Content, section 6.5.1 + */ + BAD_REQUEST(400, "Bad Request"), + /** + * {@code 401 Unauthorized}. + * @see HTTP/1.1: Authentication, section 3.1 + */ + UNAUTHORIZED(401, "Unauthorized"), + /** + * {@code 402 Payment Required}. + * @see HTTP/1.1: Semantics and Content, section 6.5.2 + */ + PAYMENT_REQUIRED(402, "Payment Required"), + /** + * {@code 403 Forbidden}. + * @see HTTP/1.1: Semantics and Content, section 6.5.3 + */ + FORBIDDEN(403, "Forbidden"), + /** + * {@code 404 Not Found}. + * @see HTTP/1.1: Semantics and Content, section 6.5.4 + */ + NOT_FOUND(404, "Not Found"), + /** + * {@code 405 Method Not Allowed}. + * @see HTTP/1.1: Semantics and Content, section 6.5.5 + */ + METHOD_NOT_ALLOWED(405, "Method Not Allowed"), + /** + * {@code 406 Not Acceptable}. + * @see HTTP/1.1: Semantics and Content, section 6.5.6 + */ + NOT_ACCEPTABLE(406, "Not Acceptable"), + /** + * {@code 407 Proxy Authentication Required}. + * @see HTTP/1.1: Authentication, section 3.2 + */ + PROXY_AUTHENTICATION_REQUIRED(407, "Proxy Authentication Required"), + /** + * {@code 408 Request Timeout}. + * @see HTTP/1.1: Semantics and Content, section 6.5.7 + */ + REQUEST_TIMEOUT(408, "Request Timeout"), + /** + * {@code 409 Conflict}. + * @see HTTP/1.1: Semantics and Content, section 6.5.8 + */ + CONFLICT(409, "Conflict"), + /** + * {@code 410 Gone}. + * @see + * HTTP/1.1: Semantics and Content, section 6.5.9 + */ + GONE(410, "Gone"), + /** + * {@code 411 Length Required}. + * @see + * HTTP/1.1: Semantics and Content, section 6.5.10 + */ + LENGTH_REQUIRED(411, "Length Required"), + /** + * {@code 412 Precondition failed}. + * @see + * HTTP/1.1: Conditional Requests, section 4.2 + */ + PRECONDITION_FAILED(412, "Precondition Failed"), + /** + * {@code 413 Payload Too Large}. + * @since 4.1 + * @see + * HTTP/1.1: Semantics and Content, section 6.5.11 + */ + PAYLOAD_TOO_LARGE(413, "Payload Too Large"), + /** + * {@code 413 Request Entity Too Large}. + * @see HTTP/1.1, section 10.4.14 + * @deprecated in favor of {@link #PAYLOAD_TOO_LARGE} which will be + * returned from {@code HttpStatus.valueOf(413)} + */ + @Deprecated + REQUEST_ENTITY_TOO_LARGE(413, "Request Entity Too Large"), + /** + * {@code 414 URI Too Long}. + * @since 4.1 + * @see + * HTTP/1.1: Semantics and Content, section 6.5.12 + */ + URI_TOO_LONG(414, "URI Too Long"), + /** + * {@code 414 Request-URI Too Long}. + * @see HTTP/1.1, section 10.4.15 + * @deprecated in favor of {@link #URI_TOO_LONG} which will be returned from {@code HttpStatus.valueOf(414)} + */ + @Deprecated + REQUEST_URI_TOO_LONG(414, "Request-URI Too Long"), + /** + * {@code 415 Unsupported Media Type}. + * @see + * HTTP/1.1: Semantics and Content, section 6.5.13 + */ + UNSUPPORTED_MEDIA_TYPE(415, "Unsupported Media Type"), + /** + * {@code 416 Requested Range Not Satisfiable}. + * @see HTTP/1.1: Range Requests, section 4.4 + */ + REQUESTED_RANGE_NOT_SATISFIABLE(416, "Requested range not satisfiable"), + /** + * {@code 417 Expectation Failed}. + * @see + * HTTP/1.1: Semantics and Content, section 6.5.14 + */ + EXPECTATION_FAILED(417, "Expectation Failed"), + /** + * {@code 418 I'm a teapot}. + * @see HTCPCP/1.0 + */ + I_AM_A_TEAPOT(418, "I'm a teapot"), + /** + * @deprecated See + * + * WebDAV Draft Changes + */ + @Deprecated + INSUFFICIENT_SPACE_ON_RESOURCE(419, "Insufficient Space On Resource"), + /** + * @deprecated See + * + * WebDAV Draft Changes + */ + @Deprecated + METHOD_FAILURE(420, "Method Failure"), + /** + * @deprecated + * See + * WebDAV Draft Changes + */ + @Deprecated + DESTINATION_LOCKED(421, "Destination Locked"), + /** + * {@code 422 Unprocessable Entity}. + * @see WebDAV + */ + UNPROCESSABLE_ENTITY(422, "Unprocessable Entity"), + /** + * {@code 423 Locked}. + * @see WebDAV + */ + LOCKED(423, "Locked"), + /** + * {@code 424 Failed Dependency}. + * @see WebDAV + */ + FAILED_DEPENDENCY(424, "Failed Dependency"), + /** + * {@code 426 Upgrade Required}. + * @see Upgrading to TLS Within HTTP/1.1 + */ + UPGRADE_REQUIRED(426, "Upgrade Required"), + /** + * {@code 428 Precondition Required}. + * @see Additional HTTP Status Codes + */ + PRECONDITION_REQUIRED(428, "Precondition Required"), + /** + * {@code 429 Too Many Requests}. + * @see Additional HTTP Status Codes + */ + TOO_MANY_REQUESTS(429, "Too Many Requests"), + /** + * {@code 431 Request Header Fields Too Large}. + * @see Additional HTTP Status Codes + */ + REQUEST_HEADER_FIELDS_TOO_LARGE(431, "Request Header Fields Too Large"), + /** + * {@code 451 Unavailable For Legal Reasons}. + * @see + * An HTTP Status Code to Report Legal Obstacles + * @since 4.3 + */ + UNAVAILABLE_FOR_LEGAL_REASONS(451, "Unavailable For Legal Reasons"), + + // --- 5xx Server Error --- + + /** + * {@code 500 Internal Server Error}. + * @see HTTP/1.1: Semantics and Content, section 6.6.1 + */ + INTERNAL_SERVER_ERROR(500, "Internal Server Error"), + /** + * {@code 501 Not Implemented}. + * @see HTTP/1.1: Semantics and Content, section 6.6.2 + */ + NOT_IMPLEMENTED(501, "Not Implemented"), + /** + * {@code 502 Bad Gateway}. + * @see HTTP/1.1: Semantics and Content, section 6.6.3 + */ + BAD_GATEWAY(502, "Bad Gateway"), + /** + * {@code 503 Service Unavailable}. + * @see HTTP/1.1: Semantics and Content, section 6.6.4 + */ + SERVICE_UNAVAILABLE(503, "Service Unavailable"), + /** + * {@code 504 Gateway Timeout}. + * @see HTTP/1.1: Semantics and Content, section 6.6.5 + */ + GATEWAY_TIMEOUT(504, "Gateway Timeout"), + /** + * {@code 505 HTTP Version Not Supported}. + * @see HTTP/1.1: Semantics and Content, section 6.6.6 + */ + HTTP_VERSION_NOT_SUPPORTED(505, "HTTP Version not supported"), + /** + * {@code 506 Variant Also Negotiates} + * @see Transparent Content Negotiation + */ + VARIANT_ALSO_NEGOTIATES(506, "Variant Also Negotiates"), + /** + * {@code 507 Insufficient Storage} + * @see WebDAV + */ + INSUFFICIENT_STORAGE(507, "Insufficient Storage"), + /** + * {@code 508 Loop Detected} + * @see WebDAV Binding Extensions + */ + LOOP_DETECTED(508, "Loop Detected"), + /** + * {@code 509 Bandwidth Limit Exceeded} + */ + BANDWIDTH_LIMIT_EXCEEDED(509, "Bandwidth Limit Exceeded"), + /** + * {@code 510 Not Extended} + * @see HTTP Extension Framework + */ + NOT_EXTENDED(510, "Not Extended"), + /** + * {@code 511 Network Authentication Required}. + * @see Additional HTTP Status Codes + */ + NETWORK_AUTHENTICATION_REQUIRED(511, "Network Authentication Required"); + + + private final int value; + + private final String reasonPhrase; + + + HttpStatus(int value, String reasonPhrase) { + this.value = value; + this.reasonPhrase = reasonPhrase; + } + + + /** + * Return the integer value of this status code. + */ + public int value() { + return this.value; + } + + /** + * Return the reason phrase of this status code. + */ + public String getReasonPhrase() { + return this.reasonPhrase; + } + + /** + * Return the HTTP status series of this status code. + * @see HttpStatus.Series + */ + public Series series() { + return Series.valueOf(this); + } + + /** + * Whether this status code is in the HTTP series + * {@link org.springframework.http.HttpStatus.Series#INFORMATIONAL}. + * This is a shortcut for checking the value of {@link #series()}. + * @since 4.0 + * @see #series() + */ + public boolean is1xxInformational() { + return (series() == Series.INFORMATIONAL); + } + + /** + * Whether this status code is in the HTTP series + * {@link org.springframework.http.HttpStatus.Series#SUCCESSFUL}. + * This is a shortcut for checking the value of {@link #series()}. + * @since 4.0 + * @see #series() + */ + public boolean is2xxSuccessful() { + return (series() == Series.SUCCESSFUL); + } + + /** + * Whether this status code is in the HTTP series + * {@link org.springframework.http.HttpStatus.Series#REDIRECTION}. + * This is a shortcut for checking the value of {@link #series()}. + * @since 4.0 + * @see #series() + */ + public boolean is3xxRedirection() { + return (series() == Series.REDIRECTION); + } + + /** + * Whether this status code is in the HTTP series + * {@link org.springframework.http.HttpStatus.Series#CLIENT_ERROR}. + * This is a shortcut for checking the value of {@link #series()}. + * @since 4.0 + * @see #series() + */ + public boolean is4xxClientError() { + return (series() == Series.CLIENT_ERROR); + } + + /** + * Whether this status code is in the HTTP series + * {@link org.springframework.http.HttpStatus.Series#SERVER_ERROR}. + * This is a shortcut for checking the value of {@link #series()}. + * @since 4.0 + * @see #series() + */ + public boolean is5xxServerError() { + return (series() == Series.SERVER_ERROR); + } + + /** + * Whether this status code is in the HTTP series + * {@link org.springframework.http.HttpStatus.Series#CLIENT_ERROR} or + * {@link org.springframework.http.HttpStatus.Series#SERVER_ERROR}. + * This is a shortcut for checking the value of {@link #series()}. + * @since 5.0 + * @see #is4xxClientError() + * @see #is5xxServerError() + */ + public boolean isError() { + return (is4xxClientError() || is5xxServerError()); + } + + /** + * Return a string representation of this status code. + */ + @Override + public String toString() { + return this.value + " " + name(); + } + + + /** + * Return the enum constant of this type with the specified numeric value. + * @param statusCode the numeric value of the enum to be returned + * @return the enum constant with the specified numeric value + * @throws IllegalArgumentException if this enum has no constant for the specified numeric value + */ + public static HttpStatus valueOf(int statusCode) { + HttpStatus status = resolve(statusCode); + if (status == null) { + throw new IllegalArgumentException("No matching constant for [" + statusCode + "]"); + } + return status; + } + + /** + * Resolve the given status code to an {@code HttpStatus}, if possible. + * @param statusCode the HTTP status code (potentially non-standard) + * @return the corresponding {@code HttpStatus}, or {@code null} if not found + * @since 5.0 + */ + @Nullable + public static HttpStatus resolve(int statusCode) { + for (HttpStatus status : values()) { + if (status.value == statusCode) { + return status; + } + } + return null; + } + + + /** + * Enumeration of HTTP status series. + *

Retrievable via {@link HttpStatus#series()}. + */ + public enum Series { + + INFORMATIONAL(1), + SUCCESSFUL(2), + REDIRECTION(3), + CLIENT_ERROR(4), + SERVER_ERROR(5); + + private final int value; + + Series(int value) { + this.value = value; + } + + /** + * Return the integer value of this status series. Ranges from 1 to 5. + */ + public int value() { + return this.value; + } + + /** + * Return the enum constant of this type with the corresponding series. + * @param status a standard HTTP status enum value + * @return the enum constant of this type with the corresponding series + * @throws IllegalArgumentException if this enum has no corresponding constant + */ + public static Series valueOf(HttpStatus status) { + return valueOf(status.value); + } + + /** + * Return the enum constant of this type with the corresponding series. + * @param statusCode the HTTP status code (potentially non-standard) + * @return the enum constant of this type with the corresponding series + * @throws IllegalArgumentException if this enum has no corresponding constant + */ + public static Series valueOf(int statusCode) { + Series series = resolve(statusCode); + if (series == null) { + throw new IllegalArgumentException("No matching constant for [" + statusCode + "]"); + } + return series; + } + + /** + * Resolve the given status code to an {@code HttpStatus.Series}, if possible. + * @param statusCode the HTTP status code (potentially non-standard) + * @return the corresponding {@code Series}, or {@code null} if not found + * @since 5.1.3 + */ + @Nullable + public static Series resolve(int statusCode) { + int seriesCode = statusCode / 100; + for (Series series : values()) { + if (series.value == seriesCode) { + return series; + } + } + return null; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/InvalidMediaTypeException.java b/spring-web/src/main/java/org/springframework/http/InvalidMediaTypeException.java new file mode 100644 index 0000000000000000000000000000000000000000..ec324b8353d03bf9444ec64d53604ff0207c8739 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/InvalidMediaTypeException.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.springframework.util.InvalidMimeTypeException; + +/** + * Exception thrown from {@link MediaType#parseMediaType(String)} in case of + * encountering an invalid media type specification String. + * + * @author Juergen Hoeller + * @since 3.2.2 + */ +@SuppressWarnings("serial") +public class InvalidMediaTypeException extends IllegalArgumentException { + + private final String mediaType; + + + /** + * Create a new InvalidMediaTypeException for the given media type. + * @param mediaType the offending media type + * @param message a detail message indicating the invalid part + */ + public InvalidMediaTypeException(String mediaType, String message) { + super("Invalid media type \"" + mediaType + "\": " + message); + this.mediaType = mediaType; + } + + /** + * Constructor that allows wrapping {@link InvalidMimeTypeException}. + */ + InvalidMediaTypeException(InvalidMimeTypeException ex) { + super(ex.getMessage(), ex); + this.mediaType = ex.getMimeType(); + } + + + /** + * Return the offending media type. + */ + public String getMediaType() { + return this.mediaType; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/MediaType.java b/spring-web/src/main/java/org/springframework/http/MediaType.java new file mode 100644 index 0000000000000000000000000000000000000000..1b096b014aaaedc65a3226cd550c5fd64679ed98 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/MediaType.java @@ -0,0 +1,743 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.Serializable; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.InvalidMimeTypeException; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; + +/** + * A subclass of {@link MimeType} that adds support for quality parameters + * as defined in the HTTP specification. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @author Kazuki Shimizu + * @author Sam Brannen + * @since 3.0 + * @see + * HTTP 1.1: Semantics and Content, section 3.1.1.1 + */ +public class MediaType extends MimeType implements Serializable { + + private static final long serialVersionUID = 2069937152339670231L; + + /** + * Public constant media type that includes all media ranges (i.e. "*/*"). + */ + public static final MediaType ALL; + + /** + * A String equivalent of {@link MediaType#ALL}. + */ + public static final String ALL_VALUE = "*/*"; + + /** + * Public constant media type for {@code application/atom+xml}. + */ + public static final MediaType APPLICATION_ATOM_XML; + + /** + * A String equivalent of {@link MediaType#APPLICATION_ATOM_XML}. + */ + public static final String APPLICATION_ATOM_XML_VALUE = "application/atom+xml"; + + /** + * Public constant media type for {@code application/x-www-form-urlencoded}. + */ + public static final MediaType APPLICATION_FORM_URLENCODED; + + /** + * A String equivalent of {@link MediaType#APPLICATION_FORM_URLENCODED}. + */ + public static final String APPLICATION_FORM_URLENCODED_VALUE = "application/x-www-form-urlencoded"; + + /** + * Public constant media type for {@code application/json}. + */ + public static final MediaType APPLICATION_JSON; + + /** + * A String equivalent of {@link MediaType#APPLICATION_JSON}. + * @see #APPLICATION_JSON_UTF8_VALUE + */ + public static final String APPLICATION_JSON_VALUE = "application/json"; + + /** + * Public constant media type for {@code application/json;charset=UTF-8}. + *

This {@link MediaType#APPLICATION_JSON} variant should be used to set JSON + * content type because while + * RFC7159 + * clearly states that "no charset parameter is defined for this registration", some + * browsers require it for interpreting correctly UTF-8 special characters. + */ + public static final MediaType APPLICATION_JSON_UTF8; + + /** + * A String equivalent of {@link MediaType#APPLICATION_JSON_UTF8}. + *

This {@link MediaType#APPLICATION_JSON_VALUE} variant should be used to set JSON + * content type because while + * RFC7159 + * clearly states that "no charset parameter is defined for this registration", some + * browsers require it for interpreting correctly UTF-8 special characters. + */ + public static final String APPLICATION_JSON_UTF8_VALUE = "application/json;charset=UTF-8"; + + /** + * Public constant media type for {@code application/octet-stream}. + */ + public static final MediaType APPLICATION_OCTET_STREAM; + + /** + * A String equivalent of {@link MediaType#APPLICATION_OCTET_STREAM}. + */ + public static final String APPLICATION_OCTET_STREAM_VALUE = "application/octet-stream"; + + /** + * Public constant media type for {@code application/pdf}. + * @since 4.3 + */ + public static final MediaType APPLICATION_PDF; + + /** + * A String equivalent of {@link MediaType#APPLICATION_PDF}. + * @since 4.3 + */ + public static final String APPLICATION_PDF_VALUE = "application/pdf"; + + /** + * Public constant media type for {@code application/problem+json}. + * @since 5.0 + * @see + * Problem Details for HTTP APIs, 6.1. application/problem+json + */ + public static final MediaType APPLICATION_PROBLEM_JSON; + + /** + * A String equivalent of {@link MediaType#APPLICATION_PROBLEM_JSON}. + * @since 5.0 + */ + public static final String APPLICATION_PROBLEM_JSON_VALUE = "application/problem+json"; + + /** + * Public constant media type for {@code application/problem+json}. + * @since 5.0 + * @see + * Problem Details for HTTP APIs, 6.1. application/problem+json + */ + public static final MediaType APPLICATION_PROBLEM_JSON_UTF8; + + /** + * A String equivalent of {@link MediaType#APPLICATION_PROBLEM_JSON_UTF8}. + * @since 5.0 + */ + public static final String APPLICATION_PROBLEM_JSON_UTF8_VALUE = "application/problem+json;charset=UTF-8"; + + /** + * Public constant media type for {@code application/problem+xml}. + * @since 5.0 + * @see + * Problem Details for HTTP APIs, 6.2. application/problem+xml + */ + public static final MediaType APPLICATION_PROBLEM_XML; + + /** + * A String equivalent of {@link MediaType#APPLICATION_PROBLEM_XML}. + * @since 5.0 + */ + public static final String APPLICATION_PROBLEM_XML_VALUE = "application/problem+xml"; + + /** + * Public constant media type for {@code application/rss+xml}. + * @since 4.3.6 + */ + public static final MediaType APPLICATION_RSS_XML; + + /** + * A String equivalent of {@link MediaType#APPLICATION_RSS_XML}. + * @since 4.3.6 + */ + public static final String APPLICATION_RSS_XML_VALUE = "application/rss+xml"; + + /** + * Public constant media type for {@code application/stream+json}. + * @since 5.0 + */ + public static final MediaType APPLICATION_STREAM_JSON; + + /** + * A String equivalent of {@link MediaType#APPLICATION_STREAM_JSON}. + * @since 5.0 + */ + public static final String APPLICATION_STREAM_JSON_VALUE = "application/stream+json"; + + /** + * Public constant media type for {@code application/xhtml+xml}. + */ + public static final MediaType APPLICATION_XHTML_XML; + + /** + * A String equivalent of {@link MediaType#APPLICATION_XHTML_XML}. + */ + public static final String APPLICATION_XHTML_XML_VALUE = "application/xhtml+xml"; + + /** + * Public constant media type for {@code application/xml}. + */ + public static final MediaType APPLICATION_XML; + + /** + * A String equivalent of {@link MediaType#APPLICATION_XML}. + */ + public static final String APPLICATION_XML_VALUE = "application/xml"; + + /** + * Public constant media type for {@code image/gif}. + */ + public static final MediaType IMAGE_GIF; + + /** + * A String equivalent of {@link MediaType#IMAGE_GIF}. + */ + public static final String IMAGE_GIF_VALUE = "image/gif"; + + /** + * Public constant media type for {@code image/jpeg}. + */ + public static final MediaType IMAGE_JPEG; + + /** + * A String equivalent of {@link MediaType#IMAGE_JPEG}. + */ + public static final String IMAGE_JPEG_VALUE = "image/jpeg"; + + /** + * Public constant media type for {@code image/png}. + */ + public static final MediaType IMAGE_PNG; + + /** + * A String equivalent of {@link MediaType#IMAGE_PNG}. + */ + public static final String IMAGE_PNG_VALUE = "image/png"; + + /** + * Public constant media type for {@code multipart/form-data}. + */ + public static final MediaType MULTIPART_FORM_DATA; + + /** + * A String equivalent of {@link MediaType#MULTIPART_FORM_DATA}. + */ + public static final String MULTIPART_FORM_DATA_VALUE = "multipart/form-data"; + + /** + * Public constant media type for {@code text/event-stream}. + * @since 4.3.6 + * @see Server-Sent Events W3C recommendation + */ + public static final MediaType TEXT_EVENT_STREAM; + + /** + * A String equivalent of {@link MediaType#TEXT_EVENT_STREAM}. + * @since 4.3.6 + */ + public static final String TEXT_EVENT_STREAM_VALUE = "text/event-stream"; + + /** + * Public constant media type for {@code text/html}. + */ + public static final MediaType TEXT_HTML; + + /** + * A String equivalent of {@link MediaType#TEXT_HTML}. + */ + public static final String TEXT_HTML_VALUE = "text/html"; + + /** + * Public constant media type for {@code text/markdown}. + * @since 4.3 + */ + public static final MediaType TEXT_MARKDOWN; + + /** + * A String equivalent of {@link MediaType#TEXT_MARKDOWN}. + * @since 4.3 + */ + public static final String TEXT_MARKDOWN_VALUE = "text/markdown"; + + /** + * Public constant media type for {@code text/plain}. + */ + public static final MediaType TEXT_PLAIN; + + /** + * A String equivalent of {@link MediaType#TEXT_PLAIN}. + */ + public static final String TEXT_PLAIN_VALUE = "text/plain"; + + /** + * Public constant media type for {@code text/xml}. + */ + public static final MediaType TEXT_XML; + + /** + * A String equivalent of {@link MediaType#TEXT_XML}. + */ + public static final String TEXT_XML_VALUE = "text/xml"; + + private static final String PARAM_QUALITY_FACTOR = "q"; + + + static { + ALL = valueOf(ALL_VALUE); + APPLICATION_ATOM_XML = valueOf(APPLICATION_ATOM_XML_VALUE); + APPLICATION_FORM_URLENCODED = valueOf(APPLICATION_FORM_URLENCODED_VALUE); + APPLICATION_JSON = valueOf(APPLICATION_JSON_VALUE); + APPLICATION_JSON_UTF8 = valueOf(APPLICATION_JSON_UTF8_VALUE); + APPLICATION_OCTET_STREAM = valueOf(APPLICATION_OCTET_STREAM_VALUE); + APPLICATION_PDF = valueOf(APPLICATION_PDF_VALUE); + APPLICATION_PROBLEM_JSON = valueOf(APPLICATION_PROBLEM_JSON_VALUE); + APPLICATION_PROBLEM_JSON_UTF8 = valueOf(APPLICATION_PROBLEM_JSON_UTF8_VALUE); + APPLICATION_PROBLEM_XML = valueOf(APPLICATION_PROBLEM_XML_VALUE); + APPLICATION_RSS_XML = valueOf(APPLICATION_RSS_XML_VALUE); + APPLICATION_STREAM_JSON = valueOf(APPLICATION_STREAM_JSON_VALUE); + APPLICATION_XHTML_XML = valueOf(APPLICATION_XHTML_XML_VALUE); + APPLICATION_XML = valueOf(APPLICATION_XML_VALUE); + IMAGE_GIF = valueOf(IMAGE_GIF_VALUE); + IMAGE_JPEG = valueOf(IMAGE_JPEG_VALUE); + IMAGE_PNG = valueOf(IMAGE_PNG_VALUE); + MULTIPART_FORM_DATA = valueOf(MULTIPART_FORM_DATA_VALUE); + TEXT_EVENT_STREAM = valueOf(TEXT_EVENT_STREAM_VALUE); + TEXT_HTML = valueOf(TEXT_HTML_VALUE); + TEXT_MARKDOWN = valueOf(TEXT_MARKDOWN_VALUE); + TEXT_PLAIN = valueOf(TEXT_PLAIN_VALUE); + TEXT_XML = valueOf(TEXT_XML_VALUE); + } + + + /** + * Create a new {@code MediaType} for the given primary type. + *

The {@linkplain #getSubtype() subtype} is set to "*", parameters empty. + * @param type the primary type + * @throws IllegalArgumentException if any of the parameters contain illegal characters + */ + public MediaType(String type) { + super(type); + } + + /** + * Create a new {@code MediaType} for the given primary type and subtype. + *

The parameters are empty. + * @param type the primary type + * @param subtype the subtype + * @throws IllegalArgumentException if any of the parameters contain illegal characters + */ + public MediaType(String type, String subtype) { + super(type, subtype, Collections.emptyMap()); + } + + /** + * Create a new {@code MediaType} for the given type, subtype, and character set. + * @param type the primary type + * @param subtype the subtype + * @param charset the character set + * @throws IllegalArgumentException if any of the parameters contain illegal characters + */ + public MediaType(String type, String subtype, Charset charset) { + super(type, subtype, charset); + } + + /** + * Create a new {@code MediaType} for the given type, subtype, and quality value. + * @param type the primary type + * @param subtype the subtype + * @param qualityValue the quality value + * @throws IllegalArgumentException if any of the parameters contain illegal characters + */ + public MediaType(String type, String subtype, double qualityValue) { + this(type, subtype, Collections.singletonMap(PARAM_QUALITY_FACTOR, Double.toString(qualityValue))); + } + + /** + * Copy-constructor that copies the type, subtype and parameters of the given + * {@code MediaType}, and allows to set the specified character set. + * @param other the other media type + * @param charset the character set + * @throws IllegalArgumentException if any of the parameters contain illegal characters + * @since 4.3 + */ + public MediaType(MediaType other, Charset charset) { + super(other, charset); + } + + /** + * Copy-constructor that copies the type and subtype of the given {@code MediaType}, + * and allows for different parameters. + * @param other the other media type + * @param parameters the parameters, may be {@code null} + * @throws IllegalArgumentException if any of the parameters contain illegal characters + */ + public MediaType(MediaType other, @Nullable Map parameters) { + super(other.getType(), other.getSubtype(), parameters); + } + + /** + * Create a new {@code MediaType} for the given type, subtype, and parameters. + * @param type the primary type + * @param subtype the subtype + * @param parameters the parameters, may be {@code null} + * @throws IllegalArgumentException if any of the parameters contain illegal characters + */ + public MediaType(String type, String subtype, @Nullable Map parameters) { + super(type, subtype, parameters); + } + + + @Override + protected void checkParameters(String attribute, String value) { + super.checkParameters(attribute, value); + if (PARAM_QUALITY_FACTOR.equals(attribute)) { + value = unquote(value); + double d = Double.parseDouble(value); + Assert.isTrue(d >= 0D && d <= 1D, + "Invalid quality value \"" + value + "\": should be between 0.0 and 1.0"); + } + } + + /** + * Return the quality factor, as indicated by a {@code q} parameter, if any. + * Defaults to {@code 1.0}. + * @return the quality factor as double value + */ + public double getQualityValue() { + String qualityFactor = getParameter(PARAM_QUALITY_FACTOR); + return (qualityFactor != null ? Double.parseDouble(unquote(qualityFactor)) : 1D); + } + + /** + * Indicate whether this {@code MediaType} includes the given media type. + *

For instance, {@code text/*} includes {@code text/plain} and {@code text/html}, + * and {@code application/*+xml} includes {@code application/soap+xml}, etc. + * This method is not symmetric. + *

Simply calls {@link MimeType#includes(MimeType)} but declared with a + * {@code MediaType} parameter for binary backwards compatibility. + * @param other the reference media type with which to compare + * @return {@code true} if this media type includes the given media type; + * {@code false} otherwise + */ + public boolean includes(@Nullable MediaType other) { + return super.includes(other); + } + + /** + * Indicate whether this {@code MediaType} is compatible with the given media type. + *

For instance, {@code text/*} is compatible with {@code text/plain}, + * {@code text/html}, and vice versa. In effect, this method is similar to + * {@link #includes}, except that it is symmetric. + *

Simply calls {@link MimeType#isCompatibleWith(MimeType)} but declared with a + * {@code MediaType} parameter for binary backwards compatibility. + * @param other the reference media type with which to compare + * @return {@code true} if this media type is compatible with the given media type; + * {@code false} otherwise + */ + public boolean isCompatibleWith(@Nullable MediaType other) { + return super.isCompatibleWith(other); + } + + /** + * Return a replica of this instance with the quality value of the given {@code MediaType}. + * @return the same instance if the given MediaType doesn't have a quality value, + * or a new one otherwise + */ + public MediaType copyQualityValue(MediaType mediaType) { + if (!mediaType.getParameters().containsKey(PARAM_QUALITY_FACTOR)) { + return this; + } + Map params = new LinkedHashMap<>(getParameters()); + params.put(PARAM_QUALITY_FACTOR, mediaType.getParameters().get(PARAM_QUALITY_FACTOR)); + return new MediaType(this, params); + } + + /** + * Return a replica of this instance with its quality value removed. + * @return the same instance if the media type doesn't contain a quality value, + * or a new one otherwise + */ + public MediaType removeQualityValue() { + if (!getParameters().containsKey(PARAM_QUALITY_FACTOR)) { + return this; + } + Map params = new LinkedHashMap<>(getParameters()); + params.remove(PARAM_QUALITY_FACTOR); + return new MediaType(this, params); + } + + + /** + * Parse the given String value into a {@code MediaType} object, + * with this method name following the 'valueOf' naming convention + * (as supported by {@link org.springframework.core.convert.ConversionService}. + * @param value the string to parse + * @throws InvalidMediaTypeException if the media type value cannot be parsed + * @see #parseMediaType(String) + */ + public static MediaType valueOf(String value) { + return parseMediaType(value); + } + + /** + * Parse the given String into a single {@code MediaType}. + * @param mediaType the string to parse + * @return the media type + * @throws InvalidMediaTypeException if the media type value cannot be parsed + */ + public static MediaType parseMediaType(String mediaType) { + MimeType type; + try { + type = MimeTypeUtils.parseMimeType(mediaType); + } + catch (InvalidMimeTypeException ex) { + throw new InvalidMediaTypeException(ex); + } + try { + return new MediaType(type.getType(), type.getSubtype(), type.getParameters()); + } + catch (IllegalArgumentException ex) { + throw new InvalidMediaTypeException(mediaType, ex.getMessage()); + } + } + + /** + * Parse the comma-separated string into a list of {@code MediaType} objects. + *

This method can be used to parse an Accept or Content-Type header. + * @param mediaTypes the string to parse + * @return the list of media types + * @throws InvalidMediaTypeException if the media type value cannot be parsed + */ + public static List parseMediaTypes(@Nullable String mediaTypes) { + if (!StringUtils.hasLength(mediaTypes)) { + return Collections.emptyList(); + } + return MimeTypeUtils.tokenize(mediaTypes).stream() + .filter(StringUtils::hasText) + .map(MediaType::parseMediaType) + .collect(Collectors.toList()); + } + + /** + * Parse the given list of (potentially) comma-separated strings into a + * list of {@code MediaType} objects. + *

This method can be used to parse an Accept or Content-Type header. + * @param mediaTypes the string to parse + * @return the list of media types + * @throws InvalidMediaTypeException if the media type value cannot be parsed + * @since 4.3.2 + */ + public static List parseMediaTypes(@Nullable List mediaTypes) { + if (CollectionUtils.isEmpty(mediaTypes)) { + return Collections.emptyList(); + } + else if (mediaTypes.size() == 1) { + return parseMediaTypes(mediaTypes.get(0)); + } + else { + List result = new ArrayList<>(8); + for (String mediaType : mediaTypes) { + result.addAll(parseMediaTypes(mediaType)); + } + return result; + } + } + + /** + * Re-create the given mime types as media types. + * @since 5.0 + */ + public static List asMediaTypes(List mimeTypes) { + return mimeTypes.stream().map(MediaType::asMediaType).collect(Collectors.toList()); + } + + /** + * Re-create the given mime type as a media type. + * @since 5.0 + */ + public static MediaType asMediaType(MimeType mimeType) { + if (mimeType instanceof MediaType) { + return (MediaType) mimeType; + } + return new MediaType(mimeType.getType(), mimeType.getSubtype(), mimeType.getParameters()); + } + + /** + * Return a string representation of the given list of {@code MediaType} objects. + *

This method can be used to for an {@code Accept} or {@code Content-Type} header. + * @param mediaTypes the media types to create a string representation for + * @return the string representation + */ + public static String toString(Collection mediaTypes) { + return MimeTypeUtils.toString(mediaTypes); + } + + /** + * Sorts the given list of {@code MediaType} objects by specificity. + *

Given two media types: + *

    + *
  1. if either media type has a {@linkplain #isWildcardType() wildcard type}, then the media type without the + * wildcard is ordered before the other.
  2. + *
  3. if the two media types have different {@linkplain #getType() types}, then they are considered equal and + * remain their current order.
  4. + *
  5. if either media type has a {@linkplain #isWildcardSubtype() wildcard subtype}, then the media type without + * the wildcard is sorted before the other.
  6. + *
  7. if the two media types have different {@linkplain #getSubtype() subtypes}, then they are considered equal + * and remain their current order.
  8. + *
  9. if the two media types have different {@linkplain #getQualityValue() quality value}, then the media type + * with the highest quality value is ordered before the other.
  10. + *
  11. if the two media types have a different amount of {@linkplain #getParameter(String) parameters}, then the + * media type with the most parameters is ordered before the other.
  12. + *
+ *

For example: + *

audio/basic < audio/* < */*
+ *
audio/* < audio/*;q=0.7; audio/*;q=0.3
+ *
audio/basic;level=1 < audio/basic
+ *
audio/basic == text/html
+ *
audio/basic == audio/wave
+ * @param mediaTypes the list of media types to be sorted + * @see HTTP 1.1: Semantics + * and Content, section 5.3.2 + */ + public static void sortBySpecificity(List mediaTypes) { + Assert.notNull(mediaTypes, "'mediaTypes' must not be null"); + if (mediaTypes.size() > 1) { + mediaTypes.sort(SPECIFICITY_COMPARATOR); + } + } + + /** + * Sorts the given list of {@code MediaType} objects by quality value. + *

Given two media types: + *

    + *
  1. if the two media types have different {@linkplain #getQualityValue() quality value}, then the media type + * with the highest quality value is ordered before the other.
  2. + *
  3. if either media type has a {@linkplain #isWildcardType() wildcard type}, then the media type without the + * wildcard is ordered before the other.
  4. + *
  5. if the two media types have different {@linkplain #getType() types}, then they are considered equal and + * remain their current order.
  6. + *
  7. if either media type has a {@linkplain #isWildcardSubtype() wildcard subtype}, then the media type without + * the wildcard is sorted before the other.
  8. + *
  9. if the two media types have different {@linkplain #getSubtype() subtypes}, then they are considered equal + * and remain their current order.
  10. + *
  11. if the two media types have a different amount of {@linkplain #getParameter(String) parameters}, then the + * media type with the most parameters is ordered before the other.
  12. + *
+ * @param mediaTypes the list of media types to be sorted + * @see #getQualityValue() + */ + public static void sortByQualityValue(List mediaTypes) { + Assert.notNull(mediaTypes, "'mediaTypes' must not be null"); + if (mediaTypes.size() > 1) { + mediaTypes.sort(QUALITY_VALUE_COMPARATOR); + } + } + + /** + * Sorts the given list of {@code MediaType} objects by specificity as the + * primary criteria and quality value the secondary. + * @see MediaType#sortBySpecificity(List) + * @see MediaType#sortByQualityValue(List) + */ + public static void sortBySpecificityAndQuality(List mediaTypes) { + Assert.notNull(mediaTypes, "'mediaTypes' must not be null"); + if (mediaTypes.size() > 1) { + mediaTypes.sort(MediaType.SPECIFICITY_COMPARATOR.thenComparing(MediaType.QUALITY_VALUE_COMPARATOR)); + } + } + + + /** + * Comparator used by {@link #sortByQualityValue(List)}. + */ + public static final Comparator QUALITY_VALUE_COMPARATOR = (mediaType1, mediaType2) -> { + double quality1 = mediaType1.getQualityValue(); + double quality2 = mediaType2.getQualityValue(); + int qualityComparison = Double.compare(quality2, quality1); + if (qualityComparison != 0) { + return qualityComparison; // audio/*;q=0.7 < audio/*;q=0.3 + } + else if (mediaType1.isWildcardType() && !mediaType2.isWildcardType()) { // */* < audio/* + return 1; + } + else if (mediaType2.isWildcardType() && !mediaType1.isWildcardType()) { // audio/* > */* + return -1; + } + else if (!mediaType1.getType().equals(mediaType2.getType())) { // audio/basic == text/html + return 0; + } + else { // mediaType1.getType().equals(mediaType2.getType()) + if (mediaType1.isWildcardSubtype() && !mediaType2.isWildcardSubtype()) { // audio/* < audio/basic + return 1; + } + else if (mediaType2.isWildcardSubtype() && !mediaType1.isWildcardSubtype()) { // audio/basic > audio/* + return -1; + } + else if (!mediaType1.getSubtype().equals(mediaType2.getSubtype())) { // audio/basic == audio/wave + return 0; + } + else { + int paramsSize1 = mediaType1.getParameters().size(); + int paramsSize2 = mediaType2.getParameters().size(); + return Integer.compare(paramsSize2, paramsSize1); // audio/basic;level=1 < audio/basic + } + } + }; + + + /** + * Comparator used by {@link #sortBySpecificity(List)}. + */ + public static final Comparator SPECIFICITY_COMPARATOR = new SpecificityComparator() { + + @Override + protected int compareParameters(MediaType mediaType1, MediaType mediaType2) { + double quality1 = mediaType1.getQualityValue(); + double quality2 = mediaType2.getQualityValue(); + int qualityComparison = Double.compare(quality2, quality1); + if (qualityComparison != 0) { + return qualityComparison; // audio/*;q=0.7 < audio/*;q=0.3 + } + return super.compareParameters(mediaType1, mediaType2); + } + }; + +} diff --git a/spring-web/src/main/java/org/springframework/http/MediaTypeEditor.java b/spring-web/src/main/java/org/springframework/http/MediaTypeEditor.java new file mode 100644 index 0000000000000000000000000000000000000000..8a6971909fbbd1b0f0a4f81b091ba220f51ed776 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/MediaTypeEditor.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.beans.PropertyEditorSupport; + +import org.springframework.util.StringUtils; + +/** + * {@link java.beans.PropertyEditor Editor} for {@link MediaType} + * descriptors, to automatically convert {@code String} specifications + * (e.g. {@code "text/html"}) to {@code MediaType} properties. + * + * @author Juergen Hoeller + * @since 3.0 + * @see MediaType + */ +public class MediaTypeEditor extends PropertyEditorSupport { + + @Override + public void setAsText(String text) { + if (StringUtils.hasText(text)) { + setValue(MediaType.parseMediaType(text)); + } + else { + setValue(null); + } + } + + @Override + public String getAsText() { + MediaType mediaType = (MediaType) getValue(); + return (mediaType != null ? mediaType.toString() : ""); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/MediaTypeFactory.java b/spring-web/src/main/java/org/springframework/http/MediaTypeFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..7ad0969f2b9fde6da1be5f757f611a63191607f7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/MediaTypeFactory.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * A factory delegate for resolving {@link MediaType} objects + * from {@link Resource} handles or filenames. + * + * @author Juergen Hoeller + * @author Arjen Poutsma + * @since 5.0 + */ +public final class MediaTypeFactory { + + private static final String MIME_TYPES_FILE_NAME = "/org/springframework/http/mime.types"; + + private static final MultiValueMap fileExtensionToMediaTypes = parseMimeTypes(); + + + private MediaTypeFactory() { + } + + + /** + * Parse the {@code mime.types} file found in the resources. Format is: + * + * # comments begin with a '#'
+ * # the format is <mime type> <space separated file extensions>
+ * # for example:
+ * text/plain txt text
+ * # this would map file.txt and file.text to
+ * # the mime type "text/plain"
+ *
+ * @return a multi-value map, mapping media types to file extensions. + */ + private static MultiValueMap parseMimeTypes() { + InputStream is = MediaTypeFactory.class.getResourceAsStream(MIME_TYPES_FILE_NAME); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.US_ASCII))) { + MultiValueMap result = new LinkedMultiValueMap<>(); + String line; + while ((line = reader.readLine()) != null) { + if (line.isEmpty() || line.charAt(0) == '#') { + continue; + } + String[] tokens = StringUtils.tokenizeToStringArray(line, " \t\n\r\f"); + MediaType mediaType = MediaType.parseMediaType(tokens[0]); + for (int i = 1; i < tokens.length; i++) { + String fileExtension = tokens[i].toLowerCase(Locale.ENGLISH); + result.add(fileExtension, mediaType); + } + } + return result; + } + catch (IOException ex) { + throw new IllegalStateException("Could not load '" + MIME_TYPES_FILE_NAME + "'", ex); + } + } + + /** + * Determine a media type for the given resource, if possible. + * @param resource the resource to introspect + * @return the corresponding media type, or {@code null} if none found + */ + public static Optional getMediaType(@Nullable Resource resource) { + return Optional.ofNullable(resource) + .map(Resource::getFilename) + .flatMap(MediaTypeFactory::getMediaType); + } + + /** + * Determine a media type for the given file name, if possible. + * @param filename the file name plus extension + * @return the corresponding media type, or {@code null} if none found + */ + public static Optional getMediaType(@Nullable String filename) { + return getMediaTypes(filename).stream().findFirst(); + } + + /** + * Determine the media types for the given file name, if possible. + * @param filename the file name plus extension + * @return the corresponding media types, or an empty list if none found + */ + public static List getMediaTypes(@Nullable String filename) { + return Optional.ofNullable(StringUtils.getFilenameExtension(filename)) + .map(s -> s.toLowerCase(Locale.ENGLISH)) + .map(fileExtensionToMediaTypes::get) + .orElse(Collections.emptyList()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/ReactiveHttpInputMessage.java b/spring-web/src/main/java/org/springframework/http/ReactiveHttpInputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..f33adf34fc6dc3a26ecc843b5773bac424fe0028 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ReactiveHttpInputMessage.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; + +/** + * An "reactive" HTTP input message that exposes the input as {@link Publisher}. + * + *

Typically implemented by an HTTP request on the server-side or a response + * on the client-side. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public interface ReactiveHttpInputMessage extends HttpMessage { + + /** + * Return the body of the message as a {@link Publisher}. + * @return the body content publisher + */ + Flux getBody(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/ReactiveHttpOutputMessage.java b/spring-web/src/main/java/org/springframework/http/ReactiveHttpOutputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..b15f84c898805e90494d56bc477290466393ace0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ReactiveHttpOutputMessage.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.function.Supplier; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; + +/** + * A "reactive" HTTP output message that accepts output as a {@link Publisher}. + * + *

Typically implemented by an HTTP request on the client-side or an + * HTTP response on the server-side. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface ReactiveHttpOutputMessage extends HttpMessage { + + /** + * Return a {@link DataBufferFactory} that can be used to create the body. + * @return a buffer factory + * @see #writeWith(Publisher) + */ + DataBufferFactory bufferFactory(); + + /** + * Register an action to apply just before the HttpOutputMessage is committed. + *

Note: the supplied action must be properly deferred, + * e.g. via {@link Mono#defer} or {@link Mono#fromRunnable}, to ensure it's + * executed in the right order, relative to other actions. + * @param action the action to apply + */ + void beforeCommit(Supplier> action); + + /** + * Whether the HttpOutputMessage is committed. + */ + boolean isCommitted(); + + /** + * Use the given {@link Publisher} to write the body of the message to the + * underlying HTTP layer. + * @param body the body content publisher + * @return a {@link Mono} that indicates completion or error + */ + + Mono writeWith(Publisher body); + + /** + * Use the given {@link Publisher} of {@code Publishers} to write the body + * of the HttpOutputMessage to the underlying HTTP layer, flushing after + * each {@code Publisher}. + * @param body the body content publisher + * @return a {@link Mono} that indicates completion or error + */ + Mono writeAndFlushWith(Publisher> body); + + /** + * Indicate that message handling is complete, allowing for any cleanup or + * end-of-processing tasks to be performed such as applying header changes + * made via {@link #getHeaders()} to the underlying HTTP message (if not + * applied already). + *

This method should be automatically invoked at the end of message + * processing so typically applications should not have to invoke it. + * If invoked multiple times it should have no side effects. + * @return a {@link Mono} that indicates completion or error + */ + Mono setComplete(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java new file mode 100644 index 0000000000000000000000000000000000000000..1ac1c16da110a617bdef6957847614d51c2c4395 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java @@ -0,0 +1,152 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code HttpHeaders} object that can only be read, not written to. + * + * @author Brian Clozel + * @author Sam Brannen + * @since 5.1.1 + */ +class ReadOnlyHttpHeaders extends HttpHeaders { + + private static final long serialVersionUID = -8578554704772377436L; + + @Nullable + private MediaType cachedContentType; + + @Nullable + private List cachedAccept; + + + ReadOnlyHttpHeaders(MultiValueMap headers) { + super(headers); + } + + + @Override + public MediaType getContentType() { + if (this.cachedContentType != null) { + return this.cachedContentType; + } + else { + MediaType contentType = super.getContentType(); + this.cachedContentType = contentType; + return contentType; + } + } + + @Override + public List getAccept() { + if (this.cachedAccept != null) { + return this.cachedAccept; + } + else { + List accept = super.getAccept(); + this.cachedAccept = accept; + return accept; + } + } + + @Override + public List get(Object key) { + List values = this.headers.get(key); + return (values != null ? Collections.unmodifiableList(values) : null); + } + + @Override + public void add(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(String key, List values) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(MultiValueMap values) { + throw new UnsupportedOperationException(); + } + + @Override + public void set(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void setAll(Map values) { + throw new UnsupportedOperationException(); + } + + @Override + public Map toSingleValueMap() { + return Collections.unmodifiableMap(this.headers.toSingleValueMap()); + } + + @Override + public Set keySet() { + return Collections.unmodifiableSet(this.headers.keySet()); + } + + @Override + public List put(String key, List value) { + throw new UnsupportedOperationException(); + } + + @Override + public List remove(Object key) { + throw new UnsupportedOperationException(); + } + + @Override + public void putAll(Map> map) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public Collection> values() { + return Collections.unmodifiableCollection(this.headers.values()); + } + + @Override + public Set>> entrySet() { + return this.headers.entrySet().stream().map(SimpleImmutableEntry::new) + .collect(Collectors.collectingAndThen( + Collectors.toCollection(LinkedHashSet::new), // Retain original ordering of entries + Collections::unmodifiableSet)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/RequestEntity.java b/spring-web/src/main/java/org/springframework/http/RequestEntity.java new file mode 100644 index 0000000000000000000000000000000000000000..c0caca3706ee82f8eef9e798b5c9bf7fe4a906d6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/RequestEntity.java @@ -0,0 +1,499 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.lang.reflect.Type; +import java.net.URI; +import java.nio.charset.Charset; +import java.time.Instant; +import java.time.ZonedDateTime; +import java.util.Arrays; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; + +/** + * Extension of {@link HttpEntity} that adds a {@linkplain HttpMethod method} and + * {@linkplain URI uri}. Used in {@code RestTemplate} and {@code @Controller} methods. + * + *

In {@code RestTemplate}, this class is used as parameter in + * {@link org.springframework.web.client.RestTemplate#exchange(RequestEntity, Class) exchange()}: + *

+ * MyRequest body = ...
+ * RequestEntity<MyRequest> request = RequestEntity
+ *     .post(new URI("https://example.com/bar"))
+ *     .accept(MediaType.APPLICATION_JSON)
+ *     .body(body);
+ * ResponseEntity<MyResponse> response = template.exchange(request, MyResponse.class);
+ * 
+ * + *

If you would like to provide a URI template with variables, consider using + * {@link org.springframework.web.util.UriTemplate}: + *

+ * URI uri = new UriTemplate("https://example.com/{foo}").expand("bar");
+ * RequestEntity<MyRequest> request = RequestEntity.post(uri).accept(MediaType.APPLICATION_JSON).body(body);
+ * 
+ * + *

Can also be used in Spring MVC, as a parameter in a @Controller method: + *

+ * @RequestMapping("/handle")
+ * public void handle(RequestEntity<String> request) {
+ *   HttpMethod method = request.getMethod();
+ *   URI url = request.getUrl();
+ *   String body = request.getBody();
+ * }
+ * 
+ * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @since 4.1 + * @param the body type + * @see #getMethod() + * @see #getUrl() + * @see org.springframework.web.client.RestOperations#exchange(RequestEntity, Class) + * @see ResponseEntity + */ +public class RequestEntity extends HttpEntity { + + @Nullable + private final HttpMethod method; + + private final URI url; + + @Nullable + private final Type type; + + + /** + * Constructor with method and URL but without body nor headers. + * @param method the method + * @param url the URL + */ + public RequestEntity(HttpMethod method, URI url) { + this(null, null, method, url, null); + } + + /** + * Constructor with method, URL and body but without headers. + * @param body the body + * @param method the method + * @param url the URL + */ + public RequestEntity(@Nullable T body, HttpMethod method, URI url) { + this(body, null, method, url, null); + } + + /** + * Constructor with method, URL, body and type but without headers. + * @param body the body + * @param method the method + * @param url the URL + * @param type the type used for generic type resolution + * @since 4.3 + */ + public RequestEntity(@Nullable T body, HttpMethod method, URI url, Type type) { + this(body, null, method, url, type); + } + + /** + * Constructor with method, URL and headers but without body. + * @param headers the headers + * @param method the method + * @param url the URL + */ + public RequestEntity(MultiValueMap headers, HttpMethod method, URI url) { + this(null, headers, method, url, null); + } + + /** + * Constructor with method, URL, headers and body. + * @param body the body + * @param headers the headers + * @param method the method + * @param url the URL + */ + public RequestEntity(@Nullable T body, @Nullable MultiValueMap headers, + @Nullable HttpMethod method, URI url) { + + this(body, headers, method, url, null); + } + + /** + * Constructor with method, URL, headers, body and type. + * @param body the body + * @param headers the headers + * @param method the method + * @param url the URL + * @param type the type used for generic type resolution + * @since 4.3 + */ + public RequestEntity(@Nullable T body, @Nullable MultiValueMap headers, + @Nullable HttpMethod method, URI url, @Nullable Type type) { + + super(body, headers); + this.method = method; + this.url = url; + this.type = type; + } + + + /** + * Return the HTTP method of the request. + * @return the HTTP method as an {@code HttpMethod} enum value + */ + @Nullable + public HttpMethod getMethod() { + return this.method; + } + + /** + * Return the URL of the request. + * @return the URL as a {@code URI} + */ + public URI getUrl() { + return this.url; + } + + /** + * Return the type of the request's body. + * @return the request's body type, or {@code null} if not known + * @since 4.3 + */ + @Nullable + public Type getType() { + if (this.type == null) { + T body = getBody(); + if (body != null) { + return body.getClass(); + } + } + return this.type; + } + + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (!super.equals(other)) { + return false; + } + RequestEntity otherEntity = (RequestEntity) other; + return (ObjectUtils.nullSafeEquals(getMethod(), otherEntity.getMethod()) && + ObjectUtils.nullSafeEquals(getUrl(), otherEntity.getUrl())); + } + + @Override + public int hashCode() { + int hashCode = super.hashCode(); + hashCode = 29 * hashCode + ObjectUtils.nullSafeHashCode(this.method); + hashCode = 29 * hashCode + ObjectUtils.nullSafeHashCode(this.url); + return hashCode; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("<"); + builder.append(getMethod()); + builder.append(' '); + builder.append(getUrl()); + builder.append(','); + T body = getBody(); + HttpHeaders headers = getHeaders(); + if (body != null) { + builder.append(body); + builder.append(','); + } + builder.append(headers); + builder.append('>'); + return builder.toString(); + } + + + // Static builder methods + + /** + * Create a builder with the given method and url. + * @param method the HTTP method (GET, POST, etc) + * @param url the URL + * @return the created builder + */ + public static BodyBuilder method(HttpMethod method, URI url) { + return new DefaultBodyBuilder(method, url); + } + + /** + * Create an HTTP GET builder with the given url. + * @param url the URL + * @return the created builder + */ + public static HeadersBuilder get(URI url) { + return method(HttpMethod.GET, url); + } + + /** + * Create an HTTP HEAD builder with the given url. + * @param url the URL + * @return the created builder + */ + public static HeadersBuilder head(URI url) { + return method(HttpMethod.HEAD, url); + } + + /** + * Create an HTTP POST builder with the given url. + * @param url the URL + * @return the created builder + */ + public static BodyBuilder post(URI url) { + return method(HttpMethod.POST, url); + } + + /** + * Create an HTTP PUT builder with the given url. + * @param url the URL + * @return the created builder + */ + public static BodyBuilder put(URI url) { + return method(HttpMethod.PUT, url); + } + + /** + * Create an HTTP PATCH builder with the given url. + * @param url the URL + * @return the created builder + */ + public static BodyBuilder patch(URI url) { + return method(HttpMethod.PATCH, url); + } + + /** + * Create an HTTP DELETE builder with the given url. + * @param url the URL + * @return the created builder + */ + public static HeadersBuilder delete(URI url) { + return method(HttpMethod.DELETE, url); + } + + /** + * Creates an HTTP OPTIONS builder with the given url. + * @param url the URL + * @return the created builder + */ + public static HeadersBuilder options(URI url) { + return method(HttpMethod.OPTIONS, url); + } + + + /** + * Defines a builder that adds headers to the request entity. + * @param the builder subclass + */ + public interface HeadersBuilder> { + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + B header(String headerName, String... headerValues); + + /** + * Set the list of acceptable {@linkplain MediaType media types}, as + * specified by the {@code Accept} header. + * @param acceptableMediaTypes the acceptable media types + */ + B accept(MediaType... acceptableMediaTypes); + + /** + * Set the list of acceptable {@linkplain Charset charsets}, as specified + * by the {@code Accept-Charset} header. + * @param acceptableCharsets the acceptable charsets + */ + B acceptCharset(Charset... acceptableCharsets); + + /** + * Set the value of the {@code If-Modified-Since} header. + * @param ifModifiedSince the new value of the header + * @since 5.1.4 + */ + B ifModifiedSince(ZonedDateTime ifModifiedSince); + + /** + * Set the value of the {@code If-Modified-Since} header. + * @param ifModifiedSince the new value of the header + * @since 5.1.4 + */ + B ifModifiedSince(Instant ifModifiedSince); + + /** + * Set the value of the {@code If-Modified-Since} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param ifModifiedSince the new value of the header + */ + B ifModifiedSince(long ifModifiedSince); + + /** + * Set the values of the {@code If-None-Match} header. + * @param ifNoneMatches the new value of the header + */ + B ifNoneMatch(String... ifNoneMatches); + + /** + * Builds the request entity with no body. + * @return the request entity + * @see BodyBuilder#body(Object) + */ + RequestEntity build(); + } + + + /** + * Defines a builder that adds a body to the response entity. + */ + public interface BodyBuilder extends HeadersBuilder { + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + * @param contentLength the content length + * @return this builder + * @see HttpHeaders#setContentLength(long) + */ + BodyBuilder contentLength(long contentLength); + + /** + * Set the {@linkplain MediaType media type} of the body, as specified + * by the {@code Content-Type} header. + * @param contentType the content type + * @return this builder + * @see HttpHeaders#setContentType(MediaType) + */ + BodyBuilder contentType(MediaType contentType); + + /** + * Set the body of the request entity and build the RequestEntity. + * @param the type of the body + * @param body the body of the request entity + * @return the built request entity + */ + RequestEntity body(T body); + + /** + * Set the body and type of the request entity and build the RequestEntity. + * @param the type of the body + * @param body the body of the request entity + * @param type the type of the body, useful for generic type resolution + * @return the built request entity + * @since 4.3 + */ + RequestEntity body(T body, Type type); + } + + + private static class DefaultBodyBuilder implements BodyBuilder { + + private final HttpMethod method; + + private final URI url; + + private final HttpHeaders headers = new HttpHeaders(); + + public DefaultBodyBuilder(HttpMethod method, URI url) { + this.method = method; + this.url = url; + } + + @Override + public BodyBuilder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public BodyBuilder accept(MediaType... acceptableMediaTypes) { + this.headers.setAccept(Arrays.asList(acceptableMediaTypes)); + return this; + } + + @Override + public BodyBuilder acceptCharset(Charset... acceptableCharsets) { + this.headers.setAcceptCharset(Arrays.asList(acceptableCharsets)); + return this; + } + + @Override + public BodyBuilder contentLength(long contentLength) { + this.headers.setContentLength(contentLength); + return this; + } + + @Override + public BodyBuilder contentType(MediaType contentType) { + this.headers.setContentType(contentType); + return this; + } + + @Override + public BodyBuilder ifModifiedSince(ZonedDateTime ifModifiedSince) { + this.headers.setIfModifiedSince(ifModifiedSince); + return this; + } + + @Override + public BodyBuilder ifModifiedSince(Instant ifModifiedSince) { + this.headers.setIfModifiedSince(ifModifiedSince); + return this; + } + + @Override + public BodyBuilder ifModifiedSince(long ifModifiedSince) { + this.headers.setIfModifiedSince(ifModifiedSince); + return this; + } + + @Override + public BodyBuilder ifNoneMatch(String... ifNoneMatches) { + this.headers.setIfNoneMatch(Arrays.asList(ifNoneMatches)); + return this; + } + + @Override + public RequestEntity build() { + return new RequestEntity<>(this.headers, this.method, this.url); + } + + @Override + public RequestEntity body(T body) { + return new RequestEntity<>(body, this.headers, this.method, this.url); + } + + @Override + public RequestEntity body(T body, Type type) { + return new RequestEntity<>(body, this.headers, this.method, this.url, type); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/ResponseCookie.java b/spring-web/src/main/java/org/springframework/http/ResponseCookie.java new file mode 100644 index 0000000000000000000000000000000000000000..d11b5dad0edf16b53ce158020423025c0365cb2c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ResponseCookie.java @@ -0,0 +1,402 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.time.Duration; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * An {@code HttpCookie} subclass with the additional attributes allowed in + * the "Set-Cookie" response header. To build an instance use the {@link #from} + * static method. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + * @see RFC 6265 + */ +public final class ResponseCookie extends HttpCookie { + + private final Duration maxAge; + + @Nullable + private final String domain; + + @Nullable + private final String path; + + private final boolean secure; + + private final boolean httpOnly; + + @Nullable + private final String sameSite; + + + /** + * Private constructor. See {@link #from(String, String)}. + */ + private ResponseCookie(String name, String value, Duration maxAge, @Nullable String domain, + @Nullable String path, boolean secure, boolean httpOnly, @Nullable String sameSite) { + + super(name, value); + Assert.notNull(maxAge, "Max age must not be null"); + + this.maxAge = maxAge; + this.domain = domain; + this.path = path; + this.secure = secure; + this.httpOnly = httpOnly; + this.sameSite = sameSite; + + Rfc6265Utils.validateCookieName(name); + Rfc6265Utils.validateCookieValue(value); + Rfc6265Utils.validateDomain(domain); + Rfc6265Utils.validatePath(path); + } + + + /** + * Return the cookie "Max-Age" attribute in seconds. + *

A positive value indicates when the cookie expires relative to the + * current time. A value of 0 means the cookie should expire immediately. + * A negative value means no "Max-Age" attribute in which case the cookie + * is removed when the browser is closed. + */ + public Duration getMaxAge() { + return this.maxAge; + } + + /** + * Return the cookie "Domain" attribute, or {@code null} if not set. + */ + @Nullable + public String getDomain() { + return this.domain; + } + + /** + * Return the cookie "Path" attribute, or {@code null} if not set. + */ + @Nullable + public String getPath() { + return this.path; + } + + /** + * Return {@code true} if the cookie has the "Secure" attribute. + */ + public boolean isSecure() { + return this.secure; + } + + /** + * Return {@code true} if the cookie has the "HttpOnly" attribute. + * @see https://www.owasp.org/index.php/HTTPOnly + */ + public boolean isHttpOnly() { + return this.httpOnly; + } + + /** + * Return the cookie "SameSite" attribute, or {@code null} if not set. + *

This limits the scope of the cookie such that it will only be attached to + * same site requests if {@code "Strict"} or cross-site requests if {@code "Lax"}. + * @see RFC6265 bis + * @since 5.1 + */ + @Nullable + public String getSameSite() { + return this.sameSite; + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ResponseCookie)) { + return false; + } + ResponseCookie otherCookie = (ResponseCookie) other; + return (getName().equalsIgnoreCase(otherCookie.getName()) && + ObjectUtils.nullSafeEquals(this.path, otherCookie.getPath()) && + ObjectUtils.nullSafeEquals(this.domain, otherCookie.getDomain())); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.domain); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.path); + return result; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getName()).append('=').append(getValue()); + if (StringUtils.hasText(getPath())) { + sb.append("; Path=").append(getPath()); + } + if (StringUtils.hasText(this.domain)) { + sb.append("; Domain=").append(this.domain); + } + if (!this.maxAge.isNegative()) { + sb.append("; Max-Age=").append(this.maxAge.getSeconds()); + sb.append("; Expires="); + long millis = this.maxAge.getSeconds() > 0 ? System.currentTimeMillis() + this.maxAge.toMillis() : 0; + sb.append(HttpHeaders.formatDate(millis)); + } + if (this.secure) { + sb.append("; Secure"); + } + if (this.httpOnly) { + sb.append("; HttpOnly"); + } + if (StringUtils.hasText(this.sameSite)) { + sb.append("; SameSite=").append(this.sameSite); + } + return sb.toString(); + } + + + /** + * Factory method to obtain a builder for a server-defined cookie that starts + * with a name-value pair and may also include attributes. + * @param name the cookie name + * @param value the cookie value + * @return a builder to create the cookie with + */ + public static ResponseCookieBuilder from(final String name, final String value) { + + return new ResponseCookieBuilder() { + + private Duration maxAge = Duration.ofSeconds(-1); + + @Nullable + private String domain; + + @Nullable + private String path; + + private boolean secure; + + private boolean httpOnly; + + @Nullable + private String sameSite; + + @Override + public ResponseCookieBuilder maxAge(Duration maxAge) { + this.maxAge = maxAge; + return this; + } + + @Override + public ResponseCookieBuilder maxAge(long maxAgeSeconds) { + this.maxAge = maxAgeSeconds >= 0 ? Duration.ofSeconds(maxAgeSeconds) : Duration.ofSeconds(-1); + return this; + } + + @Override + public ResponseCookieBuilder domain(String domain) { + this.domain = domain; + return this; + } + + @Override + public ResponseCookieBuilder path(String path) { + this.path = path; + return this; + } + + @Override + public ResponseCookieBuilder secure(boolean secure) { + this.secure = secure; + return this; + } + + @Override + public ResponseCookieBuilder httpOnly(boolean httpOnly) { + this.httpOnly = httpOnly; + return this; + } + + @Override + public ResponseCookieBuilder sameSite(@Nullable String sameSite) { + this.sameSite = sameSite; + return this; + } + + @Override + public ResponseCookie build() { + return new ResponseCookie(name, value, this.maxAge, this.domain, this.path, + this.secure, this.httpOnly, this.sameSite); + } + }; + } + + + /** + * A builder for a server-defined HttpCookie with attributes. + */ + public interface ResponseCookieBuilder { + + /** + * Set the cookie "Max-Age" attribute. + * + *

A positive value indicates when the cookie should expire relative + * to the current time. A value of 0 means the cookie should expire + * immediately. A negative value results in no "Max-Age" attribute in + * which case the cookie is removed when the browser is closed. + */ + ResponseCookieBuilder maxAge(Duration maxAge); + + /** + * Variant of {@link #maxAge(Duration)} accepting a value in seconds. + */ + ResponseCookieBuilder maxAge(long maxAgeSeconds); + + /** + * Set the cookie "Path" attribute. + */ + ResponseCookieBuilder path(String path); + + /** + * Set the cookie "Domain" attribute. + */ + ResponseCookieBuilder domain(String domain); + + /** + * Add the "Secure" attribute to the cookie. + */ + ResponseCookieBuilder secure(boolean secure); + + /** + * Add the "HttpOnly" attribute to the cookie. + * @see https://www.owasp.org/index.php/HTTPOnly + */ + ResponseCookieBuilder httpOnly(boolean httpOnly); + + /** + * Add the "SameSite" attribute to the cookie. + *

This limits the scope of the cookie such that it will only be + * attached to same site requests if {@code "Strict"} or cross-site + * requests if {@code "Lax"}. + * @since 5.1 + * @see RFC6265 bis + */ + ResponseCookieBuilder sameSite(@Nullable String sameSite); + + /** + * Create the HttpCookie. + */ + ResponseCookie build(); + } + + + private static class Rfc6265Utils { + + private static final String SEPARATOR_CHARS = new String(new char[] { + '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ' + }); + + private static final String DOMAIN_CHARS = + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-"; + + + public static void validateCookieName(String name) { + for (int i = 0; i < name.length(); i++) { + char c = name.charAt(i); + // CTL = + if (c <= 0x1F || c == 0x7F) { + throw new IllegalArgumentException( + name + ": RFC2616 token cannot have control chars"); + } + if (SEPARATOR_CHARS.indexOf(c) >= 0) { + throw new IllegalArgumentException( + name + ": RFC2616 token cannot have separator chars such as '" + c + "'"); + } + if (c >= 0x80) { + throw new IllegalArgumentException( + name + ": RFC2616 token can only have US-ASCII: 0x" + Integer.toHexString(c)); + } + } + } + + public static void validateCookieValue(@Nullable String value) { + if (value == null) { + return; + } + int start = 0; + int end = value.length(); + if (end > 1 && value.charAt(0) == '"' && value.charAt(end - 1) == '"') { + start = 1; + end--; + } + char[] chars = value.toCharArray(); + for (int i = start; i < end; i++) { + char c = chars[i]; + if (c < 0x21 || c == 0x22 || c == 0x2c || c == 0x3b || c == 0x5c || c == 0x7f) { + throw new IllegalArgumentException( + "RFC2616 cookie value cannot have '" + c + "'"); + } + if (c >= 0x80) { + throw new IllegalArgumentException( + "RFC2616 cookie value can only have US-ASCII chars: 0x" + Integer.toHexString(c)); + } + } + } + + public static void validateDomain(@Nullable String domain) { + if (!StringUtils.hasLength(domain)) { + return; + } + int char1 = domain.charAt(0); + int charN = domain.charAt(domain.length() - 1); + if (char1 == '-' || charN == '.' || charN == '-') { + throw new IllegalArgumentException("Invalid first/last char in cookie domain: " + domain); + } + for (int i = 0, c = -1; i < domain.length(); i++) { + int p = c; + c = domain.charAt(i); + if (DOMAIN_CHARS.indexOf(c) == -1 || (p == '.' && (c == '.' || c == '-')) || (p == '-' && c == '.')) { + throw new IllegalArgumentException(domain + ": invalid cookie domain char '" + c + "'"); + } + } + } + + public static void validatePath(@Nullable String path) { + if (path == null) { + return; + } + for (int i = 0; i < path.length(); i++) { + char c = path.charAt(i); + if (c < 0x20 || c > 0x7E || c == ';') { + throw new IllegalArgumentException(path + ": Invalid cookie path char '" + c + "'"); + } + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/ResponseEntity.java b/spring-web/src/main/java/org/springframework/http/ResponseEntity.java new file mode 100644 index 0000000000000000000000000000000000000000..96c229c100c58966b861c365a68a36e966eb6d64 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ResponseEntity.java @@ -0,0 +1,563 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.net.URI; +import java.time.Instant; +import java.time.ZonedDateTime; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.Optional; +import java.util.Set; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; + +/** + * Extension of {@link HttpEntity} that adds a {@link HttpStatus} status code. + * Used in {@code RestTemplate} as well {@code @Controller} methods. + * + *

In {@code RestTemplate}, this class is returned by + * {@link org.springframework.web.client.RestTemplate#getForEntity getForEntity()} and + * {@link org.springframework.web.client.RestTemplate#exchange exchange()}: + *

+ * ResponseEntity<String> entity = template.getForEntity("https://example.com", String.class);
+ * String body = entity.getBody();
+ * MediaType contentType = entity.getHeaders().getContentType();
+ * HttpStatus statusCode = entity.getStatusCode();
+ * 
+ * + *

Can also be used in Spring MVC, as the return value from a @Controller method: + *

+ * @RequestMapping("/handle")
+ * public ResponseEntity<String> handle() {
+ *   URI location = ...;
+ *   HttpHeaders responseHeaders = new HttpHeaders();
+ *   responseHeaders.setLocation(location);
+ *   responseHeaders.set("MyResponseHeader", "MyValue");
+ *   return new ResponseEntity<String>("Hello World", responseHeaders, HttpStatus.CREATED);
+ * }
+ * 
+ * + * Or, by using a builder accessible via static methods: + *
+ * @RequestMapping("/handle")
+ * public ResponseEntity<String> handle() {
+ *   URI location = ...;
+ *   return ResponseEntity.created(location).header("MyResponseHeader", "MyValue").body("Hello World");
+ * }
+ * 
+ * + * @author Arjen Poutsma + * @author Brian Clozel + * @since 3.0.2 + * @param the body type + * @see #getStatusCode() + * @see org.springframework.web.client.RestOperations#getForEntity(String, Class, Object...) + * @see org.springframework.web.client.RestOperations#getForEntity(String, Class, java.util.Map) + * @see org.springframework.web.client.RestOperations#getForEntity(URI, Class) + * @see RequestEntity + */ +public class ResponseEntity extends HttpEntity { + + private final Object status; + + + /** + * Create a new {@code ResponseEntity} with the given status code, and no body nor headers. + * @param status the status code + */ + public ResponseEntity(HttpStatus status) { + this(null, null, status); + } + + /** + * Create a new {@code ResponseEntity} with the given body and status code, and no headers. + * @param body the entity body + * @param status the status code + */ + public ResponseEntity(@Nullable T body, HttpStatus status) { + this(body, null, status); + } + + /** + * Create a new {@code HttpEntity} with the given headers and status code, and no body. + * @param headers the entity headers + * @param status the status code + */ + public ResponseEntity(MultiValueMap headers, HttpStatus status) { + this(null, headers, status); + } + + /** + * Create a new {@code HttpEntity} with the given body, headers, and status code. + * @param body the entity body + * @param headers the entity headers + * @param status the status code + */ + public ResponseEntity(@Nullable T body, @Nullable MultiValueMap headers, HttpStatus status) { + super(body, headers); + Assert.notNull(status, "HttpStatus must not be null"); + this.status = status; + } + + /** + * Create a new {@code HttpEntity} with the given body, headers, and status code. + * Just used behind the nested builder API. + * @param body the entity body + * @param headers the entity headers + * @param status the status code (as {@code HttpStatus} or as {@code Integer} value) + */ + private ResponseEntity(@Nullable T body, @Nullable MultiValueMap headers, Object status) { + super(body, headers); + Assert.notNull(status, "HttpStatus must not be null"); + this.status = status; + } + + + /** + * Return the HTTP status code of the response. + * @return the HTTP status as an HttpStatus enum entry + */ + public HttpStatus getStatusCode() { + if (this.status instanceof HttpStatus) { + return (HttpStatus) this.status; + } + else { + return HttpStatus.valueOf((Integer) this.status); + } + } + + /** + * Return the HTTP status code of the response. + * @return the HTTP status as an int value + * @since 4.3 + */ + public int getStatusCodeValue() { + if (this.status instanceof HttpStatus) { + return ((HttpStatus) this.status).value(); + } + else { + return (Integer) this.status; + } + } + + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (!super.equals(other)) { + return false; + } + ResponseEntity otherEntity = (ResponseEntity) other; + return ObjectUtils.nullSafeEquals(this.status, otherEntity.status); + } + + @Override + public int hashCode() { + return (29 * super.hashCode() + ObjectUtils.nullSafeHashCode(this.status)); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("<"); + builder.append(this.status.toString()); + if (this.status instanceof HttpStatus) { + builder.append(' '); + builder.append(((HttpStatus) this.status).getReasonPhrase()); + } + builder.append(','); + T body = getBody(); + HttpHeaders headers = getHeaders(); + if (body != null) { + builder.append(body); + builder.append(','); + } + builder.append(headers); + builder.append('>'); + return builder.toString(); + } + + + // Static builder methods + + /** + * Create a builder with the given status. + * @param status the response status + * @return the created builder + * @since 4.1 + */ + public static BodyBuilder status(HttpStatus status) { + Assert.notNull(status, "HttpStatus must not be null"); + return new DefaultBuilder(status); + } + + /** + * Create a builder with the given status. + * @param status the response status + * @return the created builder + * @since 4.1 + */ + public static BodyBuilder status(int status) { + return new DefaultBuilder(status); + } + + /** + * Create a builder with the status set to {@linkplain HttpStatus#OK OK}. + * @return the created builder + * @since 4.1 + */ + public static BodyBuilder ok() { + return status(HttpStatus.OK); + } + + /** + * A shortcut for creating a {@code ResponseEntity} with the given body and + * the status set to {@linkplain HttpStatus#OK OK}. + * @return the created {@code ResponseEntity} + * @since 4.1 + */ + public static ResponseEntity ok(T body) { + return ok().body(body); + } + + /** + * A shortcut for creating a {@code ResponseEntity} with the given body + * and the {@linkplain HttpStatus#OK OK} status, or an empty body and a + * {@linkplain HttpStatus#NOT_FOUND NOT FOUND} status in case of an + * {@linkplain Optional#empty()} parameter. + * @return the created {@code ResponseEntity} + * @since 5.1 + */ + public static ResponseEntity of(Optional body) { + Assert.notNull(body, "Body must not be null"); + return body.map(ResponseEntity::ok).orElseGet(() -> notFound().build()); + } + + /** + * Create a new builder with a {@linkplain HttpStatus#CREATED CREATED} status + * and a location header set to the given URI. + * @param location the location URI + * @return the created builder + * @since 4.1 + */ + public static BodyBuilder created(URI location) { + return status(HttpStatus.CREATED).location(location); + } + + /** + * Create a builder with an {@linkplain HttpStatus#ACCEPTED ACCEPTED} status. + * @return the created builder + * @since 4.1 + */ + public static BodyBuilder accepted() { + return status(HttpStatus.ACCEPTED); + } + + /** + * Create a builder with a {@linkplain HttpStatus#NO_CONTENT NO_CONTENT} status. + * @return the created builder + * @since 4.1 + */ + public static HeadersBuilder noContent() { + return status(HttpStatus.NO_CONTENT); + } + + /** + * Create a builder with a {@linkplain HttpStatus#BAD_REQUEST BAD_REQUEST} status. + * @return the created builder + * @since 4.1 + */ + public static BodyBuilder badRequest() { + return status(HttpStatus.BAD_REQUEST); + } + + /** + * Create a builder with a {@linkplain HttpStatus#NOT_FOUND NOT_FOUND} status. + * @return the created builder + * @since 4.1 + */ + public static HeadersBuilder notFound() { + return status(HttpStatus.NOT_FOUND); + } + + /** + * Create a builder with an + * {@linkplain HttpStatus#UNPROCESSABLE_ENTITY UNPROCESSABLE_ENTITY} status. + * @return the created builder + * @since 4.1.3 + */ + public static BodyBuilder unprocessableEntity() { + return status(HttpStatus.UNPROCESSABLE_ENTITY); + } + + + /** + * Defines a builder that adds headers to the response entity. + * @since 4.1 + * @param the builder subclass + */ + public interface HeadersBuilder> { + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + B header(String headerName, String... headerValues); + + /** + * Copy the given headers into the entity's headers map. + * @param headers the existing HttpHeaders to copy from + * @return this builder + * @since 4.1.2 + * @see HttpHeaders#add(String, String) + */ + B headers(@Nullable HttpHeaders headers); + + /** + * Set the set of allowed {@link HttpMethod HTTP methods}, as specified + * by the {@code Allow} header. + * @param allowedMethods the allowed methods + * @return this builder + * @see HttpHeaders#setAllow(Set) + */ + B allow(HttpMethod... allowedMethods); + + /** + * Set the entity tag of the body, as specified by the {@code ETag} header. + * @param etag the new entity tag + * @return this builder + * @see HttpHeaders#setETag(String) + */ + B eTag(String etag); + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @param lastModified the last modified date + * @return this builder + * @since 5.1.4 + * @see HttpHeaders#setLastModified(ZonedDateTime) + */ + B lastModified(ZonedDateTime lastModified); + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @param lastModified the last modified date + * @return this builder + * @since 5.1.4 + * @see HttpHeaders#setLastModified(Instant) + */ + B lastModified(Instant lastModified); + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param lastModified the last modified date + * @return this builder + * @see HttpHeaders#setLastModified(long) + */ + B lastModified(long lastModified); + + /** + * Set the location of a resource, as specified by the {@code Location} header. + * @param location the location + * @return this builder + * @see HttpHeaders#setLocation(URI) + */ + B location(URI location); + + /** + * Set the caching directives for the resource, as specified by the HTTP 1.1 + * {@code Cache-Control} header. + *

A {@code CacheControl} instance can be built like + * {@code CacheControl.maxAge(3600).cachePublic().noTransform()}. + * @param cacheControl a builder for cache-related HTTP response headers + * @return this builder + * @since 4.2 + * @see RFC-7234 Section 5.2 + */ + B cacheControl(CacheControl cacheControl); + + /** + * Configure one or more request header names (e.g. "Accept-Language") to + * add to the "Vary" response header to inform clients that the response is + * subject to content negotiation and variances based on the value of the + * given request headers. The configured request header names are added only + * if not already present in the response "Vary" header. + * @param requestHeaders request header names + * @since 4.3 + */ + B varyBy(String... requestHeaders); + + /** + * Build the response entity with no body. + * @return the response entity + * @see BodyBuilder#body(Object) + */ + ResponseEntity build(); + } + + + /** + * Defines a builder that adds a body to the response entity. + * @since 4.1 + */ + public interface BodyBuilder extends HeadersBuilder { + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + * @param contentLength the content length + * @return this builder + * @see HttpHeaders#setContentLength(long) + */ + BodyBuilder contentLength(long contentLength); + + /** + * Set the {@linkplain MediaType media type} of the body, as specified by the + * {@code Content-Type} header. + * @param contentType the content type + * @return this builder + * @see HttpHeaders#setContentType(MediaType) + */ + BodyBuilder contentType(MediaType contentType); + + /** + * Set the body of the response entity and returns it. + * @param the type of the body + * @param body the body of the response entity + * @return the built response entity + */ + ResponseEntity body(@Nullable T body); + } + + + private static class DefaultBuilder implements BodyBuilder { + + private final Object statusCode; + + private final HttpHeaders headers = new HttpHeaders(); + + public DefaultBuilder(Object statusCode) { + this.statusCode = statusCode; + } + + @Override + public BodyBuilder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public BodyBuilder headers(@Nullable HttpHeaders headers) { + if (headers != null) { + this.headers.putAll(headers); + } + return this; + } + + @Override + public BodyBuilder allow(HttpMethod... allowedMethods) { + this.headers.setAllow(new LinkedHashSet<>(Arrays.asList(allowedMethods))); + return this; + } + + @Override + public BodyBuilder contentLength(long contentLength) { + this.headers.setContentLength(contentLength); + return this; + } + + @Override + public BodyBuilder contentType(MediaType contentType) { + this.headers.setContentType(contentType); + return this; + } + + @Override + public BodyBuilder eTag(String etag) { + if (!etag.startsWith("\"") && !etag.startsWith("W/\"")) { + etag = "\"" + etag; + } + if (!etag.endsWith("\"")) { + etag = etag + "\""; + } + this.headers.setETag(etag); + return this; + } + + @Override + public BodyBuilder lastModified(ZonedDateTime date) { + this.headers.setLastModified(date); + return this; + } + + @Override + public BodyBuilder lastModified(Instant date) { + this.headers.setLastModified(date); + return this; + } + + @Override + public BodyBuilder lastModified(long date) { + this.headers.setLastModified(date); + return this; + } + + @Override + public BodyBuilder location(URI location) { + this.headers.setLocation(location); + return this; + } + + @Override + public BodyBuilder cacheControl(CacheControl cacheControl) { + this.headers.setCacheControl(cacheControl); + return this; + } + + @Override + public BodyBuilder varyBy(String... requestHeaders) { + this.headers.setVary(Arrays.asList(requestHeaders)); + return this; + } + + @Override + public ResponseEntity build() { + return body(null); + } + + @Override + public ResponseEntity body(@Nullable T body) { + return new ResponseEntity<>(body, this.headers, this.statusCode); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/StreamingHttpOutputMessage.java b/spring-web/src/main/java/org/springframework/http/StreamingHttpOutputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..8085cc2ebf05142dcc46673eb82e0b40e0ea348e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/StreamingHttpOutputMessage.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Represents a HTTP output message that allows for setting a streaming body. + * Note that such messages typically do not support {@link #getBody()} access. + * + * @author Arjen Poutsma + * @since 4.0 + * @see #setBody + */ +public interface StreamingHttpOutputMessage extends HttpOutputMessage { + + /** + * Set the streaming body callback for this message. + * @param body the streaming body callback + */ + void setBody(Body body); + + + /** + * Defines the contract for bodies that can be written directly to an + * {@link OutputStream}. Useful with HTTP client libraries that provide + * indirect access to an {@link OutputStream} via a callback mechanism. + */ + @FunctionalInterface + interface Body { + + /** + * Write this body to the given {@link OutputStream}. + * @param outputStream the output stream to write to + * @throws IOException in case of I/O errors + */ + void writeTo(OutputStream outputStream) throws IOException; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/ZeroCopyHttpOutputMessage.java b/spring-web/src/main/java/org/springframework/http/ZeroCopyHttpOutputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..9aec7be680ded2342b43606143644259487e171b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ZeroCopyHttpOutputMessage.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.File; +import java.nio.file.Path; + +import reactor.core.publisher.Mono; + +/** + * Sub-interface of {@code ReactiveOutputMessage} that has support for "zero-copy" + * file transfers. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 5.0 + * @see Zero-copy + */ +public interface ZeroCopyHttpOutputMessage extends ReactiveHttpOutputMessage { + + /** + * Use the given {@link File} to write the body of the message to the underlying + * HTTP layer. + * @param file the file to transfer + * @param position the position within the file from which the transfer is to begin + * @param count the number of bytes to be transferred + * @return a publisher that indicates completion or error. + */ + default Mono writeWith(File file, long position, long count) { + return writeWith(file.toPath(), position, count); + } + + /** + * Use the given {@link Path} to write the body of the message to the underlying + * HTTP layer. + * @param file the file to transfer + * @param position the position within the file from which the transfer is to begin + * @param count the number of bytes to be transferred + * @return a publisher that indicates completion or error. + * @since 5.1 + */ + Mono writeWith(Path file, long position, long count); + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractAsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/AbstractAsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..37b09134d94acddf4f3446adcbee51c73b8fe913 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractAsyncClientHttpRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.OutputStream; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * Abstract base for {@link AsyncClientHttpRequest} that makes sure that headers and body + * are not written multiple times. + * + * @author Arjen Poutsma + * @since 4.0 + * @deprecated as of Spring 5.0, in favor of {@link org.springframework.http.client.reactive.AbstractClientHttpRequest} + */ +@Deprecated +abstract class AbstractAsyncClientHttpRequest implements AsyncClientHttpRequest { + + private final HttpHeaders headers = new HttpHeaders(); + + private boolean executed = false; + + + @Override + public final HttpHeaders getHeaders() { + return (this.executed ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public final OutputStream getBody() throws IOException { + assertNotExecuted(); + return getBodyInternal(this.headers); + } + + @Override + public ListenableFuture executeAsync() throws IOException { + assertNotExecuted(); + ListenableFuture result = executeInternal(this.headers); + this.executed = true; + return result; + } + + /** + * Asserts that this request has not been {@linkplain #executeAsync() executed} yet. + * @throws IllegalStateException if this request has been executed + */ + protected void assertNotExecuted() { + Assert.state(!this.executed, "ClientHttpRequest already executed"); + } + + + /** + * Abstract template method that returns the body. + * @param headers the HTTP headers + * @return the body output stream + */ + protected abstract OutputStream getBodyInternal(HttpHeaders headers) throws IOException; + + /** + * Abstract template method that writes the given headers and content to the HTTP request. + * @param headers the HTTP headers + * @return the response object for the executed request + */ + protected abstract ListenableFuture executeInternal(HttpHeaders headers) + throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingAsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingAsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..f5c9aafe0e9fd292730179d647dd1bd96264b1a1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingAsyncClientHttpRequest.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * Base implementation of {@link AsyncClientHttpRequest} that buffers output + * in a byte array before sending it over the wire. + * + * @author Arjen Poutsma + * @since 4.0 + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +abstract class AbstractBufferingAsyncClientHttpRequest extends AbstractAsyncClientHttpRequest { + + private ByteArrayOutputStream bufferedOutput = new ByteArrayOutputStream(1024); + + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + return this.bufferedOutput; + } + + @Override + protected ListenableFuture executeInternal(HttpHeaders headers) throws IOException { + byte[] bytes = this.bufferedOutput.toByteArray(); + if (headers.getContentLength() < 0) { + headers.setContentLength(bytes.length); + } + ListenableFuture result = executeInternal(headers, bytes); + this.bufferedOutput = new ByteArrayOutputStream(0); + return result; + } + + /** + * Abstract template method that writes the given headers and content to the HTTP request. + * @param headers the HTTP headers + * @param bufferedOutput the body content + * @return the response object for the executed request + */ + protected abstract ListenableFuture executeInternal( + HttpHeaders headers, byte[] bufferedOutput) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..abedb2c051c6ee16f42fceceef53a851b73c0500 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import org.springframework.http.HttpHeaders; + +/** + * Base implementation of {@link ClientHttpRequest} that buffers output + * in a byte array before sending it over the wire. + * + * @author Arjen Poutsma + * @since 3.0.6 + */ +abstract class AbstractBufferingClientHttpRequest extends AbstractClientHttpRequest { + + private ByteArrayOutputStream bufferedOutput = new ByteArrayOutputStream(1024); + + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + return this.bufferedOutput; + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException { + byte[] bytes = this.bufferedOutput.toByteArray(); + if (headers.getContentLength() < 0) { + headers.setContentLength(bytes.length); + } + ClientHttpResponse result = executeInternal(headers, bytes); + this.bufferedOutput = new ByteArrayOutputStream(0); + return result; + } + + /** + * Abstract template method that writes the given headers and content to the HTTP request. + * @param headers the HTTP headers + * @param bufferedOutput the body content + * @return the response object for the executed request + */ + protected abstract ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) + throws IOException; + + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..2848d75f52a02fa6fdc0f92ca41de8eb24f068ac --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.OutputStream; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; + +/** + * Abstract base for {@link ClientHttpRequest} that makes sure that headers + * and body are not written multiple times. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public abstract class AbstractClientHttpRequest implements ClientHttpRequest { + + private final HttpHeaders headers = new HttpHeaders(); + + private boolean executed = false; + + + @Override + public final HttpHeaders getHeaders() { + return (this.executed ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public final OutputStream getBody() throws IOException { + assertNotExecuted(); + return getBodyInternal(this.headers); + } + + @Override + public final ClientHttpResponse execute() throws IOException { + assertNotExecuted(); + ClientHttpResponse result = executeInternal(this.headers); + this.executed = true; + return result; + } + + /** + * Assert that this request has not been {@linkplain #execute() executed} yet. + * @throws IllegalStateException if this request has been executed + */ + protected void assertNotExecuted() { + Assert.state(!this.executed, "ClientHttpRequest already executed"); + } + + + /** + * Abstract template method that returns the body. + * @param headers the HTTP headers + * @return the body output stream + */ + protected abstract OutputStream getBodyInternal(HttpHeaders headers) throws IOException; + + /** + * Abstract template method that writes the given headers and content to the HTTP request. + * @param headers the HTTP headers + * @return the response object for the executed request + */ + protected abstract ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..2f8a6446ab01a3ac3267e21ae00c1ca5e426c45e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; + +/** + * Abstract base class for {@link ClientHttpRequestFactory} implementations + * that decorate another request factory. + * + * @author Arjen Poutsma + * @since 3.1 + */ +public abstract class AbstractClientHttpRequestFactoryWrapper implements ClientHttpRequestFactory { + + private final ClientHttpRequestFactory requestFactory; + + + /** + * Create a {@code AbstractClientHttpRequestFactoryWrapper} wrapping the given request factory. + * @param requestFactory the request factory to be wrapped + */ + protected AbstractClientHttpRequestFactoryWrapper(ClientHttpRequestFactory requestFactory) { + Assert.notNull(requestFactory, "ClientHttpRequestFactory must not be null"); + this.requestFactory = requestFactory; + } + + + /** + * This implementation simply calls {@link #createRequest(URI, HttpMethod, ClientHttpRequestFactory)} + * with the wrapped request factory provided to the + * {@linkplain #AbstractClientHttpRequestFactoryWrapper(ClientHttpRequestFactory) constructor}. + */ + @Override + public final ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + return createRequest(uri, httpMethod, this.requestFactory); + } + + /** + * Create a new {@link ClientHttpRequest} for the specified URI and HTTP method + * by using the passed-on request factory. + *

Called from {@link #createRequest(URI, HttpMethod)}. + * @param uri the URI to create a request for + * @param httpMethod the HTTP method to execute + * @param requestFactory the wrapped request factory + * @return the created request + * @throws IOException in case of I/O errors + */ + protected abstract ClientHttpRequest createRequest( + URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..f45c55c23b43329b6ba0745f901b3d5d261a029b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpResponse.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; + +import org.springframework.http.HttpStatus; + +/** + * Abstract base for {@link ClientHttpResponse}. + * + * @author Arjen Poutsma + * @since 3.1.1 + */ +public abstract class AbstractClientHttpResponse implements ClientHttpResponse { + + @Override + public HttpStatus getStatusCode() throws IOException { + return HttpStatus.valueOf(getRawStatusCode()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..39f3c83ae887688829b344a155fd1c07789ede47 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; + +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.HttpRequest; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * Represents a client-side asynchronous HTTP request. Created via an + * implementation of the {@link AsyncClientHttpRequestFactory}. + * + *

A {@code AsyncHttpRequest} can be {@linkplain #executeAsync() executed}, + * getting a future {@link ClientHttpResponse} which can be read from. + * + * @author Arjen Poutsma + * @since 4.0 + * @see AsyncClientHttpRequestFactory#createAsyncRequest + * @deprecated as of Spring 5.0, in favor of {@link org.springframework.web.reactive.function.client.ClientRequest} + */ +@Deprecated +public interface AsyncClientHttpRequest extends HttpRequest, HttpOutputMessage { + + /** + * Execute this request asynchronously, resulting in a Future handle. + * {@link ClientHttpResponse} that can be read. + * @return the future response result of the execution + * @throws java.io.IOException in case of I/O errors + */ + ListenableFuture executeAsync() throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestExecution.java b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestExecution.java new file mode 100644 index 0000000000000000000000000000000000000000..97ecb225cb32d035fbc82d31106a71ca2a50dafe --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestExecution.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; + +import org.springframework.http.HttpRequest; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * Represents the context of a client-side HTTP request execution. + * + *

Used to invoke the next interceptor in the interceptor chain, or - + * if the calling interceptor is last - execute the request itself. + * + * @author Jakub Narloch + * @author Rossen Stoyanchev + * @since 4.3 + * @see AsyncClientHttpRequestInterceptor + * @deprecated as of Spring 5.0, in favor of + * {@link org.springframework.web.reactive.function.client.ExchangeFilterFunction} + */ +@Deprecated +public interface AsyncClientHttpRequestExecution { + + /** + * Resume the request execution by invoking the next interceptor in the chain + * or executing the request to the remote service. + * @param request the HTTP request, containing the HTTP method and headers + * @param body the body of the request + * @return a corresponding future handle + * @throws IOException in case of I/O errors + */ + ListenableFuture executeAsync(HttpRequest request, byte[] body) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..509d1da713a3815acdfeced659f93646537a70ea --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestFactory.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpMethod; + +/** + * Factory for {@link AsyncClientHttpRequest} objects. + * Requests are created by the {@link #createAsyncRequest(URI, HttpMethod)} method. + * + * @author Arjen Poutsma + * @since 4.0 + * @deprecated as of Spring 5.0, in favor of {@link org.springframework.http.client.reactive.ClientHttpConnector} + */ +@Deprecated +public interface AsyncClientHttpRequestFactory { + + /** + * Create a new asynchronous {@link AsyncClientHttpRequest} for the specified URI + * and HTTP method. + *

The returned request can be written to, and then executed by calling + * {@link AsyncClientHttpRequest#executeAsync()}. + * @param uri the URI to create a request for + * @param httpMethod the HTTP method to execute + * @return the created request + * @throws IOException in case of I/O errors + */ + AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestInterceptor.java b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..34ed8c024e68051baa370fd7d12d4adaf0a03cee --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/AsyncClientHttpRequestInterceptor.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; + +import org.springframework.http.HttpRequest; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * Intercepts client-side HTTP requests. Implementations of this interface can be + * {@linkplain org.springframework.web.client.AsyncRestTemplate#setInterceptors registered} + * with the {@link org.springframework.web.client.AsyncRestTemplate} as to modify + * the outgoing {@link HttpRequest} and/or register to modify the incoming + * {@link ClientHttpResponse} with help of a + * {@link org.springframework.util.concurrent.ListenableFutureAdapter}. + * + *

The main entry point for interceptors is {@link #intercept}. + * + * @author Jakub Narloch + * @author Rossen Stoyanchev + * @since 4.3 + * @see org.springframework.web.client.AsyncRestTemplate + * @see org.springframework.http.client.support.InterceptingAsyncHttpAccessor + * @deprecated as of Spring 5.0, in favor of + * {@link org.springframework.web.reactive.function.client.ExchangeFilterFunction} + */ +@Deprecated +public interface AsyncClientHttpRequestInterceptor { + + /** + * Intercept the given request, and return a response future. The given + * {@link AsyncClientHttpRequestExecution} allows the interceptor to pass on + * the request to the next entity in the chain. + *

An implementation might follow this pattern: + *

    + *
  1. Examine the {@linkplain HttpRequest request} and body
  2. + *
  3. Optionally {@linkplain org.springframework.http.client.support.HttpRequestWrapper + * wrap} the request to filter HTTP attributes.
  4. + *
  5. Optionally modify the body of the request.
  6. + *
  7. One of the following: + *
      + *
    • execute the request through {@link ClientHttpRequestExecution}
    • + *
    • don't execute the request to block the execution altogether
    • + *
    + *
  8. Optionally adapt the response to filter HTTP attributes with the help of + * {@link org.springframework.util.concurrent.ListenableFutureAdapter + * ListenableFutureAdapter}.
  9. + *
+ * @param request the request, containing method, URI, and headers + * @param body the body of the request + * @param execution the request execution + * @return the response future + * @throws IOException in case of I/O errors + */ + ListenableFuture intercept(HttpRequest request, byte[] body, + AsyncClientHttpRequestExecution execution) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..ecf0a544757ffff20f6ba42b9b9f1e06c7d3bfe3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpRequestFactory.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpMethod; + +/** + * Wrapper for a {@link ClientHttpRequestFactory} that buffers + * all outgoing and incoming streams in memory. + * + *

Using this wrapper allows for multiple reads of the + * {@linkplain ClientHttpResponse#getBody() response body}. + * + * @author Arjen Poutsma + * @since 3.1 + */ +public class BufferingClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper { + + /** + * Create a buffering wrapper for the given {@link ClientHttpRequestFactory}. + * @param requestFactory the target request factory to wrap + */ + public BufferingClientHttpRequestFactory(ClientHttpRequestFactory requestFactory) { + super(requestFactory); + } + + + @Override + protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) + throws IOException { + + ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod); + if (shouldBuffer(uri, httpMethod)) { + return new BufferingClientHttpRequestWrapper(request); + } + else { + return request; + } + } + + /** + * Indicates whether the request/response exchange for the given URI and method + * should be buffered in memory. + *

The default implementation returns {@code true} for all URIs and methods. + * Subclasses can override this method to change this behavior. + * @param uri the URI + * @param httpMethod the method + * @return {@code true} if the exchange should be buffered; {@code false} otherwise + */ + protected boolean shouldBuffer(URI uri, HttpMethod httpMethod) { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpRequestWrapper.java b/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpRequestWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..c7daa6115e2a1895bd39cff2b9e9dd16afc4885d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpRequestWrapper.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * Simple implementation of {@link ClientHttpRequest} that wraps another request. + * + * @author Arjen Poutsma + * @since 3.1 + */ +final class BufferingClientHttpRequestWrapper extends AbstractBufferingClientHttpRequest { + + private final ClientHttpRequest request; + + + BufferingClientHttpRequestWrapper(ClientHttpRequest request) { + this.request = request; + } + + + @Override + @Nullable + public HttpMethod getMethod() { + return this.request.getMethod(); + } + + @Override + public String getMethodValue() { + return this.request.getMethodValue(); + } + + @Override + public URI getURI() { + return this.request.getURI(); + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + this.request.getHeaders().putAll(headers); + StreamUtils.copy(bufferedOutput, this.request.getBody()); + ClientHttpResponse response = this.request.execute(); + return new BufferingClientHttpResponseWrapper(response); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpResponseWrapper.java b/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpResponseWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..b409a9518c527653ff164a57cffec58a51866649 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/BufferingClientHttpResponseWrapper.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * Simple implementation of {@link ClientHttpResponse} that reads the response's body + * into memory, thus allowing for multiple invocations of {@link #getBody()}. + * + * @author Arjen Poutsma + * @since 3.1 + */ +final class BufferingClientHttpResponseWrapper implements ClientHttpResponse { + + private final ClientHttpResponse response; + + @Nullable + private byte[] body; + + + BufferingClientHttpResponseWrapper(ClientHttpResponse response) { + this.response = response; + } + + + @Override + public HttpStatus getStatusCode() throws IOException { + return this.response.getStatusCode(); + } + + @Override + public int getRawStatusCode() throws IOException { + return this.response.getRawStatusCode(); + } + + @Override + public String getStatusText() throws IOException { + return this.response.getStatusText(); + } + + @Override + public HttpHeaders getHeaders() { + return this.response.getHeaders(); + } + + @Override + public InputStream getBody() throws IOException { + if (this.body == null) { + this.body = StreamUtils.copyToByteArray(this.response.getBody()); + } + return new ByteArrayInputStream(this.body); + } + + @Override + public void close() { + this.response.close(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..a925b76f6d1f94d289502e2ece4ff0162d81966b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.HttpRequest; + +/** + * Represents a client-side HTTP request. + * Created via an implementation of the {@link ClientHttpRequestFactory}. + * + *

A {@code ClientHttpRequest} can be {@linkplain #execute() executed}, + * receiving a {@link ClientHttpResponse} which can be read from. + * + * @author Arjen Poutsma + * @since 3.0 + * @see ClientHttpRequestFactory#createRequest(java.net.URI, HttpMethod) + */ +public interface ClientHttpRequest extends HttpRequest, HttpOutputMessage { + + /** + * Execute this request, resulting in a {@link ClientHttpResponse} that can be read. + * @return the response result of the execution + * @throws IOException in case of I/O errors + */ + ClientHttpResponse execute() throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestExecution.java b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestExecution.java new file mode 100644 index 0000000000000000000000000000000000000000..31a1c685ecfd84a3d194df1936ed61b515c73cfa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestExecution.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; + +import org.springframework.http.HttpRequest; + +/** + * Represents the context of a client-side HTTP request execution. + * + *

Used to invoke the next interceptor in the interceptor chain, + * or - if the calling interceptor is last - execute the request itself. + * + * @author Arjen Poutsma + * @since 3.1 + * @see ClientHttpRequestInterceptor + */ +@FunctionalInterface +public interface ClientHttpRequestExecution { + + /** + * Execute the request with the given request attributes and body, + * and return the response. + * @param request the request, containing method, URI, and headers + * @param body the body of the request to execute + * @return the response + * @throws IOException in case of I/O errors + */ + ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..2c433fb6fee7466ff874c33e8f67ae774d58636a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestFactory.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpMethod; + +/** + * Factory for {@link ClientHttpRequest} objects. + * Requests are created by the {@link #createRequest(URI, HttpMethod)} method. + * + * @author Arjen Poutsma + * @since 3.0 + */ +@FunctionalInterface +public interface ClientHttpRequestFactory { + + /** + * Create a new {@link ClientHttpRequest} for the specified URI and HTTP method. + *

The returned request can be written to, and then executed by calling + * {@link ClientHttpRequest#execute()}. + * @param uri the URI to create a request for + * @param httpMethod the HTTP method to execute + * @return the created request + * @throws IOException in case of I/O errors + */ + ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestInterceptor.java b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..966869878643e433c02512435201a282de76fe6f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/ClientHttpRequestInterceptor.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; + +import org.springframework.http.HttpRequest; + +/** + * Intercepts client-side HTTP requests. Implementations of this interface can be + * {@linkplain org.springframework.web.client.RestTemplate#setInterceptors registered} + * with the {@link org.springframework.web.client.RestTemplate RestTemplate}, + * as to modify the outgoing {@link ClientHttpRequest} and/or the incoming + * {@link ClientHttpResponse}. + * + *

The main entry point for interceptors is + * {@link #intercept(HttpRequest, byte[], ClientHttpRequestExecution)}. + * + * @author Arjen Poutsma + * @since 3.1 + */ +@FunctionalInterface +public interface ClientHttpRequestInterceptor { + + /** + * Intercept the given request, and return a response. The given + * {@link ClientHttpRequestExecution} allows the interceptor to pass on the + * request and response to the next entity in the chain. + *

A typical implementation of this method would follow the following pattern: + *

    + *
  1. Examine the {@linkplain HttpRequest request} and body
  2. + *
  3. Optionally {@linkplain org.springframework.http.client.support.HttpRequestWrapper + * wrap} the request to filter HTTP attributes.
  4. + *
  5. Optionally modify the body of the request.
  6. + *
  7. Either + *
      + *
    • execute the request using + * {@link ClientHttpRequestExecution#execute(org.springframework.http.HttpRequest, byte[])},
    • + * or + *
    • do not execute the request to block the execution altogether.
    • + *
    + *
  8. Optionally wrap the response to filter HTTP attributes.
  9. + *
+ * @param request the request, containing method, URI, and headers + * @param body the body of the request + * @param execution the request execution + * @return the response + * @throws IOException in case of I/O errors + */ + ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/ClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/ClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..b3b08c2350ea5363a283dbc921c1bb81c65295e7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/ClientHttpResponse.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.Closeable; +import java.io.IOException; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpStatus; + +/** + * Represents a client-side HTTP response. + * Obtained via an calling of the {@link ClientHttpRequest#execute()}. + * + *

A {@code ClientHttpResponse} must be {@linkplain #close() closed}, + * typically in a {@code finally} block. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public interface ClientHttpResponse extends HttpInputMessage, Closeable { + + /** + * Return the HTTP status code as an {@link HttpStatus} enum value. + * @return the HTTP status as an HttpStatus enum value (never {@code null}) + * @throws IOException in case of I/O errors + * @throws IllegalArgumentException in case of an unknown HTTP status code + * @since #getRawStatusCode() + * @see HttpStatus#valueOf(int) + */ + HttpStatus getStatusCode() throws IOException; + + /** + * Return the HTTP status code (potentially non-standard and not + * resolvable through the {@link HttpStatus} enum) as an integer. + * @return the HTTP status as an integer value + * @throws IOException in case of I/O errors + * @since 3.1.1 + * @see #getStatusCode() + * @see HttpStatus#resolve(int) + */ + int getRawStatusCode() throws IOException; + + /** + * Return the HTTP status text of the response. + * @return the HTTP status text + * @throws IOException in case of I/O errors + */ + String getStatusText() throws IOException; + + /** + * Close this response, freeing any resources created. + */ + @Override + void close(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..d1166051c759d2bdfae48103644bfcbeffc1ba21 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequest.java @@ -0,0 +1,171 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; +import java.util.concurrent.Future; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpEntityEnclosingRequest; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.concurrent.FutureCallback; +import org.apache.http.nio.client.HttpAsyncClient; +import org.apache.http.nio.entity.NByteArrayEntity; +import org.apache.http.protocol.HttpContext; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.concurrent.FailureCallback; +import org.springframework.util.concurrent.FutureAdapter; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.ListenableFutureCallbackRegistry; +import org.springframework.util.concurrent.SuccessCallback; + + +/** + * {@link ClientHttpRequest} implementation based on + * Apache HttpComponents HttpAsyncClient. + * + *

Created via the {@link HttpComponentsClientHttpRequestFactory}. + * + * @author Oleg Kalnichevski + * @author Arjen Poutsma + * @since 4.0 + * @see HttpComponentsClientHttpRequestFactory#createRequest + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +final class HttpComponentsAsyncClientHttpRequest extends AbstractBufferingAsyncClientHttpRequest { + + private final HttpAsyncClient httpClient; + + private final HttpUriRequest httpRequest; + + private final HttpContext httpContext; + + + HttpComponentsAsyncClientHttpRequest(HttpAsyncClient client, HttpUriRequest request, HttpContext context) { + this.httpClient = client; + this.httpRequest = request; + this.httpContext = context; + } + + + @Override + public String getMethodValue() { + return this.httpRequest.getMethod(); + } + + @Override + public URI getURI() { + return this.httpRequest.getURI(); + } + + HttpContext getHttpContext() { + return this.httpContext; + } + + @Override + protected ListenableFuture executeInternal(HttpHeaders headers, byte[] bufferedOutput) + throws IOException { + + HttpComponentsClientHttpRequest.addHeaders(this.httpRequest, headers); + + if (this.httpRequest instanceof HttpEntityEnclosingRequest) { + HttpEntityEnclosingRequest entityEnclosingRequest = (HttpEntityEnclosingRequest) this.httpRequest; + HttpEntity requestEntity = new NByteArrayEntity(bufferedOutput); + entityEnclosingRequest.setEntity(requestEntity); + } + + HttpResponseFutureCallback callback = new HttpResponseFutureCallback(this.httpRequest); + Future futureResponse = this.httpClient.execute(this.httpRequest, this.httpContext, callback); + return new ClientHttpResponseFuture(futureResponse, callback); + } + + + private static class HttpResponseFutureCallback implements FutureCallback { + + private final HttpUriRequest request; + + private final ListenableFutureCallbackRegistry callbacks = + new ListenableFutureCallbackRegistry<>(); + + public HttpResponseFutureCallback(HttpUriRequest request) { + this.request = request; + } + + public void addCallback(ListenableFutureCallback callback) { + this.callbacks.addCallback(callback); + } + + public void addSuccessCallback(SuccessCallback callback) { + this.callbacks.addSuccessCallback(callback); + } + + public void addFailureCallback(FailureCallback callback) { + this.callbacks.addFailureCallback(callback); + } + + @Override + public void completed(HttpResponse result) { + this.callbacks.success(new HttpComponentsAsyncClientHttpResponse(result)); + } + + @Override + public void failed(Exception ex) { + this.callbacks.failure(ex); + } + + @Override + public void cancelled() { + this.request.abort(); + } + } + + + private static class ClientHttpResponseFuture extends FutureAdapter + implements ListenableFuture { + + private final HttpResponseFutureCallback callback; + + public ClientHttpResponseFuture(Future response, HttpResponseFutureCallback callback) { + super(response); + this.callback = callback; + } + + @Override + protected ClientHttpResponse adapt(HttpResponse response) { + return new HttpComponentsAsyncClientHttpResponse(response); + } + + @Override + public void addCallback(ListenableFutureCallback callback) { + this.callback.addCallback(callback); + } + + @Override + public void addCallback(SuccessCallback successCallback, + FailureCallback failureCallback) { + + this.callback.addSuccessCallback(successCallback); + this.callback.addFailureCallback(failureCallback); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..fbf603009aba5e76041ffbd4e5e619439c9642ef --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequestFactory.java @@ -0,0 +1,214 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.Closeable; +import java.io.IOException; +import java.net.URI; + +import org.apache.http.client.HttpClient; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.Configurable; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; +import org.apache.http.impl.nio.client.HttpAsyncClients; +import org.apache.http.nio.client.HttpAsyncClient; +import org.apache.http.protocol.HttpContext; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; + +/** + * Asynchronous extension of the {@link HttpComponentsClientHttpRequestFactory}. Uses + * Apache HttpComponents + * HttpAsyncClient 4.0 to create requests. + * + * @author Arjen Poutsma + * @author Stephane Nicoll + * @since 4.0 + * @see HttpAsyncClient + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +public class HttpComponentsAsyncClientHttpRequestFactory extends HttpComponentsClientHttpRequestFactory + implements AsyncClientHttpRequestFactory, InitializingBean { + + private HttpAsyncClient asyncClient; + + + /** + * Create a new instance of the {@code HttpComponentsAsyncClientHttpRequestFactory} + * with a default {@link HttpAsyncClient} and {@link HttpClient}. + */ + public HttpComponentsAsyncClientHttpRequestFactory() { + super(); + this.asyncClient = HttpAsyncClients.createSystem(); + } + + /** + * Create a new instance of the {@code HttpComponentsAsyncClientHttpRequestFactory} + * with the given {@link HttpAsyncClient} instance and a default {@link HttpClient}. + * @param asyncClient the HttpAsyncClient instance to use for this request factory + * @since 4.3.10 + */ + public HttpComponentsAsyncClientHttpRequestFactory(HttpAsyncClient asyncClient) { + super(); + this.asyncClient = asyncClient; + } + + /** + * Create a new instance of the {@code HttpComponentsAsyncClientHttpRequestFactory} + * with the given {@link CloseableHttpAsyncClient} instance and a default {@link HttpClient}. + * @param asyncClient the CloseableHttpAsyncClient instance to use for this request factory + */ + public HttpComponentsAsyncClientHttpRequestFactory(CloseableHttpAsyncClient asyncClient) { + super(); + this.asyncClient = asyncClient; + } + + /** + * Create a new instance of the {@code HttpComponentsAsyncClientHttpRequestFactory} + * with the given {@link HttpClient} and {@link HttpAsyncClient} instances. + * @param httpClient the HttpClient instance to use for this request factory + * @param asyncClient the HttpAsyncClient instance to use for this request factory + * @since 4.3.10 + */ + public HttpComponentsAsyncClientHttpRequestFactory(HttpClient httpClient, HttpAsyncClient asyncClient) { + super(httpClient); + this.asyncClient = asyncClient; + } + + /** + * Create a new instance of the {@code HttpComponentsAsyncClientHttpRequestFactory} + * with the given {@link CloseableHttpClient} and {@link CloseableHttpAsyncClient} instances. + * @param httpClient the CloseableHttpClient instance to use for this request factory + * @param asyncClient the CloseableHttpAsyncClient instance to use for this request factory + */ + public HttpComponentsAsyncClientHttpRequestFactory( + CloseableHttpClient httpClient, CloseableHttpAsyncClient asyncClient) { + + super(httpClient); + this.asyncClient = asyncClient; + } + + + /** + * Set the {@code HttpAsyncClient} used for + * {@linkplain #createAsyncRequest(URI, HttpMethod) synchronous execution}. + * @since 4.3.10 + * @see #setHttpClient(HttpClient) + */ + public void setAsyncClient(HttpAsyncClient asyncClient) { + Assert.notNull(asyncClient, "HttpAsyncClient must not be null"); + this.asyncClient = asyncClient; + } + + /** + * Return the {@code HttpAsyncClient} used for + * {@linkplain #createAsyncRequest(URI, HttpMethod) synchronous execution}. + * @since 4.3.10 + * @see #getHttpClient() + */ + public HttpAsyncClient getAsyncClient() { + return this.asyncClient; + } + + /** + * Set the {@code CloseableHttpAsyncClient} used for + * {@linkplain #createAsyncRequest(URI, HttpMethod) asynchronous execution}. + * @deprecated as of 4.3.10, in favor of {@link #setAsyncClient(HttpAsyncClient)} + */ + @Deprecated + public void setHttpAsyncClient(CloseableHttpAsyncClient asyncClient) { + this.asyncClient = asyncClient; + } + + /** + * Return the {@code CloseableHttpAsyncClient} used for + * {@linkplain #createAsyncRequest(URI, HttpMethod) asynchronous execution}. + * @deprecated as of 4.3.10, in favor of {@link #getAsyncClient()} + */ + @Deprecated + public CloseableHttpAsyncClient getHttpAsyncClient() { + Assert.state(this.asyncClient instanceof CloseableHttpAsyncClient, + "No CloseableHttpAsyncClient - use getAsyncClient() instead"); + return (CloseableHttpAsyncClient) this.asyncClient; + } + + + @Override + public void afterPropertiesSet() { + startAsyncClient(); + } + + private HttpAsyncClient startAsyncClient() { + HttpAsyncClient client = getAsyncClient(); + if (client instanceof CloseableHttpAsyncClient) { + CloseableHttpAsyncClient closeableAsyncClient = (CloseableHttpAsyncClient) client; + if (!closeableAsyncClient.isRunning()) { + closeableAsyncClient.start(); + } + } + return client; + } + + @Override + public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) throws IOException { + HttpAsyncClient client = startAsyncClient(); + + HttpUriRequest httpRequest = createHttpUriRequest(httpMethod, uri); + postProcessHttpRequest(httpRequest); + HttpContext context = createHttpContext(httpMethod, uri); + if (context == null) { + context = HttpClientContext.create(); + } + + // Request configuration not set in the context + if (context.getAttribute(HttpClientContext.REQUEST_CONFIG) == null) { + // Use request configuration given by the user, when available + RequestConfig config = null; + if (httpRequest instanceof Configurable) { + config = ((Configurable) httpRequest).getConfig(); + } + if (config == null) { + config = createRequestConfig(client); + } + if (config != null) { + context.setAttribute(HttpClientContext.REQUEST_CONFIG, config); + } + } + + return new HttpComponentsAsyncClientHttpRequest(client, httpRequest, context); + } + + @Override + public void destroy() throws Exception { + try { + super.destroy(); + } + finally { + HttpAsyncClient asyncClient = getAsyncClient(); + if (asyncClient instanceof Closeable) { + ((Closeable) asyncClient).close(); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..cba3a722c43ffa475136062aea3ed5575eb4d925 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsAsyncClientHttpResponse.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.InputStream; + +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * {@link ClientHttpResponse} implementation based on + * Apache HttpComponents HttpAsyncClient. + * + *

Created via the {@link HttpComponentsAsyncClientHttpRequest}. + * + * @author Oleg Kalnichevski + * @author Arjen Poutsma + * @since 4.0 + * @see HttpComponentsAsyncClientHttpRequest#executeAsync() + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +final class HttpComponentsAsyncClientHttpResponse extends AbstractClientHttpResponse { + + private final HttpResponse httpResponse; + + @Nullable + private HttpHeaders headers; + + + HttpComponentsAsyncClientHttpResponse(HttpResponse httpResponse) { + this.httpResponse = httpResponse; + } + + + @Override + public int getRawStatusCode() throws IOException { + return this.httpResponse.getStatusLine().getStatusCode(); + } + + @Override + public String getStatusText() throws IOException { + return this.httpResponse.getStatusLine().getReasonPhrase(); + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + for (Header header : this.httpResponse.getAllHeaders()) { + this.headers.add(header.getName(), header.getValue()); + } + } + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + HttpEntity entity = this.httpResponse.getEntity(); + return (entity != null ? entity.getContent() : StreamUtils.emptyInput()); + } + + @Override + public void close() { + // HTTP responses returned by async HTTP client are not bound to an + // active connection and do not have to deallocate any resources... + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..e1b328df0ffd01c6473247c6512bde8a81990061 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpEntityEnclosingRequest; +import org.apache.http.HttpResponse; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.protocol.HTTP; +import org.apache.http.protocol.HttpContext; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.util.StringUtils; + +/** + * {@link ClientHttpRequest} implementation based on + * Apache HttpComponents HttpClient. + * + *

Created via the {@link HttpComponentsClientHttpRequestFactory}. + * + * @author Oleg Kalnichevski + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.1 + * @see HttpComponentsClientHttpRequestFactory#createRequest(URI, HttpMethod) + */ +final class HttpComponentsClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final HttpClient httpClient; + + private final HttpUriRequest httpRequest; + + private final HttpContext httpContext; + + + HttpComponentsClientHttpRequest(HttpClient client, HttpUriRequest request, HttpContext context) { + this.httpClient = client; + this.httpRequest = request; + this.httpContext = context; + } + + + @Override + public String getMethodValue() { + return this.httpRequest.getMethod(); + } + + @Override + public URI getURI() { + return this.httpRequest.getURI(); + } + + HttpContext getHttpContext() { + return this.httpContext; + } + + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + addHeaders(this.httpRequest, headers); + + if (this.httpRequest instanceof HttpEntityEnclosingRequest) { + HttpEntityEnclosingRequest entityEnclosingRequest = (HttpEntityEnclosingRequest) this.httpRequest; + HttpEntity requestEntity = new ByteArrayEntity(bufferedOutput); + entityEnclosingRequest.setEntity(requestEntity); + } + HttpResponse httpResponse = this.httpClient.execute(this.httpRequest, this.httpContext); + return new HttpComponentsClientHttpResponse(httpResponse); + } + + + /** + * Add the given headers to the given HTTP request. + * @param httpRequest the request to add the headers to + * @param headers the headers to add + */ + static void addHeaders(HttpUriRequest httpRequest, HttpHeaders headers) { + headers.forEach((headerName, headerValues) -> { + if (HttpHeaders.COOKIE.equalsIgnoreCase(headerName)) { // RFC 6265 + String headerValue = StringUtils.collectionToDelimitedString(headerValues, "; "); + httpRequest.addHeader(headerName, headerValue); + } + else if (!HTTP.CONTENT_LEN.equalsIgnoreCase(headerName) && + !HTTP.TRANSFER_ENCODING.equalsIgnoreCase(headerName)) { + for (String headerValue : headerValues) { + httpRequest.addHeader(headerName, headerValue); + } + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..5d6cbd9fb7566f38da2566abb9362713ed68eac3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactory.java @@ -0,0 +1,338 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.Closeable; +import java.io.IOException; +import java.net.URI; + +import org.apache.http.client.HttpClient; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.Configurable; +import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpHead; +import org.apache.http.client.methods.HttpOptions; +import org.apache.http.client.methods.HttpPatch; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpTrace; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.protocol.HttpContext; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link org.springframework.http.client.ClientHttpRequestFactory} implementation that + * uses Apache HttpComponents + * HttpClient to create requests. + * + *

Allows to use a pre-configured {@link HttpClient} instance - + * potentially with authentication, HTTP connection pooling, etc. + * + *

NOTE: Requires Apache HttpComponents 4.3 or higher, as of Spring 4.0. + * + * @author Oleg Kalnichevski + * @author Arjen Poutsma + * @author Stephane Nicoll + * @author Juergen Hoeller + * @since 3.1 + */ +public class HttpComponentsClientHttpRequestFactory implements ClientHttpRequestFactory, DisposableBean { + + private HttpClient httpClient; + + @Nullable + private RequestConfig requestConfig; + + private boolean bufferRequestBody = true; + + + /** + * Create a new instance of the {@code HttpComponentsClientHttpRequestFactory} + * with a default {@link HttpClient} based on system properties. + */ + public HttpComponentsClientHttpRequestFactory() { + this.httpClient = HttpClients.createSystem(); + } + + /** + * Create a new instance of the {@code HttpComponentsClientHttpRequestFactory} + * with the given {@link HttpClient} instance. + * @param httpClient the HttpClient instance to use for this request factory + */ + public HttpComponentsClientHttpRequestFactory(HttpClient httpClient) { + this.httpClient = httpClient; + } + + + /** + * Set the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public void setHttpClient(HttpClient httpClient) { + Assert.notNull(httpClient, "HttpClient must not be null"); + this.httpClient = httpClient; + } + + /** + * Return the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public HttpClient getHttpClient() { + return this.httpClient; + } + + /** + * Set the connection timeout for the underlying {@link RequestConfig}. + * A timeout value of 0 specifies an infinite timeout. + *

Additional properties can be configured by specifying a + * {@link RequestConfig} instance on a custom {@link HttpClient}. + *

This options does not affect connection timeouts for SSL + * handshakes or CONNECT requests; for that, it is required to + * use the {@link org.apache.http.config.SocketConfig} on the + * {@link HttpClient} itself. + * @param timeout the timeout value in milliseconds + * @see RequestConfig#getConnectTimeout() + * @see org.apache.http.config.SocketConfig#getSoTimeout + */ + public void setConnectTimeout(int timeout) { + Assert.isTrue(timeout >= 0, "Timeout must be a non-negative value"); + this.requestConfig = requestConfigBuilder().setConnectTimeout(timeout).build(); + } + + /** + * Set the timeout in milliseconds used when requesting a connection + * from the connection manager using the underlying {@link RequestConfig}. + * A timeout value of 0 specifies an infinite timeout. + *

Additional properties can be configured by specifying a + * {@link RequestConfig} instance on a custom {@link HttpClient}. + * @param connectionRequestTimeout the timeout value to request a connection in milliseconds + * @see RequestConfig#getConnectionRequestTimeout() + */ + public void setConnectionRequestTimeout(int connectionRequestTimeout) { + this.requestConfig = requestConfigBuilder() + .setConnectionRequestTimeout(connectionRequestTimeout).build(); + } + + /** + * Set the socket read timeout for the underlying {@link RequestConfig}. + * A timeout value of 0 specifies an infinite timeout. + *

Additional properties can be configured by specifying a + * {@link RequestConfig} instance on a custom {@link HttpClient}. + * @param timeout the timeout value in milliseconds + * @see RequestConfig#getSocketTimeout() + */ + public void setReadTimeout(int timeout) { + Assert.isTrue(timeout >= 0, "Timeout must be a non-negative value"); + this.requestConfig = requestConfigBuilder().setSocketTimeout(timeout).build(); + } + + /** + * Indicates whether this request factory should buffer the request body internally. + *

Default is {@code true}. When sending large amounts of data via POST or PUT, it is + * recommended to change this property to {@code false}, so as not to run out of memory. + * @since 4.0 + */ + public void setBufferRequestBody(boolean bufferRequestBody) { + this.bufferRequestBody = bufferRequestBody; + } + + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + HttpClient client = getHttpClient(); + + HttpUriRequest httpRequest = createHttpUriRequest(httpMethod, uri); + postProcessHttpRequest(httpRequest); + HttpContext context = createHttpContext(httpMethod, uri); + if (context == null) { + context = HttpClientContext.create(); + } + + // Request configuration not set in the context + if (context.getAttribute(HttpClientContext.REQUEST_CONFIG) == null) { + // Use request configuration given by the user, when available + RequestConfig config = null; + if (httpRequest instanceof Configurable) { + config = ((Configurable) httpRequest).getConfig(); + } + if (config == null) { + config = createRequestConfig(client); + } + if (config != null) { + context.setAttribute(HttpClientContext.REQUEST_CONFIG, config); + } + } + + if (this.bufferRequestBody) { + return new HttpComponentsClientHttpRequest(client, httpRequest, context); + } + else { + return new HttpComponentsStreamingClientHttpRequest(client, httpRequest, context); + } + } + + + /** + * Return a builder for modifying the factory-level {@link RequestConfig}. + * @since 4.2 + */ + private RequestConfig.Builder requestConfigBuilder() { + return (this.requestConfig != null ? RequestConfig.copy(this.requestConfig) : RequestConfig.custom()); + } + + /** + * Create a default {@link RequestConfig} to use with the given client. + * Can return {@code null} to indicate that no custom request config should + * be set and the defaults of the {@link HttpClient} should be used. + *

The default implementation tries to merge the defaults of the client + * with the local customizations of this factory instance, if any. + * @param client the {@link HttpClient} (or {@code HttpAsyncClient}) to check + * @return the actual RequestConfig to use (may be {@code null}) + * @since 4.2 + * @see #mergeRequestConfig(RequestConfig) + */ + @Nullable + protected RequestConfig createRequestConfig(Object client) { + if (client instanceof Configurable) { + RequestConfig clientRequestConfig = ((Configurable) client).getConfig(); + return mergeRequestConfig(clientRequestConfig); + } + return this.requestConfig; + } + + /** + * Merge the given {@link HttpClient}-level {@link RequestConfig} with + * the factory-level {@link RequestConfig}, if necessary. + * @param clientConfig the config held by the current + * @return the merged request config + * @since 4.2 + */ + protected RequestConfig mergeRequestConfig(RequestConfig clientConfig) { + if (this.requestConfig == null) { // nothing to merge + return clientConfig; + } + + RequestConfig.Builder builder = RequestConfig.copy(clientConfig); + int connectTimeout = this.requestConfig.getConnectTimeout(); + if (connectTimeout >= 0) { + builder.setConnectTimeout(connectTimeout); + } + int connectionRequestTimeout = this.requestConfig.getConnectionRequestTimeout(); + if (connectionRequestTimeout >= 0) { + builder.setConnectionRequestTimeout(connectionRequestTimeout); + } + int socketTimeout = this.requestConfig.getSocketTimeout(); + if (socketTimeout >= 0) { + builder.setSocketTimeout(socketTimeout); + } + return builder.build(); + } + + /** + * Create a Commons HttpMethodBase object for the given HTTP method and URI specification. + * @param httpMethod the HTTP method + * @param uri the URI + * @return the Commons HttpMethodBase object + */ + protected HttpUriRequest createHttpUriRequest(HttpMethod httpMethod, URI uri) { + switch (httpMethod) { + case GET: + return new HttpGet(uri); + case HEAD: + return new HttpHead(uri); + case POST: + return new HttpPost(uri); + case PUT: + return new HttpPut(uri); + case PATCH: + return new HttpPatch(uri); + case DELETE: + return new HttpDelete(uri); + case OPTIONS: + return new HttpOptions(uri); + case TRACE: + return new HttpTrace(uri); + default: + throw new IllegalArgumentException("Invalid HTTP method: " + httpMethod); + } + } + + /** + * Template method that allows for manipulating the {@link HttpUriRequest} before it is + * returned as part of a {@link HttpComponentsClientHttpRequest}. + *

The default implementation is empty. + * @param request the request to process + */ + protected void postProcessHttpRequest(HttpUriRequest request) { + } + + /** + * Template methods that creates a {@link HttpContext} for the given HTTP method and URI. + *

The default implementation returns {@code null}. + * @param httpMethod the HTTP method + * @param uri the URI + * @return the http context + */ + @Nullable + protected HttpContext createHttpContext(HttpMethod httpMethod, URI uri) { + return null; + } + + + /** + * Shutdown hook that closes the underlying + * {@link org.apache.http.conn.HttpClientConnectionManager ClientConnectionManager}'s + * connection pool, if any. + */ + @Override + public void destroy() throws Exception { + HttpClient httpClient = getHttpClient(); + if (httpClient instanceof Closeable) { + ((Closeable) httpClient).close(); + } + } + + + /** + * An alternative to {@link org.apache.http.client.methods.HttpDelete} that + * extends {@link org.apache.http.client.methods.HttpEntityEnclosingRequestBase} + * rather than {@link org.apache.http.client.methods.HttpRequestBase} and + * hence allows HTTP delete with a request body. For use with the RestTemplate + * exchange methods which allow the combination of HTTP DELETE with an entity. + * @since 4.1.2 + */ + private static class HttpDelete extends HttpEntityEnclosingRequestBase { + + public HttpDelete(URI uri) { + super(); + setURI(uri); + } + + @Override + public String getMethod() { + return "DELETE"; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..7d48ace787769123302dd56b60824324d5b68870 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpResponse.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; + +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; +import org.apache.http.util.EntityUtils; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * {@link ClientHttpResponse} implementation based on + * Apache HttpComponents HttpClient. + * + *

Created via the {@link HttpComponentsClientHttpRequest}. + * + * @author Oleg Kalnichevski + * @author Arjen Poutsma + * @since 3.1 + * @see HttpComponentsClientHttpRequest#execute() + */ +final class HttpComponentsClientHttpResponse extends AbstractClientHttpResponse { + + private final HttpResponse httpResponse; + + @Nullable + private HttpHeaders headers; + + + HttpComponentsClientHttpResponse(HttpResponse httpResponse) { + this.httpResponse = httpResponse; + } + + + @Override + public int getRawStatusCode() throws IOException { + return this.httpResponse.getStatusLine().getStatusCode(); + } + + @Override + public String getStatusText() throws IOException { + return this.httpResponse.getStatusLine().getReasonPhrase(); + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + for (Header header : this.httpResponse.getAllHeaders()) { + this.headers.add(header.getName(), header.getValue()); + } + } + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + HttpEntity entity = this.httpResponse.getEntity(); + return (entity != null ? entity.getContent() : StreamUtils.emptyInput()); + } + + @Override + public void close() { + // Release underlying connection back to the connection manager + try { + try { + // Attempt to keep connection alive by consuming its remaining content + EntityUtils.consume(this.httpResponse.getEntity()); + } + finally { + if (this.httpResponse instanceof Closeable) { + ((Closeable) this.httpResponse).close(); + } + } + } + catch (IOException ex) { + // Ignore exception on close... + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsStreamingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsStreamingClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..7b145e2a306f5c1f673221131d7311433b56081b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsStreamingClientHttpRequest.java @@ -0,0 +1,167 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; + +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpEntityEnclosingRequest; +import org.apache.http.HttpResponse; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.message.BasicHeader; +import org.apache.http.protocol.HttpContext; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; + +/** + * {@link ClientHttpRequest} implementation based on + * Apache HttpComponents HttpClient in streaming mode. + * + *

Created via the {@link HttpComponentsClientHttpRequestFactory}. + * + * @author Arjen Poutsma + * @since 4.0 + * @see HttpComponentsClientHttpRequestFactory#createRequest(java.net.URI, org.springframework.http.HttpMethod) + */ +final class HttpComponentsStreamingClientHttpRequest extends AbstractClientHttpRequest + implements StreamingHttpOutputMessage { + + private final HttpClient httpClient; + + private final HttpUriRequest httpRequest; + + private final HttpContext httpContext; + + @Nullable + private Body body; + + + HttpComponentsStreamingClientHttpRequest(HttpClient client, HttpUriRequest request, HttpContext context) { + this.httpClient = client; + this.httpRequest = request; + this.httpContext = context; + } + + + @Override + public String getMethodValue() { + return this.httpRequest.getMethod(); + } + + @Override + public URI getURI() { + return this.httpRequest.getURI(); + } + + @Override + public void setBody(Body body) { + assertNotExecuted(); + this.body = body; + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + throw new UnsupportedOperationException("getBody not supported"); + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException { + HttpComponentsClientHttpRequest.addHeaders(this.httpRequest, headers); + + if (this.httpRequest instanceof HttpEntityEnclosingRequest && this.body != null) { + HttpEntityEnclosingRequest entityEnclosingRequest = (HttpEntityEnclosingRequest) this.httpRequest; + HttpEntity requestEntity = new StreamingHttpEntity(getHeaders(), this.body); + entityEnclosingRequest.setEntity(requestEntity); + } + + HttpResponse httpResponse = this.httpClient.execute(this.httpRequest, this.httpContext); + return new HttpComponentsClientHttpResponse(httpResponse); + } + + + private static class StreamingHttpEntity implements HttpEntity { + + private final HttpHeaders headers; + + private final StreamingHttpOutputMessage.Body body; + + public StreamingHttpEntity(HttpHeaders headers, StreamingHttpOutputMessage.Body body) { + this.headers = headers; + this.body = body; + } + + @Override + public boolean isRepeatable() { + return false; + } + + @Override + public boolean isChunked() { + return false; + } + + @Override + public long getContentLength() { + return this.headers.getContentLength(); + } + + @Override + @Nullable + public Header getContentType() { + MediaType contentType = this.headers.getContentType(); + return (contentType != null ? new BasicHeader("Content-Type", contentType.toString()) : null); + } + + @Override + @Nullable + public Header getContentEncoding() { + String contentEncoding = this.headers.getFirst("Content-Encoding"); + return (contentEncoding != null ? new BasicHeader("Content-Encoding", contentEncoding) : null); + + } + + @Override + public InputStream getContent() throws IOException, IllegalStateException { + throw new IllegalStateException("No content available"); + } + + @Override + public void writeTo(OutputStream outputStream) throws IOException { + this.body.writeTo(outputStream); + } + + @Override + public boolean isStreaming() { + return true; + } + + @Override + @Deprecated + public void consumeContent() throws IOException { + throw new UnsupportedOperationException(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/InterceptingAsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/InterceptingAsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..4dd835f29404b6459c115832f5a13206866b5a6d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/InterceptingAsyncClientHttpRequest.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; +import java.util.Iterator; +import java.util.List; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * An {@link AsyncClientHttpRequest} wrapper that enriches it proceeds the actual + * request execution with calling the registered interceptors. + * + * @author Jakub Narloch + * @author Rossen Stoyanchev + * @see InterceptingAsyncClientHttpRequestFactory + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +class InterceptingAsyncClientHttpRequest extends AbstractBufferingAsyncClientHttpRequest { + + private AsyncClientHttpRequestFactory requestFactory; + + private List interceptors; + + private URI uri; + + private HttpMethod httpMethod; + + + /** + * Create new instance of {@link InterceptingAsyncClientHttpRequest}. + * @param requestFactory the async request factory + * @param interceptors the list of interceptors + * @param uri the request URI + * @param httpMethod the HTTP method + */ + public InterceptingAsyncClientHttpRequest(AsyncClientHttpRequestFactory requestFactory, + List interceptors, URI uri, HttpMethod httpMethod) { + + this.requestFactory = requestFactory; + this.interceptors = interceptors; + this.uri = uri; + this.httpMethod = httpMethod; + } + + + @Override + protected ListenableFuture executeInternal(HttpHeaders headers, byte[] body) + throws IOException { + + return new AsyncRequestExecution().executeAsync(this, body); + } + + @Override + public HttpMethod getMethod() { + return this.httpMethod; + } + + @Override + public String getMethodValue() { + return this.httpMethod.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + + private class AsyncRequestExecution implements AsyncClientHttpRequestExecution { + + private Iterator iterator; + + public AsyncRequestExecution() { + this.iterator = interceptors.iterator(); + } + + @Override + public ListenableFuture executeAsync(HttpRequest request, byte[] body) + throws IOException { + + if (this.iterator.hasNext()) { + AsyncClientHttpRequestInterceptor interceptor = this.iterator.next(); + return interceptor.intercept(request, body, this); + } + else { + URI uri = request.getURI(); + HttpMethod method = request.getMethod(); + HttpHeaders headers = request.getHeaders(); + + Assert.state(method != null, "No standard HTTP method"); + AsyncClientHttpRequest delegate = requestFactory.createAsyncRequest(uri, method); + delegate.getHeaders().putAll(headers); + if (body.length > 0) { + StreamUtils.copy(body, delegate.getBody()); + } + + return delegate.executeAsync(); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/InterceptingAsyncClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/InterceptingAsyncClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..9721e46cc2cce12e40862528d9560c8ba8580ee0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/InterceptingAsyncClientHttpRequestFactory.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.URI; +import java.util.Collections; +import java.util.List; + +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; + +/** + * Wrapper for a {@link AsyncClientHttpRequestFactory} that has support for + * {@link AsyncClientHttpRequestInterceptor AsyncClientHttpRequestInterceptors}. + * + * @author Jakub Narloch + * @since 4.3 + * @see InterceptingAsyncClientHttpRequest + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +public class InterceptingAsyncClientHttpRequestFactory implements AsyncClientHttpRequestFactory { + + private AsyncClientHttpRequestFactory delegate; + + private List interceptors; + + + /** + * Create new instance of {@link InterceptingAsyncClientHttpRequestFactory} + * with delegated request factory and list of interceptors. + * @param delegate the request factory to delegate to + * @param interceptors the list of interceptors to use + */ + public InterceptingAsyncClientHttpRequestFactory(AsyncClientHttpRequestFactory delegate, + @Nullable List interceptors) { + + this.delegate = delegate; + this.interceptors = (interceptors != null ? interceptors : Collections.emptyList()); + } + + + @Override + public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod method) { + return new InterceptingAsyncClientHttpRequest(this.delegate, this.interceptors, uri, method); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..7d44e41f766c030e3c50ea2d7c878ad14d1adf25 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; +import java.util.Iterator; +import java.util.List; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; + +/** + * Wrapper for a {@link ClientHttpRequest} that has support for {@link ClientHttpRequestInterceptor ClientHttpRequest} that has support for {@link ClientHttpRequestInterceptors}. + * + * @author Arjen Poutsma + * @since 3.1 + */ +class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final ClientHttpRequestFactory requestFactory; + + private final List interceptors; + + private HttpMethod method; + + private URI uri; + + + protected InterceptingClientHttpRequest(ClientHttpRequestFactory requestFactory, + List interceptors, URI uri, HttpMethod method) { + + this.requestFactory = requestFactory; + this.interceptors = interceptors; + this.method = method; + this.uri = uri; + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + protected final ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + InterceptingRequestExecution requestExecution = new InterceptingRequestExecution(); + return requestExecution.execute(this, bufferedOutput); + } + + + private class InterceptingRequestExecution implements ClientHttpRequestExecution { + + private final Iterator iterator; + + public InterceptingRequestExecution() { + this.iterator = interceptors.iterator(); + } + + @Override + public ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException { + if (this.iterator.hasNext()) { + ClientHttpRequestInterceptor nextInterceptor = this.iterator.next(); + return nextInterceptor.intercept(request, body, this); + } + else { + HttpMethod method = request.getMethod(); + Assert.state(method != null, "No standard HTTP method"); + ClientHttpRequest delegate = requestFactory.createRequest(request.getURI(), method); + request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value)); + if (body.length > 0) { + if (delegate instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) delegate; + streamingOutputMessage.setBody(outputStream -> StreamUtils.copy(body, outputStream)); + } + else { + StreamUtils.copy(body, delegate.getBody()); + } + } + return delegate.execute(); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..d5e26381c0b0ce85a1e8c74a32a519ae0dffd9e9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequestFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.URI; +import java.util.Collections; +import java.util.List; + +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; + +/** + * {@link ClientHttpRequestFactory} wrapper with support for + * {@link ClientHttpRequestInterceptor ClientHttpRequestInterceptors}. + * + * @author Arjen Poutsma + * @since 3.1 + * @see ClientHttpRequestFactory + * @see ClientHttpRequestInterceptor + */ +public class InterceptingClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper { + + private final List interceptors; + + + /** + * Create a new instance of the {@code InterceptingClientHttpRequestFactory} with the given parameters. + * @param requestFactory the request factory to wrap + * @param interceptors the interceptors that are to be applied (can be {@code null}) + */ + public InterceptingClientHttpRequestFactory(ClientHttpRequestFactory requestFactory, + @Nullable List interceptors) { + + super(requestFactory); + this.interceptors = (interceptors != null ? interceptors : Collections.emptyList()); + } + + + @Override + protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) { + return new InterceptingClientHttpRequest(requestFactory, this.interceptors, uri, httpMethod); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..a3f40f4349a4671bdd5ffd61189b5f676134a926 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java @@ -0,0 +1,326 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.reactivestreams.Publisher; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.core.ResolvableTypeProvider; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Builder for the body of a multipart request, producing + * {@code MultiValueMap}, which can be provided to the + * {@code WebClient} through the {@code syncBody} method. + * + * Examples: + *

+ *
+ * // Add form field
+ * MultipartBodyBuilder builder = new MultipartBodyBuilder();
+ * builder.part("form field", "form value").header("foo", "bar");
+ *
+ * // Add file part
+ * Resource image = new ClassPathResource("image.jpg");
+ * builder.part("image", image).header("foo", "bar");
+ *
+ * // Add content (e.g. JSON)
+ * Account account = ...
+ * builder.part("account", account).header("foo", "bar");
+ *
+ * // Add content from Publisher
+ * Mono<Account> accountMono = ...
+ * builder.asyncPart("account", accountMono).header("foo", "bar");
+ *
+ * // Build and use
+ * MultiValueMap<String, HttpEntity<?>> multipartBody = builder.build();
+ *
+ * Mono<Void> result = webClient.post()
+ *     .uri("...")
+ *     .syncBody(multipartBody)
+ *     .retrieve()
+ *     .bodyToMono(Void.class)
+ * 
+ * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 5.0.2 + * @see RFC 7578 + */ +public final class MultipartBodyBuilder { + + private final LinkedMultiValueMap parts = new LinkedMultiValueMap<>(); + + + /** + * Creates a new, empty instance of the {@code MultipartBodyBuilder}. + */ + public MultipartBodyBuilder() { + } + + + /** + * Add a part where the Object may be: + *
    + *
  • String -- form field + *
  • {@link org.springframework.core.io.Resource Resource} -- file part + *
  • Object -- content to be encoded (e.g. to JSON) + *
  • HttpEntity -- part content and headers although generally it's + * easier to add headers through the returned builder
  • + *
+ * @param name the name of the part to add + * @param part the part data + * @return builder that allows for further customization of part headers + */ + public PartBuilder part(String name, Object part) { + return part(name, part, null); + } + + /** + * Variant of {@link #part(String, Object)} that also accepts a MediaType. + * @param name the name of the part to add + * @param part the part data + * @param contentType the media type to help with encoding the part + * @return builder that allows for further customization of part headers + */ + public PartBuilder part(String name, Object part, @Nullable MediaType contentType) { + Assert.hasLength(name, "'name' must not be empty"); + Assert.notNull(part, "'part' must not be null"); + + if (part instanceof PublisherEntity) { + PublisherPartBuilder builder = new PublisherPartBuilder<>((PublisherEntity) part); + if (contentType != null) { + builder.header(HttpHeaders.CONTENT_TYPE, contentType.toString()); + } + this.parts.add(name, builder); + return builder; + } + + Object partBody; + HttpHeaders partHeaders = null; + if (part instanceof HttpEntity) { + partBody = ((HttpEntity) part).getBody(); + partHeaders = new HttpHeaders(); + partHeaders.putAll(((HttpEntity) part).getHeaders()); + } + else { + partBody = part; + } + + if (partBody instanceof Publisher) { + throw new IllegalArgumentException( + "Use asyncPart(String, Publisher, Class)" + + " or asyncPart(String, Publisher, ParameterizedTypeReference) or" + + " or MultipartBodyBuilder.PublisherEntity"); + } + + DefaultPartBuilder builder = new DefaultPartBuilder(partHeaders, partBody); + if (contentType != null) { + builder.header(HttpHeaders.CONTENT_TYPE, contentType.toString()); + } + this.parts.add(name, builder); + return builder; + } + + /** + * Add a part from {@link Publisher} content. + * @param name the name of the part to add + * @param publisher the part contents + * @param elementClass the type of elements contained in the publisher + * @return builder that allows for further customization of part headers + */ + public > PartBuilder asyncPart(String name, P publisher, Class elementClass) { + Assert.hasLength(name, "'name' must not be empty"); + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(elementClass, "'elementClass' must not be null"); + + PublisherPartBuilder builder = new PublisherPartBuilder<>(null, publisher, elementClass); + this.parts.add(name, builder); + return builder; + + } + + /** + * Variant of {@link #asyncPart(String, Publisher, Class)} with a + * {@link ParameterizedTypeReference} for the element type information. + * @param name the name of the part to add + * @param publisher the part contents + * @param typeReference the type of elements contained in the publisher + * @return builder that allows for further customization of part headers + */ + public > PartBuilder asyncPart( + String name, P publisher, ParameterizedTypeReference typeReference) { + + Assert.hasLength(name, "'name' must not be empty"); + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(typeReference, "'typeReference' must not be null"); + + PublisherPartBuilder builder = new PublisherPartBuilder<>(null, publisher, typeReference); + this.parts.add(name, builder); + return builder; + } + + /** + * Return a {@code MultiValueMap} with the configured parts. + */ + public MultiValueMap> build() { + MultiValueMap> result = new LinkedMultiValueMap<>(this.parts.size()); + for (Map.Entry> entry : this.parts.entrySet()) { + for (DefaultPartBuilder builder : entry.getValue()) { + HttpEntity entity = builder.build(); + result.add(entry.getKey(), entity); + } + } + return result; + } + + + /** + * Builder that allows for further customization of part headers. + */ + public interface PartBuilder { + + /** + * Add part header values. + * @param headerName the part header name + * @param headerValues the part header value(s) + * @return this builder + * @see HttpHeaders#addAll(String, List) + */ + PartBuilder header(String headerName, String... headerValues); + + /** + * Manipulate the part headers through the given consumer. + * @param headersConsumer consumer to manipulate the part headers with + * @return this builder + */ + PartBuilder headers(Consumer headersConsumer); + } + + + private static class DefaultPartBuilder implements PartBuilder { + + @Nullable + protected HttpHeaders headers; + + @Nullable + protected final Object body; + + public DefaultPartBuilder(@Nullable HttpHeaders headers, @Nullable Object body) { + this.headers = headers; + this.body = body; + } + + @Override + public PartBuilder header(String headerName, String... headerValues) { + initHeadersIfNecessary().addAll(headerName, Arrays.asList(headerValues)); + return this; + } + + @Override + public PartBuilder headers(Consumer headersConsumer) { + headersConsumer.accept(initHeadersIfNecessary()); + return this; + } + + private HttpHeaders initHeadersIfNecessary() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + } + return this.headers; + } + + public HttpEntity build() { + return new HttpEntity<>(this.body, this.headers); + } + } + + + private static class PublisherPartBuilder> extends DefaultPartBuilder { + + private final ResolvableType resolvableType; + + public PublisherPartBuilder(@Nullable HttpHeaders headers, P body, Class elementClass) { + super(headers, body); + this.resolvableType = ResolvableType.forClass(elementClass); + } + + public PublisherPartBuilder(@Nullable HttpHeaders headers, P body, ParameterizedTypeReference typeRef) { + super(headers, body); + this.resolvableType = ResolvableType.forType(typeRef); + } + + public PublisherPartBuilder(PublisherEntity other) { + super(other.getHeaders(), other.getBody()); + this.resolvableType = other.getResolvableType(); + } + + @Override + @SuppressWarnings("unchecked") + public HttpEntity build() { + P publisher = (P) this.body; + Assert.state(publisher != null, "Publisher must not be null"); + return new PublisherEntity<>(this.headers, publisher, this.resolvableType); + } + } + + + /** + * Specialization of {@link HttpEntity} for use with a + * {@link Publisher}-based body, for which we also need to keep track of + * the element type. + * @param the type contained in the publisher + * @param

the publisher + */ + static final class PublisherEntity> extends HttpEntity

+ implements ResolvableTypeProvider { + + private final ResolvableType resolvableType; + + PublisherEntity( + @Nullable MultiValueMap headers, P publisher, ResolvableType resolvableType) { + + super(publisher, headers); + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(resolvableType, "'resolvableType' must not be null"); + this.resolvableType = resolvableType; + } + + /** + * Return the element type for the {@code Publisher} body. + */ + @Override + @NonNull + public ResolvableType getResolvableType() { + return this.resolvableType; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..c4c7d6f37d40422f06d22a0485d297b1e894a22b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequest.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.util.concurrent.ExecutionException; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpVersion; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; + +/** + * {@link ClientHttpRequest} implementation based on Netty 4. + * + *

Created via the {@link Netty4ClientHttpRequestFactory}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 4.1.2 + * @deprecated as of Spring 5.0, in favor of + * {@link org.springframework.http.client.reactive.ReactorClientHttpConnector} + */ +@Deprecated +class Netty4ClientHttpRequest extends AbstractAsyncClientHttpRequest implements ClientHttpRequest { + + private final Bootstrap bootstrap; + + private final URI uri; + + private final HttpMethod method; + + private final ByteBufOutputStream body; + + + public Netty4ClientHttpRequest(Bootstrap bootstrap, URI uri, HttpMethod method) { + this.bootstrap = bootstrap; + this.uri = uri; + this.method = method; + this.body = new ByteBufOutputStream(Unpooled.buffer(1024)); + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public ClientHttpResponse execute() throws IOException { + try { + return executeAsync().get(); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted during request execution", ex); + } + catch (ExecutionException ex) { + if (ex.getCause() instanceof IOException) { + throw (IOException) ex.getCause(); + } + else { + throw new IOException(ex.getMessage(), ex.getCause()); + } + } + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + return this.body; + } + + @Override + protected ListenableFuture executeInternal(final HttpHeaders headers) throws IOException { + final SettableListenableFuture responseFuture = new SettableListenableFuture<>(); + + ChannelFutureListener connectionListener = future -> { + if (future.isSuccess()) { + Channel channel = future.channel(); + channel.pipeline().addLast(new RequestExecuteHandler(responseFuture)); + FullHttpRequest nettyRequest = createFullHttpRequest(headers); + channel.writeAndFlush(nettyRequest); + } + else { + responseFuture.setException(future.cause()); + } + }; + + this.bootstrap.connect(this.uri.getHost(), getPort(this.uri)).addListener(connectionListener); + return responseFuture; + } + + private FullHttpRequest createFullHttpRequest(HttpHeaders headers) { + io.netty.handler.codec.http.HttpMethod nettyMethod = + io.netty.handler.codec.http.HttpMethod.valueOf(this.method.name()); + + String authority = this.uri.getRawAuthority(); + String path = this.uri.toString().substring(this.uri.toString().indexOf(authority) + authority.length()); + FullHttpRequest nettyRequest = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, nettyMethod, path, this.body.buffer()); + + nettyRequest.headers().set(HttpHeaders.HOST, this.uri.getHost() + ":" + getPort(this.uri)); + nettyRequest.headers().set(HttpHeaders.CONNECTION, "close"); + headers.forEach((headerName, headerValues) -> nettyRequest.headers().add(headerName, headerValues)); + if (!nettyRequest.headers().contains(HttpHeaders.CONTENT_LENGTH) && this.body.buffer().readableBytes() > 0) { + nettyRequest.headers().set(HttpHeaders.CONTENT_LENGTH, this.body.buffer().readableBytes()); + } + + return nettyRequest; + } + + private static int getPort(URI uri) { + int port = uri.getPort(); + if (port == -1) { + if ("http".equalsIgnoreCase(uri.getScheme())) { + port = 80; + } + else if ("https".equalsIgnoreCase(uri.getScheme())) { + port = 443; + } + } + return port; + } + + + /** + * A SimpleChannelInboundHandler to update the given SettableListenableFuture. + */ + private static class RequestExecuteHandler extends SimpleChannelInboundHandler { + + private final SettableListenableFuture responseFuture; + + public RequestExecuteHandler(SettableListenableFuture responseFuture) { + this.responseFuture = responseFuture; + } + + @Override + protected void channelRead0(ChannelHandlerContext context, FullHttpResponse response) throws Exception { + this.responseFuture.set(new Netty4ClientHttpResponse(context, response)); + } + + @Override + public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception { + this.responseFuture.setException(cause); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..bc0c83472fb983cf353f04af83ae1da316d3e5ac --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequestFactory.java @@ -0,0 +1,243 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLException; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.SocketChannelConfig; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.timeout.ReadTimeoutHandler; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link org.springframework.http.client.ClientHttpRequestFactory} implementation + * that uses Netty 4 to create requests. + * + *

Allows to use a pre-configured {@link EventLoopGroup} instance: useful for + * sharing across multiple clients. + * + *

Note that this implementation consistently closes the HTTP connection on each + * request. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Brian Clozel + * @author Mark Paluch + * @since 4.1.2 + * @deprecated as of Spring 5.0, in favor of + * {@link org.springframework.http.client.reactive.ReactorClientHttpConnector} + */ +@Deprecated +public class Netty4ClientHttpRequestFactory implements ClientHttpRequestFactory, + AsyncClientHttpRequestFactory, InitializingBean, DisposableBean { + + /** + * The default maximum response size. + * @see #setMaxResponseSize(int) + */ + public static final int DEFAULT_MAX_RESPONSE_SIZE = 1024 * 1024 * 10; + + + private final EventLoopGroup eventLoopGroup; + + private final boolean defaultEventLoopGroup; + + private int maxResponseSize = DEFAULT_MAX_RESPONSE_SIZE; + + @Nullable + private SslContext sslContext; + + private int connectTimeout = -1; + + private int readTimeout = -1; + + @Nullable + private volatile Bootstrap bootstrap; + + + /** + * Create a new {@code Netty4ClientHttpRequestFactory} with a default + * {@link NioEventLoopGroup}. + */ + public Netty4ClientHttpRequestFactory() { + int ioWorkerCount = Runtime.getRuntime().availableProcessors() * 2; + this.eventLoopGroup = new NioEventLoopGroup(ioWorkerCount); + this.defaultEventLoopGroup = true; + } + + /** + * Create a new {@code Netty4ClientHttpRequestFactory} with the given + * {@link EventLoopGroup}. + *

NOTE: the given group will not be + * {@linkplain EventLoopGroup#shutdownGracefully() shutdown} by this factory; + * doing so becomes the responsibility of the caller. + */ + public Netty4ClientHttpRequestFactory(EventLoopGroup eventLoopGroup) { + Assert.notNull(eventLoopGroup, "EventLoopGroup must not be null"); + this.eventLoopGroup = eventLoopGroup; + this.defaultEventLoopGroup = false; + } + + + /** + * Set the default maximum response size. + *

By default this is set to {@link #DEFAULT_MAX_RESPONSE_SIZE}. + * @since 4.1.5 + * @see HttpObjectAggregator#HttpObjectAggregator(int) + */ + public void setMaxResponseSize(int maxResponseSize) { + this.maxResponseSize = maxResponseSize; + } + + /** + * Set the SSL context. When configured it is used to create and insert an + * {@link io.netty.handler.ssl.SslHandler} in the channel pipeline. + *

A default client SslContext is configured if none has been provided. + */ + public void setSslContext(SslContext sslContext) { + this.sslContext = sslContext; + } + + /** + * Set the underlying connect timeout (in milliseconds). + * A timeout value of 0 specifies an infinite timeout. + * @see ChannelConfig#setConnectTimeoutMillis(int) + */ + public void setConnectTimeout(int connectTimeout) { + this.connectTimeout = connectTimeout; + } + + /** + * Set the underlying URLConnection's read timeout (in milliseconds). + * A timeout value of 0 specifies an infinite timeout. + * @see ReadTimeoutHandler + */ + public void setReadTimeout(int readTimeout) { + this.readTimeout = readTimeout; + } + + + @Override + public void afterPropertiesSet() { + if (this.sslContext == null) { + this.sslContext = getDefaultClientSslContext(); + } + } + + private SslContext getDefaultClientSslContext() { + try { + return SslContextBuilder.forClient().build(); + } + catch (SSLException ex) { + throw new IllegalStateException("Could not create default client SslContext", ex); + } + } + + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + return createRequestInternal(uri, httpMethod); + } + + @Override + public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) throws IOException { + return createRequestInternal(uri, httpMethod); + } + + private Netty4ClientHttpRequest createRequestInternal(URI uri, HttpMethod httpMethod) { + return new Netty4ClientHttpRequest(getBootstrap(uri), uri, httpMethod); + } + + private Bootstrap getBootstrap(URI uri) { + boolean isSecure = (uri.getPort() == 443 || "https".equalsIgnoreCase(uri.getScheme())); + if (isSecure) { + return buildBootstrap(uri, true); + } + else { + Bootstrap bootstrap = this.bootstrap; + if (bootstrap == null) { + bootstrap = buildBootstrap(uri, false); + this.bootstrap = bootstrap; + } + return bootstrap; + } + } + + private Bootstrap buildBootstrap(URI uri, boolean isSecure) { + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(this.eventLoopGroup).channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel channel) throws Exception { + configureChannel(channel.config()); + ChannelPipeline pipeline = channel.pipeline(); + if (isSecure) { + Assert.notNull(sslContext, "sslContext should not be null"); + pipeline.addLast(sslContext.newHandler(channel.alloc(), uri.getHost(), uri.getPort())); + } + pipeline.addLast(new HttpClientCodec()); + pipeline.addLast(new HttpObjectAggregator(maxResponseSize)); + if (readTimeout > 0) { + pipeline.addLast(new ReadTimeoutHandler(readTimeout, + TimeUnit.MILLISECONDS)); + } + } + }); + return bootstrap; + } + + /** + * Template method for changing properties on the given {@link SocketChannelConfig}. + *

The default implementation sets the connect timeout based on the set property. + * @param config the channel configuration + */ + protected void configureChannel(SocketChannelConfig config) { + if (this.connectTimeout >= 0) { + config.setConnectTimeoutMillis(this.connectTimeout); + } + } + + + @Override + public void destroy() throws InterruptedException { + if (this.defaultEventLoopGroup) { + // Clean up the EventLoopGroup if we created it in the constructor + this.eventLoopGroup.shutdownGracefully().sync(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..6ac0de0a19e6ed2c3269f26d1a84112da60c6ba2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpResponse.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; + +import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpResponse; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link ClientHttpResponse} implementation based on Netty 4. + * + * @author Arjen Poutsma + * @since 4.1.2 + * @deprecated as of Spring 5.0, in favor of + * {@link org.springframework.http.client.reactive.ReactorClientHttpConnector} + */ +@Deprecated +class Netty4ClientHttpResponse extends AbstractClientHttpResponse { + + private final ChannelHandlerContext context; + + private final FullHttpResponse nettyResponse; + + private final ByteBufInputStream body; + + @Nullable + private volatile HttpHeaders headers; + + + public Netty4ClientHttpResponse(ChannelHandlerContext context, FullHttpResponse nettyResponse) { + Assert.notNull(context, "ChannelHandlerContext must not be null"); + Assert.notNull(nettyResponse, "FullHttpResponse must not be null"); + this.context = context; + this.nettyResponse = nettyResponse; + this.body = new ByteBufInputStream(this.nettyResponse.content()); + this.nettyResponse.retain(); + } + + + @Override + public int getRawStatusCode() throws IOException { + return this.nettyResponse.getStatus().code(); + } + + @Override + public String getStatusText() throws IOException { + return this.nettyResponse.getStatus().reasonPhrase(); + } + + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = this.headers; + if (headers == null) { + headers = new HttpHeaders(); + for (Map.Entry entry : this.nettyResponse.headers()) { + headers.add(entry.getKey(), entry.getValue()); + } + this.headers = headers; + } + return headers; + } + + @Override + public InputStream getBody() throws IOException { + return this.body; + } + + @Override + public void close() { + this.nettyResponse.release(); + this.context.close(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/OkHttp3AsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/OkHttp3AsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..ee8d2eb233948a5c1e14b6a59dee9b96a9f17072 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/OkHttp3AsyncClientHttpRequest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; + +/** + * {@link AsyncClientHttpRequest} implementation based on OkHttp 3.x. + * + *

Created via the {@link OkHttp3ClientHttpRequestFactory}. + * + * @author Luciano Leggieri + * @author Arjen Poutsma + * @author Roy Clarkson + * @since 4.3 + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +class OkHttp3AsyncClientHttpRequest extends AbstractBufferingAsyncClientHttpRequest { + + private final OkHttpClient client; + + private final URI uri; + + private final HttpMethod method; + + + public OkHttp3AsyncClientHttpRequest(OkHttpClient client, URI uri, HttpMethod method) { + this.client = client; + this.uri = uri; + this.method = method; + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + protected ListenableFuture executeInternal(HttpHeaders headers, byte[] content) + throws IOException { + + Request request = OkHttp3ClientHttpRequestFactory.buildRequest(headers, content, this.uri, this.method); + return new OkHttpListenableFuture(this.client.newCall(request)); + } + + + private static class OkHttpListenableFuture extends SettableListenableFuture { + + private final Call call; + + public OkHttpListenableFuture(Call call) { + this.call = call; + this.call.enqueue(new Callback() { + @Override + public void onResponse(Call call, Response response) { + set(new OkHttp3ClientHttpResponse(response)); + } + @Override + public void onFailure(Call call, IOException ex) { + setException(ex); + } + }); + } + + @Override + protected void interruptTask() { + this.call.cancel(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..74ff955332cd7986dc9ed02d7ace37df28420704 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpRequest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; + +import okhttp3.OkHttpClient; +import okhttp3.Request; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; + +/** + * {@link ClientHttpRequest} implementation based on OkHttp 3.x. + * + *

Created via the {@link OkHttp3ClientHttpRequestFactory}. + * + * @author Luciano Leggieri + * @author Arjen Poutsma + * @author Roy Clarkson + * @since 4.3 + */ +class OkHttp3ClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final OkHttpClient client; + + private final URI uri; + + private final HttpMethod method; + + + public OkHttp3ClientHttpRequest(OkHttpClient client, URI uri, HttpMethod method) { + this.client = client; + this.uri = uri; + this.method = method; + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] content) throws IOException { + Request request = OkHttp3ClientHttpRequestFactory.buildRequest(headers, content, this.uri, this.method); + return new OkHttp3ClientHttpResponse(this.client.newCall(request).execute()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..a0ca57711e1cb243cca02be8091926ba193f7064 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpRequestFactory.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import okhttp3.Cache; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * {@link ClientHttpRequestFactory} implementation that uses + * OkHttp 3.x to create requests. + * + * @author Luciano Leggieri + * @author Arjen Poutsma + * @author Roy Clarkson + * @since 4.3 + */ +@SuppressWarnings("deprecation") +public class OkHttp3ClientHttpRequestFactory + implements ClientHttpRequestFactory, AsyncClientHttpRequestFactory, DisposableBean { + + private OkHttpClient client; + + private final boolean defaultClient; + + + /** + * Create a factory with a default {@link OkHttpClient} instance. + */ + public OkHttp3ClientHttpRequestFactory() { + this.client = new OkHttpClient(); + this.defaultClient = true; + } + + /** + * Create a factory with the given {@link OkHttpClient} instance. + * @param client the client to use + */ + public OkHttp3ClientHttpRequestFactory(OkHttpClient client) { + Assert.notNull(client, "OkHttpClient must not be null"); + this.client = client; + this.defaultClient = false; + } + + + /** + * Set the underlying read timeout in milliseconds. + * A value of 0 specifies an infinite timeout. + */ + public void setReadTimeout(int readTimeout) { + this.client = this.client.newBuilder() + .readTimeout(readTimeout, TimeUnit.MILLISECONDS) + .build(); + } + + /** + * Set the underlying write timeout in milliseconds. + * A value of 0 specifies an infinite timeout. + */ + public void setWriteTimeout(int writeTimeout) { + this.client = this.client.newBuilder() + .writeTimeout(writeTimeout, TimeUnit.MILLISECONDS) + .build(); + } + + /** + * Set the underlying connect timeout in milliseconds. + * A value of 0 specifies an infinite timeout. + */ + public void setConnectTimeout(int connectTimeout) { + this.client = this.client.newBuilder() + .connectTimeout(connectTimeout, TimeUnit.MILLISECONDS) + .build(); + } + + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) { + return new OkHttp3ClientHttpRequest(this.client, uri, httpMethod); + } + + @Override + public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) { + return new OkHttp3AsyncClientHttpRequest(this.client, uri, httpMethod); + } + + + @Override + public void destroy() throws IOException { + if (this.defaultClient) { + // Clean up the client if we created it in the constructor + Cache cache = this.client.cache(); + if (cache != null) { + cache.close(); + } + this.client.dispatcher().executorService().shutdown(); + } + } + + + static Request buildRequest(HttpHeaders headers, byte[] content, URI uri, HttpMethod method) + throws MalformedURLException { + + okhttp3.MediaType contentType = getContentType(headers); + RequestBody body = (content.length > 0 || + okhttp3.internal.http.HttpMethod.requiresRequestBody(method.name()) ? + RequestBody.create(contentType, content) : null); + + Request.Builder builder = new Request.Builder().url(uri.toURL()).method(method.name(), body); + headers.forEach((headerName, headerValues) -> { + for (String headerValue : headerValues) { + builder.addHeader(headerName, headerValue); + } + }); + return builder.build(); + } + + @Nullable + private static okhttp3.MediaType getContentType(HttpHeaders headers) { + String rawContentType = headers.getFirst(HttpHeaders.CONTENT_TYPE); + return (StringUtils.hasText(rawContentType) ? okhttp3.MediaType.parse(rawContentType) : null); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..c266454225a96f10aa7fbcfb34b28875034476b8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/OkHttp3ClientHttpResponse.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.InputStream; + +import okhttp3.Response; +import okhttp3.ResponseBody; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; + +/** + * {@link ClientHttpResponse} implementation based on OkHttp 3.x. + * + * @author Luciano Leggieri + * @author Arjen Poutsma + * @author Roy Clarkson + * @since 4.3 + */ +class OkHttp3ClientHttpResponse extends AbstractClientHttpResponse { + + private final Response response; + + @Nullable + private volatile HttpHeaders headers; + + + public OkHttp3ClientHttpResponse(Response response) { + Assert.notNull(response, "Response must not be null"); + this.response = response; + } + + + @Override + public int getRawStatusCode() { + return this.response.code(); + } + + @Override + public String getStatusText() { + return this.response.message(); + } + + @Override + public InputStream getBody() throws IOException { + ResponseBody body = this.response.body(); + return (body != null ? body.byteStream() : StreamUtils.emptyInput()); + } + + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = this.headers; + if (headers == null) { + headers = new HttpHeaders(); + for (String headerName : this.response.headers().names()) { + for (String headerValue : this.response.headers(headerName)) { + headers.add(headerName, headerValue); + } + } + this.headers = headers; + } + return headers; + } + + @Override + public void close() { + ResponseBody body = this.response.body(); + if (body != null) { + body.close(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingAsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingAsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..2be5c170df026fedbb2b3cc9e0927680adb80c31 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingAsyncClientHttpRequest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.concurrent.Callable; + +import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * {@link org.springframework.http.client.ClientHttpRequest} implementation that uses + * standard JDK facilities to execute buffered requests. Created via the + * {@link org.springframework.http.client.SimpleClientHttpRequestFactory}. + * + * @author Arjen Poutsma + * @since 3.0 + * @see org.springframework.http.client.SimpleClientHttpRequestFactory#createRequest + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +final class SimpleBufferingAsyncClientHttpRequest extends AbstractBufferingAsyncClientHttpRequest { + + private final HttpURLConnection connection; + + private final boolean outputStreaming; + + private final AsyncListenableTaskExecutor taskExecutor; + + + SimpleBufferingAsyncClientHttpRequest(HttpURLConnection connection, + boolean outputStreaming, AsyncListenableTaskExecutor taskExecutor) { + + this.connection = connection; + this.outputStreaming = outputStreaming; + this.taskExecutor = taskExecutor; + } + + + @Override + public String getMethodValue() { + return this.connection.getRequestMethod(); + } + + @Override + public URI getURI() { + try { + return this.connection.getURL().toURI(); + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Could not get HttpURLConnection URI: " + ex.getMessage(), ex); + } + } + + @Override + protected ListenableFuture executeInternal( + final HttpHeaders headers, final byte[] bufferedOutput) throws IOException { + + return this.taskExecutor.submitListenable(new Callable() { + @Override + public ClientHttpResponse call() throws Exception { + SimpleBufferingClientHttpRequest.addHeaders(connection, headers); + // JDK <1.8 doesn't support getOutputStream with HTTP DELETE + if (getMethod() == HttpMethod.DELETE && bufferedOutput.length == 0) { + connection.setDoOutput(false); + } + if (connection.getDoOutput() && outputStreaming) { + connection.setFixedLengthStreamingMode(bufferedOutput.length); + } + connection.connect(); + if (connection.getDoOutput()) { + FileCopyUtils.copy(bufferedOutput, connection.getOutputStream()); + } + else { + // Immediately trigger the request in a no-output scenario as well + connection.getResponseCode(); + } + return new SimpleClientHttpResponse(connection); + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..52c3eb60e81eac0039f8613a85fdc28613647b14 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.StringUtils; + +/** + * {@link ClientHttpRequest} implementation that uses standard JDK facilities to + * execute buffered requests. Created via the {@link SimpleClientHttpRequestFactory}. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @see SimpleClientHttpRequestFactory#createRequest(java.net.URI, HttpMethod) + */ +final class SimpleBufferingClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final HttpURLConnection connection; + + private final boolean outputStreaming; + + + SimpleBufferingClientHttpRequest(HttpURLConnection connection, boolean outputStreaming) { + this.connection = connection; + this.outputStreaming = outputStreaming; + } + + + @Override + public String getMethodValue() { + return this.connection.getRequestMethod(); + } + + @Override + public URI getURI() { + try { + return this.connection.getURL().toURI(); + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Could not get HttpURLConnection URI: " + ex.getMessage(), ex); + } + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + addHeaders(this.connection, headers); + // JDK <1.8 doesn't support getOutputStream with HTTP DELETE + if (getMethod() == HttpMethod.DELETE && bufferedOutput.length == 0) { + this.connection.setDoOutput(false); + } + if (this.connection.getDoOutput() && this.outputStreaming) { + this.connection.setFixedLengthStreamingMode(bufferedOutput.length); + } + this.connection.connect(); + if (this.connection.getDoOutput()) { + FileCopyUtils.copy(bufferedOutput, this.connection.getOutputStream()); + } + else { + // Immediately trigger the request in a no-output scenario as well + this.connection.getResponseCode(); + } + return new SimpleClientHttpResponse(this.connection); + } + + + /** + * Add the given headers to the given HTTP connection. + * @param connection the connection to add the headers to + * @param headers the headers to add + */ + static void addHeaders(HttpURLConnection connection, HttpHeaders headers) { + headers.forEach((headerName, headerValues) -> { + if (HttpHeaders.COOKIE.equalsIgnoreCase(headerName)) { // RFC 6265 + String headerValue = StringUtils.collectionToDelimitedString(headerValues, "; "); + connection.setRequestProperty(headerName, headerValue); + } + else { + for (String headerValue : headerValues) { + String actualHeaderValue = headerValue != null ? headerValue : ""; + connection.addRequestProperty(headerName, actualHeaderValue); + } + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/SimpleClientHttpRequestFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..3665f93e71e867296bb580e129865b08a93c37ac --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleClientHttpRequestFactory.java @@ -0,0 +1,229 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.Proxy; +import java.net.URI; +import java.net.URL; +import java.net.URLConnection; + +import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link ClientHttpRequestFactory} implementation that uses standard JDK facilities. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @see java.net.HttpURLConnection + * @see HttpComponentsClientHttpRequestFactory + */ +@SuppressWarnings("deprecation") +public class SimpleClientHttpRequestFactory implements ClientHttpRequestFactory, AsyncClientHttpRequestFactory { + + private static final int DEFAULT_CHUNK_SIZE = 4096; + + + @Nullable + private Proxy proxy; + + private boolean bufferRequestBody = true; + + private int chunkSize = DEFAULT_CHUNK_SIZE; + + private int connectTimeout = -1; + + private int readTimeout = -1; + + private boolean outputStreaming = true; + + @Nullable + private AsyncListenableTaskExecutor taskExecutor; + + + /** + * Set the {@link Proxy} to use for this request factory. + */ + public void setProxy(Proxy proxy) { + this.proxy = proxy; + } + + /** + * Indicate whether this request factory should buffer the + * {@linkplain ClientHttpRequest#getBody() request body} internally. + *

Default is {@code true}. When sending large amounts of data via POST or PUT, + * it is recommended to change this property to {@code false}, so as not to run + * out of memory. This will result in a {@link ClientHttpRequest} that either + * streams directly to the underlying {@link HttpURLConnection} (if the + * {@link org.springframework.http.HttpHeaders#getContentLength() Content-Length} + * is known in advance), or that will use "Chunked transfer encoding" + * (if the {@code Content-Length} is not known in advance). + * @see #setChunkSize(int) + * @see HttpURLConnection#setFixedLengthStreamingMode(int) + */ + public void setBufferRequestBody(boolean bufferRequestBody) { + this.bufferRequestBody = bufferRequestBody; + } + + /** + * Set the number of bytes to write in each chunk when not buffering request + * bodies locally. + *

Note that this parameter is only used when + * {@link #setBufferRequestBody(boolean) bufferRequestBody} is set to {@code false}, + * and the {@link org.springframework.http.HttpHeaders#getContentLength() Content-Length} + * is not known in advance. + * @see #setBufferRequestBody(boolean) + */ + public void setChunkSize(int chunkSize) { + this.chunkSize = chunkSize; + } + + /** + * Set the underlying URLConnection's connect timeout (in milliseconds). + * A timeout value of 0 specifies an infinite timeout. + *

Default is the system's default timeout. + * @see URLConnection#setConnectTimeout(int) + */ + public void setConnectTimeout(int connectTimeout) { + this.connectTimeout = connectTimeout; + } + + /** + * Set the underlying URLConnection's read timeout (in milliseconds). + * A timeout value of 0 specifies an infinite timeout. + *

Default is the system's default timeout. + * @see URLConnection#setReadTimeout(int) + */ + public void setReadTimeout(int readTimeout) { + this.readTimeout = readTimeout; + } + + /** + * Set if the underlying URLConnection can be set to 'output streaming' mode. + * Default is {@code true}. + *

When output streaming is enabled, authentication and redirection cannot be handled automatically. + * If output streaming is disabled, the {@link HttpURLConnection#setFixedLengthStreamingMode} and + * {@link HttpURLConnection#setChunkedStreamingMode} methods of the underlying connection will never + * be called. + * @param outputStreaming if output streaming is enabled + */ + public void setOutputStreaming(boolean outputStreaming) { + this.outputStreaming = outputStreaming; + } + + /** + * Set the task executor for this request factory. Setting this property is required + * for {@linkplain #createAsyncRequest(URI, HttpMethod) creating asynchronous requests}. + * @param taskExecutor the task executor + */ + public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) { + this.taskExecutor = taskExecutor; + } + + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + HttpURLConnection connection = openConnection(uri.toURL(), this.proxy); + prepareConnection(connection, httpMethod.name()); + + if (this.bufferRequestBody) { + return new SimpleBufferingClientHttpRequest(connection, this.outputStreaming); + } + else { + return new SimpleStreamingClientHttpRequest(connection, this.chunkSize, this.outputStreaming); + } + } + + /** + * {@inheritDoc} + *

Setting the {@link #setTaskExecutor taskExecutor} property is required before calling this method. + */ + @Override + public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) throws IOException { + Assert.state(this.taskExecutor != null, "Asynchronous execution requires TaskExecutor to be set"); + + HttpURLConnection connection = openConnection(uri.toURL(), this.proxy); + prepareConnection(connection, httpMethod.name()); + + if (this.bufferRequestBody) { + return new SimpleBufferingAsyncClientHttpRequest( + connection, this.outputStreaming, this.taskExecutor); + } + else { + return new SimpleStreamingAsyncClientHttpRequest( + connection, this.chunkSize, this.outputStreaming, this.taskExecutor); + } + } + + /** + * Opens and returns a connection to the given URL. + *

The default implementation uses the given {@linkplain #setProxy(java.net.Proxy) proxy} - + * if any - to open a connection. + * @param url the URL to open a connection to + * @param proxy the proxy to use, may be {@code null} + * @return the opened connection + * @throws IOException in case of I/O errors + */ + protected HttpURLConnection openConnection(URL url, @Nullable Proxy proxy) throws IOException { + URLConnection urlConnection = (proxy != null ? url.openConnection(proxy) : url.openConnection()); + if (!HttpURLConnection.class.isInstance(urlConnection)) { + throw new IllegalStateException("HttpURLConnection required for [" + url + "] but got: " + urlConnection); + } + return (HttpURLConnection) urlConnection; + } + + /** + * Template method for preparing the given {@link HttpURLConnection}. + *

The default implementation prepares the connection for input and output, and sets the HTTP method. + * @param connection the connection to prepare + * @param httpMethod the HTTP request method ({@code GET}, {@code POST}, etc.) + * @throws IOException in case of I/O errors + */ + protected void prepareConnection(HttpURLConnection connection, String httpMethod) throws IOException { + if (this.connectTimeout >= 0) { + connection.setConnectTimeout(this.connectTimeout); + } + if (this.readTimeout >= 0) { + connection.setReadTimeout(this.readTimeout); + } + + connection.setDoInput(true); + + if ("GET".equals(httpMethod)) { + connection.setInstanceFollowRedirects(true); + } + else { + connection.setInstanceFollowRedirects(false); + } + + if ("POST".equals(httpMethod) || "PUT".equals(httpMethod) || + "PATCH".equals(httpMethod) || "DELETE".equals(httpMethod)) { + connection.setDoOutput(true); + } + else { + connection.setDoOutput(false); + } + + connection.setRequestMethod(httpMethod); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/SimpleClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..2c157ae7ef09bcc590fb87fe804748d1bc982aed --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleClientHttpResponse.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; + +/** + * {@link ClientHttpResponse} implementation that uses standard JDK facilities. + * Obtained via {@link SimpleBufferingClientHttpRequest#execute()} and + * {@link SimpleStreamingClientHttpRequest#execute()}. + * + * @author Arjen Poutsma + * @author Brian Clozel + * @since 3.0 + */ +final class SimpleClientHttpResponse extends AbstractClientHttpResponse { + + private final HttpURLConnection connection; + + @Nullable + private HttpHeaders headers; + + @Nullable + private InputStream responseStream; + + + SimpleClientHttpResponse(HttpURLConnection connection) { + this.connection = connection; + } + + + @Override + public int getRawStatusCode() throws IOException { + return this.connection.getResponseCode(); + } + + @Override + public String getStatusText() throws IOException { + return this.connection.getResponseMessage(); + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + // Header field 0 is the status line for most HttpURLConnections, but not on GAE + String name = this.connection.getHeaderFieldKey(0); + if (StringUtils.hasLength(name)) { + this.headers.add(name, this.connection.getHeaderField(0)); + } + int i = 1; + while (true) { + name = this.connection.getHeaderFieldKey(i); + if (!StringUtils.hasLength(name)) { + break; + } + this.headers.add(name, this.connection.getHeaderField(i)); + i++; + } + } + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + InputStream errorStream = this.connection.getErrorStream(); + this.responseStream = (errorStream != null ? errorStream : this.connection.getInputStream()); + return this.responseStream; + } + + @Override + public void close() { + try { + if (this.responseStream == null) { + getBody(); + } + StreamUtils.drain(this.responseStream); + this.responseStream.close(); + } + catch (Exception ex) { + // ignore + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleStreamingAsyncClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/SimpleStreamingAsyncClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..de4da24aabdf3c520ef126abbec9652b03e6e8ad --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleStreamingAsyncClientHttpRequest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.concurrent.Callable; + +import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * {@link org.springframework.http.client.ClientHttpRequest} implementation + * that uses standard Java facilities to execute streaming requests. Created + * via the {@link org.springframework.http.client.SimpleClientHttpRequestFactory}. + * + * @author Arjen Poutsma + * @since 3.0 + * @see org.springframework.http.client.SimpleClientHttpRequestFactory#createRequest + * @see org.springframework.http.client.support.AsyncHttpAccessor + * @see org.springframework.web.client.AsyncRestTemplate + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +final class SimpleStreamingAsyncClientHttpRequest extends AbstractAsyncClientHttpRequest { + + private final HttpURLConnection connection; + + private final int chunkSize; + + @Nullable + private OutputStream body; + + private final boolean outputStreaming; + + private final AsyncListenableTaskExecutor taskExecutor; + + + SimpleStreamingAsyncClientHttpRequest(HttpURLConnection connection, int chunkSize, + boolean outputStreaming, AsyncListenableTaskExecutor taskExecutor) { + + this.connection = connection; + this.chunkSize = chunkSize; + this.outputStreaming = outputStreaming; + this.taskExecutor = taskExecutor; + } + + + @Override + public String getMethodValue() { + return this.connection.getRequestMethod(); + } + + @Override + public URI getURI() { + try { + return this.connection.getURL().toURI(); + } + catch (URISyntaxException ex) { + throw new IllegalStateException( + "Could not get HttpURLConnection URI: " + ex.getMessage(), ex); + } + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + if (this.body == null) { + if (this.outputStreaming) { + long contentLength = headers.getContentLength(); + if (contentLength >= 0) { + this.connection.setFixedLengthStreamingMode(contentLength); + } + else { + this.connection.setChunkedStreamingMode(this.chunkSize); + } + } + SimpleBufferingClientHttpRequest.addHeaders(this.connection, headers); + this.connection.connect(); + this.body = this.connection.getOutputStream(); + } + return StreamUtils.nonClosing(this.body); + } + + @Override + protected ListenableFuture executeInternal(final HttpHeaders headers) throws IOException { + return this.taskExecutor.submitListenable(new Callable() { + @Override + public ClientHttpResponse call() throws Exception { + try { + if (body != null) { + body.close(); + } + else { + SimpleBufferingClientHttpRequest.addHeaders(connection, headers); + connection.connect(); + // Immediately trigger the request in a no-output scenario as well + connection.getResponseCode(); + } + } + catch (IOException ex) { + // ignore + } + return new SimpleClientHttpResponse(connection); + } + }); + + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleStreamingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/SimpleStreamingClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..2b751ac1c21ad0096ef86251acb02f0221139979 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleStreamingClientHttpRequest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * {@link ClientHttpRequest} implementation that uses standard JDK facilities to + * execute streaming requests. Created via the {@link SimpleClientHttpRequestFactory}. + * + * @author Arjen Poutsma + * @since 3.0 + * @see SimpleClientHttpRequestFactory#createRequest(java.net.URI, HttpMethod) + * @see org.springframework.http.client.support.HttpAccessor + * @see org.springframework.web.client.RestTemplate + */ +final class SimpleStreamingClientHttpRequest extends AbstractClientHttpRequest { + + private final HttpURLConnection connection; + + private final int chunkSize; + + @Nullable + private OutputStream body; + + private final boolean outputStreaming; + + + SimpleStreamingClientHttpRequest(HttpURLConnection connection, int chunkSize, boolean outputStreaming) { + this.connection = connection; + this.chunkSize = chunkSize; + this.outputStreaming = outputStreaming; + } + + + @Override + public String getMethodValue() { + return this.connection.getRequestMethod(); + } + + @Override + public URI getURI() { + try { + return this.connection.getURL().toURI(); + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Could not get HttpURLConnection URI: " + ex.getMessage(), ex); + } + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + if (this.body == null) { + if (this.outputStreaming) { + long contentLength = headers.getContentLength(); + if (contentLength >= 0) { + this.connection.setFixedLengthStreamingMode(contentLength); + } + else { + this.connection.setChunkedStreamingMode(this.chunkSize); + } + } + SimpleBufferingClientHttpRequest.addHeaders(this.connection, headers); + this.connection.connect(); + this.body = this.connection.getOutputStream(); + } + return StreamUtils.nonClosing(this.body); + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException { + try { + if (this.body != null) { + this.body.close(); + } + else { + SimpleBufferingClientHttpRequest.addHeaders(this.connection, headers); + this.connection.connect(); + // Immediately trigger the request in a no-output scenario as well + this.connection.getResponseCode(); + } + } + catch (IOException ex) { + // ignore + } + return new SimpleClientHttpResponse(this.connection); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/package-info.java b/spring-web/src/main/java/org/springframework/http/client/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..9c9935e3e9afdbdc425fb407fb1fbb31940d3ab7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/package-info.java @@ -0,0 +1,11 @@ +/** + * Contains an abstraction over client-side HTTP. This package + * contains the {@code ClientHttpRequest} and {@code ClientHttpResponse}, + * as well as a basic implementation of these interfaces. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.client; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/AbstractClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/reactive/AbstractClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..55678dfecc07d4003876622a4eb91c7e1776a012 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/AbstractClientHttpRequest.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Base class for {@link ClientHttpRequest} implementations. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public abstract class AbstractClientHttpRequest implements ClientHttpRequest { + + /** + * COMMITTING -> COMMITTED is the period after doCommit is called but before + * the response status and headers have been applied to the underlying + * response during which time pre-commit actions can still make changes to + * the response status and headers. + */ + private enum State {NEW, COMMITTING, COMMITTED} + + + private final HttpHeaders headers; + + private final MultiValueMap cookies; + + private final AtomicReference state = new AtomicReference<>(State.NEW); + + private final List>> commitActions = new ArrayList<>(4); + + + public AbstractClientHttpRequest() { + this(new HttpHeaders()); + } + + public AbstractClientHttpRequest(HttpHeaders headers) { + Assert.notNull(headers, "HttpHeaders must not be null"); + this.headers = headers; + this.cookies = new LinkedMultiValueMap<>(); + } + + + @Override + public HttpHeaders getHeaders() { + if (State.COMMITTED.equals(this.state.get())) { + return HttpHeaders.readOnlyHttpHeaders(this.headers); + } + return this.headers; + } + + @Override + public MultiValueMap getCookies() { + if (State.COMMITTED.equals(this.state.get())) { + return CollectionUtils.unmodifiableMultiValueMap(this.cookies); + } + return this.cookies; + } + + @Override + public void beforeCommit(Supplier> action) { + Assert.notNull(action, "Action must not be null"); + this.commitActions.add(action); + } + + @Override + public boolean isCommitted() { + return (this.state.get() != State.NEW); + } + + /** + * A variant of {@link #doCommit(Supplier)} for a request without body. + * @return a completion publisher + */ + protected Mono doCommit() { + return doCommit(null); + } + + /** + * Apply {@link #beforeCommit(Supplier) beforeCommit} actions, apply the + * request headers/cookies, and write the request body. + * @param writeAction the action to write the request body (may be {@code null}) + * @return a completion publisher + */ + protected Mono doCommit(@Nullable Supplier> writeAction) { + if (!this.state.compareAndSet(State.NEW, State.COMMITTING)) { + return Mono.empty(); + } + + this.commitActions.add(() -> + Mono.fromRunnable(() -> { + applyHeaders(); + applyCookies(); + this.state.set(State.COMMITTED); + })); + + if (writeAction != null) { + this.commitActions.add(writeAction); + } + + List> actions = this.commitActions.stream() + .map(Supplier::get).collect(Collectors.toList()); + + return Flux.concat(actions).then(); + } + + + /** + * Apply header changes from {@link #getHeaders()} to the underlying response. + * This method is called once only. + */ + protected abstract void applyHeaders(); + + /** + * Add cookies from {@link #getHeaders()} to the underlying response. + * This method is called once only. + */ + protected abstract void applyCookies(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpConnector.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpConnector.java new file mode 100644 index 0000000000000000000000000000000000000000..8de59c9c260afe8545aec514877ffd7924ba0b7c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpConnector.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.URI; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; + +/** + * Abstraction over HTTP clients driving the underlying HTTP client to connect + * to the origin server and provide all necessary infrastructure to send a + * {@link ClientHttpRequest} and receive a {@link ClientHttpResponse}. + * + * @author Brian Clozel + * @since 5.0 + */ +public interface ClientHttpConnector { + + /** + * Connect to the origin server using the given {@code HttpMethod} and + * {@code URI} and apply the given {@code requestCallback} when the HTTP + * request of the underlying API can be initialized and written to. + * @param method the HTTP request method + * @param uri the HTTP request URI + * @param requestCallback a function that prepares and writes to the request, + * returning a publisher that signals when it's done writing. + * Implementations can return a {@code Mono} by calling + * {@link ClientHttpRequest#writeWith} or {@link ClientHttpRequest#setComplete}. + * @return publisher for the {@link ClientHttpResponse} + */ + Mono connect(HttpMethod method, URI uri, + Function> requestCallback); + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..4b41c261a0511c5a0784a83d48a7d1b51d52981f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpRequest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.URI; + +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpMethod; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.util.MultiValueMap; + +/** + * Represents a client-side reactive HTTP request. + * + * @author Arjen Poutsma + * @author Brian Clozel + * @since 5.0 + */ +public interface ClientHttpRequest extends ReactiveHttpOutputMessage { + + /** + * Return the HTTP method of the request. + */ + HttpMethod getMethod(); + + /** + * Return the URI of the request. + */ + URI getURI(); + + /** + * Return a mutable map of request cookies to send to the server. + */ + MultiValueMap getCookies(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpRequestDecorator.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpRequestDecorator.java new file mode 100644 index 0000000000000000000000000000000000000000..35ed867b8bfd0466c491a52fed3e11af56136a7c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpRequestDecorator.java @@ -0,0 +1,114 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.URI; +import java.util.function.Supplier; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * Wraps another {@link ClientHttpRequest} and delegates all methods to it. + * Sub-classes can override specific methods selectively. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ClientHttpRequestDecorator implements ClientHttpRequest { + + private final ClientHttpRequest delegate; + + + public ClientHttpRequestDecorator(ClientHttpRequest delegate) { + Assert.notNull(delegate, "Delegate is required"); + this.delegate = delegate; + } + + + public ClientHttpRequest getDelegate() { + return this.delegate; + } + + + // ClientHttpRequest delegation methods... + + @Override + public HttpMethod getMethod() { + return this.delegate.getMethod(); + } + + @Override + public URI getURI() { + return this.delegate.getURI(); + } + + @Override + public HttpHeaders getHeaders() { + return this.delegate.getHeaders(); + } + + @Override + public MultiValueMap getCookies() { + return this.delegate.getCookies(); + } + + @Override + public DataBufferFactory bufferFactory() { + return this.delegate.bufferFactory(); + } + + @Override + public void beforeCommit(Supplier> action) { + this.delegate.beforeCommit(action); + } + + @Override + public boolean isCommitted() { + return this.delegate.isCommitted(); + } + + @Override + public Mono writeWith(Publisher body) { + return this.delegate.writeWith(body); + } + + @Override + public Mono writeAndFlushWith(Publisher> body) { + return this.delegate.writeAndFlushWith(body); + } + + @Override + public Mono setComplete() { + return this.delegate.setComplete(); + } + + + @Override + public String toString() { + return getClass().getSimpleName() + " [delegate=" + getDelegate() + "]"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..43e711635b63438c8799570c78c0532f24d6ad0f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpResponse.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import org.springframework.http.HttpStatus; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.ResponseCookie; +import org.springframework.util.MultiValueMap; + +/** + * Represents a client-side reactive HTTP response. + * + * @author Arjen Poutsma + * @author Brian Clozel + * @since 5.0 + */ +public interface ClientHttpResponse extends ReactiveHttpInputMessage { + + /** + * Return the HTTP status code as an {@link HttpStatus} enum value. + * @return the HTTP status as an HttpStatus enum value (never {@code null}) + * @throws IllegalArgumentException in case of an unknown HTTP status code + * @since #getRawStatusCode() + * @see HttpStatus#valueOf(int) + */ + HttpStatus getStatusCode(); + + /** + * Return the HTTP status code (potentially non-standard and not + * resolvable through the {@link HttpStatus} enum) as an integer. + * @return the HTTP status as an integer value + * @since 5.0.6 + * @see #getStatusCode() + * @see HttpStatus#resolve(int) + */ + int getRawStatusCode(); + + /** + * Return a read-only map of response cookies received from the server. + */ + MultiValueMap getCookies(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpResponseDecorator.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpResponseDecorator.java new file mode 100644 index 0000000000000000000000000000000000000000..233ea5047647ef9267c92b9843324373539712f6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ClientHttpResponseDecorator.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * Wraps another {@link ClientHttpResponse} and delegates all methods to it. + * Sub-classes can override specific methods selectively. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ClientHttpResponseDecorator implements ClientHttpResponse { + + private final ClientHttpResponse delegate; + + + public ClientHttpResponseDecorator(ClientHttpResponse delegate) { + Assert.notNull(delegate, "Delegate is required"); + this.delegate = delegate; + } + + + public ClientHttpResponse getDelegate() { + return this.delegate; + } + + + // ClientHttpResponse delegation methods... + + @Override + public HttpStatus getStatusCode() { + return this.delegate.getStatusCode(); + } + + @Override + public int getRawStatusCode() { + return this.delegate.getRawStatusCode(); + } + + @Override + public HttpHeaders getHeaders() { + return this.delegate.getHeaders(); + } + + @Override + public MultiValueMap getCookies() { + return this.delegate.getCookies(); + } + + @Override + public Flux getBody() { + return this.delegate.getBody(); + } + + @Override + public String toString() { + return getClass().getSimpleName() + " [delegate=" + getDelegate() + "]"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpConnector.java b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpConnector.java new file mode 100644 index 0000000000000000000000000000000000000000..c166fce6c327b7cc3ab04ca4b9001b8cecdaa57c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpConnector.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.URI; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.reactive.client.ContentChunk; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link ClientHttpConnector} for the Jetty Reactive Streams HttpClient. + * + * @author Sebastien Deleuze + * @since 5.1 + * @see Jetty ReactiveStreams HttpClient + */ +public class JettyClientHttpConnector implements ClientHttpConnector { + + private final HttpClient httpClient; + + private DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + + /** + * Default constructor that creates a new instance of {@link HttpClient}. + */ + public JettyClientHttpConnector() { + this(new HttpClient()); + } + + /** + * Constructor with an {@link JettyResourceFactory} that will manage shared resources. + * @param resourceFactory the {@link JettyResourceFactory} to use + * @param customizer the lambda used to customize the {@link HttpClient} + */ + public JettyClientHttpConnector( + JettyResourceFactory resourceFactory, @Nullable Consumer customizer) { + + HttpClient httpClient = new HttpClient(); + httpClient.setExecutor(resourceFactory.getExecutor()); + httpClient.setByteBufferPool(resourceFactory.getByteBufferPool()); + httpClient.setScheduler(resourceFactory.getScheduler()); + if (customizer != null) { + customizer.accept(httpClient); + } + this.httpClient = httpClient; + } + + /** + * Constructor with an initialized {@link HttpClient}. + */ + public JettyClientHttpConnector(HttpClient httpClient) { + Assert.notNull(httpClient, "HttpClient is required"); + this.httpClient = httpClient; + } + + + public void setBufferFactory(DataBufferFactory bufferFactory) { + this.bufferFactory = bufferFactory; + } + + + @Override + public Mono connect(HttpMethod method, URI uri, + Function> requestCallback) { + + if (!uri.isAbsolute()) { + return Mono.error(new IllegalArgumentException("URI is not absolute: " + uri)); + } + + if (!this.httpClient.isStarted()) { + try { + this.httpClient.start(); + } + catch (Exception ex) { + return Mono.error(ex); + } + } + + JettyClientHttpRequest clientHttpRequest = new JettyClientHttpRequest( + this.httpClient.newRequest(uri).method(method.toString()), this.bufferFactory); + + return requestCallback.apply(clientHttpRequest).then(Mono.from( + clientHttpRequest.getReactiveRequest().response((response, chunks) -> { + Flux content = Flux.from(chunks).map(this::toDataBuffer); + return Mono.just(new JettyClientHttpResponse(response, content)); + }))); + } + + private DataBuffer toDataBuffer(ContentChunk chunk) { + + // We must copy until this is resolved: + // https://github.com/eclipse/jetty.project/issues/2429 + + // Use copy instead of buffer wrapping because Callback#succeeded() is + // used not only to release the buffer but also to request more data + // which is a problem for codecs that buffer data. + + DataBuffer buffer = this.bufferFactory.allocateBuffer(chunk.buffer.capacity()); + buffer.write(chunk.buffer); + chunk.callback.succeeded(); + return buffer; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..7532945a88d19fac7a778bd3985bf7936104ef56 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpRequest.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.HttpCookie; +import java.net.URI; +import java.util.Collection; +import java.util.function.Function; + +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.reactive.client.ContentChunk; +import org.eclipse.jetty.reactive.client.ReactiveRequest; +import org.eclipse.jetty.util.Callback; +import org.reactivestreams.Publisher; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * {@link ClientHttpRequest} implementation for the Jetty ReactiveStreams HTTP client. + * + * @author Sebastien Deleuze + * @since 5.1 + * @see Jetty ReactiveStreams HttpClient + */ +class JettyClientHttpRequest extends AbstractClientHttpRequest { + + private final Request jettyRequest; + + private final DataBufferFactory bufferFactory; + + @Nullable + private ReactiveRequest reactiveRequest; + + + public JettyClientHttpRequest(Request jettyRequest, DataBufferFactory bufferFactory) { + this.jettyRequest = jettyRequest; + this.bufferFactory = bufferFactory; + } + + + @Override + public HttpMethod getMethod() { + return HttpMethod.valueOf(this.jettyRequest.getMethod()); + } + + @Override + public URI getURI() { + return this.jettyRequest.getURI(); + } + + @Override + public Mono setComplete() { + return doCommit(this::completes); + } + + @Override + public DataBufferFactory bufferFactory() { + return this.bufferFactory; + } + + @Override + public Mono writeWith(Publisher body) { + Flux chunks = Flux.from(body).map(this::toContentChunk); + ReactiveRequest.Content content = ReactiveRequest.Content.fromPublisher(chunks, getContentType()); + this.reactiveRequest = ReactiveRequest.newBuilder(this.jettyRequest).content(content).build(); + return doCommit(this::completes); + } + + @Override + public Mono writeAndFlushWith(Publisher> body) { + Flux chunks = Flux.from(body) + .flatMap(Function.identity()) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release) + .map(this::toContentChunk); + ReactiveRequest.Content content = ReactiveRequest.Content.fromPublisher(chunks, getContentType()); + this.reactiveRequest = ReactiveRequest.newBuilder(this.jettyRequest).content(content).build(); + return doCommit(this::completes); + } + + private String getContentType() { + MediaType contentType = getHeaders().getContentType(); + return contentType != null ? contentType.toString() : MediaType.APPLICATION_OCTET_STREAM_VALUE; + } + + private Mono completes() { + return Mono.empty(); + } + + private ContentChunk toContentChunk(DataBuffer buffer) { + return new ContentChunk(buffer.asByteBuffer(), new Callback() { + @Override + public void succeeded() { + DataBufferUtils.release(buffer); + } + @Override + public void failed(Throwable x) { + DataBufferUtils.release(buffer); + throw Exceptions.propagate(x); + } + }); + } + + + @Override + protected void applyCookies() { + getCookies().values().stream().flatMap(Collection::stream) + .map(cookie -> new HttpCookie(cookie.getName(), cookie.getValue())) + .forEach(this.jettyRequest::cookie); + } + + @Override + protected void applyHeaders() { + HttpHeaders headers = getHeaders(); + headers.forEach((key, value) -> value.forEach(v -> this.jettyRequest.header(key, v))); + if (!headers.containsKey(HttpHeaders.ACCEPT)) { + this.jettyRequest.header(HttpHeaders.ACCEPT, "*/*"); + } + } + + ReactiveRequest getReactiveRequest() { + if (this.reactiveRequest == null) { + this.reactiveRequest = ReactiveRequest.newBuilder(this.jettyRequest).build(); + } + return this.reactiveRequest; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..60d94517dac69c90ead261f4cea4ba48eedf5a0e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyClientHttpResponse.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.HttpCookie; +import java.util.List; + +import org.eclipse.jetty.reactive.client.ReactiveResponse; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * {@link ClientHttpResponse} implementation for the Jetty ReactiveStreams HTTP client. + * + * @author Sebastien Deleuze + * @since 5.1 + * @see Jetty ReactiveStreams HttpClient + */ +class JettyClientHttpResponse implements ClientHttpResponse { + + private final ReactiveResponse reactiveResponse; + + private final Flux content; + + + public JettyClientHttpResponse(ReactiveResponse reactiveResponse, Publisher content) { + this.reactiveResponse = reactiveResponse; + this.content = Flux.from(content); + } + + + @Override + public HttpStatus getStatusCode() { + return HttpStatus.valueOf(getRawStatusCode()); + } + + @Override + public int getRawStatusCode() { + return this.reactiveResponse.getStatus(); + } + + @Override + public MultiValueMap getCookies() { + MultiValueMap result = new LinkedMultiValueMap<>(); + List cookieHeader = getHeaders().get(HttpHeaders.SET_COOKIE); + if (cookieHeader != null) { + cookieHeader.forEach(header -> + HttpCookie.parse(header) + .forEach(cookie -> result.add(cookie.getName(), + ResponseCookie.from(cookie.getName(), cookie.getValue()) + .domain(cookie.getDomain()) + .path(cookie.getPath()) + .maxAge(cookie.getMaxAge()) + .secure(cookie.getSecure()) + .httpOnly(cookie.isHttpOnly()) + .build()))); + } + return CollectionUtils.unmodifiableMultiValueMap(result); + } + + @Override + public Flux getBody() { + return this.content; + } + + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = new HttpHeaders(); + this.reactiveResponse.getHeaders().stream() + .forEach(field -> headers.add(field.getName(), field.getValue())); + return headers; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/JettyResourceFactory.java b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyResourceFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..b31e86ae1ea6fffc611e929a2079eafc7d4b1a5c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/JettyResourceFactory.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + + +import java.nio.ByteBuffer; +import java.util.concurrent.Executor; + +import org.eclipse.jetty.io.ByteBufferPool; +import org.eclipse.jetty.io.MappedByteBufferPool; +import org.eclipse.jetty.util.ProcessorUtils; +import org.eclipse.jetty.util.component.LifeCycle; +import org.eclipse.jetty.util.thread.QueuedThreadPool; +import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler; +import org.eclipse.jetty.util.thread.Scheduler; +import org.eclipse.jetty.util.thread.ThreadPool; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Factory to manage Jetty resources, i.e. {@link Executor}, {@link ByteBufferPool} and + * {@link Scheduler}, within the lifecycle of a Spring {@code ApplicationContext}. + * + *

This factory implements {@link InitializingBean} and {@link DisposableBean} + * and is expected typically to be declared as a Spring-managed bean. + * + * @author Sebastien Deleuze + * @since 5.1 + */ +public class JettyResourceFactory implements InitializingBean, DisposableBean { + + @Nullable + private Executor executor; + + @Nullable + private ByteBufferPool byteBufferPool; + + @Nullable + private Scheduler scheduler; + + private String threadPrefix = "jetty-http"; + + + /** + * Configure the {@link Executor} to use. + *

By default, initialized with a {@link QueuedThreadPool}. + * @param executor the executor to use + */ + public void setExecutor(@Nullable Executor executor) { + this.executor = executor; + } + + /** + * Configure the {@link ByteBufferPool} to use. + *

By default, initialized with a {@link MappedByteBufferPool}. + * @param byteBufferPool the {@link ByteBuffer} pool to use + */ + public void setByteBufferPool(@Nullable ByteBufferPool byteBufferPool) { + this.byteBufferPool = byteBufferPool; + } + + /** + * Configure the {@link Scheduler} to use. + *

By default, initialized with a {@link ScheduledExecutorScheduler}. + * @param scheduler the {@link Scheduler} to use + */ + public void setScheduler(@Nullable Scheduler scheduler) { + this.scheduler = scheduler; + } + + /** + * Configure the thread prefix to initialize {@link QueuedThreadPool} executor with. This + * is used only when a {@link Executor} instance isn't + * {@link #setExecutor(Executor) provided}. + *

By default set to "jetty-http". + * @param threadPrefix the thread prefix to use + */ + public void setThreadPrefix(String threadPrefix) { + Assert.notNull(threadPrefix, "Thread prefix is required"); + this.threadPrefix = threadPrefix; + } + + /** + * Return the configured {@link Executor}. + */ + @Nullable + public Executor getExecutor() { + return this.executor; + } + + /** + * Return the configured {@link ByteBufferPool}. + */ + @Nullable + public ByteBufferPool getByteBufferPool() { + return this.byteBufferPool; + } + + /** + * Return the configured {@link Scheduler}. + */ + @Nullable + public Scheduler getScheduler() { + return this.scheduler; + } + + @Override + public void afterPropertiesSet() throws Exception { + String name = this.threadPrefix + "@" + Integer.toHexString(hashCode()); + if (this.executor == null) { + QueuedThreadPool threadPool = new QueuedThreadPool(); + threadPool.setName(name); + this.executor = threadPool; + } + if (this.byteBufferPool == null) { + this.byteBufferPool = new MappedByteBufferPool(2048, + this.executor instanceof ThreadPool.SizedThreadPool + ? ((ThreadPool.SizedThreadPool) executor).getMaxThreads() / 2 + : ProcessorUtils.availableProcessors() * 2); + } + if (this.scheduler == null) { + this.scheduler = new ScheduledExecutorScheduler(name + "-scheduler", false); + } + + if (this.executor instanceof LifeCycle) { + ((LifeCycle)this.executor).start(); + } + this.scheduler.start(); + } + + @Override + public void destroy() throws Exception { + try { + if (this.executor instanceof LifeCycle) { + ((LifeCycle)this.executor).stop(); + } + } + catch (Throwable ex) { + // ignore + } + try { + if (this.scheduler != null) { + this.scheduler.stop(); + } + } + catch (Throwable ex) { + // ignore + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpConnector.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpConnector.java new file mode 100644 index 0000000000000000000000000000000000000000..98dd9dd8babbca935ef529d3febb50b915fcd2ae --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpConnector.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.URI; +import java.util.function.Function; + +import io.netty.buffer.ByteBufAllocator; +import reactor.core.publisher.Mono; +import reactor.netty.NettyInbound; +import reactor.netty.NettyOutbound; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.client.HttpClientRequest; +import reactor.netty.http.client.HttpClientResponse; +import reactor.netty.resources.ConnectionProvider; +import reactor.netty.resources.LoopResources; + +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; + +/** + * Reactor-Netty implementation of {@link ClientHttpConnector}. + * + * @author Brian Clozel + * @since 5.0 + * @see reactor.netty.http.client.HttpClient + */ +public class ReactorClientHttpConnector implements ClientHttpConnector { + + private final static Function defaultInitializer = client -> client.compress(true); + + + private final HttpClient httpClient; + + + /** + * Default constructor. Initializes {@link HttpClient} via: + *

+	 * HttpClient.create().compress()
+	 * 
+ */ + public ReactorClientHttpConnector() { + this.httpClient = defaultInitializer.apply(HttpClient.create()); + } + + /** + * Constructor with externally managed Reactor Netty resources, including + * {@link LoopResources} for event loop threads, and {@link ConnectionProvider} + * for the connection pool. + *

This constructor should be used only when you don't want the client + * to participate in the Reactor Netty global resources. By default the + * client participates in the Reactor Netty global resources held in + * {@link reactor.netty.http.HttpResources}, which is recommended since + * fixed, shared resources are favored for event loop concurrency. However, + * consider declaring a {@link ReactorResourceFactory} bean with + * {@code globalResources=true} in order to ensure the Reactor Netty global + * resources are shut down when the Spring ApplicationContext is closed. + * @param factory the resource factory to obtain the resources from + * @param mapper a mapper for further initialization of the created client + * @since 5.1 + */ + public ReactorClientHttpConnector(ReactorResourceFactory factory, Function mapper) { + this.httpClient = defaultInitializer.andThen(mapper).apply(initHttpClient(factory)); + } + + private static HttpClient initHttpClient(ReactorResourceFactory resourceFactory) { + ConnectionProvider provider = resourceFactory.getConnectionProvider(); + LoopResources resources = resourceFactory.getLoopResources(); + Assert.notNull(provider, "No ConnectionProvider: is ReactorResourceFactory not initialized yet?"); + Assert.notNull(resources, "No LoopResources: is ReactorResourceFactory not initialized yet?"); + return HttpClient.create(provider).tcpConfiguration(tcpClient -> tcpClient.runOn(resources)); + } + + /** + * Constructor with a pre-configured {@code HttpClient} instance. + * @param httpClient the client to use + * @since 5.1 + */ + public ReactorClientHttpConnector(HttpClient httpClient) { + Assert.notNull(httpClient, "HttpClient is required"); + this.httpClient = httpClient; + } + + + @Override + public Mono connect(HttpMethod method, URI uri, + Function> requestCallback) { + + if (!uri.isAbsolute()) { + return Mono.error(new IllegalArgumentException("URI is not absolute: " + uri)); + } + + return this.httpClient + .request(io.netty.handler.codec.http.HttpMethod.valueOf(method.name())) + .uri(uri.toString()) + .send((request, outbound) -> requestCallback.apply(adaptRequest(method, uri, request, outbound))) + .responseConnection((res, con) -> Mono.just(adaptResponse(res, con.inbound(), con.outbound().alloc()))) + .next(); + } + + private ReactorClientHttpRequest adaptRequest(HttpMethod method, URI uri, HttpClientRequest request, + NettyOutbound nettyOutbound) { + + return new ReactorClientHttpRequest(method, uri, request, nettyOutbound); + } + + private ClientHttpResponse adaptResponse(HttpClientResponse response, NettyInbound nettyInbound, + ByteBufAllocator allocator) { + + return new ReactorClientHttpResponse(response, nettyInbound, allocator); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..8cc4d423488c3e48c70f35978420a3fe44a11eba --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpRequest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.net.URI; +import java.nio.file.Path; +import java.util.Collection; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.cookie.DefaultCookie; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.NettyOutbound; +import reactor.netty.http.client.HttpClientRequest; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpMethod; +import org.springframework.http.ZeroCopyHttpOutputMessage; + +/** + * {@link ClientHttpRequest} implementation for the Reactor-Netty HTTP client. + * + * @author Brian Clozel + * @since 5.0 + * @see reactor.netty.http.client.HttpClient + */ +class ReactorClientHttpRequest extends AbstractClientHttpRequest implements ZeroCopyHttpOutputMessage { + + private final HttpMethod httpMethod; + + private final URI uri; + + private final HttpClientRequest request; + + private final NettyOutbound outbound; + + private final NettyDataBufferFactory bufferFactory; + + + public ReactorClientHttpRequest(HttpMethod method, URI uri, HttpClientRequest request, NettyOutbound outbound) { + this.httpMethod = method; + this.uri = uri; + this.request = request; + this.outbound = outbound; + this.bufferFactory = new NettyDataBufferFactory(outbound.alloc()); + } + + + @Override + public DataBufferFactory bufferFactory() { + return this.bufferFactory; + } + + @Override + public HttpMethod getMethod() { + return this.httpMethod; + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Mono writeWith(Publisher body) { + return doCommit(() -> { + // Send as Mono if possible as an optimization hint to Reactor Netty + if (body instanceof Mono) { + Mono byteBufMono = Mono.from(body).map(NettyDataBufferFactory::toByteBuf); + return this.outbound.send(byteBufMono).then(); + + } + else { + Flux byteBufFlux = Flux.from(body).map(NettyDataBufferFactory::toByteBuf); + return this.outbound.send(byteBufFlux).then(); + } + }); + } + + @Override + public Mono writeAndFlushWith(Publisher> body) { + Publisher> byteBufs = Flux.from(body).map(ReactorClientHttpRequest::toByteBufs); + return doCommit(() -> this.outbound.sendGroups(byteBufs).then()); + } + + private static Publisher toByteBufs(Publisher dataBuffers) { + return Flux.from(dataBuffers).map(NettyDataBufferFactory::toByteBuf); + } + + @Override + public Mono writeWith(Path file, long position, long count) { + return doCommit(() -> this.outbound.sendFile(file, position, count).then()); + } + + @Override + public Mono setComplete() { + return doCommit(this.outbound::then); + } + + @Override + protected void applyHeaders() { + getHeaders().forEach((key, value) -> this.request.requestHeaders().set(key, value)); + } + + @Override + protected void applyCookies() { + getCookies().values().stream().flatMap(Collection::stream) + .map(cookie -> new DefaultCookie(cookie.getName(), cookie.getValue())) + .forEach(this.request::addCookie); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..38b9b58f36142ae9df3640dd86c3650ee63b2cea --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; + +import io.netty.buffer.ByteBufAllocator; +import reactor.core.publisher.Flux; +import reactor.netty.NettyInbound; +import reactor.netty.http.client.HttpClientResponse; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * {@link ClientHttpResponse} implementation for the Reactor-Netty HTTP client. + * + * @author Brian Clozel + * @since 5.0 + * @see reactor.netty.http.client.HttpClient + */ +class ReactorClientHttpResponse implements ClientHttpResponse { + + private final NettyDataBufferFactory bufferFactory; + + private final HttpClientResponse response; + + private final NettyInbound inbound; + + private final AtomicBoolean rejectSubscribers = new AtomicBoolean(); + + + public ReactorClientHttpResponse(HttpClientResponse response, NettyInbound inbound, ByteBufAllocator alloc) { + this.response = response; + this.inbound = inbound; + this.bufferFactory = new NettyDataBufferFactory(alloc); + } + + + @Override + public Flux getBody() { + return this.inbound.receive() + .doOnSubscribe(s -> { + if (this.rejectSubscribers.get()) { + throw new IllegalStateException("The client response body can only be consumed once."); + } + }) + .doOnCancel(() -> + // https://github.com/reactor/reactor-netty/issues/503 + // FluxReceive rejects multiple subscribers, but not after a cancel(). + // Subsequent subscribers after cancel() will not be rejected, but will hang instead. + // So we need to intercept and reject them in that case. + this.rejectSubscribers.set(true)) + .map(byteBuf -> { + byteBuf.retain(); + return this.bufferFactory.wrap(byteBuf); + }); + } + + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = new HttpHeaders(); + this.response.responseHeaders().entries().forEach(e -> headers.add(e.getKey(), e.getValue())); + return headers; + } + + @Override + public HttpStatus getStatusCode() { + return HttpStatus.valueOf(getRawStatusCode()); + } + + @Override + public int getRawStatusCode() { + return this.response.status().code(); + } + + @Override + public MultiValueMap getCookies() { + MultiValueMap result = new LinkedMultiValueMap<>(); + this.response.cookies().values().stream().flatMap(Collection::stream) + .forEach(cookie -> + result.add(cookie.name(), ResponseCookie.from(cookie.name(), cookie.value()) + .domain(cookie.domain()) + .path(cookie.path()) + .maxAge(cookie.maxAge()) + .secure(cookie.isSecure()) + .httpOnly(cookie.isHttpOnly()) + .build())); + return CollectionUtils.unmodifiableMultiValueMap(result); + } + + @Override + public String toString() { + return "ReactorClientHttpResponse{" + + "request=[" + this.response.method().name() + " " + this.response.uri() + "]," + + "status=" + getRawStatusCode() + '}'; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorResourceFactory.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorResourceFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..fdfd1a78d8f90acae74ba86c3dadd8228e73d23a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorResourceFactory.java @@ -0,0 +1,209 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.reactive; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +import reactor.netty.http.HttpResources; +import reactor.netty.resources.ConnectionProvider; +import reactor.netty.resources.LoopResources; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Factory to manage Reactor Netty resources, i.e. {@link LoopResources} for + * event loop threads, and {@link ConnectionProvider} for the connection pool, + * within the lifecycle of a Spring {@code ApplicationContext}. + * + *

This factory implements {@link InitializingBean} and {@link DisposableBean} + * and is expected typically to be declared as a Spring-managed bean. + * + * @author Rossen Stoyanchev + * @since 5.1 + */ +public class ReactorResourceFactory implements InitializingBean, DisposableBean { + + private boolean useGlobalResources = true; + + @Nullable + private Consumer globalResourcesConsumer; + + private Supplier connectionProviderSupplier = () -> ConnectionProvider.elastic("webflux"); + + private Supplier loopResourcesSupplier = () -> LoopResources.create("webflux-http"); + + @Nullable + private ConnectionProvider connectionProvider; + + @Nullable + private LoopResources loopResources; + + private boolean manageConnectionProvider = false; + + private boolean manageLoopResources = false; + + + /** + * Whether to use global Reactor Netty resources via {@link HttpResources}. + *

Default is "true" in which case this factory initializes and stops the + * global Reactor Netty resources within Spring's {@code ApplicationContext} + * lifecycle. If set to "false" the factory manages its resources independent + * of the global ones. + * @param useGlobalResources whether to expose and manage the global resources + * @see #addGlobalResourcesConsumer(Consumer) + */ + public void setUseGlobalResources(boolean useGlobalResources) { + this.useGlobalResources = useGlobalResources; + } + + /** + * Whether this factory exposes the global + * {@link reactor.netty.http.HttpResources HttpResources} holder. + */ + public boolean isUseGlobalResources() { + return this.useGlobalResources; + } + + /** + * Add a Consumer for configuring the global Reactor Netty resources on + * startup. When this option is used, {@link #setUseGlobalResources} is also + * enabled. + * @param consumer the consumer to apply + * @see #setUseGlobalResources(boolean) + */ + public void addGlobalResourcesConsumer(Consumer consumer) { + this.useGlobalResources = true; + this.globalResourcesConsumer = this.globalResourcesConsumer != null ? + this.globalResourcesConsumer.andThen(consumer) : consumer; + } + + /** + * Use this option when you don't want to participate in global resources and + * you want to customize the creation of the managed {@code ConnectionProvider}. + *

By default, {@code ConnectionProvider.elastic("http")} is used. + *

Note that this option is ignored if {@code userGlobalResources=false} or + * {@link #setConnectionProvider(ConnectionProvider)} is set. + * @param supplier the supplier to use + */ + public void setConnectionProviderSupplier(Supplier supplier) { + this.connectionProviderSupplier = supplier; + } + + /** + * Use this option when you don't want to participate in global resources and + * you want to customize the creation of the managed {@code LoopResources}. + *

By default, {@code LoopResources.create("reactor-http")} is used. + *

Note that this option is ignored if {@code userGlobalResources=false} or + * {@link #setLoopResources(LoopResources)} is set. + * @param supplier the supplier to use + */ + public void setLoopResourcesSupplier(Supplier supplier) { + this.loopResourcesSupplier = supplier; + } + + /** + * Use this option when you want to provide an externally managed + * {@link ConnectionProvider} instance. + * @param connectionProvider the connection provider to use as is + */ + public void setConnectionProvider(ConnectionProvider connectionProvider) { + this.connectionProvider = connectionProvider; + } + + /** + * Return the configured {@link ConnectionProvider}. + */ + public ConnectionProvider getConnectionProvider() { + Assert.state(this.connectionProvider != null, "ConnectionProvider not initialized yet"); + return this.connectionProvider; + } + + /** + * Use this option when you want to provide an externally managed + * {@link LoopResources} instance. + * @param loopResources the loop resources to use as is + */ + public void setLoopResources(LoopResources loopResources) { + this.loopResources = loopResources; + } + + /** + * Return the configured {@link LoopResources}. + */ + public LoopResources getLoopResources() { + Assert.state(this.loopResources != null, "LoopResources not initialized yet"); + return this.loopResources; + } + + + @Override + public void afterPropertiesSet() { + if (this.useGlobalResources) { + Assert.isTrue(this.loopResources == null && this.connectionProvider == null, + "'useGlobalResources' is mutually exclusive with explicitly configured resources"); + HttpResources httpResources = HttpResources.get(); + if (this.globalResourcesConsumer != null) { + this.globalResourcesConsumer.accept(httpResources); + } + this.connectionProvider = httpResources; + this.loopResources = httpResources; + } + else { + if (this.loopResources == null) { + this.manageLoopResources = true; + this.loopResources = this.loopResourcesSupplier.get(); + } + if (this.connectionProvider == null) { + this.manageConnectionProvider = true; + this.connectionProvider = this.connectionProviderSupplier.get(); + } + } + } + + @Override + public void destroy() { + if (this.useGlobalResources) { + HttpResources.disposeLoopsAndConnectionsLater().block(); + } + else { + try { + ConnectionProvider provider = this.connectionProvider; + if (provider != null && this.manageConnectionProvider) { + provider.disposeLater().block(); + } + } + catch (Throwable ex) { + // ignore + } + + try { + LoopResources resources = this.loopResources; + if (resources != null && this.manageLoopResources) { + resources.disposeLater().block(); + } + } + catch (Throwable ex) { + // ignore + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/package-info.java b/spring-web/src/main/java/org/springframework/http/client/reactive/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..b17f099cdb86b777919716e2940a82bcac64930d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/package-info.java @@ -0,0 +1,12 @@ +/** + * Abstractions for reactive HTTP client support including + * {@link org.springframework.http.client.reactive.ClientHttpRequest} and + * {@link org.springframework.http.client.reactive.ClientHttpResponse} as well as a + * {@link org.springframework.http.client.reactive.ClientHttpConnector}. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.client.reactive; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/client/support/AsyncHttpAccessor.java b/spring-web/src/main/java/org/springframework/http/client/support/AsyncHttpAccessor.java new file mode 100644 index 0000000000000000000000000000000000000000..faf8c309c7278ec1f7107f2b3016edf0e54b84ff --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/AsyncHttpAccessor.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.io.IOException; +import java.net.URI; + +import org.apache.commons.logging.Log; + +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Base class for {@link org.springframework.web.client.AsyncRestTemplate} + * and other HTTP accessing gateway helpers, defining common properties + * such as the {@link org.springframework.http.client.AsyncClientHttpRequestFactory} + * to operate on. + * + *

Not intended to be used directly. See + * {@link org.springframework.web.client.AsyncRestTemplate}. + * + * @author Arjen Poutsma + * @since 4.0 + * @see org.springframework.web.client.AsyncRestTemplate + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +public class AsyncHttpAccessor { + + /** Logger available to subclasses. */ + protected final Log logger = HttpLogging.forLogName(getClass()); + + @Nullable + private org.springframework.http.client.AsyncClientHttpRequestFactory asyncRequestFactory; + + + /** + * Set the request factory that this accessor uses for obtaining {@link + * org.springframework.http.client.ClientHttpRequest HttpRequests}. + */ + public void setAsyncRequestFactory( + org.springframework.http.client.AsyncClientHttpRequestFactory asyncRequestFactory) { + + Assert.notNull(asyncRequestFactory, "AsyncClientHttpRequestFactory must not be null"); + this.asyncRequestFactory = asyncRequestFactory; + } + + /** + * Return the request factory that this accessor uses for obtaining {@link + * org.springframework.http.client.ClientHttpRequest HttpRequests}. + */ + public org.springframework.http.client.AsyncClientHttpRequestFactory getAsyncRequestFactory() { + Assert.state(this.asyncRequestFactory != null, "No AsyncClientHttpRequestFactory set"); + return this.asyncRequestFactory; + } + + /** + * Create a new {@link org.springframework.http.client.AsyncClientHttpRequest} via this template's + * {@link org.springframework.http.client.AsyncClientHttpRequestFactory}. + * @param url the URL to connect to + * @param method the HTTP method to execute (GET, POST, etc.) + * @return the created request + * @throws IOException in case of I/O errors + */ + protected org.springframework.http.client.AsyncClientHttpRequest createAsyncRequest(URI url, HttpMethod method) + throws IOException { + + org.springframework.http.client.AsyncClientHttpRequest request = + getAsyncRequestFactory().createAsyncRequest(url, method); + if (logger.isDebugEnabled()) { + logger.debug("Created asynchronous " + method.name() + " request for \"" + url + "\""); + } + return request; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/BasicAuthenticationInterceptor.java b/spring-web/src/main/java/org/springframework/http/client/support/BasicAuthenticationInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..42a4c18c21911195ea8307f4397bdb50a4395bd5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/BasicAuthenticationInterceptor.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.io.IOException; +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; +import org.springframework.http.client.ClientHttpRequestExecution; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link ClientHttpRequestInterceptor} to apply a given HTTP Basic Authentication + * username/password pair, unless a custom Authorization header has been set before. + * + * @author Juergen Hoeller + * @since 5.1.1 + * @see HttpHeaders#setBasicAuth + * @see HttpHeaders#AUTHORIZATION + */ +public class BasicAuthenticationInterceptor implements ClientHttpRequestInterceptor { + + private final String username; + + private final String password; + + @Nullable + private final Charset charset; + + + /** + * Create a new interceptor which adds Basic Authentication for the + * given username and password. + * @param username the username to use + * @param password the password to use + * @see HttpHeaders#setBasicAuth(String, String) + */ + public BasicAuthenticationInterceptor(String username, String password) { + this(username, password, null); + } + + /** + * Create a new interceptor which adds Basic Authentication for the + * given username and password, encoded using the specified charset. + * @param username the username to use + * @param password the password to use + * @param charset the charset to use + * @see HttpHeaders#setBasicAuth(String, String, Charset) + */ + public BasicAuthenticationInterceptor(String username, String password, @Nullable Charset charset) { + Assert.doesNotContain(username, ":", "Username must not contain a colon"); + this.username = username; + this.password = password; + this.charset = charset; + } + + + @Override + public ClientHttpResponse intercept( + HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { + + HttpHeaders headers = request.getHeaders(); + if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) { + headers.setBasicAuth(this.username, this.password, this.charset); + } + return execution.execute(request, body); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/BasicAuthorizationInterceptor.java b/spring-web/src/main/java/org/springframework/http/client/support/BasicAuthorizationInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..87936fb0003f3b5067c165716e2925132b6b787c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/BasicAuthorizationInterceptor.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import org.springframework.http.HttpRequest; +import org.springframework.http.client.ClientHttpRequestExecution; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.Base64Utils; + +/** + * {@link ClientHttpRequestInterceptor} to apply a BASIC authorization header. + * + * @author Phillip Webb + * @since 4.3.1 + * @deprecated as of 5.1.1, in favor of {@link BasicAuthenticationInterceptor} + * which reuses {@link org.springframework.http.HttpHeaders#setBasicAuth}, + * sharing its default charset ISO-8859-1 instead of UTF-8 as used here + */ +@Deprecated +public class BasicAuthorizationInterceptor implements ClientHttpRequestInterceptor { + + private final String username; + + private final String password; + + + /** + * Create a new interceptor which adds a BASIC authorization header + * for the given username and password. + * @param username the username to use + * @param password the password to use + */ + public BasicAuthorizationInterceptor(@Nullable String username, @Nullable String password) { + Assert.doesNotContain(username, ":", "Username must not contain a colon"); + this.username = (username != null ? username : ""); + this.password = (password != null ? password : ""); + } + + + @Override + public ClientHttpResponse intercept( + HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { + + String token = Base64Utils.encodeToString( + (this.username + ":" + this.password).getBytes(StandardCharsets.UTF_8)); + request.getHeaders().add("Authorization", "Basic " + token); + return execution.execute(request, body); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/HttpAccessor.java b/spring-web/src/main/java/org/springframework/http/client/support/HttpAccessor.java new file mode 100644 index 0000000000000000000000000000000000000000..895fde9ef0524e9f929debffd99cbfdd992e8022 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/HttpAccessor.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.io.IOException; +import java.net.URI; + +import org.apache.commons.logging.Log; + +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.util.Assert; + +/** + * Base class for {@link org.springframework.web.client.RestTemplate} + * and other HTTP accessing gateway helpers, defining common properties + * such as the {@link ClientHttpRequestFactory} to operate on. + * + *

Not intended to be used directly. + * See {@link org.springframework.web.client.RestTemplate} for an entry point. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @see ClientHttpRequestFactory + * @see org.springframework.web.client.RestTemplate + */ +public abstract class HttpAccessor { + + /** Logger available to subclasses. */ + protected final Log logger = HttpLogging.forLogName(getClass()); + + private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + + + /** + * Set the request factory that this accessor uses for obtaining client request handles. + *

The default is a {@link SimpleClientHttpRequestFactory} based on the JDK's own + * HTTP libraries ({@link java.net.HttpURLConnection}). + *

Note that the standard JDK HTTP library does not support the HTTP PATCH method. + * Configure the Apache HttpComponents or OkHttp request factory to enable PATCH. + * @see #createRequest(URI, HttpMethod) + * @see SimpleClientHttpRequestFactory + * @see org.springframework.http.client.HttpComponentsAsyncClientHttpRequestFactory + * @see org.springframework.http.client.OkHttp3ClientHttpRequestFactory + */ + public void setRequestFactory(ClientHttpRequestFactory requestFactory) { + Assert.notNull(requestFactory, "ClientHttpRequestFactory must not be null"); + this.requestFactory = requestFactory; + } + + /** + * Return the request factory that this accessor uses for obtaining client request handles. + */ + public ClientHttpRequestFactory getRequestFactory() { + return this.requestFactory; + } + + + /** + * Create a new {@link ClientHttpRequest} via this template's {@link ClientHttpRequestFactory}. + * @param url the URL to connect to + * @param method the HTTP method to execute (GET, POST, etc) + * @return the created request + * @throws IOException in case of I/O errors + * @see #getRequestFactory() + * @see ClientHttpRequestFactory#createRequest(URI, HttpMethod) + */ + protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException { + ClientHttpRequest request = getRequestFactory().createRequest(url, method); + if (logger.isDebugEnabled()) { + logger.debug("HTTP " + method.name() + " " + url); + } + return request; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/HttpRequestWrapper.java b/spring-web/src/main/java/org/springframework/http/client/support/HttpRequestWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..d3a72f55449f86a104dbf1b74ebe279fd7525430 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/HttpRequestWrapper.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.net.URI; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Provides a convenient implementation of the {@link HttpRequest} interface + * that can be overridden to adapt the request. + * + *

These methods default to calling through to the wrapped request object. + * + * @author Arjen Poutsma + * @since 3.1 + */ +public class HttpRequestWrapper implements HttpRequest { + + private final HttpRequest request; + + + /** + * Create a new {@code HttpRequest} wrapping the given request object. + * @param request the request object to be wrapped + */ + public HttpRequestWrapper(HttpRequest request) { + Assert.notNull(request, "HttpRequest must not be null"); + this.request = request; + } + + + /** + * Return the wrapped request. + */ + public HttpRequest getRequest() { + return this.request; + } + + /** + * Return the method of the wrapped request. + */ + @Override + @Nullable + public HttpMethod getMethod() { + return this.request.getMethod(); + } + + /** + * Return the method value of the wrapped request. + */ + @Override + public String getMethodValue() { + return this.request.getMethodValue(); + } + + /** + * Return the URI of the wrapped request. + */ + @Override + public URI getURI() { + return this.request.getURI(); + } + + /** + * Return the headers of the wrapped request. + */ + @Override + public HttpHeaders getHeaders() { + return this.request.getHeaders(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/InterceptingAsyncHttpAccessor.java b/spring-web/src/main/java/org/springframework/http/client/support/InterceptingAsyncHttpAccessor.java new file mode 100644 index 0000000000000000000000000000000000000000..9f8586ec84e1bf0d8791a22a0153639475336391 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/InterceptingAsyncHttpAccessor.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.util.CollectionUtils; + +/** + * The HTTP accessor that extends the base {@link AsyncHttpAccessor} with + * request intercepting functionality. + * + * @author Jakub Narloch + * @author Rossen Stoyanchev + * @since 4.3 + * @deprecated as of Spring 5.0, with no direct replacement + */ +@Deprecated +public abstract class InterceptingAsyncHttpAccessor extends AsyncHttpAccessor { + + private List interceptors = + new ArrayList<>(); + + + /** + * Set the request interceptors that this accessor should use. + * @param interceptors the list of interceptors + */ + public void setInterceptors(List interceptors) { + this.interceptors = interceptors; + } + + /** + * Return the request interceptor that this accessor uses. + */ + public List getInterceptors() { + return this.interceptors; + } + + + @Override + public org.springframework.http.client.AsyncClientHttpRequestFactory getAsyncRequestFactory() { + org.springframework.http.client.AsyncClientHttpRequestFactory delegate = super.getAsyncRequestFactory(); + if (!CollectionUtils.isEmpty(getInterceptors())) { + return new org.springframework.http.client.InterceptingAsyncClientHttpRequestFactory(delegate, getInterceptors()); + } + else { + return delegate; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/InterceptingHttpAccessor.java b/spring-web/src/main/java/org/springframework/http/client/support/InterceptingHttpAccessor.java new file mode 100644 index 0000000000000000000000000000000000000000..c1b359f38717957e2f57ee375d4c9097167be8f6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/InterceptingHttpAccessor.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.InterceptingClientHttpRequestFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; + +/** + * Base class for {@link org.springframework.web.client.RestTemplate} + * and other HTTP accessing gateway helpers, adding interceptor-related + * properties to {@link HttpAccessor}'s common properties. + * + *

Not intended to be used directly. + * See {@link org.springframework.web.client.RestTemplate} for an entry point. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @see ClientHttpRequestInterceptor + * @see InterceptingClientHttpRequestFactory + * @see org.springframework.web.client.RestTemplate + */ +public abstract class InterceptingHttpAccessor extends HttpAccessor { + + private final List interceptors = new ArrayList<>(); + + @Nullable + private volatile ClientHttpRequestFactory interceptingRequestFactory; + + + /** + * Set the request interceptors that this accessor should use. + *

The interceptors will get sorted according to their order + * once the {@link ClientHttpRequestFactory} will be built. + * @see #getRequestFactory() + * @see AnnotationAwareOrderComparator + */ + public void setInterceptors(List interceptors) { + // Take getInterceptors() List as-is when passed in here + if (this.interceptors != interceptors) { + this.interceptors.clear(); + this.interceptors.addAll(interceptors); + AnnotationAwareOrderComparator.sort(this.interceptors); + } + } + + /** + * Return the request interceptors that this accessor uses. + *

The returned {@link List} is active and may get appended to. + */ + public List getInterceptors() { + return this.interceptors; + } + + /** + * {@inheritDoc} + */ + @Override + public void setRequestFactory(ClientHttpRequestFactory requestFactory) { + super.setRequestFactory(requestFactory); + this.interceptingRequestFactory = null; + } + + /** + * Overridden to expose an {@link InterceptingClientHttpRequestFactory} + * if necessary. + * @see #getInterceptors() + */ + @Override + public ClientHttpRequestFactory getRequestFactory() { + List interceptors = getInterceptors(); + if (!CollectionUtils.isEmpty(interceptors)) { + ClientHttpRequestFactory factory = this.interceptingRequestFactory; + if (factory == null) { + factory = new InterceptingClientHttpRequestFactory(super.getRequestFactory(), interceptors); + this.interceptingRequestFactory = factory; + } + return factory; + } + else { + return super.getRequestFactory(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/ProxyFactoryBean.java b/spring-web/src/main/java/org/springframework/http/client/support/ProxyFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..e1b9e0c8c7fee5dedba9c4b7533373aa65afc534 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/ProxyFactoryBean.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.net.SocketAddress; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link FactoryBean} that creates a {@link Proxy java.net.Proxy}. + * + * @author Arjen Poutsma + * @since 3.0.4 + * @see FactoryBean + * @see Proxy + */ +public class ProxyFactoryBean implements FactoryBean, InitializingBean { + + private Proxy.Type type = Proxy.Type.HTTP; + + @Nullable + private String hostname; + + private int port = -1; + + @Nullable + private Proxy proxy; + + + /** + * Set the proxy type. + *

Defaults to {@link java.net.Proxy.Type#HTTP}. + */ + public void setType(Proxy.Type type) { + this.type = type; + } + + /** + * Set the proxy host name. + */ + public void setHostname(String hostname) { + this.hostname = hostname; + } + + /** + * Set the proxy port. + */ + public void setPort(int port) { + this.port = port; + } + + + @Override + public void afterPropertiesSet() throws IllegalArgumentException { + Assert.notNull(this.type, "Property 'type' is required"); + Assert.notNull(this.hostname, "Property 'hostname' is required"); + if (this.port < 0 || this.port > 65535) { + throw new IllegalArgumentException("Property 'port' value out of range: " + this.port); + } + + SocketAddress socketAddress = new InetSocketAddress(this.hostname, this.port); + this.proxy = new Proxy(this.type, socketAddress); + } + + + @Override + @Nullable + public Proxy getObject() { + return this.proxy; + } + + @Override + public Class getObjectType() { + return Proxy.class; + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/support/package-info.java b/spring-web/src/main/java/org/springframework/http/client/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..0308cb51a3d179374c27561510f8bb90c798ce77 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/support/package-info.java @@ -0,0 +1,10 @@ +/** + * This package provides generic HTTP support classes, + * to be used by higher-level classes like RestTemplate. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.client.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/codec/ClientCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/ClientCodecConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..070b0610ee1fd9e152766d204fa26883367af1ca --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/ClientCodecConfigurer.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; + +/** + * Extension of {@link CodecConfigurer} for HTTP message reader and writer + * options relevant on the client side. + * + *

HTTP message readers for the following are registered by default: + *

    {@code byte[]} + *
  • {@link java.nio.ByteBuffer} + *
  • {@link org.springframework.core.io.buffer.DataBuffer DataBuffer} + *
  • {@link org.springframework.core.io.Resource Resource} + *
  • {@link String} + *
  • {@link org.springframework.util.MultiValueMap + * MultiValueMap<String,String>} for form data + *
  • JSON and Smile, if Jackson is present + *
  • XML, if JAXB2 is present + *
  • Server-Sent Events + *
+ * + *

HTTP message writers registered by default: + *

    {@code byte[]} + *
  • {@link java.nio.ByteBuffer} + *
  • {@link org.springframework.core.io.buffer.DataBuffer DataBuffer} + *
  • {@link org.springframework.core.io.Resource Resource} + *
  • {@link String} + *
  • {@link org.springframework.util.MultiValueMap + * MultiValueMap<String,String>} for form data + *
  • {@link org.springframework.util.MultiValueMap + * MultiValueMap<String,Object>} for multipart data + *
  • JSON and Smile, if Jackson is present + *
  • XML, if JAXB2 is present + *
+ * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface ClientCodecConfigurer extends CodecConfigurer { + + /** + * {@inheritDoc} + *

On the client side, built-in default also include customizations related + * to multipart readers and writers, as well as the decoder for SSE. + */ + @Override + ClientDefaultCodecs defaultCodecs(); + + /** + * {@inheritDoc}. + */ + @Override + ClientCodecConfigurer clone(); + + + /** + * Static factory method for a {@code ClientCodecConfigurer}. + */ + static ClientCodecConfigurer create() { + return CodecConfigurerFactory.create(ClientCodecConfigurer.class); + } + + + /** + * {@link CodecConfigurer.DefaultCodecs} extension with extra client-side options. + */ + interface ClientDefaultCodecs extends DefaultCodecs { + + /** + * Configure encoders or writers for use with + * {@link org.springframework.http.codec.multipart.MultipartHttpMessageWriter + * MultipartHttpMessageWriter}. + */ + MultipartCodecs multipartCodecs(); + + /** + * Configure the {@code Decoder} to use for Server-Sent Events. + *

By default if this is not set, and Jackson is available, the + * {@link #jackson2JsonDecoder} override is used instead. Use this property + * if you want to further customize the SSE decoder. + *

Note that {@link #maxInMemorySize(int)}, if configured, will be + * applied to the given decoder. + * @param decoder the decoder to use + */ + void serverSentEventDecoder(Decoder decoder); + } + + + /** + * Registry and container for multipart HTTP message writers. + */ + interface MultipartCodecs { + + /** + * Add a Part {@code Encoder}, internally wrapped with + * {@link EncoderHttpMessageWriter}. + * @param encoder the encoder to add + */ + MultipartCodecs encoder(Encoder encoder); + + /** + * Add a Part {@link HttpMessageWriter}. For writers of type + * {@link EncoderHttpMessageWriter} consider using the shortcut + * {@link #encoder(Encoder)} instead. + * @param writer the writer to add + */ + MultipartCodecs writer(HttpMessageWriter writer); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..8d387d198589404d8ed62b6f71a7fe87bff1c9e9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java @@ -0,0 +1,318 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.List; +import java.util.function.Consumer; + +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.lang.Nullable; + +/** + * Defines a common interface for configuring either client or server HTTP + * message readers and writers. This is used as follows: + *

    + *
  • Use {@link ClientCodecConfigurer#create()} or + * {@link ServerCodecConfigurer#create()} to create an instance. + *
  • Use {@link #defaultCodecs()} to customize HTTP message readers or writers + * registered by default. + *
  • Use {@link #customCodecs()} to add custom HTTP message readers or writers. + *
  • Use {@link #getReaders()} and {@link #getWriters()} to obtain the list of + * configured HTTP message readers and writers. + *
+ * + *

HTTP message readers and writers are divided into 3 categories that are + * ordered as follows: + *

    + *
  1. Typed readers and writers that support specific types, e.g. byte[], String. + *
  2. Object readers and writers, e.g. JSON, XML. + *
  3. Catch-all readers or writers, e.g. String with any media type. + *
+ * + *

Typed and object readers are further sub-divided and ordered as follows: + *

    + *
  1. Default HTTP reader and writer registrations. + *
  2. Custom readers and writers. + *
+ * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface CodecConfigurer { + + /** + * Provides a way to customize or replace HTTP message readers and writers + * registered by default. + * @see #registerDefaults(boolean) + */ + DefaultCodecs defaultCodecs(); + + /** + * Register custom HTTP message readers or writers in addition to the ones + * registered by default. + */ + CustomCodecs customCodecs(); + + /** + * Provides a way to completely turn off registration of default HTTP message + * readers and writers, and instead rely only on the ones provided via + * {@link #customCodecs()}. + *

By default this is set to {@code "true"} in which case default + * registrations are made; setting this to {@code false} disables default + * registrations. + */ + void registerDefaults(boolean registerDefaults); + + + /** + * Obtain the configured HTTP message readers. + */ + List> getReaders(); + + /** + * Obtain the configured HTTP message writers. + */ + List> getWriters(); + + /** + * Create a copy of this {@link CodecConfigurer}. The returned clone has its + * own lists of default and custom codecs and generally can be configured + * independently. Keep in mind however that codec instances (if any are + * configured) are themselves not cloned. + * @since 5.1.12 + */ + CodecConfigurer clone(); + + + /** + * Customize or replace the HTTP message readers and writers registered by + * default. The options are further extended by + * {@link ClientCodecConfigurer.ClientDefaultCodecs ClientDefaultCodecs} and + * {@link ServerCodecConfigurer.ServerDefaultCodecs ServerDefaultCodecs}. + */ + interface DefaultCodecs { + + /** + * Override the default Jackson JSON {@code Decoder}. + *

Note that {@link #maxInMemorySize(int)}, if configured, will be + * applied to the given decoder. + * @param decoder the decoder instance to use + * @see org.springframework.http.codec.json.Jackson2JsonDecoder + */ + void jackson2JsonDecoder(Decoder decoder); + + /** + * Override the default Jackson JSON {@code Encoder}. + * @param encoder the encoder instance to use + * @see org.springframework.http.codec.json.Jackson2JsonEncoder + */ + void jackson2JsonEncoder(Encoder encoder); + + /** + * Override the default Protobuf {@code Decoder}. + *

Note that {@link #maxInMemorySize(int)}, if configured, will be + * applied to the given decoder. + * @param decoder the decoder instance to use + * @since 5.1 + * @see org.springframework.http.codec.protobuf.ProtobufDecoder + */ + void protobufDecoder(Decoder decoder); + + /** + * Override the default Protobuf {@code Encoder}. + * @param encoder the encoder instance to use + * @since 5.1 + * @see org.springframework.http.codec.protobuf.ProtobufEncoder + * @see org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter + */ + void protobufEncoder(Encoder encoder); + + /** + * Override the default JAXB2 {@code Decoder}. + *

Note that {@link #maxInMemorySize(int)}, if configured, will be + * applied to the given decoder. + * @param decoder the decoder instance to use + * @since 5.1.3 + * @see org.springframework.http.codec.xml.Jaxb2XmlDecoder + */ + void jaxb2Decoder(Decoder decoder); + + /** + * Override the default JABX2 {@code Encoder}. + * @param encoder the encoder instance to use + * @since 5.1.3 + * @see org.springframework.http.codec.xml.Jaxb2XmlEncoder + */ + void jaxb2Encoder(Encoder encoder); + + /** + * Configure a limit on the number of bytes that can be buffered whenever + * the input stream needs to be aggregated. This can be a result of + * decoding to a single {@code DataBuffer}, + * {@link java.nio.ByteBuffer ByteBuffer}, {@code byte[]}, + * {@link org.springframework.core.io.Resource Resource}, {@code String}, etc. + * It can also occur when splitting the input stream, e.g. delimited text, + * in which case the limit applies to data buffered between delimiters. + *

By default this is not set, in which case individual codec defaults + * apply. In 5.1 most codecs are not limited except {@code FormHttpMessageReader} + * which is limited to 256K. In 5.2 all codecs are limited to 256K by default. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + void maxInMemorySize(int byteCount); + + /** + * Whether to log form data at DEBUG level, and headers at TRACE level. + * Both may contain sensitive information. + *

By default set to {@code false} so that request details are not shown. + * @param enable whether to enable or not + * @since 5.1 + */ + void enableLoggingRequestDetails(boolean enable); + } + + + /** + * Registry for custom HTTP message readers and writers. + */ + interface CustomCodecs { + + /** + * Register a custom codec. This is expected to be one of the following: + *

    + *
  • {@link HttpMessageReader} + *
  • {@link HttpMessageWriter} + *
  • {@link Encoder} (wrapped internally with {@link EncoderHttpMessageWriter}) + *
  • {@link Decoder} (wrapped internally with {@link DecoderHttpMessageReader}) + *
+ * @param codec the codec to register + * @since 5.1.13 + */ + void register(Object codec); + + /** + * Variant of {@link #register(Object)} that also applies the below + * properties, if configured, via {@link #defaultCodecs()}: + *
    + *
  • {@link CodecConfigurer.DefaultCodecs#maxInMemorySize(int) maxInMemorySize} + *
  • {@link CodecConfigurer.DefaultCodecs#enableLoggingRequestDetails(boolean) enableLoggingRequestDetails} + *
+ *

The properties are applied every time {@link #getReaders()} or + * {@link #getWriters()} are used to obtain the list of configured + * readers or writers. + * @param codec the codec to register and apply default config to + * @since 5.1.13 + */ + void registerWithDefaultConfig(Object codec); + + /** + * Variant of {@link #register(Object)} that also allows the caller to + * apply the properties from {@link DefaultCodecConfig} to the given + * codec. If you want to apply all the properties, prefer using + * {@link #registerWithDefaultConfig(Object)}. + *

The consumer is called every time {@link #getReaders()} or + * {@link #getWriters()} are used to obtain the list of configured + * readers or writers. + * @param codec the codec to register + * @param configConsumer consumer of the default config + * @since 5.1.13 + */ + void registerWithDefaultConfig(Object codec, Consumer configConsumer); + + /** + * Add a custom {@code Decoder} internally wrapped with + * {@link DecoderHttpMessageReader}). + * @param decoder the decoder to add + * @deprecated as of 5.1.13, use {@link #register(Object)} or + * {@link #registerWithDefaultConfig(Object)} instead. + */ + @Deprecated + void decoder(Decoder decoder); + + /** + * Add a custom {@code Encoder}, internally wrapped with + * {@link EncoderHttpMessageWriter}. + * @param encoder the encoder to add + * @deprecated as of 5.1.13, use {@link #register(Object)} or + * {@link #registerWithDefaultConfig(Object)} instead. + */ + @Deprecated + void encoder(Encoder encoder); + + /** + * Add a custom {@link HttpMessageReader}. For readers of type + * {@link DecoderHttpMessageReader} consider using the shortcut + * {@link #decoder(Decoder)} instead. + * @param reader the reader to add + * @deprecated as of 5.1.13, use {@link #register(Object)} or + * {@link #registerWithDefaultConfig(Object)} instead. + */ + @Deprecated + void reader(HttpMessageReader reader); + + /** + * Add a custom {@link HttpMessageWriter}. For writers of type + * {@link EncoderHttpMessageWriter} consider using the shortcut + * {@link #encoder(Encoder)} instead. + * @param writer the writer to add + * @deprecated as of 5.1.13, use {@link #register(Object)} or + * {@link #registerWithDefaultConfig(Object)} instead. + */ + @Deprecated + void writer(HttpMessageWriter writer); + + /** + * Register a callback for the {@link DefaultCodecConfig configuration} + * applied to default codecs. This allows custom codecs to follow general + * guidelines applied to default ones, such as logging details and limiting + * the amount of buffered data. + * @param codecsConfigConsumer the default codecs configuration callback + * @deprecated as of 5.1.13, use {@link #registerWithDefaultConfig(Object)} + * or {@link #registerWithDefaultConfig(Object, Consumer)} instead. + */ + @Deprecated + void withDefaultCodecConfig(Consumer codecsConfigConsumer); + } + + + /** + * Exposes the values of properties configured through + * {@link #defaultCodecs()} that are applied to default codecs. + * The main purpose of this interface is to provide access to them so they + * can also be applied to custom codecs if needed. + * @since 5.1.12 + * @see CustomCodecs#registerWithDefaultConfig(Object, Consumer) + */ + interface DefaultCodecConfig { + + /** + * Get the configured limit on the number of bytes that can be buffered whenever + * the input stream needs to be aggregated. + */ + @Nullable + Integer maxInMemorySize(); + + /** + * Whether to log form data at DEBUG level, and headers at TRACE level. + * Both may contain sensitive information. + */ + @Nullable + Boolean isEnableLoggingRequestDetails(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurerFactory.java b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurerFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..9280593588d164d953ed36e048d6b591d2913b41 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurerFactory.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.springframework.beans.BeanUtils; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.support.PropertiesLoaderUtils; +import org.springframework.util.ClassUtils; + +/** + * Internal delegate for loading the default codec configurer class names. + * Models a loose relationship with the default implementations in the support + * package, literally only needing to know the default class name to use. + * + * @author Juergen Hoeller + * @since 5.0.1 + * @see ClientCodecConfigurer#create() + * @see ServerCodecConfigurer#create() + */ +final class CodecConfigurerFactory { + + private static final String DEFAULT_CONFIGURERS_PATH = "CodecConfigurer.properties"; + + private static final Map, Class> defaultCodecConfigurers = new HashMap<>(4); + + static { + try { + Properties props = PropertiesLoaderUtils.loadProperties( + new ClassPathResource(DEFAULT_CONFIGURERS_PATH, CodecConfigurerFactory.class)); + for (String ifcName : props.stringPropertyNames()) { + String implName = props.getProperty(ifcName); + Class ifc = ClassUtils.forName(ifcName, CodecConfigurerFactory.class.getClassLoader()); + Class impl = ClassUtils.forName(implName, CodecConfigurerFactory.class.getClassLoader()); + defaultCodecConfigurers.put(ifc, impl); + } + } + catch (IOException | ClassNotFoundException ex) { + throw new IllegalStateException(ex); + } + } + + + private CodecConfigurerFactory() { + } + + + @SuppressWarnings("unchecked") + public static T create(Class ifc) { + Class impl = defaultCodecConfigurers.get(ifc); + if (impl == null) { + throw new IllegalStateException("No default codec configurer found for " + ifc); + } + return (T) BeanUtils.instantiateClass(impl); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/DecoderHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/DecoderHttpMessageReader.java new file mode 100644 index 0000000000000000000000000000000000000000..b5d689d2e60558f02c479d92793aaeaa2fb5e88a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/DecoderHttpMessageReader.java @@ -0,0 +1,159 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDecoder; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Hints; +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpMessage; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@code HttpMessageReader} that wraps and delegates to a {@link Decoder}. + * + *

Also a {@code HttpMessageReader} that pre-resolves decoding hints + * from the extra information available on the server side such as the request + * or controller method parameter annotations. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @param the type of objects in the decoded output stream + */ +public class DecoderHttpMessageReader implements HttpMessageReader { + + private final Decoder decoder; + + private final List mediaTypes; + + + /** + * Create an instance wrapping the given {@link Decoder}. + */ + public DecoderHttpMessageReader(Decoder decoder) { + Assert.notNull(decoder, "Decoder is required"); + initLogger(decoder); + this.decoder = decoder; + this.mediaTypes = MediaType.asMediaTypes(decoder.getDecodableMimeTypes()); + } + + private static void initLogger(Decoder decoder) { + if (decoder instanceof AbstractDecoder && + decoder.getClass().getName().startsWith("org.springframework.core.codec")) { + Log logger = HttpLogging.forLog(((AbstractDecoder) decoder).getLogger()); + ((AbstractDecoder) decoder).setLogger(logger); + } + } + + + /** + * Return the {@link Decoder} of this reader. + */ + public Decoder getDecoder() { + return this.decoder; + } + + @Override + public List getReadableMediaTypes() { + return this.mediaTypes; + } + + + @Override + public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { + return this.decoder.canDecode(elementType, mediaType); + } + + @Override + public Flux read(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + MediaType contentType = getContentType(message); + return this.decoder.decode(message.getBody(), elementType, contentType, hints); + } + + @Override + public Mono readMono(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + MediaType contentType = getContentType(message); + return this.decoder.decodeToMono(message.getBody(), elementType, contentType, hints); + } + + /** + * Determine the Content-Type of the HTTP message based on the + * "Content-Type" header or otherwise default to + * {@link MediaType#APPLICATION_OCTET_STREAM}. + * @param inputMessage the HTTP message + * @return the MediaType, possibly {@code null}. + */ + @Nullable + protected MediaType getContentType(HttpMessage inputMessage) { + MediaType contentType = inputMessage.getHeaders().getContentType(); + return (contentType != null ? contentType : MediaType.APPLICATION_OCTET_STREAM); + } + + + // Server-side only... + + @Override + public Flux read(ResolvableType actualType, ResolvableType elementType, + ServerHttpRequest request, ServerHttpResponse response, Map hints) { + + Map allHints = Hints.merge(hints, + getReadHints(actualType, elementType, request, response)); + + return read(elementType, request, allHints); + } + + @Override + public Mono readMono(ResolvableType actualType, ResolvableType elementType, + ServerHttpRequest request, ServerHttpResponse response, Map hints) { + + Map allHints = Hints.merge(hints, + getReadHints(actualType, elementType, request, response)); + + return readMono(elementType, request, allHints); + } + + /** + * Get additional hints for decoding for example based on the server request + * or annotations from controller method parameters. By default, delegate to + * the decoder if it is an instance of {@link HttpMessageDecoder}. + */ + protected Map getReadHints(ResolvableType actualType, + ResolvableType elementType, ServerHttpRequest request, ServerHttpResponse response) { + + if (this.decoder instanceof HttpMessageDecoder) { + HttpMessageDecoder decoder = (HttpMessageDecoder) this.decoder; + return decoder.getDecodeHints(actualType, elementType, request, response); + } + return Hints.none(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..12adac50997c56615278a670e1bac561772d0656 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractEncoder; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.http.HttpLogging; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * {@code HttpMessageWriter} that wraps and delegates to an {@link Encoder}. + * + *

Also a {@code HttpMessageWriter} that pre-resolves encoding hints + * from the extra information available on the server side such as the request + * or controller method annotations. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Brian Clozel + * @author Sam Brannen + * @since 5.0 + * @param the type of objects in the input stream + */ +public class EncoderHttpMessageWriter implements HttpMessageWriter { + + private final Encoder encoder; + + private final List mediaTypes; + + @Nullable + private final MediaType defaultMediaType; + + + /** + * Create an instance wrapping the given {@link Encoder}. + */ + public EncoderHttpMessageWriter(Encoder encoder) { + Assert.notNull(encoder, "Encoder is required"); + initLogger(encoder); + this.encoder = encoder; + this.mediaTypes = MediaType.asMediaTypes(encoder.getEncodableMimeTypes()); + this.defaultMediaType = initDefaultMediaType(this.mediaTypes); + } + + private static void initLogger(Encoder encoder) { + if (encoder instanceof AbstractEncoder && + encoder.getClass().getName().startsWith("org.springframework.core.codec")) { + Log logger = HttpLogging.forLog(((AbstractEncoder) encoder).getLogger()); + ((AbstractEncoder) encoder).setLogger(logger); + } + } + + @Nullable + private static MediaType initDefaultMediaType(List mediaTypes) { + return mediaTypes.stream().filter(MediaType::isConcrete).findFirst().orElse(null); + } + + + /** + * Return the {@code Encoder} of this writer. + */ + public Encoder getEncoder() { + return this.encoder; + } + + @Override + public List getWritableMediaTypes() { + return this.mediaTypes; + } + + + @Override + public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) { + return this.encoder.canEncode(elementType, mediaType); + } + + @Override + public Mono write(Publisher inputStream, ResolvableType elementType, + @Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map hints) { + + MediaType contentType = updateContentType(message, mediaType); + + Flux body = this.encoder.encode( + inputStream, message.bufferFactory(), elementType, contentType, hints); + + if (inputStream instanceof Mono) { + return body + .singleOrEmpty() + .switchIfEmpty(Mono.defer(() -> { + message.getHeaders().setContentLength(0); + return message.setComplete().then(Mono.empty()); + })) + .flatMap(buffer -> { + message.getHeaders().setContentLength(buffer.readableByteCount()); + return message.writeWith(Mono.just(buffer) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)); + }) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); + } + + if (isStreamingMediaType(contentType)) { + return message.writeAndFlushWith(body.map(buffer -> + Mono.just(buffer).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release))); + } + + return message.writeWith(body); + } + + @Nullable + private MediaType updateContentType(ReactiveHttpOutputMessage message, @Nullable MediaType mediaType) { + MediaType result = message.getHeaders().getContentType(); + if (result != null) { + return result; + } + MediaType fallback = this.defaultMediaType; + result = (useFallback(mediaType, fallback) ? fallback : mediaType); + if (result != null) { + result = addDefaultCharset(result, fallback); + message.getHeaders().setContentType(result); + } + return result; + } + + private static boolean useFallback(@Nullable MediaType main, @Nullable MediaType fallback) { + return (main == null || !main.isConcrete() || + main.equals(MediaType.APPLICATION_OCTET_STREAM) && fallback != null); + } + + private static MediaType addDefaultCharset(MediaType main, @Nullable MediaType defaultType) { + if (main.getCharset() == null && defaultType != null && defaultType.getCharset() != null) { + return new MediaType(main, defaultType.getCharset()); + } + return main; + } + + private boolean isStreamingMediaType(@Nullable MediaType mediaType) { + if (mediaType == null || !(this.encoder instanceof HttpMessageEncoder)) { + return false; + } + for (MediaType streamingMediaType : ((HttpMessageEncoder) this.encoder).getStreamingMediaTypes()) { + if (mediaType.isCompatibleWith(streamingMediaType) && matchParameters(mediaType, streamingMediaType)) { + return true; + } + } + return false; + } + + private boolean matchParameters(MediaType streamingMediaType, MediaType mediaType) { + for (String name : streamingMediaType.getParameters().keySet()) { + String s1 = streamingMediaType.getParameter(name); + String s2 = mediaType.getParameter(name); + if (StringUtils.hasText(s1) && StringUtils.hasText(s2) && !s1.equalsIgnoreCase(s2)) { + return false; + } + } + return true; + } + + + // Server side only... + + @Override + public Mono write(Publisher inputStream, ResolvableType actualType, + ResolvableType elementType, @Nullable MediaType mediaType, ServerHttpRequest request, + ServerHttpResponse response, Map hints) { + + Map allHints = Hints.merge(hints, + getWriteHints(actualType, elementType, mediaType, request, response)); + + return write(inputStream, elementType, mediaType, response, allHints); + } + + /** + * Get additional hints for encoding for example based on the server request + * or annotations from controller method parameters. By default, delegate to + * the encoder if it is an instance of {@link HttpMessageEncoder}. + */ + protected Map getWriteHints(ResolvableType streamType, ResolvableType elementType, + @Nullable MediaType mediaType, ServerHttpRequest request, ServerHttpResponse response) { + + if (this.encoder instanceof HttpMessageEncoder) { + HttpMessageEncoder encoder = (HttpMessageEncoder) this.encoder; + return encoder.getEncodeHints(streamType, elementType, mediaType, request, response); + } + return Hints.none(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java new file mode 100644 index 0000000000000000000000000000000000000000..39c75a8578b08c13a771a7ab001782b2a72f1cd9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Implementation of an {@link HttpMessageReader} to read HTML form data, i.e. + * request body with media type {@code "application/x-www-form-urlencoded"}. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class FormHttpMessageReader extends LoggingCodecSupport + implements HttpMessageReader> { + + /** + * The default charset used by the reader. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + private static final ResolvableType MULTIVALUE_STRINGS_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + + + private Charset defaultCharset = DEFAULT_CHARSET; + + private int maxInMemorySize = 256 * 1024; + + + /** + * Set the default character set to use for reading form data when the + * request Content-Type header does not explicitly specify it. + *

By default this is set to "UTF-8". + */ + public void setDefaultCharset(Charset charset) { + Assert.notNull(charset, "Charset must not be null"); + this.defaultCharset = charset; + } + + /** + * Return the configured default charset. + */ + public Charset getDefaultCharset() { + return this.defaultCharset; + } + + /** + * Set the max number of bytes for input form data. As form data is buffered + * before it is parsed, this helps to limit the amount of buffering. Once + * the limit is exceeded, {@link DataBufferLimitException} is raised. + *

By default this is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + + @Override + public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { + boolean multiValueUnresolved = + elementType.hasUnresolvableGenerics() && + MultiValueMap.class.isAssignableFrom(elementType.toClass()); + + return ((MULTIVALUE_STRINGS_TYPE.isAssignableFrom(elementType) || multiValueUnresolved) && + (mediaType == null || MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(mediaType))); + } + + @Override + public Flux> read(ResolvableType elementType, + ReactiveHttpInputMessage message, Map hints) { + + return Flux.from(readMono(elementType, message, hints)); + } + + @Override + public Mono> readMono(ResolvableType elementType, + ReactiveHttpInputMessage message, Map hints) { + + MediaType contentType = message.getHeaders().getContentType(); + Charset charset = getMediaTypeCharset(contentType); + + return DataBufferUtils.join(message.getBody(), getMaxInMemorySize()) + .map(buffer -> { + CharBuffer charBuffer = charset.decode(buffer.asByteBuffer()); + String body = charBuffer.toString(); + DataBufferUtils.release(buffer); + MultiValueMap formData = parseFormData(charset, body); + logFormData(formData, hints); + return formData; + }); + } + + private void logFormData(MultiValueMap formData, Map hints) { + LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Read " + + (isEnableLoggingRequestDetails() ? + LogFormatUtils.formatValue(formData, !traceOn) : + "form fields " + formData.keySet() + " (content masked)")); + } + + private Charset getMediaTypeCharset(@Nullable MediaType mediaType) { + if (mediaType != null && mediaType.getCharset() != null) { + return mediaType.getCharset(); + } + else { + return getDefaultCharset(); + } + } + + private MultiValueMap parseFormData(Charset charset, String body) { + String[] pairs = StringUtils.tokenizeToStringArray(body, "&"); + MultiValueMap result = new LinkedMultiValueMap<>(pairs.length); + try { + for (String pair : pairs) { + int idx = pair.indexOf('='); + if (idx == -1) { + result.add(URLDecoder.decode(pair, charset.name()), null); + } + else { + String name = URLDecoder.decode(pair.substring(0, idx), charset.name()); + String value = URLDecoder.decode(pair.substring(idx + 1), charset.name()); + result.add(name, value); + } + } + } + catch (UnsupportedEncodingException ex) { + throw new IllegalStateException(ex); + } + return result; + } + + @Override + public List getReadableMediaTypes() { + return Collections.singletonList(MediaType.APPLICATION_FORM_URLENCODED); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..c982dab0b7fe21d20f6f0dcf4127591b283f2dc8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageWriter.java @@ -0,0 +1,183 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * {@link HttpMessageWriter} for writing a {@code MultiValueMap} + * as HTML form data, i.e. {@code "application/x-www-form-urlencoded"}, to the + * body of a request. + * + *

Note that unless the media type is explicitly set to + * {@link MediaType#APPLICATION_FORM_URLENCODED}, the {@link #canWrite} method + * will need generic type information to confirm the target map has String values. + * This is because a MultiValueMap with non-String values can be used to write + * multipart requests. + * + *

To support both form data and multipart requests, consider using + * {@link org.springframework.http.codec.multipart.MultipartHttpMessageWriter} + * configured with this writer as the fallback for writing plain form data. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see org.springframework.http.codec.multipart.MultipartHttpMessageWriter + */ +public class FormHttpMessageWriter extends LoggingCodecSupport + implements HttpMessageWriter> { + + /** + * The default charset used by the writer. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + private static final MediaType DEFAULT_FORM_DATA_MEDIA_TYPE = + new MediaType(MediaType.APPLICATION_FORM_URLENCODED, DEFAULT_CHARSET); + + private static final List MEDIA_TYPES = + Collections.singletonList(MediaType.APPLICATION_FORM_URLENCODED); + + private static final ResolvableType MULTIVALUE_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + + + private Charset defaultCharset = DEFAULT_CHARSET; + + + /** + * Set the default character set to use for writing form data when the response + * Content-Type header does not explicitly specify it. + *

By default this is set to "UTF-8". + */ + public void setDefaultCharset(Charset charset) { + Assert.notNull(charset, "Charset must not be null"); + this.defaultCharset = charset; + } + + /** + * Return the configured default charset. + */ + public Charset getDefaultCharset() { + return this.defaultCharset; + } + + + @Override + public List getWritableMediaTypes() { + return MEDIA_TYPES; + } + + + @Override + public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) { + if (!MultiValueMap.class.isAssignableFrom(elementType.toClass())) { + return false; + } + if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(mediaType)) { + // Optimistically, any MultiValueMap with or without generics + return true; + } + if (mediaType == null) { + // Only String-based MultiValueMap + return MULTIVALUE_TYPE.isAssignableFrom(elementType); + } + return false; + } + + @Override + public Mono write(Publisher> inputStream, + ResolvableType elementType, @Nullable MediaType mediaType, ReactiveHttpOutputMessage message, + Map hints) { + + mediaType = getMediaType(mediaType); + message.getHeaders().setContentType(mediaType); + + Charset charset = mediaType.getCharset(); + Assert.notNull(charset, "No charset"); // should never occur + + return Mono.from(inputStream).flatMap(form -> { + logFormData(form, hints); + String value = serializeForm(form, charset); + ByteBuffer byteBuffer = charset.encode(value); + DataBuffer buffer = message.bufferFactory().wrap(byteBuffer); // wrapping only, no allocation + message.getHeaders().setContentLength(byteBuffer.remaining()); + return message.writeWith(Mono.just(buffer)); + }); + } + + private MediaType getMediaType(@Nullable MediaType mediaType) { + if (mediaType == null) { + return DEFAULT_FORM_DATA_MEDIA_TYPE; + } + else if (mediaType.getCharset() == null) { + return new MediaType(mediaType, getDefaultCharset()); + } + else { + return mediaType; + } + } + + private void logFormData(MultiValueMap form, Map hints) { + LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Writing " + + (isEnableLoggingRequestDetails() ? + LogFormatUtils.formatValue(form, !traceOn) : + "form fields " + form.keySet() + " (content masked)")); + } + + protected String serializeForm(MultiValueMap formData, Charset charset) { + StringBuilder builder = new StringBuilder(); + formData.forEach((name, values) -> + values.forEach(value -> { + try { + if (builder.length() != 0) { + builder.append('&'); + } + builder.append(URLEncoder.encode(name, charset.name())); + if (value != null) { + builder.append('='); + builder.append(URLEncoder.encode(value, charset.name())); + } + } + catch (UnsupportedEncodingException ex) { + throw new IllegalStateException(ex); + } + })); + return builder.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/HttpMessageDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..7afa500ee8ca24e1dfc180c6b5621306403054e6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageDecoder.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.Map; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Decoder; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; + +/** + * Extension of {@code Decoder} exposing extra methods relevant in the context + * of HTTP request or response body decoding. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @param the type of elements in the output stream + */ +public interface HttpMessageDecoder extends Decoder { + + /** + * Get decoding hints based on the server request or annotations on the + * target controller method parameter. + * @param actualType the actual target type to decode to, possibly a reactive + * wrapper and sourced from {@link org.springframework.core.MethodParameter}, + * i.e. providing access to method parameter annotations + * @param elementType the element type within {@code Flux/Mono} that we're + * trying to decode to + * @param request the current request + * @param response the current response + * @return a Map with hints, possibly empty + */ + Map getDecodeHints(ResolvableType actualType, ResolvableType elementType, + ServerHttpRequest request, ServerHttpResponse response); + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/HttpMessageEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageEncoder.java new file mode 100644 index 0000000000000000000000000000000000000000..b968236a9c9a5777325d58c47c9f4b9fe9be3d1a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageEncoder.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.List; +import java.util.Map; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.Hints; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; + +/** + * Extension of {@code Encoder} exposing extra methods relevant in the context + * of HTTP request or response body encoding. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @param the type of elements in the input stream + */ +public interface HttpMessageEncoder extends Encoder { + + /** + * Return "streaming" media types for which flushing should be performed + * automatically vs at the end of the input stream. + */ + List getStreamingMediaTypes(); + + /** + * Get decoding hints based on the server request or annotations on the + * target controller method parameter. + * @param actualType the actual source type to encode, possibly a reactive + * wrapper and sourced from {@link org.springframework.core.MethodParameter}, + * i.e. providing access to method annotations. + * @param elementType the element type within {@code Flux/Mono} that we're + * trying to encode. + * @param request the current request + * @param response the current response + * @return a Map with hints, possibly empty + */ + default Map getEncodeHints(ResolvableType actualType, ResolvableType elementType, + @Nullable MediaType mediaType, ServerHttpRequest request, ServerHttpResponse response) { + + return Hints.none(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/HttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageReader.java new file mode 100644 index 0000000000000000000000000000000000000000..ed99da6c9ab756780c679f9669f29e7390edb32c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageReader.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; + +/** + * Strategy for reading from a {@link ReactiveHttpInputMessage} and decoding + * the stream of bytes to Objects of type {@code }. + * + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @since 5.0 + * @param the type of objects in the decoded output stream + */ +public interface HttpMessageReader { + + /** + * Return the {@link MediaType}'s that this reader supports. + */ + List getReadableMediaTypes(); + + /** + * Whether the given object type is supported by this reader. + * @param elementType the type of object to check + * @param mediaType the media type for the read (possibly {@code null}) + * @return {@code true} if readable, {@code false} otherwise + */ + boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType); + + /** + * Read from the input message and encode to a stream of objects. + * @param elementType the type of objects in the stream which must have been + * previously checked via {@link #canRead(ResolvableType, MediaType)} + * @param message the message to read from + * @param hints additional information about how to read and decode the input + * @return the decoded stream of elements + */ + Flux read(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints); + + /** + * Read from the input message and encode to a single object. + * @param elementType the type of objects in the stream which must have been + * previously checked via {@link #canRead(ResolvableType, MediaType)} + * @param message the message to read from + * @param hints additional information about how to read and decode the input + * @return the decoded object + */ + Mono readMono(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints); + + /** + * Server-side only alternative to + * {@link #read(ResolvableType, ReactiveHttpInputMessage, Map)} + * with additional context available. + * @param actualType the actual type of the target method parameter; + * for annotated controllers, the {@link MethodParameter} can be accessed + * via {@link ResolvableType#getSource()}. + * @param elementType the type of Objects in the output stream + * @param request the current request + * @param response the current response + * @param hints additional information about how to read the body + * @return the decoded stream of elements + */ + default Flux read(ResolvableType actualType, ResolvableType elementType, ServerHttpRequest request, + ServerHttpResponse response, Map hints) { + + return read(elementType, request, hints); + } + + /** + * Server-side only alternative to + * {@link #readMono(ResolvableType, ReactiveHttpInputMessage, Map)} + * with additional, context available. + * @param actualType the actual type of the target method parameter; + * for annotated controllers, the {@link MethodParameter} can be accessed + * via {@link ResolvableType#getSource()}. + * @param elementType the type of Objects in the output stream + * @param request the current request + * @param response the current response + * @param hints additional information about how to read the body + * @return the decoded stream of elements + */ + default Mono readMono(ResolvableType actualType, ResolvableType elementType, ServerHttpRequest request, + ServerHttpResponse response, Map hints) { + + return readMono(elementType, request, hints); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/HttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..0b1b94b714d5603737b25ff693bc6205ad1b0855 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/HttpMessageWriter.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.List; +import java.util.Map; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; + +/** + * Strategy for encoding a stream of objects of type {@code } and writing + * the encoded stream of bytes to an {@link ReactiveHttpOutputMessage}. + * + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @since 5.0 + * @param the type of objects in the input stream + */ +public interface HttpMessageWriter { + + /** + * Return the {@link MediaType}'s that this writer supports. + */ + List getWritableMediaTypes(); + + /** + * Whether the given object type is supported by this writer. + * @param elementType the type of object to check + * @param mediaType the media type for the write (possibly {@code null}) + * @return {@code true} if writable, {@code false} otherwise + */ + boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType); + + /** + * Write an given stream of object to the output message. + * @param inputStream the objects to write + * @param elementType the type of objects in the stream which must have been + * previously checked via {@link #canWrite(ResolvableType, MediaType)} + * @param mediaType the content type for the write (possibly {@code null} to + * indicate that the default content type of the writer must be used) + * @param message the message to write to + * @param hints additional information about how to encode and write + * @return indicates completion or error + */ + Mono write(Publisher inputStream, ResolvableType elementType, + @Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map hints); + + /** + * Server-side only alternative to + * {@link #write(Publisher, ResolvableType, MediaType, ReactiveHttpOutputMessage, Map)} + * with additional context available. + * @param actualType the actual return type of the method that returned the + * value; for annotated controllers, the {@link MethodParameter} can be + * accessed via {@link ResolvableType#getSource()}. + * @param elementType the type of Objects in the input stream + * @param mediaType the content type to use (possibly {@code null} indicating + * the default content type of the writer should be used) + * @param request the current request + * @param response the current response + * @return a {@link Mono} that indicates completion of writing or error + */ + default Mono write(Publisher inputStream, ResolvableType actualType, + ResolvableType elementType, @Nullable MediaType mediaType, ServerHttpRequest request, + ServerHttpResponse response, Map hints) { + + return write(inputStream, elementType, mediaType, response, hints); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/LoggingCodecSupport.java b/spring-web/src/main/java/org/springframework/http/codec/LoggingCodecSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..f2e73480c83241a3f0179614a5919b1873f4a2a8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/LoggingCodecSupport.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import org.apache.commons.logging.Log; + +import org.springframework.http.HttpLogging; + +/** + * Base class for {@link org.springframework.core.codec.Encoder}, + * {@link org.springframework.core.codec.Decoder}, {@link HttpMessageReader}, or + * {@link HttpMessageWriter} that uses a logger and shows potentially sensitive + * request data. + * + * @author Rossen Stoyanchev + * @since 5.1 + */ +public class LoggingCodecSupport { + + protected final Log logger = HttpLogging.forLogName(getClass()); + + /** Whether to log potentially sensitive info (form data at DEBUG and headers at TRACE). */ + private boolean enableLoggingRequestDetails = false; + + + /** + * Whether to log form data at DEBUG level, and headers at TRACE level. + * Both may contain sensitive information. + *

By default set to {@code false} so that request details are not shown. + * @param enable whether to enable or not + */ + public void setEnableLoggingRequestDetails(boolean enable) { + this.enableLoggingRequestDetails = enable; + } + + /** + * Whether any logging of values being encoded or decoded is explicitly + * disabled regardless of log level. + */ + public boolean isEnableLoggingRequestDetails() { + return this.enableLoggingRequestDetails; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/ResourceHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/ResourceHttpMessageWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..4376d39d06c6485c6a2a1811abb6643ed1ded5a9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/ResourceHttpMessageWriter.java @@ -0,0 +1,255 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.core.codec.ResourceDecoder; +import org.springframework.core.codec.ResourceEncoder; +import org.springframework.core.codec.ResourceRegionEncoder; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.support.ResourceRegion; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpRange; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.ZeroCopyHttpOutputMessage; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeTypeUtils; + +/** + * {@code HttpMessageWriter} that can write a {@link Resource}. + * + *

Also an implementation of {@code HttpMessageWriter} with support for writing one + * or more {@link ResourceRegion}'s based on the HTTP ranges specified in the request. + * + *

For reading to a Resource, use {@link ResourceDecoder} wrapped with + * {@link DecoderHttpMessageReader}. + * + * @author Arjen Poutsma + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 5.0 + * @see ResourceEncoder + * @see ResourceRegionEncoder + * @see HttpRange + */ +public class ResourceHttpMessageWriter implements HttpMessageWriter { + + private static final ResolvableType REGION_TYPE = ResolvableType.forClass(ResourceRegion.class); + + private static final Log logger = HttpLogging.forLogName(ResourceHttpMessageWriter.class); + + + private final ResourceEncoder encoder; + + private final ResourceRegionEncoder regionEncoder; + + private final List mediaTypes; + + + public ResourceHttpMessageWriter() { + this(ResourceEncoder.DEFAULT_BUFFER_SIZE); + } + + public ResourceHttpMessageWriter(int bufferSize) { + this.encoder = new ResourceEncoder(bufferSize); + this.regionEncoder = new ResourceRegionEncoder(bufferSize); + this.mediaTypes = MediaType.asMediaTypes(this.encoder.getEncodableMimeTypes()); + } + + + @Override + public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) { + return this.encoder.canEncode(elementType, mediaType); + } + + @Override + public List getWritableMediaTypes() { + return this.mediaTypes; + } + + + // Client or server: single Resource... + + @Override + public Mono write(Publisher inputStream, ResolvableType elementType, + @Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map hints) { + + return Mono.from(inputStream).flatMap(resource -> + writeResource(resource, elementType, mediaType, message, hints)); + } + + private Mono writeResource(Resource resource, ResolvableType type, @Nullable MediaType mediaType, + ReactiveHttpOutputMessage message, Map hints) { + + HttpHeaders headers = message.getHeaders(); + MediaType resourceMediaType = getResourceMediaType(mediaType, resource, hints); + headers.setContentType(resourceMediaType); + + if (headers.getContentLength() < 0) { + long length = lengthOf(resource); + if (length != -1) { + headers.setContentLength(length); + } + } + + return zeroCopy(resource, null, message, hints) + .orElseGet(() -> { + Mono input = Mono.just(resource); + DataBufferFactory factory = message.bufferFactory(); + Flux body = this.encoder.encode(input, factory, type, resourceMediaType, hints); + return message.writeWith(body); + }); + } + + private static MediaType getResourceMediaType( + @Nullable MediaType mediaType, Resource resource, Map hints) { + + if (mediaType != null && mediaType.isConcrete() && !mediaType.equals(MediaType.APPLICATION_OCTET_STREAM)) { + return mediaType; + } + mediaType = MediaTypeFactory.getMediaType(resource).orElse(MediaType.APPLICATION_OCTET_STREAM); + if (logger.isDebugEnabled() && !Hints.isLoggingSuppressed(hints)) { + logger.debug(Hints.getLogPrefix(hints) + "Resource associated with '" + mediaType + "'"); + } + return mediaType; + } + + private static long lengthOf(Resource resource) { + // Don't consume InputStream... + if (InputStreamResource.class != resource.getClass()) { + try { + return resource.contentLength(); + } + catch (IOException ignored) { + } + } + return -1; + } + + private static Optional> zeroCopy(Resource resource, @Nullable ResourceRegion region, + ReactiveHttpOutputMessage message, Map hints) { + + if (message instanceof ZeroCopyHttpOutputMessage && resource.isFile()) { + try { + File file = resource.getFile(); + long pos = region != null ? region.getPosition() : 0; + long count = region != null ? region.getCount() : file.length(); + if (logger.isDebugEnabled()) { + String formatted = region != null ? "region " + pos + "-" + (count) + " of " : ""; + logger.debug(Hints.getLogPrefix(hints) + "Zero-copy " + formatted + "[" + resource + "]"); + } + return Optional.of(((ZeroCopyHttpOutputMessage) message).writeWith(file, pos, count)); + } + catch (IOException ex) { + // should not happen + } + } + return Optional.empty(); + } + + + // Server-side only: single Resource or sub-regions... + + @Override + public Mono write(Publisher inputStream, @Nullable ResolvableType actualType, + ResolvableType elementType, @Nullable MediaType mediaType, ServerHttpRequest request, + ServerHttpResponse response, Map hints) { + + HttpHeaders headers = response.getHeaders(); + headers.set(HttpHeaders.ACCEPT_RANGES, "bytes"); + + List ranges; + try { + ranges = request.getHeaders().getRange(); + } + catch (IllegalArgumentException ex) { + response.setStatusCode(HttpStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + return response.setComplete(); + } + + return Mono.from(inputStream).flatMap(resource -> { + if (ranges.isEmpty()) { + return writeResource(resource, elementType, mediaType, response, hints); + } + response.setStatusCode(HttpStatus.PARTIAL_CONTENT); + List regions = HttpRange.toResourceRegions(ranges, resource); + MediaType resourceMediaType = getResourceMediaType(mediaType, resource, hints); + if (regions.size() == 1){ + ResourceRegion region = regions.get(0); + headers.setContentType(resourceMediaType); + long contentLength = lengthOf(resource); + if (contentLength != -1) { + long start = region.getPosition(); + long end = start + region.getCount() - 1; + end = Math.min(end, contentLength - 1); + headers.add("Content-Range", "bytes " + start + '-' + end + '/' + contentLength); + headers.setContentLength(end - start + 1); + } + return writeSingleRegion(region, response, hints); + } + else { + String boundary = MimeTypeUtils.generateMultipartBoundaryString(); + MediaType multipartType = MediaType.parseMediaType("multipart/byteranges;boundary=" + boundary); + headers.setContentType(multipartType); + Map allHints = Hints.merge(hints, ResourceRegionEncoder.BOUNDARY_STRING_HINT, boundary); + return encodeAndWriteRegions(Flux.fromIterable(regions), resourceMediaType, response, allHints); + } + }); + } + + private Mono writeSingleRegion(ResourceRegion region, ReactiveHttpOutputMessage message, + Map hints) { + + return zeroCopy(region.getResource(), region, message, hints) + .orElseGet(() -> { + Publisher input = Mono.just(region); + MediaType mediaType = message.getHeaders().getContentType(); + return encodeAndWriteRegions(input, mediaType, message, hints); + }); + } + + private Mono encodeAndWriteRegions(Publisher publisher, + @Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map hints) { + + Flux body = this.regionEncoder.encode( + publisher, message.bufferFactory(), REGION_TYPE, mediaType, hints); + + return message.writeWith(body); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/ServerCodecConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..0029b7b2345fde07e3d820d64e8a3558341937e5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerCodecConfigurer.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import org.springframework.core.codec.Encoder; + +/** + * Extension of {@link CodecConfigurer} for HTTP message reader and writer + * options relevant on the server side. + * + *

HTTP message readers for the following are registered by default: + *

    {@code byte[]} + *
  • {@link java.nio.ByteBuffer} + *
  • {@link org.springframework.core.io.buffer.DataBuffer DataBuffer} + *
  • {@link org.springframework.core.io.Resource Resource} + *
  • {@link String} + *
  • {@link org.springframework.util.MultiValueMap + * MultiValueMap<String,String>} for form data + *
  • {@link org.springframework.util.MultiValueMap + * MultiValueMap<String,Object>} for multipart data + *
  • JSON and Smile, if Jackson is present + *
  • XML, if JAXB2 is present + *
+ * + *

HTTP message writers registered by default: + *

    {@code byte[]} + *
  • {@link java.nio.ByteBuffer} + *
  • {@link org.springframework.core.io.buffer.DataBuffer DataBuffer} + *
  • {@link org.springframework.core.io.Resource Resource} + *
  • {@link String} + *
  • {@link org.springframework.util.MultiValueMap + * MultiValueMap<String,String>} for form data + *
  • JSON and Smile, if Jackson is present + *
  • XML, if JAXB2 is present + *
  • Server-Sent Events + *
+ * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface ServerCodecConfigurer extends CodecConfigurer { + + /** + * {@inheritDoc} + *

On the server side, built-in default also include customizations + * related to the encoder for SSE. + */ + @Override + ServerDefaultCodecs defaultCodecs(); + + /** + * {@inheritDoc}. + */ + @Override + ServerCodecConfigurer clone(); + + + /** + * Static factory method for a {@code ServerCodecConfigurer}. + */ + static ServerCodecConfigurer create() { + return CodecConfigurerFactory.create(ServerCodecConfigurer.class); + } + + + /** + * {@link CodecConfigurer.DefaultCodecs} extension with extra client-side options. + */ + interface ServerDefaultCodecs extends DefaultCodecs { + + /** + * Configure the {@code HttpMessageReader} to use for multipart requests. + *

By default, if + * Synchronoss NIO Multipart + * is present, this is set to + * {@link org.springframework.http.codec.multipart.MultipartHttpMessageReader + * MultipartHttpMessageReader} created with an instance of + * {@link org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader + * SynchronossPartHttpMessageReader}. + *

Note that {@link #maxInMemorySize(int)} and/or + * {@link #enableLoggingRequestDetails(boolean)}, if configured, will be + * applied to the given reader, if applicable. + * @param reader the message reader to use for multipart requests. + * @since 5.1.11 + */ + void multipartReader(HttpMessageReader reader); + + /** + * Configure the {@code Encoder} to use for Server-Sent Events. + *

By default if this is not set, and Jackson is available, the + * {@link #jackson2JsonEncoder} override is used instead. Use this property + * if you want to further customize the SSE encoder. + */ + void serverSentEventEncoder(Encoder encoder); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java new file mode 100644 index 0000000000000000000000000000000000000000..c3a1571110fdc193c63fae0b2efbd5cb73c7d96c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java @@ -0,0 +1,245 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.time.Duration; + +import org.springframework.lang.Nullable; + +/** + * Representation for a Server-Sent Event for use with Spring's reactive Web support. + * {@code Flux} or {@code Observable} is the + * reactive equivalent to Spring MVC's {@code SseEmitter}. + * + * @author Sebastien Deleuze + * @author Arjen Poutsma + * @since 5.0 + * @param the type of data that this event contains + * @see ServerSentEventHttpMessageWriter + * @see Server-Sent Events W3C recommendation + */ +public final class ServerSentEvent { + + @Nullable + private final String id; + + @Nullable + private final String event; + + @Nullable + private final Duration retry; + + @Nullable + private final String comment; + + @Nullable + private final T data; + + + private ServerSentEvent(@Nullable String id, @Nullable String event, @Nullable Duration retry, + @Nullable String comment, @Nullable T data) { + + this.id = id; + this.event = event; + this.retry = retry; + this.comment = comment; + this.data = data; + } + + + /** + * Return the {@code id} field of this event, if available. + */ + @Nullable + public String id() { + return this.id; + } + + /** + * Return the {@code event} field of this event, if available. + */ + @Nullable + public String event() { + return this.event; + } + + /** + * Return the {@code retry} field of this event, if available. + */ + @Nullable + public Duration retry() { + return this.retry; + } + + /** + * Return the comment of this event, if available. + */ + @Nullable + public String comment() { + return this.comment; + } + + /** + * Return the {@code data} field of this event, if available. + */ + @Nullable + public T data() { + return this.data; + } + + + @Override + public String toString() { + return ("ServerSentEvent [id = '" + this.id + "\', event='" + this.event + "\', retry=" + + this.retry + ", comment='" + this.comment + "', data=" + this.data + ']'); + } + + + /** + * Return a builder for a {@code SseEvent}. + * @param the type of data that this event contains + * @return the builder + */ + public static Builder builder() { + return new BuilderImpl<>(); + } + + /** + * Return a builder for a {@code SseEvent}, populated with the give {@linkplain #data() data}. + * @param the type of data that this event contains + * @return the builder + */ + public static Builder builder(T data) { + return new BuilderImpl<>(data); + } + + + /** + * A mutable builder for a {@code SseEvent}. + * + * @param the type of data that this event contains + */ + public interface Builder { + + /** + * Set the value of the {@code id} field. + * @param id the value of the id field + * @return {@code this} builder + */ + Builder id(String id); + + /** + * Set the value of the {@code event} field. + * @param event the value of the event field + * @return {@code this} builder + */ + Builder event(String event); + + /** + * Set the value of the {@code retry} field. + * @param retry the value of the retry field + * @return {@code this} builder + */ + Builder retry(Duration retry); + + /** + * Set SSE comment. If a multi-line comment is provided, it will be turned into multiple + * SSE comment lines as defined in Server-Sent Events W3C recommendation. + * @param comment the comment to set + * @return {@code this} builder + */ + Builder comment(String comment); + + /** + * Set the value of the {@code data} field. If the {@code data} argument is a multi-line + * {@code String}, it will be turned into multiple {@code data} field lines as defined + * in the Server-Sent Events W3C recommendation. If {@code data} is not a String, it will + * be {@linkplain org.springframework.http.codec.json.Jackson2JsonEncoder encoded} into JSON. + * @param data the value of the data field + * @return {@code this} builder + */ + Builder data(@Nullable T data); + + /** + * Builds the event. + * @return the built event + */ + ServerSentEvent build(); + } + + + private static class BuilderImpl implements Builder { + + @Nullable + private String id; + + @Nullable + private String event; + + @Nullable + private Duration retry; + + @Nullable + private String comment; + + @Nullable + private T data; + + public BuilderImpl() { + } + + public BuilderImpl(T data) { + this.data = data; + } + + @Override + public Builder id(String id) { + this.id = id; + return this; + } + + @Override + public Builder event(String event) { + this.event = event; + return this; + } + + @Override + public Builder retry(Duration retry) { + this.retry = retry; + return this; + } + + @Override + public Builder comment(String comment) { + this.comment = comment; + return this; + } + + @Override + public Builder data(@Nullable T data) { + this.data = data; + return this; + } + + @Override + public ServerSentEvent build() { + return new ServerSentEvent<>(this.id, this.event, this.retry, this.comment, this.data); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java new file mode 100644 index 0000000000000000000000000000000000000000..9135b962313ebc174c2ff428f8ab7f527fda2833 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.lang.Nullable; + +/** + * Reader that supports a stream of {@link ServerSentEvent ServerSentEvents} and also plain + * {@link Object Objects} which is the same as an {@link ServerSentEvent} with data only. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ServerSentEventHttpMessageReader implements HttpMessageReader { + + private static final ResolvableType STRING_TYPE = ResolvableType.forClass(String.class); + + private static final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + private static final StringDecoder stringDecoder = StringDecoder.textPlainOnly(); + + + @Nullable + private final Decoder decoder; + + + /** + * Constructor without a {@code Decoder}. In this mode only {@code String} + * is supported as the data of an event. + */ + public ServerSentEventHttpMessageReader() { + this(null); + } + + /** + * Constructor with JSON {@code Decoder} for decoding to Objects. + * Support for decoding to {@code String} event data is built-in. + */ + public ServerSentEventHttpMessageReader(@Nullable Decoder decoder) { + this.decoder = decoder; + } + + + /** + * Return the configured {@code Decoder}. + */ + @Nullable + public Decoder getDecoder() { + return this.decoder; + } + + @Override + public List getReadableMediaTypes() { + return Collections.singletonList(MediaType.TEXT_EVENT_STREAM); + } + + @Override + public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { + return (MediaType.TEXT_EVENT_STREAM.includes(mediaType) || isServerSentEvent(elementType)); + } + + private boolean isServerSentEvent(ResolvableType elementType) { + return ServerSentEvent.class.isAssignableFrom(elementType.toClass()); + } + + + @Override + public Flux read( + ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + + boolean shouldWrap = isServerSentEvent(elementType); + ResolvableType valueType = (shouldWrap ? elementType.getGeneric() : elementType); + + return stringDecoder.decode(message.getBody(), STRING_TYPE, null, hints) + .bufferUntil(line -> line.equals("")) + .concatMap(lines -> buildEvent(lines, valueType, shouldWrap, hints)); + } + + private Mono buildEvent(List lines, ResolvableType valueType, boolean shouldWrap, + Map hints) { + + ServerSentEvent.Builder sseBuilder = shouldWrap ? ServerSentEvent.builder() : null; + StringBuilder data = null; + StringBuilder comment = null; + + for (String line : lines) { + if (line.startsWith("data:")) { + data = (data != null ? data : new StringBuilder()); + data.append(line.substring(5).trim()).append("\n"); + } + if (shouldWrap) { + if (line.startsWith("id:")) { + sseBuilder.id(line.substring(3).trim()); + } + else if (line.startsWith("event:")) { + sseBuilder.event(line.substring(6).trim()); + } + else if (line.startsWith("retry:")) { + sseBuilder.retry(Duration.ofMillis(Long.parseLong(line.substring(6).trim()))); + } + else if (line.startsWith(":")) { + comment = (comment != null ? comment : new StringBuilder()); + comment.append(line.substring(1).trim()).append("\n"); + } + } + } + + Mono decodedData = (data != null ? decodeData(data.toString(), valueType, hints) : Mono.empty()); + + if (shouldWrap) { + if (comment != null) { + sseBuilder.comment(comment.substring(0, comment.length() - 1)); + } + return decodedData.map(o -> { + sseBuilder.data(o); + return sseBuilder.build(); + }); + } + else { + return decodedData; + } + } + + private Mono decodeData(String data, ResolvableType dataType, Map hints) { + if (String.class == dataType.resolve()) { + return Mono.just(data.substring(0, data.length() - 1)); + } + + if (this.decoder == null) { + return Mono.error(new CodecException("No SSE decoder configured and the data is not String.")); + } + + byte[] bytes = data.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = bufferFactory.wrap(bytes); // wrapping only, no allocation + return this.decoder.decodeToMono(Mono.just(buffer), dataType, MediaType.TEXT_EVENT_STREAM, hints); + } + + @Override + public Mono readMono( + ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + + // We're ahead of String + "*/*" + // Let's see if we can aggregate the output (lest we time out)... + + if (elementType.resolve() == String.class) { + Flux body = message.getBody(); + return stringDecoder.decodeToMono(body, elementType, null, null).cast(Object.class); + } + + return Mono.error(new UnsupportedOperationException( + "ServerSentEventHttpMessageReader only supports reading stream of events as a Flux")); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..04dc30b3cc646d768706808c470a9196713199d1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java @@ -0,0 +1,211 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * {@code HttpMessageWriter} for {@code "text/event-stream"} responses. + * + * @author Sebastien Deleuze + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ServerSentEventHttpMessageWriter implements HttpMessageWriter { + + private static final MediaType DEFAULT_MEDIA_TYPE = new MediaType("text", "event-stream", StandardCharsets.UTF_8); + + private static final List WRITABLE_MEDIA_TYPES = Collections.singletonList(MediaType.TEXT_EVENT_STREAM); + + + @Nullable + private final Encoder encoder; + + + /** + * Constructor without an {@code Encoder}. In this mode only {@code String} + * is supported for event data to be encoded. + */ + public ServerSentEventHttpMessageWriter() { + this(null); + } + + /** + * Constructor with JSON {@code Encoder} for encoding objects. + * Support for {@code String} event data is built-in. + * @param encoder the Encoder to use (may be {@code null}) + */ + public ServerSentEventHttpMessageWriter(@Nullable Encoder encoder) { + this.encoder = encoder; + } + + + /** + * Return the configured {@code Encoder}, if any. + */ + @Nullable + public Encoder getEncoder() { + return this.encoder; + } + + @Override + public List getWritableMediaTypes() { + return WRITABLE_MEDIA_TYPES; + } + + + @Override + public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) { + return (mediaType == null || MediaType.TEXT_EVENT_STREAM.includes(mediaType) || + ServerSentEvent.class.isAssignableFrom(elementType.toClass())); + } + + @Override + public Mono write(Publisher input, ResolvableType elementType, @Nullable MediaType mediaType, + ReactiveHttpOutputMessage message, Map hints) { + + mediaType = (mediaType != null && mediaType.getCharset() != null ? mediaType : DEFAULT_MEDIA_TYPE); + DataBufferFactory bufferFactory = message.bufferFactory(); + + message.getHeaders().setContentType(mediaType); + return message.writeAndFlushWith(encode(input, elementType, mediaType, bufferFactory, hints)); + } + + private Flux> encode(Publisher input, ResolvableType elementType, + MediaType mediaType, DataBufferFactory factory, Map hints) { + + ResolvableType valueType = (ServerSentEvent.class.isAssignableFrom(elementType.toClass()) ? + elementType.getGeneric() : elementType); + + return Flux.from(input).map(element -> { + + ServerSentEvent sse = (element instanceof ServerSentEvent ? + (ServerSentEvent) element : ServerSentEvent.builder().data(element).build()); + + StringBuilder sb = new StringBuilder(); + String id = sse.id(); + String event = sse.event(); + Duration retry = sse.retry(); + String comment = sse.comment(); + Object data = sse.data(); + if (id != null) { + writeField("id", id, sb); + } + if (event != null) { + writeField("event", event, sb); + } + if (retry != null) { + writeField("retry", retry.toMillis(), sb); + } + if (comment != null) { + sb.append(':').append(StringUtils.replace(comment, "\n", "\n:")).append("\n"); + } + if (data != null) { + sb.append("data:"); + } + + Flux flux = Flux.concat( + encodeText(sb, mediaType, factory), + encodeData(data, valueType, mediaType, factory, hints), + encodeText("\n", mediaType, factory)); + + return flux.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); + }); + } + + private void writeField(String fieldName, Object fieldValue, StringBuilder sb) { + sb.append(fieldName); + sb.append(':'); + sb.append(fieldValue.toString()); + sb.append("\n"); + } + + @SuppressWarnings("unchecked") + private Flux encodeData(@Nullable T dataValue, ResolvableType valueType, + MediaType mediaType, DataBufferFactory factory, Map hints) { + + if (dataValue == null) { + return Flux.empty(); + } + + if (dataValue instanceof String) { + String text = (String) dataValue; + return Flux.from(encodeText(StringUtils.replace(text, "\n", "\ndata:") + "\n", mediaType, factory)); + } + + if (this.encoder == null) { + return Flux.error(new CodecException("No SSE encoder configured and the data is not String.")); + } + + return ((Encoder) this.encoder) + .encode(Mono.just(dataValue), factory, valueType, mediaType, hints) + .concatWith(encodeText("\n", mediaType, factory)); + } + + private Mono encodeText(CharSequence text, MediaType mediaType, DataBufferFactory bufferFactory) { + Assert.notNull(mediaType.getCharset(), "Expected MediaType with charset"); + byte[] bytes = text.toString().getBytes(mediaType.getCharset()); + return Mono.just(bufferFactory.wrap(bytes)); // wrapping, not allocating + } + + @Override + public Mono write(Publisher input, ResolvableType actualType, ResolvableType elementType, + @Nullable MediaType mediaType, ServerHttpRequest request, ServerHttpResponse response, + Map hints) { + + Map allHints = Hints.merge(hints, + getEncodeHints(actualType, elementType, mediaType, request, response)); + + return write(input, elementType, mediaType, response, allHints); + } + + private Map getEncodeHints(ResolvableType actualType, ResolvableType elementType, + @Nullable MediaType mediaType, ServerHttpRequest request, ServerHttpResponse response) { + + if (this.encoder instanceof HttpMessageEncoder) { + HttpMessageEncoder encoder = (HttpMessageEncoder) this.encoder; + return encoder.getEncodeHints(actualType, elementType, mediaType, request, response); + } + return Hints.none(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Decoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Decoder.java new file mode 100644 index 0000000000000000000000000000000000000000..f5086a121dfa14123d9bb9d63d163d9f079922a4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Decoder.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; +import com.fasterxml.jackson.databind.exc.InvalidDefinitionException; +import com.fasterxml.jackson.databind.util.TokenBuffer; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.codec.HttpMessageDecoder; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Abstract base class for Jackson 2.9 decoding, leveraging non-blocking parsing. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @since 5.0 + * @see Add support for non-blocking ("async") JSON parsing + */ +public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport implements HttpMessageDecoder { + + private int maxInMemorySize = -1; + + + /** + * Until https://github.com/FasterXML/jackson-core/issues/476 is resolved, + * we need to ensure buffer recycling is off. + */ + private final JsonFactory jsonFactory; + + + /** + * Constructor with a Jackson {@link ObjectMapper} to use. + */ + protected AbstractJackson2Decoder(ObjectMapper mapper, MimeType... mimeTypes) { + super(mapper, mimeTypes); + this.jsonFactory = mapper.getFactory().copy() + .disable(JsonFactory.Feature.USE_THREAD_LOCAL_FOR_BUFFER_RECYCLING); + } + + + /** + * Set the max number of bytes that can be buffered by this decoder. This + * is either the size of the entire input when decoding as a whole, or the + * size of one top-level JSON object within a JSON stream. When the limit + * is exceeded, {@link DataBufferLimitException} is raised. + *

By default in 5.1 this is set to -1, unlimited. In 5.2 the default + * value for this limit is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + + @Override + public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { + JavaType javaType = getObjectMapper().constructType(elementType.getType()); + // Skip String: CharSequenceDecoder + "*/*" comes after + return (!CharSequence.class.isAssignableFrom(elementType.toClass()) && + getObjectMapper().canDeserialize(javaType) && supportsMimeType(mimeType)); + } + + @Override + public Flux decode(Publisher input, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + boolean forceUseOfBigDecimal = getObjectMapper().isEnabled(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + if (BigDecimal.class.equals(elementType.getType())) { + forceUseOfBigDecimal = true; + } + + Flux processed = processInput(input, elementType, mimeType, hints); + Flux tokens = Jackson2Tokenizer.tokenize(processed, this.jsonFactory, getObjectMapper(), + true, forceUseOfBigDecimal, getMaxInMemorySize()); + return decodeInternal(tokens, elementType, mimeType, hints); + } + + /** + * Process the input publisher into a flux. Default implementation returns + * {@link Flux#from(Publisher)}, but subclasses can choose to to customize + * this behaviour. + * @param input the {@code DataBuffer} input stream to process + * @param elementType the expected type of elements in the output stream + * @param mimeType the MIME type associated with the input stream (optional) + * @param hints additional information about how to do encode + * @return the processed flux + * @since 5.1.14 + */ + protected Flux processInput(Publisher input, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + return Flux.from(input); + } + + @Override + public Mono decodeToMono(Publisher input, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + boolean forceUseOfBigDecimal = getObjectMapper().isEnabled(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + if (BigDecimal.class.equals(elementType.getType())) { + forceUseOfBigDecimal = true; + } + + Flux processed = processInput(input, elementType, mimeType, hints); + Flux tokens = Jackson2Tokenizer.tokenize(processed, this.jsonFactory, getObjectMapper(), + false, forceUseOfBigDecimal, getMaxInMemorySize()); + return decodeInternal(tokens, elementType, mimeType, hints).singleOrEmpty(); + } + + private Flux decodeInternal(Flux tokens, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + Assert.notNull(tokens, "'tokens' must not be null"); + Assert.notNull(elementType, "'elementType' must not be null"); + + MethodParameter param = getParameter(elementType); + Class contextClass = (param != null ? param.getContainingClass() : null); + JavaType javaType = getJavaType(elementType.getType(), contextClass); + Class jsonView = (hints != null ? (Class) hints.get(Jackson2CodecSupport.JSON_VIEW_HINT) : null); + + ObjectReader reader = (jsonView != null ? + getObjectMapper().readerWithView(jsonView).forType(javaType) : + getObjectMapper().readerFor(javaType)); + + return tokens.flatMap(tokenBuffer -> { + try { + Object value = reader.readValue(tokenBuffer.asParser(getObjectMapper())); + if (!Hints.isLoggingSuppressed(hints)) { + LogFormatUtils.traceDebug(logger, traceOn -> { + String formatted = LogFormatUtils.formatValue(value, !traceOn); + return Hints.getLogPrefix(hints) + "Decoded [" + formatted + "]"; + }); + } + return Mono.justOrEmpty(value); + } + catch (InvalidDefinitionException ex) { + return Mono.error(new CodecException("Type definition error: " + ex.getType(), ex)); + } + catch (JsonProcessingException ex) { + return Mono.error(new DecodingException("JSON decoding error: " + ex.getOriginalMessage(), ex)); + } + catch (IOException ex) { + return Mono.error(new DecodingException("I/O error while parsing input stream", ex)); + } + }); + } + + + // HttpMessageDecoder + + @Override + public Map getDecodeHints(ResolvableType actualType, ResolvableType elementType, + ServerHttpRequest request, ServerHttpResponse response) { + + return getHints(actualType); + } + + @Override + public List getDecodableMimeTypes() { + return getMimeTypes(); + } + + + // Jackson2CodecSupport + + @Override + protected A getAnnotation(MethodParameter parameter, Class annotType) { + return parameter.getParameterAnnotation(annotType); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Encoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Encoder.java new file mode 100644 index 0000000000000000000000000000000000000000..98d23ad2ac1b9538bf2ffc7dd90e38decf60dc66 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Encoder.java @@ -0,0 +1,325 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.util.ByteArrayBuilder; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.SequenceWriter; +import com.fasterxml.jackson.databind.exc.InvalidDefinitionException; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.EncodingException; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageEncoder; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Base class providing support methods for Jackson 2.9 encoding. For non-streaming use + * cases, {@link Flux} elements are collected into a {@link List} before serialization for + * performance reason. + * + * @author Sebastien Deleuze + * @author Arjen Poutsma + * @since 5.0 + */ +public abstract class AbstractJackson2Encoder extends Jackson2CodecSupport implements HttpMessageEncoder { + + private static final byte[] NEWLINE_SEPARATOR = {'\n'}; + + private static final Map STREAM_SEPARATORS; + + private static final Map ENCODINGS; + + static { + STREAM_SEPARATORS = new HashMap<>(4); + STREAM_SEPARATORS.put(MediaType.APPLICATION_STREAM_JSON, NEWLINE_SEPARATOR); + STREAM_SEPARATORS.put(MediaType.parseMediaType("application/stream+x-jackson-smile"), new byte[0]); + + ENCODINGS = new HashMap<>(JsonEncoding.values().length + 1); + for (JsonEncoding encoding : JsonEncoding.values()) { + ENCODINGS.put(encoding.getJavaName(), encoding); + } + ENCODINGS.put("US-ASCII", JsonEncoding.UTF8); + } + + + private final List streamingMediaTypes = new ArrayList<>(1); + + + /** + * Constructor with a Jackson {@link ObjectMapper} to use. + */ + protected AbstractJackson2Encoder(ObjectMapper mapper, MimeType... mimeTypes) { + super(mapper, mimeTypes); + } + + + /** + * Configure "streaming" media types for which flushing should be performed + * automatically vs at the end of the stream. + *

By default this is set to {@link MediaType#APPLICATION_STREAM_JSON}. + * @param mediaTypes one or more media types to add to the list + * @see HttpMessageEncoder#getStreamingMediaTypes() + */ + public void setStreamingMediaTypes(List mediaTypes) { + this.streamingMediaTypes.clear(); + this.streamingMediaTypes.addAll(mediaTypes); + } + + + @Override + public boolean canEncode(ResolvableType elementType, @Nullable MimeType mimeType) { + Class clazz = elementType.toClass(); + if (!supportsMimeType(mimeType)) { + return false; + } + if (mimeType != null && mimeType.getCharset() != null) { + Charset charset = mimeType.getCharset(); + if (!ENCODINGS.containsKey(charset.name())) { + return false; + } + } + return (Object.class == clazz || + (!String.class.isAssignableFrom(elementType.resolve(clazz)) && getObjectMapper().canSerialize(clazz))); + } + + @Override + public Flux encode(Publisher inputStream, DataBufferFactory bufferFactory, + ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { + + Assert.notNull(inputStream, "'inputStream' must not be null"); + Assert.notNull(bufferFactory, "'bufferFactory' must not be null"); + Assert.notNull(elementType, "'elementType' must not be null"); + + if (inputStream instanceof Mono) { + return Mono.from(inputStream) + .map(value -> encodeValue(value, bufferFactory, elementType, mimeType, hints)) + .flux(); + } + else { + byte[] separator = streamSeparator(mimeType); + if (separator != null) { // streaming + try { + ObjectWriter writer = createObjectWriter(elementType, mimeType, hints); + ByteArrayBuilder byteBuilder = new ByteArrayBuilder(writer.getFactory()._getBufferRecycler()); + JsonEncoding encoding = getJsonEncoding(mimeType); + JsonGenerator generator = getObjectMapper().getFactory().createGenerator(byteBuilder, encoding); + SequenceWriter sequenceWriter = writer.writeValues(generator); + + return Flux.from(inputStream) + .map(value -> encodeStreamingValue(value, bufferFactory, hints, sequenceWriter, byteBuilder, + separator)); + } + catch (IOException ex) { + return Flux.error(ex); + } + } + else { // non-streaming + ResolvableType listType = ResolvableType.forClassWithGenerics(List.class, elementType); + return Flux.from(inputStream) + .collectList() + .map(list -> encodeValue(list, bufferFactory, listType, mimeType, hints)) + .flux(); + } + + } + } + + public DataBuffer encodeValue(Object value, DataBufferFactory bufferFactory, + ResolvableType valueType, @Nullable MimeType mimeType, @Nullable Map hints) { + + ObjectWriter writer = createObjectWriter(valueType, mimeType, hints); + ByteArrayBuilder byteBuilder = new ByteArrayBuilder(writer.getFactory()._getBufferRecycler()); + JsonEncoding encoding = getJsonEncoding(mimeType); + + logValue(hints, value); + + try { + JsonGenerator generator = getObjectMapper().getFactory().createGenerator(byteBuilder, encoding); + writer.writeValue(generator, value); + generator.flush(); + } + catch (InvalidDefinitionException ex) { + throw new CodecException("Type definition error: " + ex.getType(), ex); + } + catch (JsonProcessingException ex) { + throw new EncodingException("JSON encoding error: " + ex.getOriginalMessage(), ex); + } + catch (IOException ex) { + throw new IllegalStateException("Unexpected I/O error while writing to byte array builder", ex); + } + + byte[] bytes = byteBuilder.toByteArray(); + DataBuffer buffer = bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + + return buffer; + } + + private DataBuffer encodeStreamingValue(Object value, DataBufferFactory bufferFactory, @Nullable Map hints, + SequenceWriter sequenceWriter, ByteArrayBuilder byteArrayBuilder, byte[] separator) { + + logValue(hints, value); + + try { + sequenceWriter.write(value); + sequenceWriter.flush(); + } + catch (InvalidDefinitionException ex) { + throw new CodecException("Type definition error: " + ex.getType(), ex); + } + catch (JsonProcessingException ex) { + throw new EncodingException("JSON encoding error: " + ex.getOriginalMessage(), ex); + } + catch (IOException ex) { + throw new IllegalStateException("Unexpected I/O error while writing to byte array builder", ex); + } + + byte[] bytes = byteArrayBuilder.toByteArray(); + byteArrayBuilder.reset(); + + int offset; + int length; + if (bytes.length > 0 && bytes[0] == ' ') { + // SequenceWriter writes an unnecessary space in between values + offset = 1; + length = bytes.length - 1; + } + else { + offset = 0; + length = bytes.length; + } + DataBuffer buffer = bufferFactory.allocateBuffer(length + separator.length); + buffer.write(bytes, offset, length); + buffer.write(separator); + + return buffer; + } + + private void logValue(@Nullable Map hints, Object value) { + if (!Hints.isLoggingSuppressed(hints)) { + LogFormatUtils.traceDebug(logger, traceOn -> { + String formatted = LogFormatUtils.formatValue(value, !traceOn); + return Hints.getLogPrefix(hints) + "Encoding [" + formatted + "]"; + }); + } + } + + private ObjectWriter createObjectWriter(ResolvableType valueType, @Nullable MimeType mimeType, + @Nullable Map hints) { + + JavaType javaType = getJavaType(valueType.getType(), null); + Class jsonView = (hints != null ? (Class) hints.get(Jackson2CodecSupport.JSON_VIEW_HINT) : null); + ObjectWriter writer = (jsonView != null ? + getObjectMapper().writerWithView(jsonView) : getObjectMapper().writer()); + + if (javaType.isContainerType()) { + writer = writer.forType(javaType); + } + + return customizeWriter(writer, mimeType, valueType, hints); + } + + protected ObjectWriter customizeWriter(ObjectWriter writer, @Nullable MimeType mimeType, + ResolvableType elementType, @Nullable Map hints) { + + return writer; + } + + @Nullable + private byte[] streamSeparator(@Nullable MimeType mimeType) { + for (MediaType streamingMediaType : this.streamingMediaTypes) { + if (streamingMediaType.isCompatibleWith(mimeType)) { + return STREAM_SEPARATORS.getOrDefault(streamingMediaType, NEWLINE_SEPARATOR); + } + } + return null; + } + + /** + * Determine the JSON encoding to use for the given mime type. + * @param mimeType the mime type as requested by the caller + * @return the JSON encoding to use (never {@code null}) + * @since 5.0.5 + */ + protected JsonEncoding getJsonEncoding(@Nullable MimeType mimeType) { + if (mimeType != null && mimeType.getCharset() != null) { + Charset charset = mimeType.getCharset(); + JsonEncoding result = ENCODINGS.get(charset.name()); + if (result != null) { + return result; + } + } + return JsonEncoding.UTF8; + } + + + // HttpMessageEncoder + + @Override + public List getEncodableMimeTypes() { + return getMimeTypes(); + } + + @Override + public List getStreamingMediaTypes() { + return Collections.unmodifiableList(this.streamingMediaTypes); + } + + @Override + public Map getEncodeHints(@Nullable ResolvableType actualType, ResolvableType elementType, + @Nullable MediaType mediaType, ServerHttpRequest request, ServerHttpResponse response) { + + return (actualType != null ? getHints(actualType) : Hints.none()); + } + + + // Jackson2CodecSupport + + @Override + protected A getAnnotation(MethodParameter parameter, Class annotType) { + return parameter.getMethodAnnotation(annotType); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2CodecSupport.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2CodecSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..4d7d2771ffc77dbb150a3547b70e377210229437 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2CodecSupport.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.logging.Log; + +import org.springframework.core.GenericTypeResolver; +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.http.HttpLogging; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.ObjectUtils; + +/** + * Base class providing support methods for Jackson 2.9 encoding and decoding. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + */ +public abstract class Jackson2CodecSupport { + + /** + * The key for the hint to specify a "JSON View" for encoding or decoding + * with the value expected to be a {@link Class}. + * @see Jackson JSON Views + */ + public static final String JSON_VIEW_HINT = Jackson2CodecSupport.class.getName() + ".jsonView"; + + private static final String JSON_VIEW_HINT_ERROR = + "@JsonView only supported for write hints with exactly 1 class argument: "; + + private static final List DEFAULT_MIME_TYPES = Collections.unmodifiableList( + Arrays.asList( + new MimeType("application", "json", StandardCharsets.UTF_8), + new MimeType("application", "*+json", StandardCharsets.UTF_8))); + + + protected final Log logger = HttpLogging.forLogName(getClass()); + + private final ObjectMapper objectMapper; + + private final List mimeTypes; + + + /** + * Constructor with a Jackson {@link ObjectMapper} to use. + */ + protected Jackson2CodecSupport(ObjectMapper objectMapper, MimeType... mimeTypes) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + this.mimeTypes = !ObjectUtils.isEmpty(mimeTypes) ? + Collections.unmodifiableList(Arrays.asList(mimeTypes)) : DEFAULT_MIME_TYPES; + } + + + public ObjectMapper getObjectMapper() { + return this.objectMapper; + } + + /** + * Subclasses should expose this as "decodable" or "encodable" mime types. + */ + protected List getMimeTypes() { + return this.mimeTypes; + } + + + protected boolean supportsMimeType(@Nullable MimeType mimeType) { + return (mimeType == null || this.mimeTypes.stream().anyMatch(m -> m.isCompatibleWith(mimeType))); + } + + protected JavaType getJavaType(Type type, @Nullable Class contextClass) { + return this.objectMapper.constructType(GenericTypeResolver.resolveType(type, contextClass)); + } + + protected Map getHints(ResolvableType resolvableType) { + MethodParameter param = getParameter(resolvableType); + if (param != null) { + JsonView annotation = getAnnotation(param, JsonView.class); + if (annotation != null) { + Class[] classes = annotation.value(); + Assert.isTrue(classes.length == 1, JSON_VIEW_HINT_ERROR + param); + return Hints.from(JSON_VIEW_HINT, classes[0]); + } + } + return Hints.none(); + } + + @Nullable + protected MethodParameter getParameter(ResolvableType type) { + return (type.getSource() instanceof MethodParameter ? (MethodParameter) type.getSource() : null); + } + + @Nullable + protected abstract A getAnnotation(MethodParameter parameter, Class annotType); + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..861fa05be268f883fbd157c2fe32e4e9dba47130 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonDecoder.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Map; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +/** + * Decode a byte stream into JSON and convert to Object's with Jackson 2.9, + * leveraging non-blocking parsing. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see Jackson2JsonEncoder + */ +public class Jackson2JsonDecoder extends AbstractJackson2Decoder { + + private static final StringDecoder STRING_DECODER = StringDecoder.textPlainOnly(Arrays.asList(",", "\n"), false); + + private static final ResolvableType STRING_TYPE = ResolvableType.forClass(String.class); + + + public Jackson2JsonDecoder() { + super(Jackson2ObjectMapperBuilder.json().build()); + } + + public Jackson2JsonDecoder(ObjectMapper mapper, MimeType... mimeTypes) { + super(mapper, mimeTypes); + } + + @Override + protected Flux processInput(Publisher input, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + Flux flux = Flux.from(input); + if (mimeType == null) { + return flux; + } + + // Jackson asynchronous parser only supports UTF-8 + Charset charset = mimeType.getCharset(); + if (charset == null || StandardCharsets.UTF_8.equals(charset) || StandardCharsets.US_ASCII.equals(charset)) { + return flux; + } + + // Potentially, the memory consumption of this conversion could be improved by using CharBuffers instead + // of allocating Strings, but that would require refactoring the buffer tokenization code from StringDecoder + + MimeType textMimeType = new MimeType(MimeTypeUtils.TEXT_PLAIN, charset); + Flux decoded = STRING_DECODER.decode(input, STRING_TYPE, textMimeType, null); + DataBufferFactory factory = new DefaultDataBufferFactory(); + return decoded.map(s -> factory.wrap(s.getBytes(StandardCharsets.UTF_8))); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonEncoder.java new file mode 100644 index 0000000000000000000000000000000000000000..5ac99c5ddef6ecd4d5736624bf0470cea9ab0c5c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2JsonEncoder.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.core.PrettyPrinter; +import com.fasterxml.jackson.core.util.DefaultIndenter; +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.SerializationFeature; +import reactor.core.publisher.Flux; + +import org.springframework.core.ResolvableType; +import org.springframework.http.MediaType; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; + +/** + * Encode from an {@code Object} stream to a byte stream of JSON objects using Jackson 2.9. + * For non-streaming use cases, {@link Flux} elements are collected into a {@link List} + * before serialization for performance reason. + * + * @author Sebastien Deleuze + * @author Arjen Poutsma + * @since 5.0 + * @see Jackson2JsonDecoder + */ +public class Jackson2JsonEncoder extends AbstractJackson2Encoder { + + @Nullable + private final PrettyPrinter ssePrettyPrinter; + + + public Jackson2JsonEncoder() { + this(Jackson2ObjectMapperBuilder.json().build()); + } + + public Jackson2JsonEncoder(ObjectMapper mapper, MimeType... mimeTypes) { + super(mapper, mimeTypes); + setStreamingMediaTypes(Collections.singletonList(MediaType.APPLICATION_STREAM_JSON)); + this.ssePrettyPrinter = initSsePrettyPrinter(); + } + + private static PrettyPrinter initSsePrettyPrinter() { + DefaultPrettyPrinter printer = new DefaultPrettyPrinter(); + printer.indentObjectsWith(new DefaultIndenter(" ", "\ndata:")); + return printer; + } + + + @Override + protected ObjectWriter customizeWriter(ObjectWriter writer, @Nullable MimeType mimeType, + ResolvableType elementType, @Nullable Map hints) { + + return (this.ssePrettyPrinter != null && + MediaType.TEXT_EVENT_STREAM.isCompatibleWith(mimeType) && + writer.getConfig().isEnabled(SerializationFeature.INDENT_OUTPUT) ? + writer.with(this.ssePrettyPrinter) : writer); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2SmileDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2SmileDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..3d69190b185570769b81ecb2173f5f506aee5529 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2SmileDecoder.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.nio.charset.StandardCharsets; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; + +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Decode a byte stream into Smile and convert to Object's with Jackson 2.9, + * leveraging non-blocking parsing. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see Jackson2JsonEncoder + */ +public class Jackson2SmileDecoder extends AbstractJackson2Decoder { + + private static final MimeType[] DEFAULT_SMILE_MIME_TYPES = new MimeType[] { + new MimeType("application", "x-jackson-smile", StandardCharsets.UTF_8), + new MimeType("application", "*+x-jackson-smile", StandardCharsets.UTF_8)}; + + + public Jackson2SmileDecoder() { + this(Jackson2ObjectMapperBuilder.smile().build(), DEFAULT_SMILE_MIME_TYPES); + } + + public Jackson2SmileDecoder(ObjectMapper mapper, MimeType... mimeTypes) { + super(mapper, mimeTypes); + Assert.isAssignable(SmileFactory.class, mapper.getFactory().getClass()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2SmileEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2SmileEncoder.java new file mode 100644 index 0000000000000000000000000000000000000000..60499abfa4d06249fd8f99cab9166f41ea13e2a6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2SmileEncoder.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import reactor.core.publisher.Flux; + +import org.springframework.http.MediaType; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Encode from an {@code Object} stream to a byte stream of Smile objects using Jackson 2.9. + * For non-streaming use cases, {@link Flux} elements are collected into a {@link List} + * before serialization for performance reason. + * + * @author Sebastien Deleuze + * @since 5.0 + * @see Jackson2SmileDecoder + */ +public class Jackson2SmileEncoder extends AbstractJackson2Encoder { + + private static final MimeType[] DEFAULT_SMILE_MIME_TYPES = new MimeType[] { + new MimeType("application", "x-jackson-smile", StandardCharsets.UTF_8), + new MimeType("application", "*+x-jackson-smile", StandardCharsets.UTF_8)}; + + + public Jackson2SmileEncoder() { + this(Jackson2ObjectMapperBuilder.smile().build(), DEFAULT_SMILE_MIME_TYPES); + } + + public Jackson2SmileEncoder(ObjectMapper mapper, MimeType... mimeTypes) { + super(mapper, mimeTypes); + Assert.isAssignable(SmileFactory.class, mapper.getFactory().getClass()); + setStreamingMediaTypes(Collections.singletonList(new MediaType("application", "stream+x-jackson-smile"))); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2Tokenizer.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2Tokenizer.java new file mode 100644 index 0000000000000000000000000000000000000000..9234799314da2154c0eeb913f5b7694fbd3d506c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2Tokenizer.java @@ -0,0 +1,256 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.core.async.ByteArrayFeeder; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.deser.DefaultDeserializationContext; +import com.fasterxml.jackson.databind.util.TokenBuffer; +import reactor.core.publisher.Flux; + +import org.springframework.core.codec.DecodingException; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.io.buffer.DataBufferUtils; + +/** + * {@link Function} to transform a JSON stream of arbitrary size, byte array + * chunks into a {@code Flux} where each token buffer is a + * well-formed JSON object. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 5.0 + */ +final class Jackson2Tokenizer { + + private final JsonParser parser; + + private final DeserializationContext deserializationContext; + + private final boolean tokenizeArrayElements; + + private final boolean forceUseOfBigDecimal; + + private final int maxInMemorySize; + + private int objectDepth; + + private int arrayDepth; + + private int byteCount; + + private TokenBuffer tokenBuffer; + + + // TODO: change to ByteBufferFeeder when supported by Jackson + // See https://github.com/FasterXML/jackson-core/issues/478 + private final ByteArrayFeeder inputFeeder; + + + private Jackson2Tokenizer(JsonParser parser, DeserializationContext deserializationContext, + boolean tokenizeArrayElements, boolean forceUseOfBigDecimal, int maxInMemorySize) { + + this.parser = parser; + this.deserializationContext = deserializationContext; + this.tokenizeArrayElements = tokenizeArrayElements; + this.forceUseOfBigDecimal = forceUseOfBigDecimal; + this.inputFeeder = (ByteArrayFeeder) this.parser.getNonBlockingInputFeeder(); + this.maxInMemorySize = maxInMemorySize; + this.tokenBuffer = createToken(); + } + + + + private Flux tokenize(DataBuffer dataBuffer) { + int bufferSize = dataBuffer.readableByteCount(); + byte[] bytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(bytes); + DataBufferUtils.release(dataBuffer); + + try { + this.inputFeeder.feedInput(bytes, 0, bytes.length); + List result = parseTokenBufferFlux(); + assertInMemorySize(bufferSize, result); + return Flux.fromIterable(result); + } + catch (JsonProcessingException ex) { + return Flux.error(new DecodingException("JSON decoding error: " + ex.getOriginalMessage(), ex)); + } + catch (IOException ex) { + return Flux.error(ex); + } + } + + private Flux endOfInput() { + this.inputFeeder.endOfInput(); + try { + List result = parseTokenBufferFlux(); + return Flux.fromIterable(result); + } + catch (JsonProcessingException ex) { + return Flux.error(new DecodingException("JSON decoding error: " + ex.getOriginalMessage(), ex)); + } + catch (IOException ex) { + return Flux.error(ex); + } + } + + private List parseTokenBufferFlux() throws IOException { + List result = new ArrayList<>(); + + // SPR-16151: Smile data format uses null to separate documents + boolean previousNull = false; + while (!this.parser.isClosed()) { + JsonToken token = this.parser.nextToken(); + if (token == JsonToken.NOT_AVAILABLE || + token == null && previousNull) { + break; + } + else if (token == null ) { // !previousNull + previousNull = true; + continue; + } + else { + previousNull = false; + } + updateDepth(token); + if (!this.tokenizeArrayElements) { + processTokenNormal(token, result); + } + else { + processTokenArray(token, result); + } + } + return result; + } + + private void updateDepth(JsonToken token) { + switch (token) { + case START_OBJECT: + this.objectDepth++; + break; + case END_OBJECT: + this.objectDepth--; + break; + case START_ARRAY: + this.arrayDepth++; + break; + case END_ARRAY: + this.arrayDepth--; + break; + } + } + + private void processTokenNormal(JsonToken token, List result) throws IOException { + this.tokenBuffer.copyCurrentEvent(this.parser); + + if ((token.isStructEnd() || token.isScalarValue()) && this.objectDepth == 0 && this.arrayDepth == 0) { + result.add(this.tokenBuffer); + this.tokenBuffer = createToken(); + } + } + + private void processTokenArray(JsonToken token, List result) throws IOException { + if (!isTopLevelArrayToken(token)) { + this.tokenBuffer.copyCurrentEvent(this.parser); + } + + if (this.objectDepth == 0 && (this.arrayDepth == 0 || this.arrayDepth == 1) && + (token == JsonToken.END_OBJECT || token.isScalarValue())) { + result.add(this.tokenBuffer); + this.tokenBuffer = createToken(); + } + } + + private TokenBuffer createToken() { + TokenBuffer tokenBuffer = new TokenBuffer(this.parser, this.deserializationContext); + tokenBuffer.forceUseOfBigDecimal(this.forceUseOfBigDecimal); + return tokenBuffer; + } + + private boolean isTopLevelArrayToken(JsonToken token) { + return this.objectDepth == 0 && ((token == JsonToken.START_ARRAY && this.arrayDepth == 1) || + (token == JsonToken.END_ARRAY && this.arrayDepth == 0)); + } + + private void assertInMemorySize(int currentBufferSize, List result) { + if (this.maxInMemorySize >= 0) { + if (!result.isEmpty()) { + this.byteCount = 0; + } + else if (currentBufferSize > Integer.MAX_VALUE - this.byteCount) { + raiseLimitException(); + } + else { + this.byteCount += currentBufferSize; + if (this.byteCount > this.maxInMemorySize) { + raiseLimitException(); + } + } + } + } + + private void raiseLimitException() { + throw new DataBufferLimitException( + "Exceeded limit on max bytes per JSON object: " + this.maxInMemorySize); + } + + + /** + * Tokenize the given {@code Flux} into {@code Flux}. + * @param dataBuffers the source data buffers + * @param jsonFactory the factory to use + * @param objectMapper the current mapper instance + * @param tokenizeArrays if {@code true} and the "top level" JSON object is + * an array, each element is returned individually immediately after it is received + * @param forceUseOfBigDecimal if {@code true}, any floating point values encountered + * in source will use {@link java.math.BigDecimal} + * @param maxInMemorySize maximum memory size + * @return the resulting token buffers + */ + public static Flux tokenize(Flux dataBuffers, JsonFactory jsonFactory, + ObjectMapper objectMapper, boolean tokenizeArrays, boolean forceUseOfBigDecimal, int maxInMemorySize) { + + try { + JsonParser parser = jsonFactory.createNonBlockingByteArrayParser(); + DeserializationContext context = objectMapper.getDeserializationContext(); + if (context instanceof DefaultDeserializationContext) { + context = ((DefaultDeserializationContext) context).createInstance( + objectMapper.getDeserializationConfig(), parser, objectMapper.getInjectableValues()); + } + Jackson2Tokenizer tokenizer = + new Jackson2Tokenizer(parser, context, tokenizeArrays, forceUseOfBigDecimal, maxInMemorySize); + return dataBuffers.flatMap(tokenizer::tokenize, Flux::error, tokenizer::endOfInput); + } + catch (IOException ex) { + return Flux.error(ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/package-info.java b/spring-web/src/main/java/org/springframework/http/codec/json/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..710d9c02512e8b9b938f5384e070ba9d94cfebfa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/json/package-info.java @@ -0,0 +1,9 @@ +/** + * JSON encoder and decoder support. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.codec.json; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java new file mode 100644 index 0000000000000000000000000000000000000000..0765a9dd1ba723bfe96b2ccb0f10753cf5fb3029 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.io.File; +import java.nio.file.Path; + +import reactor.core.publisher.Mono; + +/** + * Specialization of {@link Part} that represents an uploaded file received in + * a multipart request. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 5.0 + */ +public interface FilePart extends Part { + + /** + * Return the original filename in the client's filesystem. + */ + String filename(); + + /** + * Convenience method to copy the content of the file in this part to the + * given destination file. If the destination file already exists, it will + * be truncated first. + *

The default implementation delegates to {@link #transferTo(Path)}. + * @param dest the target file + * @return completion {@code Mono} with the result of the file transfer, + * possibly {@link IllegalStateException} if the part isn't a file + * @see #transferTo(Path) + */ + default Mono transferTo(File dest) { + return transferTo(dest.toPath()); + } + + /** + * Convenience method to copy the content of the file in this part to the + * given destination file. If the destination file already exists, it will + * be truncated first. + * @param dest the target file + * @return completion {@code Mono} with the result of the file transfer, + * possibly {@link IllegalStateException} if the part isn't a file + * @since 5.1 + * @see #transferTo(File) + */ + Mono transferTo(Path dest); + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/FormFieldPart.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/FormFieldPart.java new file mode 100644 index 0000000000000000000000000000000000000000..b10cea7472ece041bf1bb32ba708f3b8e8c4a4bd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/FormFieldPart.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +/** + * Specialization of {@link Part} for a form field. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface FormFieldPart extends Part { + + /** + * Return the form field value. + */ + String value(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageReader.java new file mode 100644 index 0000000000000000000000000000000000000000..3c8c4b483e7028ffc9ea911237b2bb66527ec50a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageReader.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.LoggingCodecSupport; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * {@code HttpMessageReader} for reading {@code "multipart/form-data"} requests + * into a {@code MultiValueMap}. + * + *

Note that this reader depends on access to an + * {@code HttpMessageReader} for the actual parsing of multipart content. + * The purpose of this reader is to collect the parts into a map. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class MultipartHttpMessageReader extends LoggingCodecSupport + implements HttpMessageReader> { + + private static final ResolvableType MULTIPART_VALUE_TYPE = ResolvableType.forClassWithGenerics( + MultiValueMap.class, String.class, Part.class); + + + private final HttpMessageReader partReader; + + + public MultipartHttpMessageReader(HttpMessageReader partReader) { + Assert.notNull(partReader, "'partReader' is required"); + this.partReader = partReader; + } + + + /** + * Return the configured parts reader. + * @since 5.1.11 + */ + public HttpMessageReader getPartReader() { + return this.partReader; + } + + @Override + public List getReadableMediaTypes() { + return Collections.singletonList(MediaType.MULTIPART_FORM_DATA); + } + + @Override + public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { + return MULTIPART_VALUE_TYPE.isAssignableFrom(elementType) && + (mediaType == null || MediaType.MULTIPART_FORM_DATA.isCompatibleWith(mediaType)); + } + + + @Override + public Flux> read(ResolvableType elementType, + ReactiveHttpInputMessage message, Map hints) { + + return Flux.from(readMono(elementType, message, hints)); + } + + + @Override + public Mono> readMono(ResolvableType elementType, + ReactiveHttpInputMessage inputMessage, Map hints) { + + Map allHints = Hints.merge(hints, Hints.SUPPRESS_LOGGING_HINT, true); + + return this.partReader.read(elementType, inputMessage, allHints) + .collectMultimap(Part::name) + .doOnNext(map -> + LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Parsed " + + (isEnableLoggingRequestDetails() ? + LogFormatUtils.formatValue(map, !traceOn) : + "parts " + map.keySet() + " (content masked)"))) + .map(this::toMultiValueMap); + } + + private LinkedMultiValueMap toMultiValueMap(Map> map) { + return new LinkedMultiValueMap<>(map.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> toList(e.getValue())))); + } + + private List toList(Collection collection) { + return collection instanceof List ? (List) collection : new ArrayList<>(collection); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..13a1dfdc684ab7a483595bcbfd206d3162b58863 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java @@ -0,0 +1,455 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.ResolvableTypeProvider; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.LoggingCodecSupport; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.MultiValueMap; + +/** + * {@link HttpMessageWriter} for writing a {@code MultiValueMap} + * as multipart form data, i.e. {@code "multipart/form-data"}, to the body + * of a request. + * + *

The serialization of individual parts is delegated to other writers. + * By default only {@link String} and {@link Resource} parts are supported but + * you can configure others through a constructor argument. + * + *

This writer can be configured with a {@link FormHttpMessageWriter} to + * delegate to. It is the preferred way of supporting both form data and + * multipart data (as opposed to registering each writer separately) so that + * when the {@link MediaType} is not specified and generics are not present on + * the target element type, we can inspect the values in the actual map and + * decide whether to write plain form data (String values only) or otherwise. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see FormHttpMessageWriter + */ +public class MultipartHttpMessageWriter extends LoggingCodecSupport + implements HttpMessageWriter> { + + /** + * THe default charset used by the writer. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + /** Suppress logging from individual part writers (full map logged at this level). */ + private static final Map DEFAULT_HINTS = Hints.from(Hints.SUPPRESS_LOGGING_HINT, true); + + + private final List> partWriters; + + @Nullable + private final HttpMessageWriter> formWriter; + + private Charset charset = DEFAULT_CHARSET; + + private final List supportedMediaTypes; + + + /** + * Constructor with a default list of part writers (String and Resource). + */ + public MultipartHttpMessageWriter() { + this(Arrays.asList( + new EncoderHttpMessageWriter<>(CharSequenceEncoder.textPlainOnly()), + new ResourceHttpMessageWriter() + )); + } + + /** + * Constructor with explicit list of writers for serializing parts. + */ + public MultipartHttpMessageWriter(List> partWriters) { + this(partWriters, new FormHttpMessageWriter()); + } + + /** + * Constructor with explicit list of writers for serializing parts and a + * writer for plain form data to fall back when no media type is specified + * and the actual map consists of String values only. + * @param partWriters the writers for serializing parts + * @param formWriter the fallback writer for form data, {@code null} by default + */ + public MultipartHttpMessageWriter(List> partWriters, + @Nullable HttpMessageWriter> formWriter) { + + this.partWriters = partWriters; + this.formWriter = formWriter; + this.supportedMediaTypes = initMediaTypes(formWriter); + } + + private static List initMediaTypes(@Nullable HttpMessageWriter formWriter) { + List result = new ArrayList<>(); + result.add(MediaType.MULTIPART_FORM_DATA); + if (formWriter != null) { + result.addAll(formWriter.getWritableMediaTypes()); + } + return Collections.unmodifiableList(result); + } + + + /** + * Return the configured part writers. + * @since 5.0.7 + */ + public List> getPartWriters() { + return Collections.unmodifiableList(this.partWriters); + } + + + /** + * Return the configured form writer. + * @since 5.1.13 + */ + @Nullable + public HttpMessageWriter> getFormWriter() { + return this.formWriter; + } + + /** + * Set the character set to use for part headers such as + * "Content-Disposition" (and its filename parameter). + *

By default this is set to "UTF-8". + */ + public void setCharset(Charset charset) { + Assert.notNull(charset, "Charset must not be null"); + this.charset = charset; + } + + /** + * Return the configured charset for part headers. + */ + public Charset getCharset() { + return this.charset; + } + + + @Override + public List getWritableMediaTypes() { + return this.supportedMediaTypes; + } + + @Override + public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) { + return (MultiValueMap.class.isAssignableFrom(elementType.toClass()) && + (mediaType == null || + this.supportedMediaTypes.stream().anyMatch(element -> element.isCompatibleWith(mediaType)))); + } + + @Override + public Mono write(Publisher> inputStream, + ResolvableType elementType, @Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage, + Map hints) { + + return Mono.from(inputStream) + .flatMap(map -> { + if (this.formWriter == null || isMultipart(map, mediaType)) { + return writeMultipart(map, outputMessage, hints); + } + else { + @SuppressWarnings("unchecked") + Mono> input = Mono.just((MultiValueMap) map); + return this.formWriter.write(input, elementType, mediaType, outputMessage, hints); + } + }); + } + + private boolean isMultipart(MultiValueMap map, @Nullable MediaType contentType) { + if (contentType != null) { + return MediaType.MULTIPART_FORM_DATA.includes(contentType); + } + for (List values : map.values()) { + for (Object value : values) { + if (value != null && !(value instanceof String)) { + return true; + } + } + } + return false; + } + + private Mono writeMultipart( + MultiValueMap map, ReactiveHttpOutputMessage outputMessage, Map hints) { + + byte[] boundary = generateMultipartBoundary(); + + Map params = new HashMap<>(2); + params.put("boundary", new String(boundary, StandardCharsets.US_ASCII)); + params.put("charset", getCharset().name()); + + outputMessage.getHeaders().setContentType(new MediaType(MediaType.MULTIPART_FORM_DATA, params)); + + LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Encoding " + + (isEnableLoggingRequestDetails() ? + LogFormatUtils.formatValue(map, !traceOn) : + "parts " + map.keySet() + " (content masked)")); + + DataBufferFactory bufferFactory = outputMessage.bufferFactory(); + + Flux body = Flux.fromIterable(map.entrySet()) + .concatMap(entry -> encodePartValues(boundary, entry.getKey(), entry.getValue(), bufferFactory)) + .concatWith(generateLastLine(boundary, bufferFactory)) + .doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release); + + return outputMessage.writeWith(body); + } + + /** + * Generate a multipart boundary. + *

By default delegates to {@link MimeTypeUtils#generateMultipartBoundary()}. + */ + protected byte[] generateMultipartBoundary() { + return MimeTypeUtils.generateMultipartBoundary(); + } + + private Flux encodePartValues( + byte[] boundary, String name, List values, DataBufferFactory bufferFactory) { + + return Flux.fromIterable(values) + .concatMap(value -> encodePart(boundary, name, value, bufferFactory)); + } + + @SuppressWarnings("unchecked") + private Flux encodePart(byte[] boundary, String name, T value, DataBufferFactory bufferFactory) { + MultipartHttpOutputMessage outputMessage = new MultipartHttpOutputMessage(bufferFactory, getCharset()); + HttpHeaders outputHeaders = outputMessage.getHeaders(); + + T body; + ResolvableType resolvableType = null; + if (value instanceof HttpEntity) { + HttpEntity httpEntity = (HttpEntity) value; + outputHeaders.putAll(httpEntity.getHeaders()); + body = httpEntity.getBody(); + Assert.state(body != null, "MultipartHttpMessageWriter only supports HttpEntity with body"); + if (httpEntity instanceof ResolvableTypeProvider) { + resolvableType = ((ResolvableTypeProvider) httpEntity).getResolvableType(); + } + } + else { + body = value; + } + if (resolvableType == null) { + resolvableType = ResolvableType.forClass(body.getClass()); + } + + if (!outputHeaders.containsKey(HttpHeaders.CONTENT_DISPOSITION)) { + if (body instanceof Resource) { + outputHeaders.setContentDispositionFormData(name, ((Resource) body).getFilename()); + } + else if (resolvableType.resolve() == Resource.class) { + body = (T) Mono.from((Publisher) body).doOnNext(o -> outputHeaders + .setContentDispositionFormData(name, ((Resource) o).getFilename())); + } + else { + outputHeaders.setContentDispositionFormData(name, null); + } + } + + MediaType contentType = outputHeaders.getContentType(); + + final ResolvableType finalBodyType = resolvableType; + Optional> writer = this.partWriters.stream() + .filter(partWriter -> partWriter.canWrite(finalBodyType, contentType)) + .findFirst(); + + if (!writer.isPresent()) { + return Flux.error(new CodecException("No suitable writer found for part: " + name)); + } + + Publisher bodyPublisher = + body instanceof Publisher ? (Publisher) body : Mono.just(body); + + // The writer will call MultipartHttpOutputMessage#write which doesn't actually write + // but only stores the body Flux and returns Mono.empty(). + + Mono partContentReady = ((HttpMessageWriter) writer.get()) + .write(bodyPublisher, resolvableType, contentType, outputMessage, DEFAULT_HINTS); + + // After partContentReady, we can access the part content from MultipartHttpOutputMessage + // and use it for writing to the actual request body + + Flux partContent = partContentReady.thenMany(Flux.defer(outputMessage::getBody)); + + return Flux.concat( + generateBoundaryLine(boundary, bufferFactory), + partContent, + generateNewLine(bufferFactory)); + } + + + private Mono generateBoundaryLine(byte[] boundary, DataBufferFactory bufferFactory) { + return Mono.fromCallable(() -> { + DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 4); + buffer.write((byte)'-'); + buffer.write((byte)'-'); + buffer.write(boundary); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + + private Mono generateNewLine(DataBufferFactory bufferFactory) { + return Mono.fromCallable(() -> { + DataBuffer buffer = bufferFactory.allocateBuffer(2); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + + private Mono generateLastLine(byte[] boundary, DataBufferFactory bufferFactory) { + return Mono.fromCallable(() -> { + DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 6); + buffer.write((byte)'-'); + buffer.write((byte)'-'); + buffer.write(boundary); + buffer.write((byte)'-'); + buffer.write((byte)'-'); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + + + private static class MultipartHttpOutputMessage implements ReactiveHttpOutputMessage { + + private final DataBufferFactory bufferFactory; + + private final Charset charset; + + private final HttpHeaders headers = new HttpHeaders(); + + private final AtomicBoolean committed = new AtomicBoolean(); + + @Nullable + private Flux body; + + public MultipartHttpOutputMessage(DataBufferFactory bufferFactory, Charset charset) { + this.bufferFactory = bufferFactory; + this.charset = charset; + } + + @Override + public HttpHeaders getHeaders() { + return (this.body != null ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public DataBufferFactory bufferFactory() { + return this.bufferFactory; + } + + @Override + public void beforeCommit(Supplier> action) { + this.committed.set(true); + } + + @Override + public boolean isCommitted() { + return this.committed.get(); + } + + @Override + public Mono writeWith(Publisher body) { + if (this.body != null) { + return Mono.error(new IllegalStateException("Multiple calls to writeWith() not supported")); + } + this.body = generateHeaders().concatWith(body); + + // We don't actually want to write (just save the body Flux) + return Mono.empty(); + } + + private Mono generateHeaders() { + return Mono.fromCallable(() -> { + DataBuffer buffer = this.bufferFactory.allocateBuffer(); + for (Map.Entry> entry : this.headers.entrySet()) { + byte[] headerName = entry.getKey().getBytes(this.charset); + for (String headerValueString : entry.getValue()) { + byte[] headerValue = headerValueString.getBytes(this.charset); + buffer.write(headerName); + buffer.write((byte)':'); + buffer.write((byte)' '); + buffer.write(headerValue); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + } + } + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + + @Override + public Mono writeAndFlushWith(Publisher> body) { + return Mono.error(new UnsupportedOperationException()); + } + + public Flux getBody() { + return (this.body != null ? this.body : + Flux.error(new IllegalStateException("Body has not been written yet"))); + } + + @Override + public Mono setComplete() { + return Mono.error(new UnsupportedOperationException()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/Part.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/Part.java new file mode 100644 index 0000000000000000000000000000000000000000..c611adf22ae67c648a67e56a7d20094f7c0097ac --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/Part.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpHeaders; + +/** + * Representation for a part in a "multipart/form-data" request. + * + *

The origin of a multipart request may be a browser form in which case each + * part is either a {@link FormFieldPart} or a {@link FilePart}. + * + *

Multipart requests may also be used outside of a browser for data of any + * content type (e.g. JSON, PDF, etc). + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see RFC 7578 (multipart/form-data) + * @see RFC 2183 (Content-Disposition) + * @see HTML5 (multipart forms) + */ +public interface Part { + + /** + * Return the name of the part in the multipart form. + * @return the name of the part, never {@code null} or empty + */ + String name(); + + /** + * Return the headers associated with the part. + */ + HttpHeaders headers(); + + /** + * Return the content for this part. + *

Note that for a {@link FormFieldPart} the content may be accessed + * more easily via {@link FormFieldPart#value()}. + */ + Flux content(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java new file mode 100644 index 0000000000000000000000000000000000000000..f4194071b9f3f6382942e790a3fa743a76cc0ab1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java @@ -0,0 +1,540 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.nio.file.OpenOption; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +import org.synchronoss.cloud.nio.multipart.DefaultPartBodyStreamStorageFactory; +import org.synchronoss.cloud.nio.multipart.Multipart; +import org.synchronoss.cloud.nio.multipart.MultipartContext; +import org.synchronoss.cloud.nio.multipart.MultipartUtils; +import org.synchronoss.cloud.nio.multipart.NioMultipartParser; +import org.synchronoss.cloud.nio.multipart.NioMultipartParserListener; +import org.synchronoss.cloud.nio.multipart.PartBodyStreamStorageFactory; +import org.synchronoss.cloud.nio.stream.storage.StreamStorage; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.LoggingCodecSupport; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@code HttpMessageReader} for parsing {@code "multipart/form-data"} requests + * to a stream of {@link Part}'s using the Synchronoss NIO Multipart library. + * + *

This reader can be provided to {@link MultipartHttpMessageReader} in order + * to aggregate all parts into a Map. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @author Brian Clozel + * @since 5.0 + * @see Synchronoss NIO Multipart + * @see MultipartHttpMessageReader + */ +public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implements HttpMessageReader { + + // Static DataBufferFactory to copy from FileInputStream or wrap bytes[]. + private static final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + + private int maxInMemorySize = -1; + + private long maxDiskUsagePerPart = -1; + + private int maxParts = -1; + + + /** + * Configure the maximum amount of memory that is allowed to use per part. + * When the limit is exceeded: + *

    + *
  • file parts are written to a temporary file. + *
  • non-file parts are rejected with {@link DataBufferLimitException}. + *
+ *

By default in 5.1 this is set to -1 in which case this limit is + * not enforced and all parts may be written to disk and are limited only + * by the {@link #setMaxDiskUsagePerPart(long) maxDiskUsagePerPart} property. + * In 5.2 this default value for this limit is set to 256K. + * @param byteCount the in-memory limit in bytes, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Get the {@link #setMaxInMemorySize configured} maximum in-memory size. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + /** + * Configure the maximum amount of disk space allowed for file parts. + *

By default this is set to -1. + * @param maxDiskUsagePerPart the disk limit in bytes, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxDiskUsagePerPart(long maxDiskUsagePerPart) { + this.maxDiskUsagePerPart = maxDiskUsagePerPart; + } + + /** + * Get the {@link #setMaxDiskUsagePerPart configured} maximum disk usage. + * @since 5.1.11 + */ + public long getMaxDiskUsagePerPart() { + return this.maxDiskUsagePerPart; + } + + /** + * Specify the maximum number of parts allowed in a given multipart request. + * @since 5.1.11 + */ + public void setMaxParts(int maxParts) { + this.maxParts = maxParts; + } + + /** + * Return the {@link #setMaxParts configured} limit on the number of parts. + * @since 5.1.11 + */ + public int getMaxParts() { + return this.maxParts; + } + + + @Override + public List getReadableMediaTypes() { + return Collections.singletonList(MediaType.MULTIPART_FORM_DATA); + } + + @Override + public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { + return Part.class.equals(elementType.toClass()) && + (mediaType == null || MediaType.MULTIPART_FORM_DATA.isCompatibleWith(mediaType)); + } + + @Override + public Flux read(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + return Flux.create(new SynchronossPartGenerator(message)) + .doOnNext(part -> { + if (!Hints.isLoggingSuppressed(hints)) { + LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Parsed " + + (isEnableLoggingRequestDetails() ? + LogFormatUtils.formatValue(part, !traceOn) : + "parts '" + part.name() + "' (content masked)")); + } + }); + } + + @Override + public Mono readMono(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + return Mono.error(new UnsupportedOperationException("Cannot read multipart request body into single Part")); + } + + + /** + * Subscribe to the input stream and feed the Synchronoss parser. Then listen + * for parser output, creating parts, and pushing them into the FluxSink. + */ + private class SynchronossPartGenerator extends BaseSubscriber implements Consumer> { + + private final ReactiveHttpInputMessage inputMessage; + + private final LimitedPartBodyStreamStorageFactory storageFactory = new LimitedPartBodyStreamStorageFactory(); + + @Nullable + private NioMultipartParserListener listener; + + @Nullable + private NioMultipartParser parser; + + public SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage) { + this.inputMessage = inputMessage; + } + + @Override + public void accept(FluxSink sink) { + HttpHeaders headers = this.inputMessage.getHeaders(); + MediaType mediaType = headers.getContentType(); + Assert.state(mediaType != null, "No content type set"); + + int length = getContentLength(headers); + Charset charset = Optional.ofNullable(mediaType.getCharset()).orElse(StandardCharsets.UTF_8); + MultipartContext context = new MultipartContext(mediaType.toString(), length, charset.name()); + + this.listener = new FluxSinkAdapterListener(sink, context, this.storageFactory); + + this.parser = Multipart + .multipart(context) + .usePartBodyStreamStorageFactory(this.storageFactory) + .forNIO(this.listener); + + this.inputMessage.getBody().subscribe(this); + } + + @Override + protected void hookOnNext(DataBuffer buffer) { + Assert.state(this.parser != null && this.listener != null, "Not initialized yet"); + + int size = buffer.readableByteCount(); + this.storageFactory.increaseByteCount(size); + byte[] resultBytes = new byte[size]; + buffer.read(resultBytes); + + try { + this.parser.write(resultBytes); + } + catch (IOException ex) { + cancel(); + int index = this.storageFactory.getCurrentPartIndex(); + this.listener.onError("Parser error for part [" + index + "]", ex); + } + finally { + DataBufferUtils.release(buffer); + } + } + + @Override + protected void hookOnError(Throwable ex) { + if (this.listener != null) { + int index = this.storageFactory.getCurrentPartIndex(); + this.listener.onError("Failure while parsing part[" + index + "]", ex); + } + } + + @Override + protected void hookOnComplete() { + if (this.listener != null) { + this.listener.onAllPartsFinished(); + } + } + + @Override + protected void hookFinally(SignalType type) { + try { + if (this.parser != null) { + this.parser.close(); + } + } + catch (IOException ex) { + // ignore + } + } + + private int getContentLength(HttpHeaders headers) { + // Until this is fixed https://github.com/synchronoss/nio-multipart/issues/10 + long length = headers.getContentLength(); + return (int) length == length ? (int) length : -1; + } + } + + + private class LimitedPartBodyStreamStorageFactory implements PartBodyStreamStorageFactory { + + private final PartBodyStreamStorageFactory storageFactory = (maxInMemorySize > 0 ? + new DefaultPartBodyStreamStorageFactory(maxInMemorySize) : + new DefaultPartBodyStreamStorageFactory()); + + private int index = 1; + + private boolean isFilePart; + + private long partSize; + + public int getCurrentPartIndex() { + return this.index; + } + + @Override + public StreamStorage newStreamStorageForPartBody(Map> headers, int index) { + this.index = index; + this.isFilePart = (MultipartUtils.getFileName(headers) != null); + this.partSize = 0; + if (maxParts > 0 && index > maxParts) { + throw new DecodingException("Too many parts (" + index + " allowed)"); + } + return this.storageFactory.newStreamStorageForPartBody(headers, index); + } + + public void increaseByteCount(long byteCount) { + this.partSize += byteCount; + if (maxInMemorySize > 0 && !this.isFilePart && this.partSize >= maxInMemorySize) { + throw new DataBufferLimitException("Part[" + this.index + "] " + + "exceeded the in-memory limit of " + maxInMemorySize + " bytes"); + } + if (maxDiskUsagePerPart > 0 && this.isFilePart && this.partSize > maxDiskUsagePerPart) { + throw new DecodingException("Part[" + this.index + "] " + + "exceeded the disk usage limit of " + maxDiskUsagePerPart + " bytes"); + } + } + + public void partFinished() { + this.index++; + this.isFilePart = false; + this.partSize = 0; + } + } + + + /** + * Listen for parser output and adapt to {@code Flux>}. + */ + private static class FluxSinkAdapterListener implements NioMultipartParserListener { + + private final FluxSink sink; + + private final MultipartContext context; + + private final LimitedPartBodyStreamStorageFactory storageFactory; + + private final AtomicInteger terminated = new AtomicInteger(0); + + FluxSinkAdapterListener( + FluxSink sink, MultipartContext context, LimitedPartBodyStreamStorageFactory factory) { + + this.sink = sink; + this.context = context; + this.storageFactory = factory; + } + + @Override + public void onPartFinished(StreamStorage storage, Map> headers) { + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.putAll(headers); + this.storageFactory.partFinished(); + this.sink.next(createPart(storage, httpHeaders)); + } + + private Part createPart(StreamStorage storage, HttpHeaders httpHeaders) { + String filename = MultipartUtils.getFileName(httpHeaders); + if (filename != null) { + return new SynchronossFilePart(httpHeaders, filename, storage); + } + else if (MultipartUtils.isFormField(httpHeaders, this.context)) { + String value = MultipartUtils.readFormParameterValue(storage, httpHeaders); + return new SynchronossFormFieldPart(httpHeaders, value); + } + else { + return new SynchronossPart(httpHeaders, storage); + } + } + + @Override + public void onError(String message, Throwable cause) { + if (this.terminated.getAndIncrement() == 0) { + this.sink.error(new DecodingException(message, cause)); + } + } + + @Override + public void onAllPartsFinished() { + if (this.terminated.getAndIncrement() == 0) { + this.sink.complete(); + } + } + + @Override + public void onNestedPartStarted(Map> headersFromParentPart) { + } + + @Override + public void onNestedPartFinished() { + } + } + + + private abstract static class AbstractSynchronossPart implements Part { + + private final String name; + + private final HttpHeaders headers; + + AbstractSynchronossPart(HttpHeaders headers) { + Assert.notNull(headers, "HttpHeaders is required"); + this.name = MultipartUtils.getFieldName(headers); + this.headers = headers; + } + + @Override + public String name() { + return this.name; + } + + @Override + public HttpHeaders headers() { + return this.headers; + } + + @Override + public String toString() { + return "Part '" + this.name + "', headers=" + this.headers; + } + } + + + private static class SynchronossPart extends AbstractSynchronossPart { + + private final StreamStorage storage; + + SynchronossPart(HttpHeaders headers, StreamStorage storage) { + super(headers); + Assert.notNull(storage, "StreamStorage is required"); + this.storage = storage; + } + + @Override + public Flux content() { + return DataBufferUtils.readInputStream(getStorage()::getInputStream, bufferFactory, 4096); + } + + protected StreamStorage getStorage() { + return this.storage; + } + } + + + private static class SynchronossFilePart extends SynchronossPart implements FilePart { + + private static final OpenOption[] FILE_CHANNEL_OPTIONS = + {StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE}; + + private final String filename; + + SynchronossFilePart(HttpHeaders headers, String filename, StreamStorage storage) { + super(headers, storage); + this.filename = filename; + } + + @Override + public String filename() { + return this.filename; + } + + @Override + public Mono transferTo(Path dest) { + ReadableByteChannel input = null; + FileChannel output = null; + try { + input = Channels.newChannel(getStorage().getInputStream()); + output = FileChannel.open(dest, FILE_CHANNEL_OPTIONS); + long size = (input instanceof FileChannel ? ((FileChannel) input).size() : Long.MAX_VALUE); + long totalWritten = 0; + while (totalWritten < size) { + long written = output.transferFrom(input, totalWritten, size - totalWritten); + if (written <= 0) { + break; + } + totalWritten += written; + } + } + catch (IOException ex) { + return Mono.error(ex); + } + finally { + if (input != null) { + try { + input.close(); + } + catch (IOException ignored) { + } + } + if (output != null) { + try { + output.close(); + } + catch (IOException ignored) { + } + } + } + return Mono.empty(); + } + + @Override + public String toString() { + return "Part '" + name() + "', filename='" + this.filename + "'"; + } + } + + + private static class SynchronossFormFieldPart extends AbstractSynchronossPart implements FormFieldPart { + + private final String content; + + SynchronossFormFieldPart(HttpHeaders headers, String content) { + super(headers); + this.content = content; + } + + @Override + public String value() { + return this.content; + } + + @Override + public Flux content() { + byte[] bytes = this.content.getBytes(getCharset()); + return Flux.just(bufferFactory.wrap(bytes)); + } + + private Charset getCharset() { + String name = MultipartUtils.getCharEncoding(headers()); + return (name != null ? Charset.forName(name) : StandardCharsets.UTF_8); + } + + @Override + public String toString() { + return "Part '" + name() + "=" + this.content + "'"; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/package-info.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..44dad3697d6ec660854b5bdff84e7a7e422a0c22 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/package-info.java @@ -0,0 +1,9 @@ +/** + * Multipart support. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.codec.multipart; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/codec/package-info.java b/spring-web/src/main/java/org/springframework/http/codec/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..a4c6b8505b208b213b96d11f7e2ffadae82419b9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/package-info.java @@ -0,0 +1,15 @@ +/** + * Provides implementations of {@link org.springframework.core.codec.Encoder} + * and {@link org.springframework.core.codec.Decoder} for web use. + * + *

Also declares a high-level + * {@link org.springframework.http.codec.HttpMessageReader} and + * {@link org.springframework.http.codec.HttpMessageWriter} for reading and + * writing the body of HTTP requests and responses. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.codec; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufCodecSupport.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufCodecSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..84bcc880bb092be744cfc03483ef1aebbb45fb28 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufCodecSupport.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.protobuf; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; + +/** + * Base class providing support methods for Protobuf encoding and decoding. + * + * @author Sebastien Deleuze + * @since 5.1 + */ +public abstract class ProtobufCodecSupport { + + static final List MIME_TYPES = Collections.unmodifiableList( + Arrays.asList( + new MimeType("application", "x-protobuf"), + new MimeType("application", "octet-stream"))); + + static final String DELIMITED_KEY = "delimited"; + + static final String DELIMITED_VALUE = "true"; + + + protected boolean supportsMimeType(@Nullable MimeType mimeType) { + return (mimeType == null || MIME_TYPES.stream().anyMatch(m -> m.isCompatibleWith(mimeType))); + } + + protected List getMimeTypes() { + return MIME_TYPES; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..a2aba3addd51366442851a70b2f1e53ca30e195e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java @@ -0,0 +1,318 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.protobuf; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Function; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.Message; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ConcurrentReferenceHashMap; +import org.springframework.util.MimeType; + +/** + * A {@code Decoder} that reads {@link com.google.protobuf.Message}s using + * Google Protocol Buffers. + * + *

Flux deserialized via + * {@link #decode(Publisher, ResolvableType, MimeType, Map)} are expected to use + * + * delimited Protobuf messages with the size of each message specified before + * the message itself. Single values deserialized via + * {@link #decodeToMono(Publisher, ResolvableType, MimeType, Map)} are expected + * to use regular Protobuf message format (without the size prepended before + * the message). + * + *

Notice that default instance of Protobuf message produces empty byte + * array, so {@code Mono.just(Msg.getDefaultInstance())} sent over the network + * will be deserialized as an empty {@link Mono}. + * + *

To generate {@code Message} Java classes, you need to install the + * {@code protoc} binary. + * + *

This decoder requires Protobuf 3 or higher, and supports + * {@code "application/x-protobuf"} and {@code "application/octet-stream"} with + * the official {@code "com.google.protobuf:protobuf-java"} library. + * + * @author Sébastien Deleuze + * @since 5.1 + * @see ProtobufEncoder + */ +public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder { + + /** The default max size for aggregating messages. */ + protected static final int DEFAULT_MESSAGE_MAX_SIZE = 64 * 1024; + + private static final ConcurrentMap, Method> methodCache = new ConcurrentReferenceHashMap<>(); + + + private final ExtensionRegistry extensionRegistry; + + private int maxMessageSize = DEFAULT_MESSAGE_MAX_SIZE; + + + /** + * Construct a new {@code ProtobufDecoder}. + */ + public ProtobufDecoder() { + this(ExtensionRegistry.newInstance()); + } + + /** + * Construct a new {@code ProtobufDecoder} with an initializer that allows the + * registration of message extensions. + * @param extensionRegistry a message extension registry + */ + public ProtobufDecoder(ExtensionRegistry extensionRegistry) { + Assert.notNull(extensionRegistry, "ExtensionRegistry must not be null"); + this.extensionRegistry = extensionRegistry; + } + + + /** + * The max size allowed per message. + *

By default in 5.1 this is set to 64K. In 5.2 the default for this limit + * is set to 256K. + * @param maxMessageSize the max size per message, or -1 for unlimited + */ + public void setMaxMessageSize(int maxMessageSize) { + this.maxMessageSize = maxMessageSize; + } + + /** + * Return the {@link #setMaxMessageSize configured} message size limit. + * @since 5.1.11 + */ + public int getMaxMessageSize() { + return this.maxMessageSize; + } + + + @Override + public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { + return Message.class.isAssignableFrom(elementType.toClass()) && supportsMimeType(mimeType); + } + + @Override + public Flux decode(Publisher inputStream, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + MessageDecoderFunction decoderFunction = + new MessageDecoderFunction(elementType, this.maxMessageSize); + + return Flux.from(inputStream) + .flatMapIterable(decoderFunction) + .doOnTerminate(decoderFunction::discard); + } + + @Override + public Mono decodeToMono(Publisher inputStream, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + return DataBufferUtils.join(inputStream, getMaxMessageSize()).map(dataBuffer -> { + try { + Message.Builder builder = getMessageBuilder(elementType.toClass()); + ByteBuffer buffer = dataBuffer.asByteBuffer(); + builder.mergeFrom(CodedInputStream.newInstance(buffer), this.extensionRegistry); + return builder.build(); + } + catch (IOException ex) { + throw new DecodingException("I/O error while parsing input stream", ex); + } + catch (Exception ex) { + throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex); + } + finally { + DataBufferUtils.release(dataBuffer); + } + } + ); + } + + /** + * Create a new {@code Message.Builder} instance for the given class. + *

This method uses a ConcurrentHashMap for caching method lookups. + */ + private static Message.Builder getMessageBuilder(Class clazz) throws Exception { + Method method = methodCache.get(clazz); + if (method == null) { + method = clazz.getMethod("newBuilder"); + methodCache.put(clazz, method); + } + return (Message.Builder) method.invoke(clazz); + } + + @Override + public List getDecodableMimeTypes() { + return getMimeTypes(); + } + + + private class MessageDecoderFunction implements Function> { + + private final ResolvableType elementType; + + private final int maxMessageSize; + + @Nullable + private DataBuffer output; + + private int messageBytesToRead; + + private int offset; + + + public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) { + this.elementType = elementType; + this.maxMessageSize = maxMessageSize; + } + + + @Override + public Iterable apply(DataBuffer input) { + try { + List messages = new ArrayList<>(); + int remainingBytesToRead; + int chunkBytesToRead; + + do { + if (this.output == null) { + if (!readMessageSize(input)) { + return messages; + } + if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) { + throw new DataBufferLimitException( + "The number of bytes to read for message " + + "(" + this.messageBytesToRead + ") exceeds " + + "the configured limit (" + this.maxMessageSize + ")"); + } + this.output = input.factory().allocateBuffer(this.messageBytesToRead); + } + + chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ? + input.readableByteCount() : this.messageBytesToRead; + remainingBytesToRead = input.readableByteCount() - chunkBytesToRead; + + byte[] bytesToWrite = new byte[chunkBytesToRead]; + input.read(bytesToWrite, 0, chunkBytesToRead); + this.output.write(bytesToWrite); + this.messageBytesToRead -= chunkBytesToRead; + + if (this.messageBytesToRead == 0) { + CodedInputStream stream = CodedInputStream.newInstance(this.output.asByteBuffer()); + DataBufferUtils.release(this.output); + this.output = null; + Message message = getMessageBuilder(this.elementType.toClass()) + .mergeFrom(stream, extensionRegistry) + .build(); + messages.add(message); + } + } while (remainingBytesToRead > 0); + return messages; + } + catch (DecodingException ex) { + throw ex; + } + catch (IOException ex) { + throw new DecodingException("I/O error while parsing input stream", ex); + } + catch (Exception ex) { + throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex); + } + finally { + DataBufferUtils.release(input); + } + } + + /** + * Parse message size as a varint from the input stream, updating {@code messageBytesToRead} and + * {@code offset} fields if needed to allow processing of upcoming chunks. + * Inspired from {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)} + * + * @return {code true} when the message size is parsed successfully, {code false} when the message size is + * truncated + * @see Base 128 Varints + */ + private boolean readMessageSize(DataBuffer input) { + if (this.offset == 0) { + if (input.readableByteCount() == 0) { + return false; + } + int firstByte = input.read(); + if ((firstByte & 0x80) == 0) { + this.messageBytesToRead = firstByte; + return true; + } + this.messageBytesToRead = firstByte & 0x7f; + this.offset = 7; + } + + if (this.offset < 32) { + for (; this.offset < 32; this.offset += 7) { + if (input.readableByteCount() == 0) { + return false; + } + final int b = input.read(); + this.messageBytesToRead |= (b & 0x7f) << offset; + if ((b & 0x80) == 0) { + this.offset = 0; + return true; + } + } + } + // Keep reading up to 64 bits. + for (; this.offset < 64; this.offset += 7) { + if (input.readableByteCount() == 0) { + return false; + } + final int b = input.read(); + if ((b & 0x80) == 0) { + this.offset = 0; + return true; + } + } + this.offset = 0; + throw new DecodingException("Cannot parse message size: malformed varint"); + } + + public void discard() { + if (this.output != null) { + DataBufferUtils.release(this.output); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java new file mode 100644 index 0000000000000000000000000000000000000000..3be1ac477442d0ca63c8908c38db780f869d4908 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.protobuf; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.google.protobuf.Message; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageEncoder; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; + +/** + * An {@code Encoder} that writes {@link com.google.protobuf.Message}s + * using Google Protocol Buffers. + * + *

Flux are serialized using + * delimited Protobuf messages + * with the size of each message specified before the message itself. Single values are + * serialized using regular Protobuf message format (without the size prepended before the message). + * + *

To generate {@code Message} Java classes, you need to install the {@code protoc} binary. + * + *

This encoder requires Protobuf 3 or higher, and supports + * {@code "application/x-protobuf"} and {@code "application/octet-stream"} with the official + * {@code "com.google.protobuf:protobuf-java"} library. + * + * @author Sébastien Deleuze + * @since 5.1 + * @see ProtobufDecoder + */ +public class ProtobufEncoder extends ProtobufCodecSupport implements HttpMessageEncoder { + + private static final List streamingMediaTypes = MIME_TYPES + .stream() + .map(mimeType -> new MediaType(mimeType.getType(), mimeType.getSubtype(), + Collections.singletonMap(DELIMITED_KEY, DELIMITED_VALUE))) + .collect(Collectors.toList()); + + + @Override + public boolean canEncode(ResolvableType elementType, @Nullable MimeType mimeType) { + return Message.class.isAssignableFrom(elementType.toClass()) && supportsMimeType(mimeType); + } + + @Override + public Flux encode(Publisher inputStream, DataBufferFactory bufferFactory, + ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { + + return Flux.from(inputStream) + .map(message -> { + DataBuffer buffer = bufferFactory.allocateBuffer(); + boolean release = true; + try { + if (!(inputStream instanceof Mono)) { + message.writeDelimitedTo(buffer.asOutputStream()); + } + else { + message.writeTo(buffer.asOutputStream()); + } + release = false; + return buffer; + } + catch (IOException ex) { + throw new IllegalStateException("Unexpected I/O error while writing to data buffer", ex); + } + finally { + if (release) { + DataBufferUtils.release(buffer); + } + } + }); + } + + @Override + public List getStreamingMediaTypes() { + return streamingMediaTypes; + } + + @Override + public List getEncodableMimeTypes() { + return getMimeTypes(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..0825637b9d8b3dc81df7f42790cf85a1ebd4095e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.protobuf; + +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.Message; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.codec.Encoder; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.HttpMessageEncoder; +import org.springframework.lang.Nullable; +import org.springframework.util.ConcurrentReferenceHashMap; + +/** + * {@code HttpMessageWriter} that can write a protobuf {@link Message} and adds + * {@code X-Protobuf-Schema}, {@code X-Protobuf-Message} headers and a + * {@code delimited=true} parameter is added to the content type if a flux is serialized. + * + *

For {@code HttpMessageReader}, just use + * {@code new DecoderHttpMessageReader(new ProtobufDecoder())}. + * + * @author Sébastien Deleuze + * @since 5.1 + * @see ProtobufEncoder + */ +public class ProtobufHttpMessageWriter extends EncoderHttpMessageWriter { + + private static final String X_PROTOBUF_SCHEMA_HEADER = "X-Protobuf-Schema"; + + private static final String X_PROTOBUF_MESSAGE_HEADER = "X-Protobuf-Message"; + + private static final ConcurrentMap, Method> methodCache = new ConcurrentReferenceHashMap<>(); + + + /** + * Create a new {@code ProtobufHttpMessageWriter} with a default {@link ProtobufEncoder}. + */ + public ProtobufHttpMessageWriter() { + super(new ProtobufEncoder()); + } + + /** + * Create a new {@code ProtobufHttpMessageWriter} with the given encoder. + * @param encoder the Protobuf message encoder to use + */ + public ProtobufHttpMessageWriter(Encoder encoder) { + super(encoder); + } + + + @SuppressWarnings("unchecked") + @Override + public Mono write(Publisher inputStream, ResolvableType elementType, + @Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map hints) { + + try { + Message.Builder builder = getMessageBuilder(elementType.toClass()); + Descriptors.Descriptor descriptor = builder.getDescriptorForType(); + message.getHeaders().add(X_PROTOBUF_SCHEMA_HEADER, descriptor.getFile().getName()); + message.getHeaders().add(X_PROTOBUF_MESSAGE_HEADER, descriptor.getFullName()); + if (inputStream instanceof Flux) { + if (mediaType == null) { + message.getHeaders().setContentType(((HttpMessageEncoder)getEncoder()).getStreamingMediaTypes().get(0)); + } + else if (!ProtobufEncoder.DELIMITED_VALUE.equals(mediaType.getParameters().get(ProtobufEncoder.DELIMITED_KEY))) { + Map parameters = new HashMap<>(mediaType.getParameters()); + parameters.put(ProtobufEncoder.DELIMITED_KEY, ProtobufEncoder.DELIMITED_VALUE); + message.getHeaders().setContentType(new MediaType(mediaType.getType(), mediaType.getSubtype(), parameters)); + } + } + return super.write(inputStream, elementType, mediaType, message, hints); + } + catch (Exception ex) { + return Mono.error(new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex)); + } + } + + /** + * Create a new {@code Message.Builder} instance for the given class. + *

This method uses a ConcurrentHashMap for caching method lookups. + */ + private static Message.Builder getMessageBuilder(Class clazz) throws Exception { + Method method = methodCache.get(clazz); + if (method == null) { + method = clazz.getMethod("newBuilder"); + methodCache.put(clazz, method); + } + return (Message.Builder) method.invoke(clazz); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/package-info.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..c9dfcb9196fe198db87da30af8209713c8750730 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/package-info.java @@ -0,0 +1,10 @@ +/** + * Provides an encoder and a decoder for + * Google Protocol Buffers. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.codec.protobuf; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/BaseCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/support/BaseCodecConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..c82e9ed9064bb51ff712302bbec53739aaeae69a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/support/BaseCodecConfigurer.java @@ -0,0 +1,244 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.http.codec.CodecConfigurer; +import org.springframework.http.codec.DecoderHttpMessageReader; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.util.Assert; + +/** + * Default implementation of {@link CodecConfigurer} that serves as a base for + * client and server specific variants. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +abstract class BaseCodecConfigurer implements CodecConfigurer { + + protected final BaseDefaultCodecs defaultCodecs; + + protected final DefaultCustomCodecs customCodecs; + + + /** + * Constructor with the base {@link BaseDefaultCodecs} to use, which can be + * a client or server specific variant. + */ + BaseCodecConfigurer(BaseDefaultCodecs defaultCodecs) { + Assert.notNull(defaultCodecs, "'defaultCodecs' is required"); + this.defaultCodecs = defaultCodecs; + this.customCodecs = new DefaultCustomCodecs(); + } + + /** + * Create a deep copy of the given {@link BaseCodecConfigurer}. + * @since 5.1.12 + */ + protected BaseCodecConfigurer(BaseCodecConfigurer other) { + this.defaultCodecs = other.cloneDefaultCodecs(); + this.customCodecs = new DefaultCustomCodecs(other.customCodecs); + } + + /** + * Sub-classes should override this to create deep copy of + * {@link BaseDefaultCodecs} which can can be client or server specific. + * @since 5.1.12 + */ + protected abstract BaseDefaultCodecs cloneDefaultCodecs(); + + + @Override + public DefaultCodecs defaultCodecs() { + return this.defaultCodecs; + } + + @Override + public void registerDefaults(boolean shouldRegister) { + this.defaultCodecs.registerDefaults(shouldRegister); + } + + @Override + public CustomCodecs customCodecs() { + return this.customCodecs; + } + + @Override + public List> getReaders() { + this.defaultCodecs.applyDefaultConfig(this.customCodecs); + + List> result = new ArrayList<>(); + result.addAll(this.defaultCodecs.getTypedReaders()); + result.addAll(this.customCodecs.getTypedReaders().keySet()); + result.addAll(this.defaultCodecs.getObjectReaders()); + result.addAll(this.customCodecs.getObjectReaders().keySet()); + result.addAll(this.defaultCodecs.getCatchAllReaders()); + return result; + } + + @Override + public List> getWriters() { + this.defaultCodecs.applyDefaultConfig(this.customCodecs); + + List> result = new ArrayList<>(); + result.addAll(this.defaultCodecs.getTypedWriters()); + result.addAll(this.customCodecs.getTypedWriters().keySet()); + result.addAll(this.defaultCodecs.getObjectWriters()); + result.addAll(this.customCodecs.getObjectWriters().keySet()); + result.addAll(this.defaultCodecs.getCatchAllWriters()); + return result; + } + + @Override + public abstract CodecConfigurer clone(); + + + /** + * Default implementation of {@code CustomCodecs}. + */ + protected static final class DefaultCustomCodecs implements CustomCodecs { + + private final Map, Boolean> typedReaders = new LinkedHashMap<>(4); + + private final Map, Boolean> typedWriters = new LinkedHashMap<>(4); + + private final Map, Boolean> objectReaders = new LinkedHashMap<>(4); + + private final Map, Boolean> objectWriters = new LinkedHashMap<>(4); + + private final List> defaultConfigConsumers = new ArrayList<>(4); + + DefaultCustomCodecs() { + } + + /** + * Create a deep copy of the given {@link DefaultCustomCodecs}. + * @since 5.1.12 + */ + DefaultCustomCodecs(DefaultCustomCodecs other) { + this.typedReaders.putAll(other.typedReaders); + this.typedWriters.putAll(other.typedWriters); + this.objectReaders.putAll(other.objectReaders); + this.objectWriters.putAll(other.objectWriters); + } + + @Override + public void register(Object codec) { + addCodec(codec, false); + } + + @Override + public void registerWithDefaultConfig(Object codec) { + addCodec(codec, true); + } + + @Override + public void registerWithDefaultConfig(Object codec, Consumer configConsumer) { + addCodec(codec, false); + this.defaultConfigConsumers.add(configConsumer); + } + + @SuppressWarnings("deprecation") + @Override + public void decoder(Decoder decoder) { + addCodec(decoder, false); + } + + @SuppressWarnings("deprecation") + @Override + public void encoder(Encoder encoder) { + addCodec(encoder, false); + } + + @SuppressWarnings("deprecation") + @Override + public void reader(HttpMessageReader reader) { + addCodec(reader, false); + } + + @SuppressWarnings("deprecation") + @Override + public void writer(HttpMessageWriter writer) { + addCodec(writer, false); + } + + @SuppressWarnings("deprecation") + @Override + public void withDefaultCodecConfig(Consumer codecsConfigConsumer) { + this.defaultConfigConsumers.add(codecsConfigConsumer); + } + + private void addCodec(Object codec, boolean applyDefaultConfig) { + + if (codec instanceof Decoder) { + codec = new DecoderHttpMessageReader<>((Decoder) codec); + } + else if (codec instanceof Encoder) { + codec = new EncoderHttpMessageWriter<>((Encoder) codec); + } + + if (codec instanceof HttpMessageReader) { + HttpMessageReader reader = (HttpMessageReader) codec; + boolean canReadToObject = reader.canRead(ResolvableType.forClass(Object.class), null); + (canReadToObject ? this.objectReaders : this.typedReaders).put(reader, applyDefaultConfig); + } + else if (codec instanceof HttpMessageWriter) { + HttpMessageWriter writer = (HttpMessageWriter) codec; + boolean canWriteObject = writer.canWrite(ResolvableType.forClass(Object.class), null); + (canWriteObject ? this.objectWriters : this.typedWriters).put(writer, applyDefaultConfig); + } + else { + throw new IllegalArgumentException("Unexpected codec type: " + codec.getClass().getName()); + } + } + + // Package private accessors... + + Map, Boolean> getTypedReaders() { + return this.typedReaders; + } + + Map, Boolean> getTypedWriters() { + return this.typedWriters; + } + + Map, Boolean> getObjectReaders() { + return this.objectReaders; + } + + Map, Boolean> getObjectWriters() { + return this.objectWriters; + } + + List> getDefaultConfigConsumers() { + return this.defaultConfigConsumers; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java new file mode 100644 index 0000000000000000000000000000000000000000..e01fb3c29c29615c43ef3e0f04a161c6e78f20c9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java @@ -0,0 +1,476 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.springframework.core.codec.AbstractDataBufferDecoder; +import org.springframework.core.codec.ByteArrayDecoder; +import org.springframework.core.codec.ByteArrayEncoder; +import org.springframework.core.codec.ByteBufferDecoder; +import org.springframework.core.codec.ByteBufferEncoder; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.DataBufferDecoder; +import org.springframework.core.codec.DataBufferEncoder; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.ResourceDecoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.http.codec.CodecConfigurer; +import org.springframework.http.codec.DecoderHttpMessageReader; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageReader; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.ServerSentEventHttpMessageReader; +import org.springframework.http.codec.json.AbstractJackson2Decoder; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.json.Jackson2SmileDecoder; +import org.springframework.http.codec.json.Jackson2SmileEncoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageReader; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader; +import org.springframework.http.codec.protobuf.ProtobufDecoder; +import org.springframework.http.codec.protobuf.ProtobufEncoder; +import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter; +import org.springframework.http.codec.xml.Jaxb2XmlDecoder; +import org.springframework.http.codec.xml.Jaxb2XmlEncoder; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; + +/** + * Default implementation of {@link CodecConfigurer.DefaultCodecs} that serves + * as a base for client and server specific variants. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + */ +class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs, CodecConfigurer.DefaultCodecConfig { + + static final boolean jackson2Present; + + private static final boolean jackson2SmilePresent; + + private static final boolean jaxb2Present; + + private static final boolean protobufPresent; + + static final boolean synchronossMultipartPresent; + + static { + ClassLoader classLoader = BaseCodecConfigurer.class.getClassLoader(); + jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader) && + ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader); + jackson2SmilePresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.smile.SmileFactory", classLoader); + jaxb2Present = ClassUtils.isPresent("javax.xml.bind.Binder", classLoader); + protobufPresent = ClassUtils.isPresent("com.google.protobuf.Message", classLoader); + synchronossMultipartPresent = ClassUtils.isPresent("org.synchronoss.cloud.nio.multipart.NioMultipartParser", classLoader); + } + + + @Nullable + private Decoder jackson2JsonDecoder; + + @Nullable + private Encoder jackson2JsonEncoder; + + @Nullable + private Decoder protobufDecoder; + + @Nullable + private Encoder protobufEncoder; + + @Nullable + private Decoder jaxb2Decoder; + + @Nullable + private Encoder jaxb2Encoder; + + @Nullable + private Integer maxInMemorySize; + + @Nullable + private Boolean enableLoggingRequestDetails; + + private boolean registerDefaults = true; + + + BaseDefaultCodecs() { + } + + /** + * Create a deep copy of the given {@link BaseDefaultCodecs}. + */ + protected BaseDefaultCodecs(BaseDefaultCodecs other) { + this.jackson2JsonDecoder = other.jackson2JsonDecoder; + this.jackson2JsonEncoder = other.jackson2JsonEncoder; + this.protobufDecoder = other.protobufDecoder; + this.protobufEncoder = other.protobufEncoder; + this.jaxb2Decoder = other.jaxb2Decoder; + this.jaxb2Encoder = other.jaxb2Encoder; + this.maxInMemorySize = other.maxInMemorySize; + this.enableLoggingRequestDetails = other.enableLoggingRequestDetails; + this.registerDefaults = other.registerDefaults; + } + + @Override + public void jackson2JsonDecoder(Decoder decoder) { + this.jackson2JsonDecoder = decoder; + } + + @Override + public void jackson2JsonEncoder(Encoder encoder) { + this.jackson2JsonEncoder = encoder; + } + + @Override + public void protobufDecoder(Decoder decoder) { + this.protobufDecoder = decoder; + } + + @Override + public void protobufEncoder(Encoder encoder) { + this.protobufEncoder = encoder; + } + + @Override + public void jaxb2Decoder(Decoder decoder) { + this.jaxb2Decoder = decoder; + } + + @Override + public void jaxb2Encoder(Encoder encoder) { + this.jaxb2Encoder = encoder; + } + + @Override + public void maxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + @Override + @Nullable + public Integer maxInMemorySize() { + return this.maxInMemorySize; + } + + @Override + public void enableLoggingRequestDetails(boolean enable) { + this.enableLoggingRequestDetails = enable; + } + + @Override + @Nullable + public Boolean isEnableLoggingRequestDetails() { + return this.enableLoggingRequestDetails; + } + + /** + * Delegate method used from {@link BaseCodecConfigurer#registerDefaults}. + */ + void registerDefaults(boolean registerDefaults) { + this.registerDefaults = registerDefaults; + } + + + /** + * Return readers that support specific types. + */ + final List> getTypedReaders() { + if (!this.registerDefaults) { + return Collections.emptyList(); + } + List> readers = new ArrayList<>(); + addCodec(readers, new DecoderHttpMessageReader<>(new ByteArrayDecoder())); + addCodec(readers, new DecoderHttpMessageReader<>(new ByteBufferDecoder())); + addCodec(readers, new DecoderHttpMessageReader<>(new DataBufferDecoder())); + addCodec(readers, new DecoderHttpMessageReader<>(new ResourceDecoder())); + addCodec(readers, new DecoderHttpMessageReader<>(StringDecoder.textPlainOnly())); + if (protobufPresent) { + Decoder decoder = this.protobufDecoder != null ? this.protobufDecoder : new ProtobufDecoder(); + addCodec(readers, new DecoderHttpMessageReader<>(decoder)); + } + addCodec(readers, new FormHttpMessageReader()); + + // client vs server.. + extendTypedReaders(readers); + + return readers; + } + + /** + * Initialize a codec and add it to the List. + * @since 5.1.13 + */ + protected void addCodec(List codecs, T codec) { + initCodec(codec); + codecs.add(codec); + } + + /** + * Apply {@link #maxInMemorySize()} and {@link #enableLoggingRequestDetails}, + * if configured by the application, to the given codec , including any + * codec it contains. + */ + private void initCodec(@Nullable Object codec) { + + if (codec instanceof DecoderHttpMessageReader) { + codec = ((DecoderHttpMessageReader) codec).getDecoder(); + } + else if (codec instanceof ServerSentEventHttpMessageReader) { + codec = ((ServerSentEventHttpMessageReader) codec).getDecoder(); + } + + if (codec == null) { + return; + } + + Integer size = this.maxInMemorySize; + if (size != null) { + if (codec instanceof AbstractDataBufferDecoder) { + ((AbstractDataBufferDecoder) codec).setMaxInMemorySize(size); + } + if (protobufPresent) { + if (codec instanceof ProtobufDecoder) { + ((ProtobufDecoder) codec).setMaxMessageSize(size); + } + } + if (jackson2Present) { + if (codec instanceof AbstractJackson2Decoder) { + ((AbstractJackson2Decoder) codec).setMaxInMemorySize(size); + } + } + if (jaxb2Present) { + if (codec instanceof Jaxb2XmlDecoder) { + ((Jaxb2XmlDecoder) codec).setMaxInMemorySize(size); + } + } + if (codec instanceof FormHttpMessageReader) { + ((FormHttpMessageReader) codec).setMaxInMemorySize(size); + } + if (synchronossMultipartPresent) { + if (codec instanceof SynchronossPartHttpMessageReader) { + ((SynchronossPartHttpMessageReader) codec).setMaxInMemorySize(size); + } + } + } + + Boolean enable = this.enableLoggingRequestDetails; + if (enable != null) { + if (codec instanceof FormHttpMessageReader) { + ((FormHttpMessageReader) codec).setEnableLoggingRequestDetails(enable); + } + if (codec instanceof MultipartHttpMessageReader) { + ((MultipartHttpMessageReader) codec).setEnableLoggingRequestDetails(enable); + } + if (synchronossMultipartPresent) { + if (codec instanceof SynchronossPartHttpMessageReader) { + ((SynchronossPartHttpMessageReader) codec).setEnableLoggingRequestDetails(enable); + } + } + if (codec instanceof FormHttpMessageWriter) { + ((FormHttpMessageWriter) codec).setEnableLoggingRequestDetails(enable); + } + if (codec instanceof MultipartHttpMessageWriter) { + ((MultipartHttpMessageWriter) codec).setEnableLoggingRequestDetails(enable); + } + } + + if (codec instanceof MultipartHttpMessageReader) { + initCodec(((MultipartHttpMessageReader) codec).getPartReader()); + } + else if (codec instanceof MultipartHttpMessageWriter) { + initCodec(((MultipartHttpMessageWriter) codec).getFormWriter()); + } + } + + /** + * Hook for client or server specific typed readers. + */ + protected void extendTypedReaders(List> typedReaders) { + } + + /** + * Return Object readers (JSON, XML, SSE). + */ + final List> getObjectReaders() { + if (!this.registerDefaults) { + return Collections.emptyList(); + } + List> readers = new ArrayList<>(); + if (jackson2Present) { + addCodec(readers, new DecoderHttpMessageReader<>(getJackson2JsonDecoder())); + } + if (jackson2SmilePresent) { + addCodec(readers, new DecoderHttpMessageReader<>(new Jackson2SmileDecoder())); + } + if (jaxb2Present) { + Decoder decoder = this.jaxb2Decoder != null ? this.jaxb2Decoder : new Jaxb2XmlDecoder(); + addCodec(readers, new DecoderHttpMessageReader<>(decoder)); + } + + // client vs server.. + extendObjectReaders(readers); + + return readers; + } + + /** + * Hook for client or server specific Object readers. + */ + protected void extendObjectReaders(List> objectReaders) { + } + + /** + * Return readers that need to be at the end, after all others. + */ + final List> getCatchAllReaders() { + if (!this.registerDefaults) { + return Collections.emptyList(); + } + List> readers = new ArrayList<>(); + addCodec(readers, new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); + return readers; + } + + /** + * Return all writers that support specific types. + */ + @SuppressWarnings({"rawtypes" }) + final List> getTypedWriters() { + if (!this.registerDefaults) { + return Collections.emptyList(); + } + List> writers = getBaseTypedWriters(); + extendTypedWriters(writers); + return writers; + } + + /** + * Return "base" typed writers only, i.e. common to client and server. + */ + @SuppressWarnings("unchecked") + final List> getBaseTypedWriters() { + if (!this.registerDefaults) { + return Collections.emptyList(); + } + List> writers = new ArrayList<>(); + writers.add(new EncoderHttpMessageWriter<>(new ByteArrayEncoder())); + writers.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder())); + writers.add(new EncoderHttpMessageWriter<>(new DataBufferEncoder())); + writers.add(new ResourceHttpMessageWriter()); + writers.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.textPlainOnly())); + if (protobufPresent) { + Encoder encoder = this.protobufEncoder != null ? this.protobufEncoder : new ProtobufEncoder(); + writers.add(new ProtobufHttpMessageWriter((Encoder) encoder)); + } + return writers; + } + + /** + * Hook for client or server specific typed writers. + */ + protected void extendTypedWriters(List> typedWriters) { + } + + /** + * Return Object writers (JSON, XML, SSE). + */ + final List> getObjectWriters() { + if (!this.registerDefaults) { + return Collections.emptyList(); + } + List> writers = getBaseObjectWriters(); + extendObjectWriters(writers); + return writers; + } + + /** + * Return "base" object writers only, i.e. common to client and server. + */ + final List> getBaseObjectWriters() { + List> writers = new ArrayList<>(); + if (jackson2Present) { + writers.add(new EncoderHttpMessageWriter<>(getJackson2JsonEncoder())); + } + if (jackson2SmilePresent) { + writers.add(new EncoderHttpMessageWriter<>(new Jackson2SmileEncoder())); + } + if (jaxb2Present) { + Encoder encoder = this.jaxb2Encoder != null ? this.jaxb2Encoder : new Jaxb2XmlEncoder(); + writers.add(new EncoderHttpMessageWriter<>(encoder)); + } + return writers; + } + + /** + * Hook for client or server specific Object writers. + */ + protected void extendObjectWriters(List> objectWriters) { + } + + /** + * Return writers that need to be at the end, after all others. + */ + List> getCatchAllWriters() { + if (!this.registerDefaults) { + return Collections.emptyList(); + } + List> result = new ArrayList<>(); + result.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.allMimeTypes())); + return result; + } + + void applyDefaultConfig(BaseCodecConfigurer.DefaultCustomCodecs customCodecs) { + applyDefaultConfig(customCodecs.getTypedReaders()); + applyDefaultConfig(customCodecs.getObjectReaders()); + applyDefaultConfig(customCodecs.getTypedWriters()); + applyDefaultConfig(customCodecs.getObjectWriters()); + customCodecs.getDefaultConfigConsumers().forEach(consumer -> consumer.accept(this)); + } + + private void applyDefaultConfig(Map readers) { + readers.entrySet().stream() + .filter(Map.Entry::getValue) + .map(Map.Entry::getKey) + .forEach(this::initCodec); + } + + + // Accessors for use in subclasses... + + protected Decoder getJackson2JsonDecoder() { + if (this.jackson2JsonDecoder == null) { + this.jackson2JsonDecoder = new Jackson2JsonDecoder(); + } + return this.jackson2JsonDecoder; + } + + protected Encoder getJackson2JsonEncoder() { + if (this.jackson2JsonEncoder == null) { + this.jackson2JsonEncoder = new Jackson2JsonEncoder(); + } + return this.jackson2JsonEncoder; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/ClientDefaultCodecsImpl.java b/spring-web/src/main/java/org/springframework/http/codec/support/ClientDefaultCodecsImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..cc1c7f1a439cc9ec320adc7ff33e8c5015a8c891 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/support/ClientDefaultCodecsImpl.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; + +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ServerSentEventHttpMessageReader; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.lang.Nullable; + +/** + * Default implementation of {@link ClientCodecConfigurer.ClientDefaultCodecs}. + * + * @author Rossen Stoyanchev + */ +class ClientDefaultCodecsImpl extends BaseDefaultCodecs implements ClientCodecConfigurer.ClientDefaultCodecs { + + @Nullable + private DefaultMultipartCodecs multipartCodecs; + + @Nullable + private Decoder sseDecoder; + + @Nullable + private Supplier>> partWritersSupplier; + + + ClientDefaultCodecsImpl() { + } + + ClientDefaultCodecsImpl(ClientDefaultCodecsImpl other) { + super(other); + this.multipartCodecs = (other.multipartCodecs != null ? + new DefaultMultipartCodecs(other.multipartCodecs) : null); + this.sseDecoder = other.sseDecoder; + } + + + /** + * Set a supplier for part writers to use when + * {@link #multipartCodecs()} are not explicitly configured. + * That's the same set of writers as for general except for the multipart + * writer itself. + */ + void setPartWritersSupplier(Supplier>> supplier) { + this.partWritersSupplier = supplier; + } + + + @Override + public ClientCodecConfigurer.MultipartCodecs multipartCodecs() { + if (this.multipartCodecs == null) { + this.multipartCodecs = new DefaultMultipartCodecs(); + } + return this.multipartCodecs; + } + + @Override + public void serverSentEventDecoder(Decoder decoder) { + this.sseDecoder = decoder; + } + + @Override + public ClientDefaultCodecsImpl clone() { + ClientDefaultCodecsImpl codecs = new ClientDefaultCodecsImpl(); + codecs.multipartCodecs = this.multipartCodecs; + codecs.sseDecoder = this.sseDecoder; + codecs.partWritersSupplier = this.partWritersSupplier; + return codecs; + } + + @Override + protected void extendObjectReaders(List> objectReaders) { + + Decoder decoder = (this.sseDecoder != null ? + this.sseDecoder : + jackson2Present ? getJackson2JsonDecoder() : null); + + addCodec(objectReaders, new ServerSentEventHttpMessageReader(decoder)); + } + + @Override + protected void extendTypedWriters(List> typedWriters) { + addCodec(typedWriters, new MultipartHttpMessageWriter(getPartWriters(), new FormHttpMessageWriter())); + } + + private List> getPartWriters() { + if (this.multipartCodecs != null) { + return this.multipartCodecs.getWriters(); + } + else if (this.partWritersSupplier != null) { + return this.partWritersSupplier.get(); + } + else { + return Collections.emptyList(); + } + } + + + /** + * Default implementation of {@link ClientCodecConfigurer.MultipartCodecs}. + */ + private static class DefaultMultipartCodecs implements ClientCodecConfigurer.MultipartCodecs { + + private final List> writers = new ArrayList<>(); + + + DefaultMultipartCodecs() { + } + + DefaultMultipartCodecs(DefaultMultipartCodecs other) { + this.writers.addAll(other.writers); + } + + + @Override + public ClientCodecConfigurer.MultipartCodecs encoder(Encoder encoder) { + writer(new EncoderHttpMessageWriter<>(encoder)); + return this; + } + + @Override + public ClientCodecConfigurer.MultipartCodecs writer(HttpMessageWriter writer) { + this.writers.add(writer); + return this; + } + + List> getWriters() { + return this.writers; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/DefaultClientCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/support/DefaultClientCodecConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..382d11bec8c472bc62605d89fb8f2199fb58a1a3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/support/DefaultClientCodecConfigurer.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.http.codec.HttpMessageWriter; + +/** + * Default implementation of {@link ClientCodecConfigurer}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class DefaultClientCodecConfigurer extends BaseCodecConfigurer implements ClientCodecConfigurer { + + + public DefaultClientCodecConfigurer() { + super(new ClientDefaultCodecsImpl()); + ((ClientDefaultCodecsImpl) defaultCodecs()).setPartWritersSupplier(this::getPartWriters); + } + + private DefaultClientCodecConfigurer(DefaultClientCodecConfigurer other) { + super(other); + ((ClientDefaultCodecsImpl) defaultCodecs()).setPartWritersSupplier(this::getPartWriters); + } + + + @Override + public ClientDefaultCodecs defaultCodecs() { + return (ClientDefaultCodecs) super.defaultCodecs(); + } + + @Override + public DefaultClientCodecConfigurer clone() { + return new DefaultClientCodecConfigurer(this); + } + + @Override + protected BaseDefaultCodecs cloneDefaultCodecs() { + return new ClientDefaultCodecsImpl((ClientDefaultCodecsImpl) defaultCodecs()); + } + + private List> getPartWriters() { + List> result = new ArrayList<>(); + result.addAll(this.customCodecs.getTypedWriters().keySet()); + result.addAll(this.defaultCodecs.getBaseTypedWriters()); + result.addAll(this.customCodecs.getObjectWriters().keySet()); + result.addAll(this.defaultCodecs.getBaseObjectWriters()); + result.addAll(this.defaultCodecs.getCatchAllWriters()); + return result; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/DefaultServerCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/support/DefaultServerCodecConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..661d45d66693a7f6b9bf2dfac2139f6724ca9a6c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/support/DefaultServerCodecConfigurer.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import org.springframework.http.codec.ServerCodecConfigurer; + +/** + * Default implementation of {@link ServerCodecConfigurer}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class DefaultServerCodecConfigurer extends BaseCodecConfigurer implements ServerCodecConfigurer { + + + public DefaultServerCodecConfigurer() { + super(new ServerDefaultCodecsImpl()); + } + + private DefaultServerCodecConfigurer(BaseCodecConfigurer other) { + super(other); + } + + + @Override + public ServerDefaultCodecs defaultCodecs() { + return (ServerDefaultCodecs) super.defaultCodecs(); + } + + @Override + public DefaultServerCodecConfigurer clone() { + return new DefaultServerCodecConfigurer(this); + } + + @Override + protected BaseDefaultCodecs cloneDefaultCodecs() { + return new ServerDefaultCodecsImpl((ServerDefaultCodecsImpl) defaultCodecs()); + } +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java b/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..eaab4e3237c49471c5d5113aa11e1bb5da5bfb4a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.codec.support; + +import java.util.List; + +import org.springframework.core.codec.Encoder; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.codec.ServerSentEventHttpMessageWriter; +import org.springframework.http.codec.multipart.MultipartHttpMessageReader; +import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader; +import org.springframework.lang.Nullable; + +/** + * Default implementation of {@link ServerCodecConfigurer.ServerDefaultCodecs}. + * + * @author Rossen Stoyanchev + */ +class ServerDefaultCodecsImpl extends BaseDefaultCodecs implements ServerCodecConfigurer.ServerDefaultCodecs { + + @Nullable + private HttpMessageReader multipartReader; + + @Nullable + private Encoder sseEncoder; + + + ServerDefaultCodecsImpl() { + } + + ServerDefaultCodecsImpl(ServerDefaultCodecsImpl other) { + super(other); + this.multipartReader = other.multipartReader; + this.sseEncoder = other.sseEncoder; + } + + + @Override + public void multipartReader(HttpMessageReader reader) { + this.multipartReader = reader; + } + + @Override + public void serverSentEventEncoder(Encoder encoder) { + this.sseEncoder = encoder; + } + + + @Override + protected void extendTypedReaders(List> typedReaders) { + if (this.multipartReader != null) { + addCodec(typedReaders, this.multipartReader); + return; + } + if (synchronossMultipartPresent) { + SynchronossPartHttpMessageReader partReader = new SynchronossPartHttpMessageReader(); + addCodec(typedReaders, partReader); + addCodec(typedReaders, new MultipartHttpMessageReader(partReader)); + } + } + + @Override + protected void extendObjectWriters(List> objectWriters) { + objectWriters.add(new ServerSentEventHttpMessageWriter(getSseEncoder())); + } + + @Nullable + private Encoder getSseEncoder() { + return this.sseEncoder != null ? this.sseEncoder : jackson2Present ? getJackson2JsonEncoder() : null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/package-info.java b/spring-web/src/main/java/org/springframework/http/codec/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..d2658887ebe32bfe4ce6d6af34a7a6741e7169da --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/support/package-info.java @@ -0,0 +1,11 @@ +/** + * Provides implementations of {@link org.springframework.http.codec.ClientCodecConfigurer} + * and {@link org.springframework.http.codec.ServerCodecConfigurer} based on the converter + * implementations from {@code org.springframework.http.codec.json} and co. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.codec.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..d2dc6106e669ba610ca963be551e0ca0cb468284 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java @@ -0,0 +1,310 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import javax.xml.XMLConstants; +import javax.xml.bind.JAXBElement; +import javax.xml.bind.JAXBException; +import javax.xml.bind.UnmarshalException; +import javax.xml.bind.Unmarshaller; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlSchema; +import javax.xml.bind.annotation.XmlType; +import javax.xml.namespace.QName; +import javax.xml.stream.XMLEventReader; +import javax.xml.stream.events.XMLEvent; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDecoder; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.xml.StaxUtils; + +/** + * Decode from a bytes stream containing XML elements to a stream of + * {@code Object}s (POJOs). + * + * @author Sebastien Deleuze + * @author Arjen Poutsma + * @since 5.0 + * @see Jaxb2XmlEncoder + */ +public class Jaxb2XmlDecoder extends AbstractDecoder { + + /** + * The default value for JAXB annotations. + * @see XmlRootElement#name() + * @see XmlRootElement#namespace() + * @see XmlType#name() + * @see XmlType#namespace() + */ + private static final String JAXB_DEFAULT_ANNOTATION_VALUE = "##default"; + + + private final XmlEventDecoder xmlEventDecoder = new XmlEventDecoder(); + + private final JaxbContextContainer jaxbContexts = new JaxbContextContainer(); + + private Function unmarshallerProcessor = Function.identity(); + + private int maxInMemorySize = -1; + + + public Jaxb2XmlDecoder() { + super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML); + } + + /** + * Create a {@code Jaxb2XmlDecoder} with the specified MIME types. + * @param supportedMimeTypes supported MIME types + * @since 5.1.9 + */ + public Jaxb2XmlDecoder(MimeType... supportedMimeTypes) { + super(supportedMimeTypes); + } + + + /** + * Configure a processor function to customize Unmarshaller instances. + * @param processor the function to use + * @since 5.1.3 + */ + public void setUnmarshallerProcessor(Function processor) { + this.unmarshallerProcessor = this.unmarshallerProcessor.andThen(processor); + } + + /** + * Return the configured processor for customizing Unmarshaller instances. + * @since 5.1.3 + */ + public Function getUnmarshallerProcessor() { + return this.unmarshallerProcessor; + } + + /** + * Set the max number of bytes that can be buffered by this decoder. + * This is either the size of the entire input when decoding as a whole, or when + * using async parsing with Aalto XML, it is the size of one top-level XML tree. + * When the limit is exceeded, {@link DataBufferLimitException} is raised. + *

By default in 5.1 this is set to -1, unlimited. In 5.2 the default + * value for this limit is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + this.xmlEventDecoder.setMaxInMemorySize(byteCount); + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + + @Override + public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { + Class outputClass = elementType.toClass(); + return (outputClass.isAnnotationPresent(XmlRootElement.class) || + outputClass.isAnnotationPresent(XmlType.class)) && super.canDecode(elementType, mimeType); + } + + @Override + public Flux decode(Publisher inputStream, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + Flux xmlEventFlux = this.xmlEventDecoder.decode( + inputStream, ResolvableType.forClass(XMLEvent.class), mimeType, hints); + + Class outputClass = elementType.toClass(); + QName typeName = toQName(outputClass); + Flux> splitEvents = split(xmlEventFlux, typeName); + + return splitEvents.map(events -> { + Object value = unmarshal(events, outputClass); + LogFormatUtils.traceDebug(logger, traceOn -> { + String formatted = LogFormatUtils.formatValue(value, !traceOn); + return Hints.getLogPrefix(hints) + "Decoded [" + formatted + "]"; + }); + return value; + }); + } + + @Override + public Mono decodeToMono(Publisher inputStream, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + return decode(inputStream, elementType, mimeType, hints).singleOrEmpty(); + } + + private Object unmarshal(List events, Class outputClass) { + try { + Unmarshaller unmarshaller = initUnmarshaller(outputClass); + XMLEventReader eventReader = StaxUtils.createXMLEventReader(events); + if (outputClass.isAnnotationPresent(XmlRootElement.class)) { + return unmarshaller.unmarshal(eventReader); + } + else { + JAXBElement jaxbElement = unmarshaller.unmarshal(eventReader, outputClass); + return jaxbElement.getValue(); + } + } + catch (UnmarshalException ex) { + throw new DecodingException("Could not unmarshal XML to " + outputClass, ex); + } + catch (JAXBException ex) { + throw new CodecException("Invalid JAXB configuration", ex); + } + } + + private Unmarshaller initUnmarshaller(Class outputClass) throws CodecException, JAXBException { + Unmarshaller unmarshaller = this.jaxbContexts.createUnmarshaller(outputClass); + return this.unmarshallerProcessor.apply(unmarshaller); + } + + /** + * Returns the qualified name for the given class, according to the mapping rules + * in the JAXB specification. + */ + QName toQName(Class outputClass) { + String localPart; + String namespaceUri; + + if (outputClass.isAnnotationPresent(XmlRootElement.class)) { + XmlRootElement annotation = outputClass.getAnnotation(XmlRootElement.class); + localPart = annotation.name(); + namespaceUri = annotation.namespace(); + } + else if (outputClass.isAnnotationPresent(XmlType.class)) { + XmlType annotation = outputClass.getAnnotation(XmlType.class); + localPart = annotation.name(); + namespaceUri = annotation.namespace(); + } + else { + throw new IllegalArgumentException("Output class [" + outputClass.getName() + + "] is neither annotated with @XmlRootElement nor @XmlType"); + } + + if (JAXB_DEFAULT_ANNOTATION_VALUE.equals(localPart)) { + localPart = ClassUtils.getShortNameAsProperty(outputClass); + } + if (JAXB_DEFAULT_ANNOTATION_VALUE.equals(namespaceUri)) { + Package outputClassPackage = outputClass.getPackage(); + if (outputClassPackage != null && outputClassPackage.isAnnotationPresent(XmlSchema.class)) { + XmlSchema annotation = outputClassPackage.getAnnotation(XmlSchema.class); + namespaceUri = annotation.namespace(); + } + else { + namespaceUri = XMLConstants.NULL_NS_URI; + } + } + return new QName(namespaceUri, localPart); + } + + /** + * Split a flux of {@link XMLEvent XMLEvents} into a flux of XMLEvent lists, one list + * for each branch of the tree that starts with the given qualified name. + * That is, given the XMLEvents shown {@linkplain XmlEventDecoder here}, + * and the {@code desiredName} "{@code child}", this method returns a flux + * of two lists, each of which containing the events of a particular branch + * of the tree that starts with "{@code child}". + *
    + *
  1. The first list, dealing with the first branch of the tree: + *
      + *
    1. {@link javax.xml.stream.events.StartElement} {@code child}
    2. + *
    3. {@link javax.xml.stream.events.Characters} {@code foo}
    4. + *
    5. {@link javax.xml.stream.events.EndElement} {@code child}
    6. + *
    + *
  2. The second list, dealing with the second branch of the tree: + *
      + *
    1. {@link javax.xml.stream.events.StartElement} {@code child}
    2. + *
    3. {@link javax.xml.stream.events.Characters} {@code bar}
    4. + *
    5. {@link javax.xml.stream.events.EndElement} {@code child}
    6. + *
    + *
  3. + *
+ */ + Flux> split(Flux xmlEventFlux, QName desiredName) { + return xmlEventFlux.flatMap(new SplitFunction(desiredName)); + } + + + private static class SplitFunction implements Function>> { + + private final QName desiredName; + + @Nullable + private List events; + + private int elementDepth = 0; + + private int barrier = Integer.MAX_VALUE; + + public SplitFunction(QName desiredName) { + this.desiredName = desiredName; + } + + @Override + public Publisher> apply(XMLEvent event) { + if (event.isStartElement()) { + if (this.barrier == Integer.MAX_VALUE) { + QName startElementName = event.asStartElement().getName(); + if (this.desiredName.equals(startElementName)) { + this.events = new ArrayList<>(); + this.barrier = this.elementDepth; + } + } + this.elementDepth++; + } + if (this.elementDepth > this.barrier) { + Assert.state(this.events != null, "No XMLEvent List"); + this.events.add(event); + } + if (event.isEndElement()) { + this.elementDepth--; + if (this.elementDepth == this.barrier) { + this.barrier = Integer.MAX_VALUE; + Assert.state(this.events != null, "No XMLEvent List"); + return Mono.just(this.events); + } + } + return Mono.empty(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java new file mode 100644 index 0000000000000000000000000000000000000000..8441d1afa326b155460856ba8284b9fdf66405c4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml; + +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.function.Function; + +import javax.xml.bind.JAXBException; +import javax.xml.bind.MarshalException; +import javax.xml.bind.Marshaller; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlType; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractSingleValueEncoder; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.EncodingException; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +/** + * Encode from single value to a byte stream containing XML elements. + * + *

{@link javax.xml.bind.annotation.XmlElements @XmlElements} and + * {@link javax.xml.bind.annotation.XmlElement @XmlElement} can be used + * to specify how collections should be marshalled. + * + * @author Sebastien Deleuze + * @author Arjen Poutsma + * @since 5.0 + * @see Jaxb2XmlDecoder + */ +public class Jaxb2XmlEncoder extends AbstractSingleValueEncoder { + + private final JaxbContextContainer jaxbContexts = new JaxbContextContainer(); + + private Function marshallerProcessor = Function.identity(); + + + public Jaxb2XmlEncoder() { + super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML); + } + + + /** + * Configure a processor function to customize Marshaller instances. + * @param processor the function to use + * @since 5.1.3 + */ + public void setMarshallerProcessor(Function processor) { + this.marshallerProcessor = this.marshallerProcessor.andThen(processor); + } + + /** + * Return the configured processor for customizing Marshaller instances. + * @since 5.1.3 + */ + public Function getMarshallerProcessor() { + return this.marshallerProcessor; + } + + + @Override + public boolean canEncode(ResolvableType elementType, @Nullable MimeType mimeType) { + if (super.canEncode(elementType, mimeType)) { + Class outputClass = elementType.toClass(); + return (outputClass.isAnnotationPresent(XmlRootElement.class) || + outputClass.isAnnotationPresent(XmlType.class)); + } + else { + return false; + } + } + + @Override + protected Flux encode(Object value, DataBufferFactory bufferFactory, + ResolvableType type, @Nullable MimeType mimeType, @Nullable Map hints) { + + if (!Hints.isLoggingSuppressed(hints)) { + LogFormatUtils.traceDebug(logger, traceOn -> { + String formatted = LogFormatUtils.formatValue(value, !traceOn); + return Hints.getLogPrefix(hints) + "Encoding [" + formatted + "]"; + }); + } + + return Mono.fromCallable(() -> { + boolean release = true; + DataBuffer buffer = bufferFactory.allocateBuffer(1024); + try { + OutputStream outputStream = buffer.asOutputStream(); + Class clazz = ClassUtils.getUserClass(value); + Marshaller marshaller = initMarshaller(clazz); + marshaller.marshal(value, outputStream); + release = false; + return buffer; // relying on doOnDiscard in base class + } + catch (MarshalException ex) { + throw new EncodingException( + "Could not marshal " + value.getClass() + " to XML", ex); + } + catch (JAXBException ex) { + throw new CodecException("Invalid JAXB configuration", ex); + } + finally { + if (release) { + DataBufferUtils.release(buffer); + } + } + }).flux(); + } + + private Marshaller initMarshaller(Class clazz) throws CodecException, JAXBException { + Marshaller marshaller = this.jaxbContexts.createMarshaller(clazz); + marshaller.setProperty(Marshaller.JAXB_ENCODING, StandardCharsets.UTF_8.name()); + marshaller = this.marshallerProcessor.apply(marshaller); + return marshaller; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/JaxbContextContainer.java b/spring-web/src/main/java/org/springframework/http/codec/xml/JaxbContextContainer.java new file mode 100644 index 0000000000000000000000000000000000000000..49441c498c38f6fea1489cc0573f461b4d0a23fc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/JaxbContextContainer.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import javax.xml.bind.JAXBContext; +import javax.xml.bind.JAXBException; +import javax.xml.bind.Marshaller; +import javax.xml.bind.Unmarshaller; + +import org.springframework.core.codec.CodecException; + +/** + * Holder for {@link JAXBContext} instances. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 5.0 + */ +final class JaxbContextContainer { + + private final ConcurrentMap, JAXBContext> jaxbContexts = new ConcurrentHashMap<>(64); + + + public Marshaller createMarshaller(Class clazz) throws CodecException, JAXBException { + JAXBContext jaxbContext = getJaxbContext(clazz); + return jaxbContext.createMarshaller(); + } + + public Unmarshaller createUnmarshaller(Class clazz) throws CodecException, JAXBException { + JAXBContext jaxbContext = getJaxbContext(clazz); + return jaxbContext.createUnmarshaller(); + } + + private JAXBContext getJaxbContext(Class clazz) throws CodecException { + return this.jaxbContexts.computeIfAbsent(clazz, key -> { + try { + return JAXBContext.newInstance(clazz); + } + catch (JAXBException ex) { + throw new CodecException( + "Could not create JAXBContext for class [" + clazz + "]: " + ex.getMessage(), ex); + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..46d3c4e139199dac2f38e77d512fd817b2cc000e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java @@ -0,0 +1,247 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml; + +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.events.XMLEvent; +import javax.xml.stream.util.XMLEventAllocator; + +import com.fasterxml.aalto.AsyncByteBufferFeeder; +import com.fasterxml.aalto.AsyncXMLInputFactory; +import com.fasterxml.aalto.AsyncXMLStreamReader; +import com.fasterxml.aalto.evt.EventAllocatorImpl; +import com.fasterxml.aalto.stax.InputFactoryImpl; +import org.reactivestreams.Publisher; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDecoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.xml.StaxUtils; + +/** + * Decodes a {@link DataBuffer} stream into a stream of {@link XMLEvent XMLEvents}. + * + *

Given the following XML: + * + *

+ * <root>
+ *     <child>foo</child>
+ *     <child>bar</child>
+ * </root>
+ * 
+ * + * this decoder will produce a {@link Flux} with the following events: + * + *
    + *
  1. {@link javax.xml.stream.events.StartDocument}
  2. + *
  3. {@link javax.xml.stream.events.StartElement} {@code root}
  4. + *
  5. {@link javax.xml.stream.events.StartElement} {@code child}
  6. + *
  7. {@link javax.xml.stream.events.Characters} {@code foo}
  8. + *
  9. {@link javax.xml.stream.events.EndElement} {@code child}
  10. + *
  11. {@link javax.xml.stream.events.StartElement} {@code child}
  12. + *
  13. {@link javax.xml.stream.events.Characters} {@code bar}
  14. + *
  15. {@link javax.xml.stream.events.EndElement} {@code child}
  16. + *
  17. {@link javax.xml.stream.events.EndElement} {@code root}
  18. + *
+ * + *

Note that this decoder is not registered by default but is used internally + * by other decoders which are registered by default. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public class XmlEventDecoder extends AbstractDecoder { + + private static final XMLInputFactory inputFactory = StaxUtils.createDefensiveInputFactory(); + + private static final boolean aaltoPresent = ClassUtils.isPresent( + "com.fasterxml.aalto.AsyncXMLStreamReader", XmlEventDecoder.class.getClassLoader()); + + boolean useAalto = aaltoPresent; + + private int maxInMemorySize = -1; + + + public XmlEventDecoder() { + super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML); + } + + + /** + * Set the max number of bytes that can be buffered by this decoder. This + * is either the size the entire input when decoding as a whole, or when + * using async parsing via Aalto XML, it is size one top-level XML tree. + * When the limit is exceeded, {@link DataBufferLimitException} is raised. + *

By default in 5.1 this is set to -1, unlimited. In 5.2 the default + * value for this limit is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) // on JDK 9 where XMLEventReader is Iterator + public Flux decode(Publisher input, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + if (this.useAalto) { + AaltoDataBufferToXmlEvent mapper = new AaltoDataBufferToXmlEvent(this.maxInMemorySize); + return Flux.from(input) + .flatMapIterable(mapper) + .doFinally(signalType -> mapper.endOfInput()); + } + else { + return DataBufferUtils.join(input, getMaxInMemorySize()) + .flatMapIterable(buffer -> { + try { + InputStream is = buffer.asInputStream(); + Iterator eventReader = inputFactory.createXMLEventReader(is); + List result = new ArrayList<>(); + eventReader.forEachRemaining(event -> result.add((XMLEvent) event)); + return result; + } + catch (XMLStreamException ex) { + throw Exceptions.propagate(ex); + } + finally { + DataBufferUtils.release(buffer); + } + }); + } + } + + + /* + * Separate static class to isolate Aalto dependency. + */ + private static class AaltoDataBufferToXmlEvent implements Function> { + + private static final AsyncXMLInputFactory inputFactory = + StaxUtils.createDefensiveInputFactory(InputFactoryImpl::new); + + private final AsyncXMLStreamReader streamReader = + inputFactory.createAsyncForByteBuffer(); + + private final XMLEventAllocator eventAllocator = EventAllocatorImpl.getDefaultInstance(); + + private final int maxInMemorySize; + + private int byteCount; + + private int elementDepth; + + + public AaltoDataBufferToXmlEvent(int maxInMemorySize) { + this.maxInMemorySize = maxInMemorySize; + } + + + @Override + public List apply(DataBuffer dataBuffer) { + try { + increaseByteCount(dataBuffer); + this.streamReader.getInputFeeder().feedInput(dataBuffer.asByteBuffer()); + List events = new ArrayList<>(); + while (true) { + if (this.streamReader.next() == AsyncXMLStreamReader.EVENT_INCOMPLETE) { + // no more events with what currently has been fed to the reader + break; + } + else { + XMLEvent event = this.eventAllocator.allocate(this.streamReader); + events.add(event); + if (event.isEndDocument()) { + break; + } + checkDepthAndResetByteCount(event); + } + } + if (this.maxInMemorySize > 0 && this.byteCount > this.maxInMemorySize) { + raiseLimitException(); + } + return events; + } + catch (XMLStreamException ex) { + throw Exceptions.propagate(ex); + } + finally { + DataBufferUtils.release(dataBuffer); + } + } + + private void increaseByteCount(DataBuffer dataBuffer) { + if (this.maxInMemorySize > 0) { + if (dataBuffer.readableByteCount() > Integer.MAX_VALUE - this.byteCount) { + raiseLimitException(); + } + else { + this.byteCount += dataBuffer.readableByteCount(); + } + } + } + + private void checkDepthAndResetByteCount(XMLEvent event) { + if (this.maxInMemorySize > 0) { + if (event.isStartElement()) { + this.byteCount = this.elementDepth == 1 ? 0 : this.byteCount; + this.elementDepth++; + } + else if (event.isEndElement()) { + this.elementDepth--; + this.byteCount = this.elementDepth == 1 ? 0 : this.byteCount; + } + } + } + + private void raiseLimitException() { + throw new DataBufferLimitException( + "Exceeded limit on max bytes per XML top-level node: " + this.maxInMemorySize); + } + + public void endOfInput() { + this.streamReader.getInputFeeder().endOfInput(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/package-info.java b/spring-web/src/main/java/org/springframework/http/codec/xml/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..2af5e7773e7f6637bea38fc2864cdde8faf97b87 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/package-info.java @@ -0,0 +1,9 @@ +/** + * XML encoder and decoder support. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.codec.xml; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractGenericHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractGenericHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..fe17ad9cf46f97a9618e5b8a4f02d8cb0700c27b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractGenericHttpMessageConverter.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.io.OutputStream; +import java.lang.reflect.Type; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; + +/** + * Abstract base class for most {@link GenericHttpMessageConverter} implementations. + * + * @author Sebastien Deleuze + * @author Juergen Hoeller + * @since 4.2 + * @param the converted object type + */ +public abstract class AbstractGenericHttpMessageConverter extends AbstractHttpMessageConverter + implements GenericHttpMessageConverter { + + /** + * Construct an {@code AbstractGenericHttpMessageConverter} with no supported media types. + * @see #setSupportedMediaTypes + */ + protected AbstractGenericHttpMessageConverter() { + } + + /** + * Construct an {@code AbstractGenericHttpMessageConverter} with one supported media type. + * @param supportedMediaType the supported media type + */ + protected AbstractGenericHttpMessageConverter(MediaType supportedMediaType) { + super(supportedMediaType); + } + + /** + * Construct an {@code AbstractGenericHttpMessageConverter} with multiple supported media type. + * @param supportedMediaTypes the supported media types + */ + protected AbstractGenericHttpMessageConverter(MediaType... supportedMediaTypes) { + super(supportedMediaTypes); + } + + + @Override + protected boolean supports(Class clazz) { + return true; + } + + @Override + public boolean canRead(Type type, @Nullable Class contextClass, @Nullable MediaType mediaType) { + return (type instanceof Class ? canRead((Class) type, mediaType) : canRead(mediaType)); + } + + @Override + public boolean canWrite(@Nullable Type type, Class clazz, @Nullable MediaType mediaType) { + return canWrite(clazz, mediaType); + } + + /** + * This implementation sets the default headers by calling {@link #addDefaultHeaders}, + * and then calls {@link #writeInternal}. + */ + public final void write(final T t, @Nullable final Type type, @Nullable MediaType contentType, + HttpOutputMessage outputMessage) throws IOException, HttpMessageNotWritableException { + + final HttpHeaders headers = outputMessage.getHeaders(); + addDefaultHeaders(headers, t, contentType); + + if (outputMessage instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) outputMessage; + streamingOutputMessage.setBody(outputStream -> writeInternal(t, type, new HttpOutputMessage() { + @Override + public OutputStream getBody() { + return outputStream; + } + @Override + public HttpHeaders getHeaders() { + return headers; + } + })); + } + else { + writeInternal(t, type, outputMessage); + outputMessage.getBody().flush(); + } + } + + @Override + protected void writeInternal(T t, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + writeInternal(t, null, outputMessage); + } + + /** + * Abstract template method that writes the actual body. Invoked from {@link #write}. + * @param t the object to write to the output message + * @param type the type of object to write (may be {@code null}) + * @param outputMessage the HTTP output message to write to + * @throws IOException in case of I/O errors + * @throws HttpMessageNotWritableException in case of conversion errors + */ + protected abstract void writeInternal(T t, @Nullable Type type, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..7b873b97b04647604578daca7af3a30af6be830a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractHttpMessageConverter.java @@ -0,0 +1,323 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.logging.Log; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Abstract base class for most {@link HttpMessageConverter} implementations. + * + *

This base class adds support for setting supported {@code MediaTypes}, through the + * {@link #setSupportedMediaTypes(List) supportedMediaTypes} bean property. It also adds + * support for {@code Content-Type} and {@code Content-Length} when writing to output messages. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 3.0 + * @param the converted object type + */ +public abstract class AbstractHttpMessageConverter implements HttpMessageConverter { + + /** Logger available to subclasses. */ + protected final Log logger = HttpLogging.forLogName(getClass()); + + private List supportedMediaTypes = Collections.emptyList(); + + @Nullable + private Charset defaultCharset; + + + /** + * Construct an {@code AbstractHttpMessageConverter} with no supported media types. + * @see #setSupportedMediaTypes + */ + protected AbstractHttpMessageConverter() { + } + + /** + * Construct an {@code AbstractHttpMessageConverter} with one supported media type. + * @param supportedMediaType the supported media type + */ + protected AbstractHttpMessageConverter(MediaType supportedMediaType) { + setSupportedMediaTypes(Collections.singletonList(supportedMediaType)); + } + + /** + * Construct an {@code AbstractHttpMessageConverter} with multiple supported media types. + * @param supportedMediaTypes the supported media types + */ + protected AbstractHttpMessageConverter(MediaType... supportedMediaTypes) { + setSupportedMediaTypes(Arrays.asList(supportedMediaTypes)); + } + + /** + * Construct an {@code AbstractHttpMessageConverter} with a default charset and + * multiple supported media types. + * @param defaultCharset the default character set + * @param supportedMediaTypes the supported media types + * @since 4.3 + */ + protected AbstractHttpMessageConverter(Charset defaultCharset, MediaType... supportedMediaTypes) { + this.defaultCharset = defaultCharset; + setSupportedMediaTypes(Arrays.asList(supportedMediaTypes)); + } + + + /** + * Set the list of {@link MediaType} objects supported by this converter. + */ + public void setSupportedMediaTypes(List supportedMediaTypes) { + Assert.notEmpty(supportedMediaTypes, "MediaType List must not be empty"); + this.supportedMediaTypes = new ArrayList<>(supportedMediaTypes); + } + + @Override + public List getSupportedMediaTypes() { + return Collections.unmodifiableList(this.supportedMediaTypes); + } + + /** + * Set the default character set, if any. + * @since 4.3 + */ + public void setDefaultCharset(@Nullable Charset defaultCharset) { + this.defaultCharset = defaultCharset; + } + + /** + * Return the default character set, if any. + * @since 4.3 + */ + @Nullable + public Charset getDefaultCharset() { + return this.defaultCharset; + } + + + /** + * This implementation checks if the given class is {@linkplain #supports(Class) supported}, + * and if the {@linkplain #getSupportedMediaTypes() supported media types} + * {@linkplain MediaType#includes(MediaType) include} the given media type. + */ + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return supports(clazz) && canRead(mediaType); + } + + /** + * Returns {@code true} if any of the {@linkplain #setSupportedMediaTypes(List) + * supported} media types {@link MediaType#includes(MediaType) include} the + * given media type. + * @param mediaType the media type to read, can be {@code null} if not specified. + * Typically the value of a {@code Content-Type} header. + * @return {@code true} if the supported media types include the media type, + * or if the media type is {@code null} + */ + protected boolean canRead(@Nullable MediaType mediaType) { + if (mediaType == null) { + return true; + } + for (MediaType supportedMediaType : getSupportedMediaTypes()) { + if (supportedMediaType.includes(mediaType)) { + return true; + } + } + return false; + } + + /** + * This implementation checks if the given class is + * {@linkplain #supports(Class) supported}, and if the + * {@linkplain #getSupportedMediaTypes() supported} media types + * {@linkplain MediaType#includes(MediaType) include} the given media type. + */ + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return supports(clazz) && canWrite(mediaType); + } + + /** + * Returns {@code true} if the given media type includes any of the + * {@linkplain #setSupportedMediaTypes(List) supported media types}. + * @param mediaType the media type to write, can be {@code null} if not specified. + * Typically the value of an {@code Accept} header. + * @return {@code true} if the supported media types are compatible with the media type, + * or if the media type is {@code null} + */ + protected boolean canWrite(@Nullable MediaType mediaType) { + if (mediaType == null || MediaType.ALL.equalsTypeAndSubtype(mediaType)) { + return true; + } + for (MediaType supportedMediaType : getSupportedMediaTypes()) { + if (supportedMediaType.isCompatibleWith(mediaType)) { + return true; + } + } + return false; + } + + /** + * This implementation simple delegates to {@link #readInternal(Class, HttpInputMessage)}. + * Future implementations might add some default behavior, however. + */ + @Override + public final T read(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + return readInternal(clazz, inputMessage); + } + + /** + * This implementation sets the default headers by calling {@link #addDefaultHeaders}, + * and then calls {@link #writeInternal}. + */ + @Override + public final void write(final T t, @Nullable MediaType contentType, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + final HttpHeaders headers = outputMessage.getHeaders(); + addDefaultHeaders(headers, t, contentType); + + if (outputMessage instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) outputMessage; + streamingOutputMessage.setBody(outputStream -> writeInternal(t, new HttpOutputMessage() { + @Override + public OutputStream getBody() { + return outputStream; + } + @Override + public HttpHeaders getHeaders() { + return headers; + } + })); + } + else { + writeInternal(t, outputMessage); + outputMessage.getBody().flush(); + } + } + + /** + * Add default headers to the output message. + *

This implementation delegates to {@link #getDefaultContentType(Object)} if a + * content type was not provided, set if necessary the default character set, calls + * {@link #getContentLength}, and sets the corresponding headers. + * @since 4.2 + */ + protected void addDefaultHeaders(HttpHeaders headers, T t, @Nullable MediaType contentType) throws IOException { + if (headers.getContentType() == null) { + MediaType contentTypeToUse = contentType; + if (contentType == null || contentType.isWildcardType() || contentType.isWildcardSubtype()) { + contentTypeToUse = getDefaultContentType(t); + } + else if (MediaType.APPLICATION_OCTET_STREAM.equals(contentType)) { + MediaType mediaType = getDefaultContentType(t); + contentTypeToUse = (mediaType != null ? mediaType : contentTypeToUse); + } + if (contentTypeToUse != null) { + if (contentTypeToUse.getCharset() == null) { + Charset defaultCharset = getDefaultCharset(); + if (defaultCharset != null) { + contentTypeToUse = new MediaType(contentTypeToUse, defaultCharset); + } + } + headers.setContentType(contentTypeToUse); + } + } + if (headers.getContentLength() < 0 && !headers.containsKey(HttpHeaders.TRANSFER_ENCODING)) { + Long contentLength = getContentLength(t, headers.getContentType()); + if (contentLength != null) { + headers.setContentLength(contentLength); + } + } + } + + /** + * Returns the default content type for the given type. Called when {@link #write} + * is invoked without a specified content type parameter. + *

By default, this returns the first element of the + * {@link #setSupportedMediaTypes(List) supportedMediaTypes} property, if any. + * Can be overridden in subclasses. + * @param t the type to return the content type for + * @return the content type, or {@code null} if not known + */ + @Nullable + protected MediaType getDefaultContentType(T t) throws IOException { + List mediaTypes = getSupportedMediaTypes(); + return (!mediaTypes.isEmpty() ? mediaTypes.get(0) : null); + } + + /** + * Returns the content length for the given type. + *

By default, this returns {@code null}, meaning that the content length is unknown. + * Can be overridden in subclasses. + * @param t the type to return the content length for + * @return the content length, or {@code null} if not known + */ + @Nullable + protected Long getContentLength(T t, @Nullable MediaType contentType) throws IOException { + return null; + } + + + /** + * Indicates whether the given class is supported by this converter. + * @param clazz the class to test for support + * @return {@code true} if supported; {@code false} otherwise + */ + protected abstract boolean supports(Class clazz); + + /** + * Abstract template method that reads the actual object. Invoked from {@link #read}. + * @param clazz the type of object to return + * @param inputMessage the HTTP input message to read from + * @return the converted object + * @throws IOException in case of I/O errors + * @throws HttpMessageNotReadableException in case of conversion errors + */ + protected abstract T readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException; + + /** + * Abstract template method that writes the actual body. Invoked from {@link #write}. + * @param t the object to write to the output message + * @param outputMessage the HTTP output message to write to + * @throws IOException in case of I/O errors + * @throws HttpMessageNotWritableException in case of conversion errors + */ + protected abstract void writeInternal(T t, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/BufferedImageHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/BufferedImageHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..cdfa5781c99d8238b8ad6078cc0d7babbb567fa1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/BufferedImageHttpMessageConverter.java @@ -0,0 +1,303 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import javax.imageio.IIOImage; +import javax.imageio.ImageIO; +import javax.imageio.ImageReadParam; +import javax.imageio.ImageReader; +import javax.imageio.ImageWriteParam; +import javax.imageio.ImageWriter; +import javax.imageio.stream.FileCacheImageInputStream; +import javax.imageio.stream.FileCacheImageOutputStream; +import javax.imageio.stream.ImageInputStream; +import javax.imageio.stream.ImageOutputStream; +import javax.imageio.stream.MemoryCacheImageInputStream; +import javax.imageio.stream.MemoryCacheImageOutputStream; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Implementation of {@link HttpMessageConverter} that can read and write + * {@link BufferedImage BufferedImages}. + * + *

By default, this converter can read all media types that are supported + * by the {@linkplain ImageIO#getReaderMIMETypes() registered image readers}, + * and writes using the media type of the first available + * {@linkplain javax.imageio.ImageIO#getWriterMIMETypes() registered image writer}. + * The latter can be overridden by setting the + * {@link #setDefaultContentType defaultContentType} property. + * + *

If the {@link #setCacheDir cacheDir} property is set, this converter + * will cache image data. + * + *

The {@link #process(ImageReadParam)} and {@link #process(ImageWriteParam)} + * template methods allow subclasses to override Image I/O parameters. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public class BufferedImageHttpMessageConverter implements HttpMessageConverter { + + private final List readableMediaTypes = new ArrayList<>(); + + @Nullable + private MediaType defaultContentType; + + @Nullable + private File cacheDir; + + + public BufferedImageHttpMessageConverter() { + String[] readerMediaTypes = ImageIO.getReaderMIMETypes(); + for (String mediaType : readerMediaTypes) { + if (StringUtils.hasText(mediaType)) { + this.readableMediaTypes.add(MediaType.parseMediaType(mediaType)); + } + } + + String[] writerMediaTypes = ImageIO.getWriterMIMETypes(); + for (String mediaType : writerMediaTypes) { + if (StringUtils.hasText(mediaType)) { + this.defaultContentType = MediaType.parseMediaType(mediaType); + break; + } + } + } + + + /** + * Sets the default {@code Content-Type} to be used for writing. + * @throws IllegalArgumentException if the given content type is not supported by the Java Image I/O API + */ + public void setDefaultContentType(@Nullable MediaType defaultContentType) { + if (defaultContentType != null) { + Iterator imageWriters = ImageIO.getImageWritersByMIMEType(defaultContentType.toString()); + if (!imageWriters.hasNext()) { + throw new IllegalArgumentException( + "Content-Type [" + defaultContentType + "] is not supported by the Java Image I/O API"); + } + } + + this.defaultContentType = defaultContentType; + } + + /** + * Returns the default {@code Content-Type} to be used for writing. + * Called when {@link #write} is invoked without a specified content type parameter. + */ + @Nullable + public MediaType getDefaultContentType() { + return this.defaultContentType; + } + + /** + * Sets the cache directory. If this property is set to an existing directory, + * this converter will cache image data. + */ + public void setCacheDir(File cacheDir) { + Assert.notNull(cacheDir, "'cacheDir' must not be null"); + Assert.isTrue(cacheDir.isDirectory(), "'cacheDir' is not a directory"); + this.cacheDir = cacheDir; + } + + + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return (BufferedImage.class == clazz && isReadable(mediaType)); + } + + private boolean isReadable(@Nullable MediaType mediaType) { + if (mediaType == null) { + return true; + } + Iterator imageReaders = ImageIO.getImageReadersByMIMEType(mediaType.toString()); + return imageReaders.hasNext(); + } + + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return (BufferedImage.class == clazz && isWritable(mediaType)); + } + + private boolean isWritable(@Nullable MediaType mediaType) { + if (mediaType == null || MediaType.ALL.equalsTypeAndSubtype(mediaType)) { + return true; + } + Iterator imageWriters = ImageIO.getImageWritersByMIMEType(mediaType.toString()); + return imageWriters.hasNext(); + } + + @Override + public List getSupportedMediaTypes() { + return Collections.unmodifiableList(this.readableMediaTypes); + } + + @Override + public BufferedImage read(@Nullable Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + ImageInputStream imageInputStream = null; + ImageReader imageReader = null; + try { + imageInputStream = createImageInputStream(inputMessage.getBody()); + MediaType contentType = inputMessage.getHeaders().getContentType(); + if (contentType == null) { + throw new HttpMessageNotReadableException("No Content-Type header", inputMessage); + } + Iterator imageReaders = ImageIO.getImageReadersByMIMEType(contentType.toString()); + if (imageReaders.hasNext()) { + imageReader = imageReaders.next(); + ImageReadParam irp = imageReader.getDefaultReadParam(); + process(irp); + imageReader.setInput(imageInputStream, true); + return imageReader.read(0, irp); + } + else { + throw new HttpMessageNotReadableException( + "Could not find javax.imageio.ImageReader for Content-Type [" + contentType + "]", + inputMessage); + } + } + finally { + if (imageReader != null) { + imageReader.dispose(); + } + if (imageInputStream != null) { + try { + imageInputStream.close(); + } + catch (IOException ex) { + // ignore + } + } + } + } + + private ImageInputStream createImageInputStream(InputStream is) throws IOException { + if (this.cacheDir != null) { + return new FileCacheImageInputStream(is, this.cacheDir); + } + else { + return new MemoryCacheImageInputStream(is); + } + } + + @Override + public void write(final BufferedImage image, @Nullable final MediaType contentType, + final HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + final MediaType selectedContentType = getContentType(contentType); + outputMessage.getHeaders().setContentType(selectedContentType); + + if (outputMessage instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) outputMessage; + streamingOutputMessage.setBody(outputStream -> writeInternal(image, selectedContentType, outputStream)); + } + else { + writeInternal(image, selectedContentType, outputMessage.getBody()); + } + } + + private MediaType getContentType(@Nullable MediaType contentType) { + if (contentType == null || contentType.isWildcardType() || contentType.isWildcardSubtype()) { + contentType = getDefaultContentType(); + } + Assert.notNull(contentType, "Could not select Content-Type. " + + "Please specify one through the 'defaultContentType' property."); + return contentType; + } + + private void writeInternal(BufferedImage image, MediaType contentType, OutputStream body) + throws IOException, HttpMessageNotWritableException { + + ImageOutputStream imageOutputStream = null; + ImageWriter imageWriter = null; + try { + Iterator imageWriters = ImageIO.getImageWritersByMIMEType(contentType.toString()); + if (imageWriters.hasNext()) { + imageWriter = imageWriters.next(); + ImageWriteParam iwp = imageWriter.getDefaultWriteParam(); + process(iwp); + imageOutputStream = createImageOutputStream(body); + imageWriter.setOutput(imageOutputStream); + imageWriter.write(null, new IIOImage(image, null, null), iwp); + } + else { + throw new HttpMessageNotWritableException( + "Could not find javax.imageio.ImageWriter for Content-Type [" + contentType + "]"); + } + } + finally { + if (imageWriter != null) { + imageWriter.dispose(); + } + if (imageOutputStream != null) { + try { + imageOutputStream.close(); + } + catch (IOException ex) { + // ignore + } + } + } + } + + private ImageOutputStream createImageOutputStream(OutputStream os) throws IOException { + if (this.cacheDir != null) { + return new FileCacheImageOutputStream(os, this.cacheDir); + } + else { + return new MemoryCacheImageOutputStream(os); + } + } + + + /** + * Template method that allows for manipulating the {@link ImageReadParam} + * before it is used to read an image. + *

The default implementation is empty. + */ + protected void process(ImageReadParam irp) { + } + + /** + * Template method that allows for manipulating the {@link ImageWriteParam} + * before it is used to write an image. + *

The default implementation is empty. + */ + protected void process(ImageWriteParam iwp) { + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/ByteArrayHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ByteArrayHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..43d1f7ed30fc55095648aad7802b3ac4382d17db --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/ByteArrayHttpMessageConverter.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * Implementation of {@link HttpMessageConverter} that can read and write byte arrays. + * + *

By default, this converter supports all media types ({@code */*}), and + * writes with a {@code Content-Type} of {@code application/octet-stream}. This can be + * overridden by setting the {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + */ +public class ByteArrayHttpMessageConverter extends AbstractHttpMessageConverter { + + /** + * Create a new instance of the {@code ByteArrayHttpMessageConverter}. + */ + public ByteArrayHttpMessageConverter() { + super(MediaType.APPLICATION_OCTET_STREAM, MediaType.ALL); + } + + + @Override + public boolean supports(Class clazz) { + return byte[].class == clazz; + } + + @Override + public byte[] readInternal(Class clazz, HttpInputMessage inputMessage) throws IOException { + long contentLength = inputMessage.getHeaders().getContentLength(); + ByteArrayOutputStream bos = + new ByteArrayOutputStream(contentLength >= 0 ? (int) contentLength : StreamUtils.BUFFER_SIZE); + StreamUtils.copy(inputMessage.getBody(), bos); + return bos.toByteArray(); + } + + @Override + protected Long getContentLength(byte[] bytes, @Nullable MediaType contentType) { + return (long) bytes.length; + } + + @Override + protected void writeInternal(byte[] bytes, HttpOutputMessage outputMessage) throws IOException { + StreamUtils.copy(bytes, outputMessage.getBody()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..86f8f6c5bb7e397687cc01e62493aa2040487341 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java @@ -0,0 +1,560 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.net.URLEncoder; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import javax.mail.internet.MimeUtility; + +import org.springframework.core.io.Resource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; + +/** + * Implementation of {@link HttpMessageConverter} to read and write 'normal' HTML + * forms and also to write (but not read) multipart data (e.g. file uploads). + * + *

In other words, this converter can read and write the + * {@code "application/x-www-form-urlencoded"} media type as + * {@link MultiValueMap MultiValueMap<String, String>} and it can also + * write (but not read) the {@code "multipart/form-data"} media type as + * {@link MultiValueMap MultiValueMap<String, Object>}. + * + *

When writing multipart data, this converter uses other + * {@link HttpMessageConverter HttpMessageConverters} to write the respective + * MIME parts. By default, basic converters are registered (for {@code Strings} + * and {@code Resources}). These can be overridden through the + * {@link #setPartConverters partConverters} property. + * + *

For example, the following snippet shows how to submit an HTML form: + *

+ * RestTemplate template = new RestTemplate();
+ * // AllEncompassingFormHttpMessageConverter is configured by default
+ *
+ * MultiValueMap<String, Object> form = new LinkedMultiValueMap<>();
+ * form.add("field 1", "value 1");
+ * form.add("field 2", "value 2");
+ * form.add("field 2", "value 3");
+ * form.add("field 3", 4);  // non-String form values supported as of 5.1.4
+ * template.postForLocation("https://example.com/myForm", form);
+ * 
+ * + *

The following snippet shows how to do a file upload: + *

+ * MultiValueMap<String, Object> parts = new LinkedMultiValueMap<>();
+ * parts.add("field 1", "value 1");
+ * parts.add("file", new ClassPathResource("myFile.jpg"));
+ * template.postForLocation("https://example.com/myFileUpload", parts);
+ * 
+ * + *

Some methods in this class were inspired by + * {@code org.apache.commons.httpclient.methods.multipart.MultipartRequestEntity}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.0 + * @see org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter + * @see org.springframework.util.MultiValueMap + */ +public class FormHttpMessageConverter implements HttpMessageConverter> { + + /** + * The default charset used by the converter. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + private static final MediaType DEFAULT_FORM_DATA_MEDIA_TYPE = + new MediaType(MediaType.APPLICATION_FORM_URLENCODED, DEFAULT_CHARSET); + + + private List supportedMediaTypes = new ArrayList<>(); + + private List> partConverters = new ArrayList<>(); + + private Charset charset = DEFAULT_CHARSET; + + @Nullable + private Charset multipartCharset; + + + public FormHttpMessageConverter() { + this.supportedMediaTypes.add(MediaType.APPLICATION_FORM_URLENCODED); + this.supportedMediaTypes.add(MediaType.MULTIPART_FORM_DATA); + + StringHttpMessageConverter stringHttpMessageConverter = new StringHttpMessageConverter(); + stringHttpMessageConverter.setWriteAcceptCharset(false); // see SPR-7316 + + this.partConverters.add(new ByteArrayHttpMessageConverter()); + this.partConverters.add(stringHttpMessageConverter); + this.partConverters.add(new ResourceHttpMessageConverter()); + + applyDefaultCharset(); + } + + + /** + * Set the list of {@link MediaType} objects supported by this converter. + */ + public void setSupportedMediaTypes(List supportedMediaTypes) { + this.supportedMediaTypes = supportedMediaTypes; + } + + @Override + public List getSupportedMediaTypes() { + return Collections.unmodifiableList(this.supportedMediaTypes); + } + + /** + * Set the message body converters to use. These converters are used to + * convert objects to MIME parts. + */ + public void setPartConverters(List> partConverters) { + Assert.notEmpty(partConverters, "'partConverters' must not be empty"); + this.partConverters = partConverters; + } + + /** + * Add a message body converter. Such a converter is used to convert objects + * to MIME parts. + */ + public void addPartConverter(HttpMessageConverter partConverter) { + Assert.notNull(partConverter, "'partConverter' must not be null"); + this.partConverters.add(partConverter); + } + + /** + * Set the default character set to use for reading and writing form data when + * the request or response Content-Type header does not explicitly specify it. + *

As of 4.3, this is also used as the default charset for the conversion + * of text bodies in a multipart request. + *

As of 5.0 this is also used for part headers including + * "Content-Disposition" (and its filename parameter) unless (the mutually + * exclusive) {@link #setMultipartCharset} is also set, in which case part + * headers are encoded as ASCII and filename is encoded with the + * "encoded-word" syntax from RFC 2047. + *

By default this is set to "UTF-8". + */ + public void setCharset(@Nullable Charset charset) { + if (charset != this.charset) { + this.charset = (charset != null ? charset : DEFAULT_CHARSET); + applyDefaultCharset(); + } + } + + /** + * Apply the configured charset as a default to registered part converters. + */ + private void applyDefaultCharset() { + for (HttpMessageConverter candidate : this.partConverters) { + if (candidate instanceof AbstractHttpMessageConverter) { + AbstractHttpMessageConverter converter = (AbstractHttpMessageConverter) candidate; + // Only override default charset if the converter operates with a charset to begin with... + if (converter.getDefaultCharset() != null) { + converter.setDefaultCharset(this.charset); + } + } + } + } + + /** + * Set the character set to use when writing multipart data to encode file + * names. Encoding is based on the "encoded-word" syntax defined in RFC 2047 + * and relies on {@code MimeUtility} from "javax.mail". + *

As of 5.0 by default part headers, including Content-Disposition (and + * its filename parameter) will be encoded based on the setting of + * {@link #setCharset(Charset)} or {@code UTF-8} by default. + * @since 4.1.1 + * @see Encoded-Word + */ + public void setMultipartCharset(Charset charset) { + this.multipartCharset = charset; + } + + + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + if (!MultiValueMap.class.isAssignableFrom(clazz)) { + return false; + } + if (mediaType == null) { + return true; + } + for (MediaType supportedMediaType : getSupportedMediaTypes()) { + // We can't read multipart.... + if (!supportedMediaType.equals(MediaType.MULTIPART_FORM_DATA) && supportedMediaType.includes(mediaType)) { + return true; + } + } + return false; + } + + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + if (!MultiValueMap.class.isAssignableFrom(clazz)) { + return false; + } + if (mediaType == null || MediaType.ALL.equals(mediaType)) { + return true; + } + for (MediaType supportedMediaType : getSupportedMediaTypes()) { + if (supportedMediaType.isCompatibleWith(mediaType)) { + return true; + } + } + return false; + } + + @Override + public MultiValueMap read(@Nullable Class> clazz, + HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException { + + MediaType contentType = inputMessage.getHeaders().getContentType(); + Charset charset = (contentType != null && contentType.getCharset() != null ? + contentType.getCharset() : this.charset); + String body = StreamUtils.copyToString(inputMessage.getBody(), charset); + + String[] pairs = StringUtils.tokenizeToStringArray(body, "&"); + MultiValueMap result = new LinkedMultiValueMap<>(pairs.length); + for (String pair : pairs) { + int idx = pair.indexOf('='); + if (idx == -1) { + result.add(URLDecoder.decode(pair, charset.name()), null); + } + else { + String name = URLDecoder.decode(pair.substring(0, idx), charset.name()); + String value = URLDecoder.decode(pair.substring(idx + 1), charset.name()); + result.add(name, value); + } + } + return result; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MultiValueMap map, @Nullable MediaType contentType, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + if (!isMultipart(map, contentType)) { + writeForm((MultiValueMap) map, contentType, outputMessage); + } + else { + writeMultipart((MultiValueMap) map, outputMessage); + } + } + + + private boolean isMultipart(MultiValueMap map, @Nullable MediaType contentType) { + if (contentType != null) { + return MediaType.MULTIPART_FORM_DATA.includes(contentType); + } + for (List values : map.values()) { + for (Object value : values) { + if (value != null && !(value instanceof String)) { + return true; + } + } + } + return false; + } + + private void writeForm(MultiValueMap formData, @Nullable MediaType contentType, + HttpOutputMessage outputMessage) throws IOException { + + contentType = getMediaType(contentType); + outputMessage.getHeaders().setContentType(contentType); + + Charset charset = contentType.getCharset(); + Assert.notNull(charset, "No charset"); // should never occur + + final byte[] bytes = serializeForm(formData, charset).getBytes(charset); + outputMessage.getHeaders().setContentLength(bytes.length); + + if (outputMessage instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) outputMessage; + streamingOutputMessage.setBody(outputStream -> StreamUtils.copy(bytes, outputStream)); + } + else { + StreamUtils.copy(bytes, outputMessage.getBody()); + } + } + + private MediaType getMediaType(@Nullable MediaType mediaType) { + if (mediaType == null) { + return DEFAULT_FORM_DATA_MEDIA_TYPE; + } + else if (mediaType.getCharset() == null) { + return new MediaType(mediaType, this.charset); + } + else { + return mediaType; + } + } + + protected String serializeForm(MultiValueMap formData, Charset charset) { + StringBuilder builder = new StringBuilder(); + formData.forEach((name, values) -> + values.forEach(value -> { + try { + if (builder.length() != 0) { + builder.append('&'); + } + builder.append(URLEncoder.encode(name, charset.name())); + if (value != null) { + builder.append('='); + builder.append(URLEncoder.encode(String.valueOf(value), charset.name())); + } + } + catch (UnsupportedEncodingException ex) { + throw new IllegalStateException(ex); + } + })); + + return builder.toString(); + } + + private void writeMultipart(final MultiValueMap parts, HttpOutputMessage outputMessage) + throws IOException { + + final byte[] boundary = generateMultipartBoundary(); + Map parameters = new LinkedHashMap<>(2); + if (!isFilenameCharsetSet()) { + parameters.put("charset", this.charset.name()); + } + parameters.put("boundary", new String(boundary, StandardCharsets.US_ASCII)); + + MediaType contentType = new MediaType(MediaType.MULTIPART_FORM_DATA, parameters); + HttpHeaders headers = outputMessage.getHeaders(); + headers.setContentType(contentType); + + if (outputMessage instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) outputMessage; + streamingOutputMessage.setBody(outputStream -> { + writeParts(outputStream, parts, boundary); + writeEnd(outputStream, boundary); + }); + } + else { + writeParts(outputMessage.getBody(), parts, boundary); + writeEnd(outputMessage.getBody(), boundary); + } + } + + /** + * When {@link #setMultipartCharset(Charset)} is configured (i.e. RFC 2047, + * "encoded-word" syntax) we need to use ASCII for part headers or otherwise + * we encode directly using the configured {@link #setCharset(Charset)}. + */ + private boolean isFilenameCharsetSet() { + return (this.multipartCharset != null); + } + + private void writeParts(OutputStream os, MultiValueMap parts, byte[] boundary) throws IOException { + for (Map.Entry> entry : parts.entrySet()) { + String name = entry.getKey(); + for (Object part : entry.getValue()) { + if (part != null) { + writeBoundary(os, boundary); + writePart(name, getHttpEntity(part), os); + writeNewLine(os); + } + } + } + } + + @SuppressWarnings("unchecked") + private void writePart(String name, HttpEntity partEntity, OutputStream os) throws IOException { + Object partBody = partEntity.getBody(); + if (partBody == null) { + throw new IllegalStateException("Empty body for part '" + name + "': " + partEntity); + } + Class partType = partBody.getClass(); + HttpHeaders partHeaders = partEntity.getHeaders(); + MediaType partContentType = partHeaders.getContentType(); + for (HttpMessageConverter messageConverter : this.partConverters) { + if (messageConverter.canWrite(partType, partContentType)) { + Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset; + HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset); + multipartMessage.getHeaders().setContentDispositionFormData(name, getFilename(partBody)); + if (!partHeaders.isEmpty()) { + multipartMessage.getHeaders().putAll(partHeaders); + } + ((HttpMessageConverter) messageConverter).write(partBody, partContentType, multipartMessage); + return; + } + } + throw new HttpMessageNotWritableException("Could not write request: no suitable HttpMessageConverter " + + "found for request type [" + partType.getName() + "]"); + } + + /** + * Generate a multipart boundary. + *

This implementation delegates to + * {@link MimeTypeUtils#generateMultipartBoundary()}. + */ + protected byte[] generateMultipartBoundary() { + return MimeTypeUtils.generateMultipartBoundary(); + } + + /** + * Return an {@link HttpEntity} for the given part Object. + * @param part the part to return an {@link HttpEntity} for + * @return the part Object itself it is an {@link HttpEntity}, + * or a newly built {@link HttpEntity} wrapper for that part + */ + protected HttpEntity getHttpEntity(Object part) { + return (part instanceof HttpEntity ? (HttpEntity) part : new HttpEntity<>(part)); + } + + /** + * Return the filename of the given multipart part. This value will be used for the + * {@code Content-Disposition} header. + *

The default implementation returns {@link Resource#getFilename()} if the part is a + * {@code Resource}, and {@code null} in other cases. Can be overridden in subclasses. + * @param part the part to determine the file name for + * @return the filename, or {@code null} if not known + */ + @Nullable + protected String getFilename(Object part) { + if (part instanceof Resource) { + Resource resource = (Resource) part; + String filename = resource.getFilename(); + if (filename != null && this.multipartCharset != null) { + filename = MimeDelegate.encode(filename, this.multipartCharset.name()); + } + return filename; + } + else { + return null; + } + } + + + private void writeBoundary(OutputStream os, byte[] boundary) throws IOException { + os.write('-'); + os.write('-'); + os.write(boundary); + writeNewLine(os); + } + + private static void writeEnd(OutputStream os, byte[] boundary) throws IOException { + os.write('-'); + os.write('-'); + os.write(boundary); + os.write('-'); + os.write('-'); + writeNewLine(os); + } + + private static void writeNewLine(OutputStream os) throws IOException { + os.write('\r'); + os.write('\n'); + } + + + /** + * Implementation of {@link org.springframework.http.HttpOutputMessage} used + * to write a MIME multipart. + */ + private static class MultipartHttpOutputMessage implements HttpOutputMessage { + + private final OutputStream outputStream; + + private final Charset charset; + + private final HttpHeaders headers = new HttpHeaders(); + + private boolean headersWritten = false; + + public MultipartHttpOutputMessage(OutputStream outputStream, Charset charset) { + this.outputStream = outputStream; + this.charset = charset; + } + + @Override + public HttpHeaders getHeaders() { + return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public OutputStream getBody() throws IOException { + writeHeaders(); + return this.outputStream; + } + + private void writeHeaders() throws IOException { + if (!this.headersWritten) { + for (Map.Entry> entry : this.headers.entrySet()) { + byte[] headerName = getBytes(entry.getKey()); + for (String headerValueString : entry.getValue()) { + byte[] headerValue = getBytes(headerValueString); + this.outputStream.write(headerName); + this.outputStream.write(':'); + this.outputStream.write(' '); + this.outputStream.write(headerValue); + writeNewLine(this.outputStream); + } + } + writeNewLine(this.outputStream); + this.headersWritten = true; + } + } + + private byte[] getBytes(String name) { + return name.getBytes(this.charset); + } + } + + + /** + * Inner class to avoid a hard dependency on the JavaMail API. + */ + private static class MimeDelegate { + + public static String encode(String value, String charset) { + try { + return MimeUtility.encodeText(value, charset, null); + } + catch (UnsupportedEncodingException ex) { + throw new IllegalStateException(ex); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/GenericHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/GenericHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..b3c92ccb3344c372e2557109b04918d005db9786 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/GenericHttpMessageConverter.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.lang.reflect.Type; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * A specialization of {@link HttpMessageConverter} that can convert an HTTP request + * into a target object of a specified generic type and a source object of a specified + * generic type into an HTTP response. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 3.2 + * @param the converted object type + * @see org.springframework.core.ParameterizedTypeReference + */ +public interface GenericHttpMessageConverter extends HttpMessageConverter { + + /** + * Indicates whether the given type can be read by this converter. + * This method should perform the same checks than + * {@link HttpMessageConverter#canRead(Class, MediaType)} with additional ones + * related to the generic type. + * @param type the (potentially generic) type to test for readability + * @param contextClass a context class for the target type, for example a class + * in which the target type appears in a method signature (can be {@code null}) + * @param mediaType the media type to read, can be {@code null} if not specified. + * Typically the value of a {@code Content-Type} header. + * @return {@code true} if readable; {@code false} otherwise + */ + boolean canRead(Type type, @Nullable Class contextClass, @Nullable MediaType mediaType); + + /** + * Read an object of the given type form the given input message, and returns it. + * @param type the (potentially generic) type of object to return. This type must have + * previously been passed to the {@link #canRead canRead} method of this interface, + * which must have returned {@code true}. + * @param contextClass a context class for the target type, for example a class + * in which the target type appears in a method signature (can be {@code null}) + * @param inputMessage the HTTP input message to read from + * @return the converted object + * @throws IOException in case of I/O errors + * @throws HttpMessageNotReadableException in case of conversion errors + */ + T read(Type type, @Nullable Class contextClass, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException; + + /** + * Indicates whether the given class can be written by this converter. + *

This method should perform the same checks than + * {@link HttpMessageConverter#canWrite(Class, MediaType)} with additional ones + * related to the generic type. + * @param type the (potentially generic) type to test for writability + * (can be {@code null} if not specified) + * @param clazz the source object class to test for writability + * @param mediaType the media type to write (can be {@code null} if not specified); + * typically the value of an {@code Accept} header. + * @return {@code true} if writable; {@code false} otherwise + * @since 4.2 + */ + boolean canWrite(@Nullable Type type, Class clazz, @Nullable MediaType mediaType); + + /** + * Write an given object to the given output message. + * @param t the object to write to the output message. The type of this object must + * have previously been passed to the {@link #canWrite canWrite} method of this + * interface, which must have returned {@code true}. + * @param type the (potentially generic) type of object to write. This type must have + * previously been passed to the {@link #canWrite canWrite} method of this interface, + * which must have returned {@code true}. Can be {@code null} if not specified. + * @param contentType the content type to use when writing. May be {@code null} to + * indicate that the default content type of the converter must be used. If not + * {@code null}, this media type must have previously been passed to the + * {@link #canWrite canWrite} method of this interface, which must have returned + * {@code true}. + * @param outputMessage the message to write to + * @throws IOException in case of I/O errors + * @throws HttpMessageNotWritableException in case of conversion errors + * @since 4.2 + */ + void write(T t, @Nullable Type type, @Nullable MediaType contentType, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConversionException.java b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConversionException.java new file mode 100644 index 0000000000000000000000000000000000000000..6a2f9bb8b8fcd76f205abade369b16cb2c9a8911 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConversionException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import org.springframework.core.NestedRuntimeException; +import org.springframework.lang.Nullable; + +/** + * Thrown by {@link HttpMessageConverter} implementations when a conversion attempt fails. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @since 3.0 + */ +@SuppressWarnings("serial") +public class HttpMessageConversionException extends NestedRuntimeException { + + /** + * Create a new HttpMessageConversionException. + * @param msg the detail message + */ + public HttpMessageConversionException(String msg) { + super(msg); + } + + /** + * Create a new HttpMessageConversionException. + * @param msg the detail message + * @param cause the root cause (if any) + */ + public HttpMessageConversionException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..a0583b9c6d57397a1180d97b1bc7ff01bf3cb31b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.util.List; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * Strategy interface for converting from and to HTTP requests and responses. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @param the converted object type + */ +public interface HttpMessageConverter { + + /** + * Indicates whether the given class can be read by this converter. + * @param clazz the class to test for readability + * @param mediaType the media type to read (can be {@code null} if not specified); + * typically the value of a {@code Content-Type} header. + * @return {@code true} if readable; {@code false} otherwise + */ + boolean canRead(Class clazz, @Nullable MediaType mediaType); + + /** + * Indicates whether the given class can be written by this converter. + * @param clazz the class to test for writability + * @param mediaType the media type to write (can be {@code null} if not specified); + * typically the value of an {@code Accept} header. + * @return {@code true} if writable; {@code false} otherwise + */ + boolean canWrite(Class clazz, @Nullable MediaType mediaType); + + /** + * Return the list of {@link MediaType} objects supported by this converter. + * @return the list of supported media types, potentially an immutable copy + */ + List getSupportedMediaTypes(); + + /** + * Read an object of the given type from the given input message, and returns it. + * @param clazz the type of object to return. This type must have previously been passed to the + * {@link #canRead canRead} method of this interface, which must have returned {@code true}. + * @param inputMessage the HTTP input message to read from + * @return the converted object + * @throws IOException in case of I/O errors + * @throws HttpMessageNotReadableException in case of conversion errors + */ + T read(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException; + + /** + * Write an given object to the given output message. + * @param t the object to write to the output message. The type of this object must have previously been + * passed to the {@link #canWrite canWrite} method of this interface, which must have returned {@code true}. + * @param contentType the content type to use when writing. May be {@code null} to indicate that the + * default content type of the converter must be used. If not {@code null}, this media type must have + * previously been passed to the {@link #canWrite canWrite} method of this interface, which must have + * returned {@code true}. + * @param outputMessage the message to write to + * @throws IOException in case of I/O errors + * @throws HttpMessageNotWritableException in case of conversion errors + */ + void write(T t, @Nullable MediaType contentType, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException; + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageNotReadableException.java b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageNotReadableException.java new file mode 100644 index 0000000000000000000000000000000000000000..1b9c7630b3b601f2aff14bca58bbad2b3a77e43a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageNotReadableException.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import org.springframework.http.HttpInputMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Thrown by {@link HttpMessageConverter} implementations when the + * {@link HttpMessageConverter#read} method fails. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + */ +@SuppressWarnings("serial") +public class HttpMessageNotReadableException extends HttpMessageConversionException { + + @Nullable + private final HttpInputMessage httpInputMessage; + + + /** + * Create a new HttpMessageNotReadableException. + * @param msg the detail message + * @deprecated as of 5.1, in favor of {@link #HttpMessageNotReadableException(String, HttpInputMessage)} + */ + @Deprecated + public HttpMessageNotReadableException(String msg) { + super(msg); + this.httpInputMessage = null; + } + + /** + * Create a new HttpMessageNotReadableException. + * @param msg the detail message + * @param cause the root cause (if any) + * @deprecated as of 5.1, in favor of {@link #HttpMessageNotReadableException(String, Throwable, HttpInputMessage)} + */ + @Deprecated + public HttpMessageNotReadableException(String msg, @Nullable Throwable cause) { + super(msg, cause); + this.httpInputMessage = null; + } + + /** + * Create a new HttpMessageNotReadableException. + * @param msg the detail message + * @param httpInputMessage the original HTTP message + * @since 5.1 + */ + public HttpMessageNotReadableException(String msg, HttpInputMessage httpInputMessage) { + super(msg); + this.httpInputMessage = httpInputMessage; + } + + /** + * Create a new HttpMessageNotReadableException. + * @param msg the detail message + * @param cause the root cause (if any) + * @param httpInputMessage the original HTTP message + * @since 5.1 + */ + public HttpMessageNotReadableException(String msg, @Nullable Throwable cause, HttpInputMessage httpInputMessage) { + super(msg, cause); + this.httpInputMessage = httpInputMessage; + } + + + /** + * Return the original HTTP message. + * @since 5.1 + */ + public HttpInputMessage getHttpInputMessage() { + Assert.state(this.httpInputMessage != null, "No HttpInputMessage available - use non-deprecated constructors"); + return this.httpInputMessage; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageNotWritableException.java b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageNotWritableException.java new file mode 100644 index 0000000000000000000000000000000000000000..b6c838fccaaa4bc3467e0df4671630c8e79325ee --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageNotWritableException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import org.springframework.lang.Nullable; + +/** + * Thrown by {@link HttpMessageConverter} implementations when the + * {@link HttpMessageConverter#write} method fails. + * + * @author Arjen Poutsma + * @since 3.0 + */ +@SuppressWarnings("serial") +public class HttpMessageNotWritableException extends HttpMessageConversionException { + + /** + * Create a new HttpMessageNotWritableException. + * @param msg the detail message + */ + public HttpMessageNotWritableException(String msg) { + super(msg); + } + + /** + * Create a new HttpMessageNotWritableException. + * @param msg the detail message + * @param cause the root cause (if any) + */ + public HttpMessageNotWritableException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/ObjectToStringHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ObjectToStringHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..13d0f4262a6fde77f6447eaf5caa2a7b9c1fef57 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/ObjectToStringHttpMessageConverter.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.nio.charset.Charset; + +import org.springframework.core.convert.ConversionService; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * An {@code HttpMessageConverter} that uses {@link StringHttpMessageConverter} + * for reading and writing content and a {@link ConversionService} for converting + * the String content to and from the target object type. + * + *

By default, this converter supports the media type {@code text/plain} only. + * This can be overridden through the {@link #setSupportedMediaTypes supportedMediaTypes} + * property. + * + *

A usage example: + * + *

+ * <bean class="org.springframework.http.converter.ObjectToStringHttpMessageConverter">
+ *   <constructor-arg>
+ *     <bean class="org.springframework.context.support.ConversionServiceFactoryBean"/>
+ *   </constructor-arg>
+ * </bean>
+ * 
+ * + * @author Dmitry Katsubo + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class ObjectToStringHttpMessageConverter extends AbstractHttpMessageConverter { + + private final ConversionService conversionService; + + private final StringHttpMessageConverter stringHttpMessageConverter; + + + /** + * A constructor accepting a {@code ConversionService} to use to convert the + * (String) message body to/from the target class type. This constructor uses + * {@link StringHttpMessageConverter#DEFAULT_CHARSET} as the default charset. + * @param conversionService the conversion service + */ + public ObjectToStringHttpMessageConverter(ConversionService conversionService) { + this(conversionService, StringHttpMessageConverter.DEFAULT_CHARSET); + } + + /** + * A constructor accepting a {@code ConversionService} as well as a default charset. + * @param conversionService the conversion service + * @param defaultCharset the default charset + */ + public ObjectToStringHttpMessageConverter(ConversionService conversionService, Charset defaultCharset) { + super(defaultCharset, MediaType.TEXT_PLAIN); + + Assert.notNull(conversionService, "ConversionService is required"); + this.conversionService = conversionService; + this.stringHttpMessageConverter = new StringHttpMessageConverter(defaultCharset); + } + + + /** + * Indicates whether the {@code Accept-Charset} should be written to any outgoing request. + *

Default is {@code true}. + */ + public void setWriteAcceptCharset(boolean writeAcceptCharset) { + this.stringHttpMessageConverter.setWriteAcceptCharset(writeAcceptCharset); + } + + + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return canRead(mediaType) && this.conversionService.canConvert(String.class, clazz); + } + + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return canWrite(mediaType) && this.conversionService.canConvert(clazz, String.class); + } + + @Override + protected boolean supports(Class clazz) { + // should not be called, since we override canRead/Write + throw new UnsupportedOperationException(); + } + + @Override + protected Object readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + String value = this.stringHttpMessageConverter.readInternal(String.class, inputMessage); + Object result = this.conversionService.convert(value, clazz); + if (result == null) { + throw new HttpMessageNotReadableException( + "Unexpected null conversion result for '" + value + "' to " + clazz, + inputMessage); + } + return result; + } + + @Override + protected void writeInternal(Object obj, HttpOutputMessage outputMessage) throws IOException { + String value = this.conversionService.convert(obj, String.class); + if (value != null) { + this.stringHttpMessageConverter.writeInternal(value, outputMessage); + } + } + + @Override + protected Long getContentLength(Object obj, @Nullable MediaType contentType) { + String value = this.conversionService.convert(obj, String.class); + if (value == null) { + return 0L; + } + return this.stringHttpMessageConverter.getContentLength(value, contentType); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/ResourceHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ResourceHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..582ab859de1692c3babd8cdacbf73f8184cf109f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/ResourceHttpMessageConverter.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; + +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * Implementation of {@link HttpMessageConverter} that can read/write {@link Resource Resources} + * and supports byte range requests. + * + *

By default, this converter can read all media types. The {@link MediaTypeFactory} is used + * to determine the {@code Content-Type} of written resources. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Kazuki Shimizu + * @since 3.0.2 + */ +public class ResourceHttpMessageConverter extends AbstractHttpMessageConverter { + + private final boolean supportsReadStreaming; + + + /** + * Create a new instance of the {@code ResourceHttpMessageConverter} + * that supports read streaming, i.e. can convert an + * {@code HttpInputMessage} to {@code InputStreamResource}. + */ + public ResourceHttpMessageConverter() { + super(MediaType.ALL); + this.supportsReadStreaming = true; + } + + /** + * Create a new instance of the {@code ResourceHttpMessageConverter}. + * @param supportsReadStreaming whether the converter should support + * read streaming, i.e. convert to {@code InputStreamResource} + * @since 5.0 + */ + public ResourceHttpMessageConverter(boolean supportsReadStreaming) { + super(MediaType.ALL); + this.supportsReadStreaming = supportsReadStreaming; + } + + + @Override + protected boolean supports(Class clazz) { + return Resource.class.isAssignableFrom(clazz); + } + + @Override + protected Resource readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + if (this.supportsReadStreaming && InputStreamResource.class == clazz) { + return new InputStreamResource(inputMessage.getBody()) { + @Override + public String getFilename() { + return inputMessage.getHeaders().getContentDisposition().getFilename(); + } + }; + } + else if (Resource.class == clazz || ByteArrayResource.class.isAssignableFrom(clazz)) { + byte[] body = StreamUtils.copyToByteArray(inputMessage.getBody()); + return new ByteArrayResource(body) { + @Override + @Nullable + public String getFilename() { + return inputMessage.getHeaders().getContentDisposition().getFilename(); + } + }; + } + else { + throw new HttpMessageNotReadableException("Unsupported resource class: " + clazz, inputMessage); + } + } + + @Override + protected MediaType getDefaultContentType(Resource resource) { + return MediaTypeFactory.getMediaType(resource).orElse(MediaType.APPLICATION_OCTET_STREAM); + } + + @Override + protected Long getContentLength(Resource resource, @Nullable MediaType contentType) throws IOException { + // Don't try to determine contentLength on InputStreamResource - cannot be read afterwards... + // Note: custom InputStreamResource subclasses could provide a pre-calculated content length! + if (InputStreamResource.class == resource.getClass()) { + return null; + } + long contentLength = resource.contentLength(); + return (contentLength < 0 ? null : contentLength); + } + + @Override + protected void writeInternal(Resource resource, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + writeContent(resource, outputMessage); + } + + protected void writeContent(Resource resource, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + try { + InputStream in = resource.getInputStream(); + try { + StreamUtils.copy(in, outputMessage.getBody()); + } + catch (NullPointerException ex) { + // ignore, see SPR-13620 + } + finally { + try { + in.close(); + } + catch (Throwable ex) { + // ignore, see SPR-12999 + } + } + } + catch (FileNotFoundException ex) { + // ignore, see SPR-12999 + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..be98db9122a19fb4b1fd0f0260d01703a02a2f2c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/ResourceRegionHttpMessageConverter.java @@ -0,0 +1,226 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.nio.charset.StandardCharsets; +import java.util.Collection; + +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourceRegion; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StreamUtils; + +/** + * Implementation of {@link HttpMessageConverter} that can write a single {@link ResourceRegion}, + * or Collections of {@link ResourceRegion ResourceRegions}. + * + * @author Brian Clozel + * @author Juergen Hoeller + * @since 4.3 + */ +public class ResourceRegionHttpMessageConverter extends AbstractGenericHttpMessageConverter { + + public ResourceRegionHttpMessageConverter() { + super(MediaType.ALL); + } + + + @Override + @SuppressWarnings("unchecked") + protected MediaType getDefaultContentType(Object object) { + Resource resource = null; + if (object instanceof ResourceRegion) { + resource = ((ResourceRegion) object).getResource(); + } + else { + Collection regions = (Collection) object; + if (!regions.isEmpty()) { + resource = regions.iterator().next().getResource(); + } + } + return MediaTypeFactory.getMediaType(resource).orElse(MediaType.APPLICATION_OCTET_STREAM); + } + + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return false; + } + + @Override + public boolean canRead(Type type, @Nullable Class contextClass, @Nullable MediaType mediaType) { + return false; + } + + @Override + public Object read(Type type, @Nullable Class contextClass, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + throw new UnsupportedOperationException(); + } + + @Override + protected ResourceRegion readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + throw new UnsupportedOperationException(); + } + + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return canWrite(clazz, null, mediaType); + } + + @Override + public boolean canWrite(@Nullable Type type, @Nullable Class clazz, @Nullable MediaType mediaType) { + if (!(type instanceof ParameterizedType)) { + return (type instanceof Class && ResourceRegion.class.isAssignableFrom((Class) type)); + } + + ParameterizedType parameterizedType = (ParameterizedType) type; + if (!(parameterizedType.getRawType() instanceof Class)) { + return false; + } + Class rawType = (Class) parameterizedType.getRawType(); + if (!(Collection.class.isAssignableFrom(rawType))) { + return false; + } + if (parameterizedType.getActualTypeArguments().length != 1) { + return false; + } + Type typeArgument = parameterizedType.getActualTypeArguments()[0]; + if (!(typeArgument instanceof Class)) { + return false; + } + + Class typeArgumentClass = (Class) typeArgument; + return ResourceRegion.class.isAssignableFrom(typeArgumentClass); + } + + @Override + @SuppressWarnings("unchecked") + protected void writeInternal(Object object, @Nullable Type type, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + if (object instanceof ResourceRegion) { + writeResourceRegion((ResourceRegion) object, outputMessage); + } + else { + Collection regions = (Collection) object; + if (regions.size() == 1) { + writeResourceRegion(regions.iterator().next(), outputMessage); + } + else { + writeResourceRegionCollection((Collection) object, outputMessage); + } + } + } + + + protected void writeResourceRegion(ResourceRegion region, HttpOutputMessage outputMessage) throws IOException { + Assert.notNull(region, "ResourceRegion must not be null"); + HttpHeaders responseHeaders = outputMessage.getHeaders(); + + long start = region.getPosition(); + long end = start + region.getCount() - 1; + Long resourceLength = region.getResource().contentLength(); + end = Math.min(end, resourceLength - 1); + long rangeLength = end - start + 1; + responseHeaders.add("Content-Range", "bytes " + start + '-' + end + '/' + resourceLength); + responseHeaders.setContentLength(rangeLength); + + InputStream in = region.getResource().getInputStream(); + try { + StreamUtils.copyRange(in, outputMessage.getBody(), start, end); + } + finally { + try { + in.close(); + } + catch (IOException ex) { + // ignore + } + } + } + + private void writeResourceRegionCollection(Collection resourceRegions, + HttpOutputMessage outputMessage) throws IOException { + + Assert.notNull(resourceRegions, "Collection of ResourceRegion should not be null"); + HttpHeaders responseHeaders = outputMessage.getHeaders(); + + MediaType contentType = responseHeaders.getContentType(); + String boundaryString = MimeTypeUtils.generateMultipartBoundaryString(); + responseHeaders.set(HttpHeaders.CONTENT_TYPE, "multipart/byteranges; boundary=" + boundaryString); + OutputStream out = outputMessage.getBody(); + + for (ResourceRegion region : resourceRegions) { + long start = region.getPosition(); + long end = start + region.getCount() - 1; + InputStream in = region.getResource().getInputStream(); + try { + // Writing MIME header. + println(out); + print(out, "--" + boundaryString); + println(out); + if (contentType != null) { + print(out, "Content-Type: " + contentType.toString()); + println(out); + } + Long resourceLength = region.getResource().contentLength(); + end = Math.min(end, resourceLength - 1); + print(out, "Content-Range: bytes " + start + '-' + end + '/' + resourceLength); + println(out); + println(out); + // Printing content + StreamUtils.copyRange(in, out, start, end); + } + finally { + try { + in.close(); + } + catch (IOException ex) { + // ignore + } + } + } + + println(out); + print(out, "--" + boundaryString + "--"); + } + + private static void println(OutputStream os) throws IOException { + os.write('\r'); + os.write('\n'); + } + + private static void print(OutputStream os, String buf) throws IOException { + os.write(buf.getBytes(StandardCharsets.US_ASCII)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/StringHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/StringHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..6f393b6235b9d58e9b661e32765530095000cef8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/StringHttpMessageConverter.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; + +/** + * Implementation of {@link HttpMessageConverter} that can read and write strings. + * + *

By default, this converter supports all media types ({@code */*}), + * and writes with a {@code Content-Type} of {@code text/plain}. This can be overridden + * by setting the {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + */ +public class StringHttpMessageConverter extends AbstractHttpMessageConverter { + + /** + * The default charset used by the converter. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.ISO_8859_1; + + + @Nullable + private volatile List availableCharsets; + + private boolean writeAcceptCharset = true; + + + /** + * A default constructor that uses {@code "ISO-8859-1"} as the default charset. + * @see #StringHttpMessageConverter(Charset) + */ + public StringHttpMessageConverter() { + this(DEFAULT_CHARSET); + } + + /** + * A constructor accepting a default charset to use if the requested content + * type does not specify one. + */ + public StringHttpMessageConverter(Charset defaultCharset) { + super(defaultCharset, MediaType.TEXT_PLAIN, MediaType.ALL); + } + + + /** + * Whether the {@code Accept-Charset} header should be written to any outgoing + * request sourced from the value of {@link Charset#availableCharsets()}. + * The behavior is suppressed if the header has already been set. + *

Default is {@code true}. + */ + public void setWriteAcceptCharset(boolean writeAcceptCharset) { + this.writeAcceptCharset = writeAcceptCharset; + } + + + @Override + public boolean supports(Class clazz) { + return String.class == clazz; + } + + @Override + protected String readInternal(Class clazz, HttpInputMessage inputMessage) throws IOException { + Charset charset = getContentTypeCharset(inputMessage.getHeaders().getContentType()); + return StreamUtils.copyToString(inputMessage.getBody(), charset); + } + + @Override + protected Long getContentLength(String str, @Nullable MediaType contentType) { + Charset charset = getContentTypeCharset(contentType); + return (long) str.getBytes(charset).length; + } + + @Override + protected void writeInternal(String str, HttpOutputMessage outputMessage) throws IOException { + HttpHeaders headers = outputMessage.getHeaders(); + if (this.writeAcceptCharset && headers.get(HttpHeaders.ACCEPT_CHARSET) == null) { + headers.setAcceptCharset(getAcceptedCharsets()); + } + Charset charset = getContentTypeCharset(headers.getContentType()); + StreamUtils.copy(str, charset, outputMessage.getBody()); + } + + + /** + * Return the list of supported {@link Charset Charsets}. + *

By default, returns {@link Charset#availableCharsets()}. + * Can be overridden in subclasses. + * @return the list of accepted charsets + */ + protected List getAcceptedCharsets() { + List charsets = this.availableCharsets; + if (charsets == null) { + charsets = new ArrayList<>(Charset.availableCharsets().values()); + this.availableCharsets = charsets; + } + return charsets; + } + + private Charset getContentTypeCharset(@Nullable MediaType contentType) { + if (contentType != null && contentType.getCharset() != null) { + return contentType.getCharset(); + } + else if (contentType != null && contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + // Matching to AbstractJackson2HttpMessageConverter#DEFAULT_CHARSET + return StandardCharsets.UTF_8; + } + else { + Charset charset = getDefaultCharset(); + Assert.state(charset != null, "No default charset"); + return charset; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/cbor/MappingJackson2CborHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/cbor/MappingJackson2CborHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..0058abbe49053cb57a4f30ceb44a2367073f74a7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/cbor/MappingJackson2CborHttpMessageConverter.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.cbor; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.cbor.CBORFactory; + +import org.springframework.http.MediaType; +import org.springframework.http.converter.json.AbstractJackson2HttpMessageConverter; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.util.Assert; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter HttpMessageConverter} + * that can read and write CBOR data format using + * + * the dedicated Jackson 2.x extension. + * + *

By default, this converter supports {@code "application/cbor"} media type. This can be + * overridden by setting the {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + *

The default constructor uses the default configuration provided by {@link Jackson2ObjectMapperBuilder}. + * + *

Compatible with Jackson 2.9 and higher. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public class MappingJackson2CborHttpMessageConverter extends AbstractJackson2HttpMessageConverter { + + /** + * Construct a new {@code MappingJackson2CborHttpMessageConverter} using default configuration + * provided by {@code Jackson2ObjectMapperBuilder}. + */ + public MappingJackson2CborHttpMessageConverter() { + this(Jackson2ObjectMapperBuilder.cbor().build()); + } + + /** + * Construct a new {@code MappingJackson2CborHttpMessageConverter} with a custom {@link ObjectMapper} + * (must be configured with a {@code CBORFactory} instance). + * You can use {@link Jackson2ObjectMapperBuilder} to build it easily. + * @see Jackson2ObjectMapperBuilder#cbor() + */ + public MappingJackson2CborHttpMessageConverter(ObjectMapper objectMapper) { + super(objectMapper, new MediaType("application", "cbor")); + Assert.isInstanceOf(CBORFactory.class, objectMapper.getFactory(), "CBORFactory required"); + } + + + /** + * {@inheritDoc} + * The {@code ObjectMapper} must be configured with a {@code CBORFactory} instance. + */ + @Override + public void setObjectMapper(ObjectMapper objectMapper) { + Assert.isInstanceOf(CBORFactory.class, objectMapper.getFactory(), "CBORFactory required"); + super.setObjectMapper(objectMapper); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/cbor/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/cbor/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..c8ffa97818010793c758a4dd97a90bcd338277b3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/cbor/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides an HttpMessageConverter for the CBOR data format. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter.cbor; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/feed/AbstractWireFeedHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/feed/AbstractWireFeedHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..bd2bd774e8b2a8f054e1066f1376736a3cf87234 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/feed/AbstractWireFeedHttpMessageConverter.java @@ -0,0 +1,107 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.feed; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.Reader; +import java.io.Writer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import com.rometools.rome.feed.WireFeed; +import com.rometools.rome.io.FeedException; +import com.rometools.rome.io.WireFeedInput; +import com.rometools.rome.io.WireFeedOutput; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.util.StringUtils; + +/** + * Abstract base class for Atom and RSS Feed message converters, using the + * ROME tools project. + * + *

NOTE: As of Spring 4.1, this is based on the {@code com.rometools} + * variant of ROME, version 1.5. Please upgrade your build dependency. + * + * @author Arjen Poutsma + * @since 3.0.2 + * @param the converted object type + * @see AtomFeedHttpMessageConverter + * @see RssChannelHttpMessageConverter + */ +public abstract class AbstractWireFeedHttpMessageConverter + extends AbstractHttpMessageConverter { + + /** + * The default charset used by the converter. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + + protected AbstractWireFeedHttpMessageConverter(MediaType supportedMediaType) { + super(supportedMediaType); + } + + + @Override + @SuppressWarnings("unchecked") + protected T readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + WireFeedInput feedInput = new WireFeedInput(); + MediaType contentType = inputMessage.getHeaders().getContentType(); + Charset charset = (contentType != null && contentType.getCharset() != null ? + contentType.getCharset() : DEFAULT_CHARSET); + try { + Reader reader = new InputStreamReader(inputMessage.getBody(), charset); + return (T) feedInput.build(reader); + } + catch (FeedException ex) { + throw new HttpMessageNotReadableException("Could not read WireFeed: " + ex.getMessage(), ex, inputMessage); + } + } + + @Override + protected void writeInternal(T wireFeed, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + Charset charset = (StringUtils.hasLength(wireFeed.getEncoding()) ? + Charset.forName(wireFeed.getEncoding()) : DEFAULT_CHARSET); + MediaType contentType = outputMessage.getHeaders().getContentType(); + if (contentType != null) { + contentType = new MediaType(contentType.getType(), contentType.getSubtype(), charset); + outputMessage.getHeaders().setContentType(contentType); + } + + WireFeedOutput feedOutput = new WireFeedOutput(); + try { + Writer writer = new OutputStreamWriter(outputMessage.getBody(), charset); + feedOutput.output(wireFeed, writer); + } + catch (FeedException ex) { + throw new HttpMessageNotWritableException("Could not write WireFeed: " + ex.getMessage(), ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/feed/AtomFeedHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/feed/AtomFeedHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..88a074752a4e2895dcec43198c3624ae83bdd76c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/feed/AtomFeedHttpMessageConverter.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.feed; + +import com.rometools.rome.feed.atom.Feed; + +import org.springframework.http.MediaType; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter} + * that can read and write Atom feeds. Specifically, this converter can handle {@link Feed} + * objects from the ROME project. + * + *

>NOTE: As of Spring 4.1, this is based on the {@code com.rometools} + * variant of ROME, version 1.5. Please upgrade your build dependency. + * + *

By default, this converter reads and writes the media type ({@code application/atom+xml}). + * This can be overridden through the {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + * @author Arjen Poutsma + * @since 3.0.2 + * @see Feed + */ +public class AtomFeedHttpMessageConverter extends AbstractWireFeedHttpMessageConverter { + + public AtomFeedHttpMessageConverter() { + super(new MediaType("application", "atom+xml")); + } + + @Override + protected boolean supports(Class clazz) { + return Feed.class.isAssignableFrom(clazz); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/feed/RssChannelHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/feed/RssChannelHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..0ecf87a703ea509dbdbbf3d5504d857009b6aaef --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/feed/RssChannelHttpMessageConverter.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.feed; + +import com.rometools.rome.feed.rss.Channel; + +import org.springframework.http.MediaType; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter} + * that can read and write RSS feeds. Specifically, this converter can handle {@link Channel} + * objects from the ROME project. + * + *

>NOTE: As of Spring 4.1, this is based on the {@code com.rometools} + * variant of ROME, version 1.5. Please upgrade your build dependency. + * + *

By default, this converter reads and writes the media type ({@code application/rss+xml}). + * This can be overridden through the {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + * @author Arjen Poutsma + * @since 3.0.2 + * @see Channel + */ +public class RssChannelHttpMessageConverter extends AbstractWireFeedHttpMessageConverter { + + public RssChannelHttpMessageConverter() { + super(MediaType.APPLICATION_RSS_XML); + } + + @Override + protected boolean supports(Class clazz) { + return Channel.class.isAssignableFrom(clazz); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/feed/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/feed/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..74d0e0656a3513a2b1f2416501672cd096b12cb5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/feed/package-info.java @@ -0,0 +1,10 @@ +/** + * Provides HttpMessageConverter implementations for handling Atom and RSS feeds. + * Based on the ROME tools project. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter.feed; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJackson2HttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJackson2HttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..465db9e4161a4c58edfd9f45b14638818f41fb13 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJackson2HttpMessageConverter.java @@ -0,0 +1,414 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.lang.reflect.Type; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.PrettyPrinter; +import com.fasterxml.jackson.core.util.DefaultIndenter; +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; +import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.SerializationConfig; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.exc.InvalidDefinitionException; +import com.fasterxml.jackson.databind.ser.FilterProvider; + +import org.springframework.core.GenericTypeResolver; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.AbstractGenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.TypeUtils; + +/** + * Abstract base class for Jackson based and content type independent + * {@link HttpMessageConverter} implementations. + * + *

Compatible with Jackson 2.9 and higher, as of Spring 5.0. + * + * @author Arjen Poutsma + * @author Keith Donald + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 4.1 + * @see MappingJackson2HttpMessageConverter + */ +public abstract class AbstractJackson2HttpMessageConverter extends AbstractGenericHttpMessageConverter { + + private static final Map ENCODINGS; + + static { + ENCODINGS = new HashMap<>(JsonEncoding.values().length + 1); + for (JsonEncoding encoding : JsonEncoding.values()) { + ENCODINGS.put(encoding.getJavaName(), encoding); + } + ENCODINGS.put("US-ASCII", JsonEncoding.UTF8); + } + + + /** + * The default charset used by the converter. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + + protected ObjectMapper objectMapper; + + @Nullable + private Boolean prettyPrint; + + @Nullable + private PrettyPrinter ssePrettyPrinter; + + + protected AbstractJackson2HttpMessageConverter(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + setDefaultCharset(DEFAULT_CHARSET); + DefaultPrettyPrinter prettyPrinter = new DefaultPrettyPrinter(); + prettyPrinter.indentObjectsWith(new DefaultIndenter(" ", "\ndata:")); + this.ssePrettyPrinter = prettyPrinter; + } + + protected AbstractJackson2HttpMessageConverter(ObjectMapper objectMapper, MediaType supportedMediaType) { + this(objectMapper); + setSupportedMediaTypes(Collections.singletonList(supportedMediaType)); + } + + protected AbstractJackson2HttpMessageConverter(ObjectMapper objectMapper, MediaType... supportedMediaTypes) { + this(objectMapper); + setSupportedMediaTypes(Arrays.asList(supportedMediaTypes)); + } + + + /** + * Set the {@code ObjectMapper} for this view. + * If not set, a default {@link ObjectMapper#ObjectMapper() ObjectMapper} is used. + *

Setting a custom-configured {@code ObjectMapper} is one way to take further + * control of the JSON serialization process. For example, an extended + * {@link com.fasterxml.jackson.databind.ser.SerializerFactory} + * can be configured that provides custom serializers for specific types. + * The other option for refining the serialization process is to use Jackson's + * provided annotations on the types to be serialized, in which case a + * custom-configured ObjectMapper is unnecessary. + */ + public void setObjectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + configurePrettyPrint(); + } + + /** + * Return the underlying {@code ObjectMapper} for this view. + */ + public ObjectMapper getObjectMapper() { + return this.objectMapper; + } + + /** + * Whether to use the {@link DefaultPrettyPrinter} when writing JSON. + * This is a shortcut for setting up an {@code ObjectMapper} as follows: + *

+	 * ObjectMapper mapper = new ObjectMapper();
+	 * mapper.configure(SerializationFeature.INDENT_OUTPUT, true);
+	 * converter.setObjectMapper(mapper);
+	 * 
+ */ + public void setPrettyPrint(boolean prettyPrint) { + this.prettyPrint = prettyPrint; + configurePrettyPrint(); + } + + private void configurePrettyPrint() { + if (this.prettyPrint != null) { + this.objectMapper.configure(SerializationFeature.INDENT_OUTPUT, this.prettyPrint); + } + } + + + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return canRead(clazz, null, mediaType); + } + + @Override + public boolean canRead(Type type, @Nullable Class contextClass, @Nullable MediaType mediaType) { + if (!canRead(mediaType)) { + return false; + } + JavaType javaType = getJavaType(type, contextClass); + AtomicReference causeRef = new AtomicReference<>(); + if (this.objectMapper.canDeserialize(javaType, causeRef)) { + return true; + } + logWarningIfNecessary(javaType, causeRef.get()); + return false; + } + + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + if (!canWrite(mediaType)) { + return false; + } + if (mediaType != null && mediaType.getCharset() != null) { + Charset charset = mediaType.getCharset(); + if (!ENCODINGS.containsKey(charset.name())) { + return false; + } + } + AtomicReference causeRef = new AtomicReference<>(); + if (this.objectMapper.canSerialize(clazz, causeRef)) { + return true; + } + logWarningIfNecessary(clazz, causeRef.get()); + return false; + } + + /** + * Determine whether to log the given exception coming from a + * {@link ObjectMapper#canDeserialize} / {@link ObjectMapper#canSerialize} check. + * @param type the class that Jackson tested for (de-)serializability + * @param cause the Jackson-thrown exception to evaluate + * (typically a {@link JsonMappingException}) + * @since 4.3 + */ + protected void logWarningIfNecessary(Type type, @Nullable Throwable cause) { + if (cause == null) { + return; + } + + // Do not log warning for serializer not found (note: different message wording on Jackson 2.9) + boolean debugLevel = (cause instanceof JsonMappingException && cause.getMessage().startsWith("Cannot find")); + + if (debugLevel ? logger.isDebugEnabled() : logger.isWarnEnabled()) { + String msg = "Failed to evaluate Jackson " + (type instanceof JavaType ? "de" : "") + + "serialization for type [" + type + "]"; + if (debugLevel) { + logger.debug(msg, cause); + } + else if (logger.isDebugEnabled()) { + logger.warn(msg, cause); + } + else { + logger.warn(msg + ": " + cause); + } + } + } + + @Override + public Object read(Type type, @Nullable Class contextClass, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + JavaType javaType = getJavaType(type, contextClass); + return readJavaType(javaType, inputMessage); + } + + @Override + protected Object readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + JavaType javaType = getJavaType(clazz, null); + return readJavaType(javaType, inputMessage); + } + + private Object readJavaType(JavaType javaType, HttpInputMessage inputMessage) throws IOException { + MediaType contentType = inputMessage.getHeaders().getContentType(); + Charset charset = getCharset(contentType); + + boolean isUnicode = ENCODINGS.containsKey(charset.name()); + try { + if (inputMessage instanceof MappingJacksonInputMessage) { + Class deserializationView = ((MappingJacksonInputMessage) inputMessage).getDeserializationView(); + if (deserializationView != null) { + ObjectReader objectReader = this.objectMapper.readerWithView(deserializationView).forType(javaType); + if (isUnicode) { + return objectReader.readValue(inputMessage.getBody()); + } + else { + Reader reader = new InputStreamReader(inputMessage.getBody(), charset); + return objectReader.readValue(reader); + } + } + } + if (isUnicode) { + return this.objectMapper.readValue(inputMessage.getBody(), javaType); + } + else { + Reader reader = new InputStreamReader(inputMessage.getBody(), charset); + return this.objectMapper.readValue(reader, javaType); + } + } + catch (InvalidDefinitionException ex) { + throw new HttpMessageConversionException("Type definition error: " + ex.getType(), ex); + } + catch (JsonProcessingException ex) { + throw new HttpMessageNotReadableException("JSON parse error: " + ex.getOriginalMessage(), ex, inputMessage); + } + } + + /** + * Determine the charset to use for JSON input. + *

By default this is either the charset from the input {@code MediaType} + * or otherwise falling back on {@code UTF-8}. Can be overridden in subclasses. + * @param contentType the content type of the HTTP input message + * @return the charset to use + * @since 5.1.18 + */ + protected Charset getCharset(@Nullable MediaType contentType) { + if (contentType != null && contentType.getCharset() != null) { + return contentType.getCharset(); + } + else { + return StandardCharsets.UTF_8; + } + } + + @Override + protected void writeInternal(Object object, @Nullable Type type, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + MediaType contentType = outputMessage.getHeaders().getContentType(); + JsonEncoding encoding = getJsonEncoding(contentType); + JsonGenerator generator = this.objectMapper.getFactory().createGenerator(outputMessage.getBody(), encoding); + try { + writePrefix(generator, object); + + Object value = object; + Class serializationView = null; + FilterProvider filters = null; + JavaType javaType = null; + + if (object instanceof MappingJacksonValue) { + MappingJacksonValue container = (MappingJacksonValue) object; + value = container.getValue(); + serializationView = container.getSerializationView(); + filters = container.getFilters(); + } + if (type != null && TypeUtils.isAssignable(type, value.getClass())) { + javaType = getJavaType(type, null); + } + + ObjectWriter objectWriter = (serializationView != null ? + this.objectMapper.writerWithView(serializationView) : this.objectMapper.writer()); + if (filters != null) { + objectWriter = objectWriter.with(filters); + } + if (javaType != null && javaType.isContainerType()) { + objectWriter = objectWriter.forType(javaType); + } + SerializationConfig config = objectWriter.getConfig(); + if (contentType != null && contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM) && + config.isEnabled(SerializationFeature.INDENT_OUTPUT)) { + objectWriter = objectWriter.with(this.ssePrettyPrinter); + } + objectWriter.writeValue(generator, value); + + writeSuffix(generator, object); + generator.flush(); + } + catch (InvalidDefinitionException ex) { + throw new HttpMessageConversionException("Type definition error: " + ex.getType(), ex); + } + catch (JsonProcessingException ex) { + throw new HttpMessageNotWritableException("Could not write JSON: " + ex.getOriginalMessage(), ex); + } + } + + /** + * Write a prefix before the main content. + * @param generator the generator to use for writing content. + * @param object the object to write to the output message. + */ + protected void writePrefix(JsonGenerator generator, Object object) throws IOException { + } + + /** + * Write a suffix after the main content. + * @param generator the generator to use for writing content. + * @param object the object to write to the output message. + */ + protected void writeSuffix(JsonGenerator generator, Object object) throws IOException { + } + + /** + * Return the Jackson {@link JavaType} for the specified type and context class. + * @param type the generic type to return the Jackson JavaType for + * @param contextClass a context class for the target type, for example a class + * in which the target type appears in a method signature (can be {@code null}) + * @return the Jackson JavaType + */ + protected JavaType getJavaType(Type type, @Nullable Class contextClass) { + return this.objectMapper.constructType(GenericTypeResolver.resolveType(type, contextClass)); + } + + /** + * Determine the JSON encoding to use for the given content type. + * @param contentType the media type as requested by the caller + * @return the JSON encoding to use (never {@code null}) + */ + protected JsonEncoding getJsonEncoding(@Nullable MediaType contentType) { + if (contentType != null && contentType.getCharset() != null) { + Charset charset = contentType.getCharset(); + JsonEncoding encoding = ENCODINGS.get(charset.name()); + if (encoding != null) { + return encoding; + } + } + return JsonEncoding.UTF8; + } + + @Override + @Nullable + protected MediaType getDefaultContentType(Object object) throws IOException { + if (object instanceof MappingJacksonValue) { + object = ((MappingJacksonValue) object).getValue(); + } + return super.getDefaultContentType(object); + } + + @Override + protected Long getContentLength(Object object, @Nullable MediaType contentType) throws IOException { + if (object instanceof MappingJacksonValue) { + object = ((MappingJacksonValue) object).getValue(); + } + return super.getContentLength(object, contentType); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJsonHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJsonHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..4584b42913e7b646cc3cc66026f7c01edc76d574 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/AbstractJsonHttpMessageConverter.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.Reader; +import java.io.Writer; +import java.lang.reflect.Type; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.springframework.core.GenericTypeResolver; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.AbstractGenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.lang.Nullable; + +/** + * Common base class for plain JSON converters, e.g. Gson and JSON-B. + * + *

Note that the Jackson converters have a dedicated class hierarchy + * due to their multi-format support. + * + * @author Juergen Hoeller + * @since 5.0 + * @see GsonHttpMessageConverter + * @see JsonbHttpMessageConverter + * @see #readInternal(Type, Reader) + * @see #writeInternal(Object, Type, Writer) + */ +public abstract class AbstractJsonHttpMessageConverter extends AbstractGenericHttpMessageConverter { + + /** + * The default charset used by the converter. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + @Nullable + private String jsonPrefix; + + + public AbstractJsonHttpMessageConverter() { + super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); + setDefaultCharset(DEFAULT_CHARSET); + } + + + /** + * Specify a custom prefix to use for JSON output. Default is none. + * @see #setPrefixJson + */ + public void setJsonPrefix(String jsonPrefix) { + this.jsonPrefix = jsonPrefix; + } + + /** + * Indicate whether the JSON output by this view should be prefixed with ")]}', ". + * Default is {@code false}. + *

Prefixing the JSON string in this manner is used to help prevent JSON + * Hijacking. The prefix renders the string syntactically invalid as a script + * so that it cannot be hijacked. + * This prefix should be stripped before parsing the string as JSON. + * @see #setJsonPrefix + */ + public void setPrefixJson(boolean prefixJson) { + this.jsonPrefix = (prefixJson ? ")]}', " : null); + } + + + @Override + public final Object read(Type type, @Nullable Class contextClass, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + return readResolved(GenericTypeResolver.resolveType(type, contextClass), inputMessage); + } + + @Override + protected final Object readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + return readResolved(clazz, inputMessage); + } + + private Object readResolved(Type resolvedType, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + Reader reader = getReader(inputMessage); + try { + return readInternal(resolvedType, reader); + } + catch (Exception ex) { + throw new HttpMessageNotReadableException("Could not read JSON: " + ex.getMessage(), ex, inputMessage); + } + } + + @Override + protected final void writeInternal(Object object, @Nullable Type type, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + Writer writer = getWriter(outputMessage); + if (this.jsonPrefix != null) { + writer.append(this.jsonPrefix); + } + try { + writeInternal(object, type, writer); + } + catch (Exception ex) { + throw new HttpMessageNotWritableException("Could not write JSON: " + ex.getMessage(), ex); + } + writer.flush(); + } + + + /** + * Template method that reads the JSON-bound object from the given {@link Reader}. + * @param resolvedType the resolved generic type + * @param reader the {@code} Reader to use + * @return the JSON-bound object + * @throws Exception in case of read/parse failures + */ + protected abstract Object readInternal(Type resolvedType, Reader reader) throws Exception; + + /** + * Template method that writes the JSON-bound object to the given {@link Writer}. + * @param object the object to write to the output message + * @param type the type of object to write (may be {@code null}) + * @param writer the {@code} Writer to use + * @throws Exception in case of write failures + */ + protected abstract void writeInternal(Object object, @Nullable Type type, Writer writer) throws Exception; + + + private static Reader getReader(HttpInputMessage inputMessage) throws IOException { + return new InputStreamReader(inputMessage.getBody(), getCharset(inputMessage.getHeaders())); + } + + private static Writer getWriter(HttpOutputMessage outputMessage) throws IOException { + return new OutputStreamWriter(outputMessage.getBody(), getCharset(outputMessage.getHeaders())); + } + + private static Charset getCharset(HttpHeaders headers) { + Charset charset = (headers.getContentType() != null ? headers.getContentType().getCharset() : null); + return (charset != null ? charset : DEFAULT_CHARSET); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/GsonBuilderUtils.java b/spring-web/src/main/java/org/springframework/http/converter/json/GsonBuilderUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..4d098ad5877b06afb1151a6bd8650034a18e1e93 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/GsonBuilderUtils.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.lang.reflect.Type; + +import com.google.gson.GsonBuilder; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonPrimitive; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; + +import org.springframework.util.Base64Utils; + +/** + * A simple utility class for obtaining a Google Gson 2.x {@link GsonBuilder} + * which Base64-encodes {@code byte[]} properties when reading and writing JSON. + * + * @author Juergen Hoeller + * @author Roy Clarkson + * @since 4.1 + * @see GsonFactoryBean#setBase64EncodeByteArrays + * @see org.springframework.util.Base64Utils + */ +public abstract class GsonBuilderUtils { + + /** + * Obtain a {@link GsonBuilder} which Base64-encodes {@code byte[]} + * properties when reading and writing JSON. + *

A custom {@link com.google.gson.TypeAdapter} will be registered via + * {@link GsonBuilder#registerTypeHierarchyAdapter(Class, Object)} which + * serializes a {@code byte[]} property to and from a Base64-encoded String + * instead of a JSON array. + */ + public static GsonBuilder gsonBuilderWithBase64EncodedByteArrays() { + GsonBuilder builder = new GsonBuilder(); + builder.registerTypeHierarchyAdapter(byte[].class, new Base64TypeAdapter()); + return builder; + } + + + private static class Base64TypeAdapter implements JsonSerializer, JsonDeserializer { + + @Override + public JsonElement serialize(byte[] src, Type typeOfSrc, JsonSerializationContext context) { + return new JsonPrimitive(Base64Utils.encodeToString(src)); + } + + @Override + public byte[] deserialize(JsonElement json, Type type, JsonDeserializationContext cxt) { + return Base64Utils.decodeFromString(json.getAsString()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/GsonFactoryBean.java b/spring-web/src/main/java/org/springframework/http/converter/json/GsonFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..be63917f52a5aac1951c3a268ac08bd595b411f1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/GsonFactoryBean.java @@ -0,0 +1,150 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.text.SimpleDateFormat; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; + +/** + * A {@link FactoryBean} for creating a Google Gson 2.x {@link Gson} instance. + * + * @author Roy Clarkson + * @author Juergen Hoeller + * @since 4.1 + */ +public class GsonFactoryBean implements FactoryBean, InitializingBean { + + private boolean base64EncodeByteArrays = false; + + private boolean serializeNulls = false; + + private boolean prettyPrinting = false; + + private boolean disableHtmlEscaping = false; + + @Nullable + private String dateFormatPattern; + + @Nullable + private Gson gson; + + + /** + * Whether to Base64-encode {@code byte[]} properties when reading and + * writing JSON. + *

When set to {@code true}, a custom {@link com.google.gson.TypeAdapter} will be + * registered via {@link GsonBuilder#registerTypeHierarchyAdapter(Class, Object)} + * which serializes a {@code byte[]} property to and from a Base64-encoded String + * instead of a JSON array. + * @see GsonBuilderUtils#gsonBuilderWithBase64EncodedByteArrays() + */ + public void setBase64EncodeByteArrays(boolean base64EncodeByteArrays) { + this.base64EncodeByteArrays = base64EncodeByteArrays; + } + + /** + * Whether to use the {@link GsonBuilder#serializeNulls()} option when writing + * JSON. This is a shortcut for setting up a {@code Gson} as follows: + *

+	 * new GsonBuilder().serializeNulls().create();
+	 * 
+ */ + public void setSerializeNulls(boolean serializeNulls) { + this.serializeNulls = serializeNulls; + } + + /** + * Whether to use the {@link GsonBuilder#setPrettyPrinting()} when writing + * JSON. This is a shortcut for setting up a {@code Gson} as follows: + *
+	 * new GsonBuilder().setPrettyPrinting().create();
+	 * 
+ */ + public void setPrettyPrinting(boolean prettyPrinting) { + this.prettyPrinting = prettyPrinting; + } + + /** + * Whether to use the {@link GsonBuilder#disableHtmlEscaping()} when writing + * JSON. Set to {@code true} to disable HTML escaping in JSON. This is a + * shortcut for setting up a {@code Gson} as follows: + *
+	 * new GsonBuilder().disableHtmlEscaping().create();
+	 * 
+ */ + public void setDisableHtmlEscaping(boolean disableHtmlEscaping) { + this.disableHtmlEscaping = disableHtmlEscaping; + } + + /** + * Define the date/time format with a {@link SimpleDateFormat}-style pattern. + * This is a shortcut for setting up a {@code Gson} as follows: + *
+	 * new GsonBuilder().setDateFormat(dateFormatPattern).create();
+	 * 
+ */ + public void setDateFormatPattern(String dateFormatPattern) { + this.dateFormatPattern = dateFormatPattern; + } + + + @Override + public void afterPropertiesSet() { + GsonBuilder builder = (this.base64EncodeByteArrays ? + GsonBuilderUtils.gsonBuilderWithBase64EncodedByteArrays() : new GsonBuilder()); + if (this.serializeNulls) { + builder.serializeNulls(); + } + if (this.prettyPrinting) { + builder.setPrettyPrinting(); + } + if (this.disableHtmlEscaping) { + builder.disableHtmlEscaping(); + } + if (this.dateFormatPattern != null) { + builder.setDateFormat(this.dateFormatPattern); + } + this.gson = builder.create(); + } + + + /** + * Return the created Gson instance. + */ + @Override + @Nullable + public Gson getObject() { + return this.gson; + } + + @Override + public Class getObjectType() { + return Gson.class; + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/GsonHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/GsonHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..8c363b4eae1d918cc9f582c84c49eb0b9538d7cb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/GsonHttpMessageConverter.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.Reader; +import java.io.Writer; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +import com.google.gson.Gson; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter} + * that can read and write JSON using the + * Google Gson library. + * + *

This converter can be used to bind to typed beans or untyped {@code HashMap}s. + * By default, it supports {@code application/json} and {@code application/*+json} with + * {@code UTF-8} character set. + * + *

Tested against Gson 2.8; compatible with Gson 2.0 and higher. + * + * @author Roy Clarkson + * @author Juergen Hoeller + * @since 4.1 + * @see com.google.gson.Gson + * @see com.google.gson.GsonBuilder + * @see #setGson + */ +public class GsonHttpMessageConverter extends AbstractJsonHttpMessageConverter { + + private Gson gson; + + + /** + * Construct a new {@code GsonHttpMessageConverter} with default configuration. + */ + public GsonHttpMessageConverter() { + this.gson = new Gson(); + } + + /** + * Construct a new {@code GsonHttpMessageConverter} with the given delegate. + * @param gson the Gson instance to use + * @since 5.0 + */ + public GsonHttpMessageConverter(Gson gson) { + Assert.notNull(gson, "A Gson instance is required"); + this.gson = gson; + } + + + /** + * Set the {@code Gson} instance to use. + * If not set, a default {@link Gson#Gson() Gson} instance will be used. + *

Setting a custom-configured {@code Gson} is one way to take further + * control of the JSON serialization process. + * @see #GsonHttpMessageConverter(Gson) + */ + public void setGson(Gson gson) { + Assert.notNull(gson, "A Gson instance is required"); + this.gson = gson; + } + + /** + * Return the configured {@code Gson} instance for this converter. + */ + public Gson getGson() { + return this.gson; + } + + + @Override + protected Object readInternal(Type resolvedType, Reader reader) throws Exception { + return getGson().fromJson(reader, resolvedType); + } + + @Override + protected void writeInternal(Object object, @Nullable Type type, Writer writer) throws Exception { + // In Gson, toJson with a type argument will exclusively use that given type, + // ignoring the actual type of the object... which might be more specific, + // e.g. a subclass of the specified type which includes additional fields. + // As a consequence, we're only passing in parameterized type declarations + // which might contain extra generics that the object instance doesn't retain. + if (type instanceof ParameterizedType) { + getGson().toJson(object, type, writer); + } + else { + getGson().toJson(object, writer); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/Jackson2ObjectMapperBuilder.java b/spring-web/src/main/java/org/springframework/http/converter/json/Jackson2ObjectMapperBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..4709aef1dbe53066278d1f0676bd4560f2844eca --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/Jackson2ObjectMapperBuilder.java @@ -0,0 +1,901 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; + +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonFilter; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.PropertyAccessor; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.AnnotationIntrospector; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.KeyDeserializer; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.Module; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.cfg.HandlerInstantiator; +import com.fasterxml.jackson.databind.jsontype.TypeResolverBuilder; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.ser.FilterProvider; +import com.fasterxml.jackson.dataformat.cbor.CBORFactory; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import com.fasterxml.jackson.dataformat.xml.JacksonXmlModule; +import com.fasterxml.jackson.dataformat.xml.XmlFactory; +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import org.apache.commons.logging.Log; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.FatalBeanException; +import org.springframework.context.ApplicationContext; +import org.springframework.core.KotlinDetector; +import org.springframework.http.HttpLogging; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.util.xml.StaxUtils; + +/** + * A builder used to create {@link ObjectMapper} instances with a fluent API. + * + *

It customizes Jackson's default properties with the following ones: + *

    + *
  • {@link MapperFeature#DEFAULT_VIEW_INCLUSION} is disabled
  • + *
  • {@link DeserializationFeature#FAIL_ON_UNKNOWN_PROPERTIES} is disabled
  • + *
+ * + *

It also automatically registers the following well-known modules if they are + * detected on the classpath: + *

+ * + *

Compatible with Jackson 2.6 and higher, as of Spring 4.3. + * + * @author Sebastien Deleuze + * @author Juergen Hoeller + * @author Tadaya Tsuyukubo + * @author Eddú Meléndez + * @since 4.1.1 + * @see #build() + * @see #configure(ObjectMapper) + * @see Jackson2ObjectMapperFactoryBean + */ +public class Jackson2ObjectMapperBuilder { + + private static volatile boolean kotlinWarningLogged = false; + + private final Log logger = HttpLogging.forLogName(getClass()); + + private final Map, Class> mixIns = new LinkedHashMap<>(); + + private final Map, JsonSerializer> serializers = new LinkedHashMap<>(); + + private final Map, JsonDeserializer> deserializers = new LinkedHashMap<>(); + + private final Map visibilities = new LinkedHashMap<>(); + + private final Map features = new LinkedHashMap<>(); + + private boolean createXmlMapper = false; + + @Nullable + private JsonFactory factory; + + @Nullable + private DateFormat dateFormat; + + @Nullable + private Locale locale; + + @Nullable + private TimeZone timeZone; + + @Nullable + private AnnotationIntrospector annotationIntrospector; + + @Nullable + private PropertyNamingStrategy propertyNamingStrategy; + + @Nullable + private TypeResolverBuilder defaultTyping; + + @Nullable + private JsonInclude.Include serializationInclusion; + + @Nullable + private FilterProvider filters; + + @Nullable + private List modules; + + @Nullable + private Class[] moduleClasses; + + private boolean findModulesViaServiceLoader = false; + + private boolean findWellKnownModules = true; + + private ClassLoader moduleClassLoader = getClass().getClassLoader(); + + @Nullable + private HandlerInstantiator handlerInstantiator; + + @Nullable + private ApplicationContext applicationContext; + + @Nullable + private Boolean defaultUseWrapper; + + + /** + * If set to {@code true}, an {@link XmlMapper} will be created using its + * default constructor. This is only applicable to {@link #build()} calls, + * not to {@link #configure} calls. + */ + public Jackson2ObjectMapperBuilder createXmlMapper(boolean createXmlMapper) { + this.createXmlMapper = createXmlMapper; + return this; + } + + /** + * Define the {@link JsonFactory} to be used to create the {@link ObjectMapper} + * instance. + * @since 5.0 + */ + public Jackson2ObjectMapperBuilder factory(JsonFactory factory) { + this.factory = factory; + return this; + } + + /** + * Define the format for date/time with the given {@link DateFormat}. + *

Note: Setting this property makes the exposed {@link ObjectMapper} + * non-thread-safe, according to Jackson's thread safety rules. + * @see #simpleDateFormat(String) + */ + public Jackson2ObjectMapperBuilder dateFormat(DateFormat dateFormat) { + this.dateFormat = dateFormat; + return this; + } + + /** + * Define the date/time format with a {@link SimpleDateFormat}. + *

Note: Setting this property makes the exposed {@link ObjectMapper} + * non-thread-safe, according to Jackson's thread safety rules. + * @see #dateFormat(DateFormat) + */ + public Jackson2ObjectMapperBuilder simpleDateFormat(String format) { + this.dateFormat = new SimpleDateFormat(format); + return this; + } + + /** + * Override the default {@link Locale} to use for formatting. + * Default value used is {@link Locale#getDefault()}. + * @since 4.1.5 + */ + public Jackson2ObjectMapperBuilder locale(Locale locale) { + this.locale = locale; + return this; + } + + /** + * Override the default {@link Locale} to use for formatting. + * Default value used is {@link Locale#getDefault()}. + * @param localeString the locale ID as a String representation + * @since 4.1.5 + */ + public Jackson2ObjectMapperBuilder locale(String localeString) { + this.locale = StringUtils.parseLocale(localeString); + return this; + } + + /** + * Override the default {@link TimeZone} to use for formatting. + * Default value used is UTC (NOT local timezone). + * @since 4.1.5 + */ + public Jackson2ObjectMapperBuilder timeZone(TimeZone timeZone) { + this.timeZone = timeZone; + return this; + } + + /** + * Override the default {@link TimeZone} to use for formatting. + * Default value used is UTC (NOT local timezone). + * @param timeZoneString the zone ID as a String representation + * @since 4.1.5 + */ + public Jackson2ObjectMapperBuilder timeZone(String timeZoneString) { + this.timeZone = StringUtils.parseTimeZoneString(timeZoneString); + return this; + } + + /** + * Set an {@link AnnotationIntrospector} for both serialization and deserialization. + */ + public Jackson2ObjectMapperBuilder annotationIntrospector(AnnotationIntrospector annotationIntrospector) { + this.annotationIntrospector = annotationIntrospector; + return this; + } + + /** + * Specify a {@link com.fasterxml.jackson.databind.PropertyNamingStrategy} to + * configure the {@link ObjectMapper} with. + */ + public Jackson2ObjectMapperBuilder propertyNamingStrategy(PropertyNamingStrategy propertyNamingStrategy) { + this.propertyNamingStrategy = propertyNamingStrategy; + return this; + } + + /** + * Specify a {@link TypeResolverBuilder} to use for Jackson's default typing. + * @since 4.2.2 + */ + public Jackson2ObjectMapperBuilder defaultTyping(TypeResolverBuilder typeResolverBuilder) { + this.defaultTyping = typeResolverBuilder; + return this; + } + + /** + * Set a custom inclusion strategy for serialization. + * @see com.fasterxml.jackson.annotation.JsonInclude.Include + */ + public Jackson2ObjectMapperBuilder serializationInclusion(JsonInclude.Include serializationInclusion) { + this.serializationInclusion = serializationInclusion; + return this; + } + + /** + * Set the global filters to use in order to support {@link JsonFilter @JsonFilter} annotated POJO. + * @since 4.2 + * @see MappingJacksonValue#setFilters(FilterProvider) + */ + public Jackson2ObjectMapperBuilder filters(FilterProvider filters) { + this.filters = filters; + return this; + } + + /** + * Add mix-in annotations to use for augmenting specified class or interface. + * @param target class (or interface) whose annotations to effectively override + * @param mixinSource class (or interface) whose annotations are to be "added" + * to target's annotations as value + * @since 4.1.2 + * @see com.fasterxml.jackson.databind.ObjectMapper#addMixIn(Class, Class) + */ + public Jackson2ObjectMapperBuilder mixIn(Class target, Class mixinSource) { + this.mixIns.put(target, mixinSource); + return this; + } + + /** + * Add mix-in annotations to use for augmenting specified class or interface. + * @param mixIns a Map of entries with target classes (or interface) whose annotations + * to effectively override as key and mix-in classes (or interface) whose + * annotations are to be "added" to target's annotations as value. + * @since 4.1.2 + * @see com.fasterxml.jackson.databind.ObjectMapper#addMixIn(Class, Class) + */ + public Jackson2ObjectMapperBuilder mixIns(Map, Class> mixIns) { + this.mixIns.putAll(mixIns); + return this; + } + + /** + * Configure custom serializers. Each serializer is registered for the type + * returned by {@link JsonSerializer#handledType()}, which must not be {@code null}. + * @see #serializersByType(Map) + */ + public Jackson2ObjectMapperBuilder serializers(JsonSerializer... serializers) { + for (JsonSerializer serializer : serializers) { + Class handledType = serializer.handledType(); + if (handledType == null || handledType == Object.class) { + throw new IllegalArgumentException("Unknown handled type in " + serializer.getClass().getName()); + } + this.serializers.put(serializer.handledType(), serializer); + } + return this; + } + + /** + * Configure a custom serializer for the given type. + * @since 4.1.2 + * @see #serializers(JsonSerializer...) + */ + public Jackson2ObjectMapperBuilder serializerByType(Class type, JsonSerializer serializer) { + this.serializers.put(type, serializer); + return this; + } + + /** + * Configure custom serializers for the given types. + * @see #serializers(JsonSerializer...) + */ + public Jackson2ObjectMapperBuilder serializersByType(Map, JsonSerializer> serializers) { + this.serializers.putAll(serializers); + return this; + } + + /** + * Configure custom deserializers. Each deserializer is registered for the type + * returned by {@link JsonDeserializer#handledType()}, which must not be {@code null}. + * @since 4.3 + * @see #deserializersByType(Map) + */ + public Jackson2ObjectMapperBuilder deserializers(JsonDeserializer... deserializers) { + for (JsonDeserializer deserializer : deserializers) { + Class handledType = deserializer.handledType(); + if (handledType == null || handledType == Object.class) { + throw new IllegalArgumentException("Unknown handled type in " + deserializer.getClass().getName()); + } + this.deserializers.put(deserializer.handledType(), deserializer); + } + return this; + } + + /** + * Configure a custom deserializer for the given type. + * @since 4.1.2 + */ + public Jackson2ObjectMapperBuilder deserializerByType(Class type, JsonDeserializer deserializer) { + this.deserializers.put(type, deserializer); + return this; + } + + /** + * Configure custom deserializers for the given types. + */ + public Jackson2ObjectMapperBuilder deserializersByType(Map, JsonDeserializer> deserializers) { + this.deserializers.putAll(deserializers); + return this; + } + + /** + * Shortcut for {@link MapperFeature#AUTO_DETECT_FIELDS} option. + */ + public Jackson2ObjectMapperBuilder autoDetectFields(boolean autoDetectFields) { + this.features.put(MapperFeature.AUTO_DETECT_FIELDS, autoDetectFields); + return this; + } + + /** + * Shortcut for {@link MapperFeature#AUTO_DETECT_SETTERS}/ + * {@link MapperFeature#AUTO_DETECT_GETTERS}/{@link MapperFeature#AUTO_DETECT_IS_GETTERS} + * options. + */ + public Jackson2ObjectMapperBuilder autoDetectGettersSetters(boolean autoDetectGettersSetters) { + this.features.put(MapperFeature.AUTO_DETECT_GETTERS, autoDetectGettersSetters); + this.features.put(MapperFeature.AUTO_DETECT_SETTERS, autoDetectGettersSetters); + this.features.put(MapperFeature.AUTO_DETECT_IS_GETTERS, autoDetectGettersSetters); + return this; + } + + /** + * Shortcut for {@link MapperFeature#DEFAULT_VIEW_INCLUSION} option. + */ + public Jackson2ObjectMapperBuilder defaultViewInclusion(boolean defaultViewInclusion) { + this.features.put(MapperFeature.DEFAULT_VIEW_INCLUSION, defaultViewInclusion); + return this; + } + + /** + * Shortcut for {@link DeserializationFeature#FAIL_ON_UNKNOWN_PROPERTIES} option. + */ + public Jackson2ObjectMapperBuilder failOnUnknownProperties(boolean failOnUnknownProperties) { + this.features.put(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, failOnUnknownProperties); + return this; + } + + /** + * Shortcut for {@link SerializationFeature#FAIL_ON_EMPTY_BEANS} option. + */ + public Jackson2ObjectMapperBuilder failOnEmptyBeans(boolean failOnEmptyBeans) { + this.features.put(SerializationFeature.FAIL_ON_EMPTY_BEANS, failOnEmptyBeans); + return this; + } + + /** + * Shortcut for {@link SerializationFeature#INDENT_OUTPUT} option. + */ + public Jackson2ObjectMapperBuilder indentOutput(boolean indentOutput) { + this.features.put(SerializationFeature.INDENT_OUTPUT, indentOutput); + return this; + } + + /** + * Define if a wrapper will be used for indexed (List, array) properties or not by + * default (only applies to {@link XmlMapper}). + * @since 4.3 + */ + public Jackson2ObjectMapperBuilder defaultUseWrapper(boolean defaultUseWrapper) { + this.defaultUseWrapper = defaultUseWrapper; + return this; + } + + /** + * Specify visibility to limit what kind of properties are auto-detected. + * @since 5.1 + * @see com.fasterxml.jackson.annotation.PropertyAccessor + * @see com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility + */ + public Jackson2ObjectMapperBuilder visibility(PropertyAccessor accessor, JsonAutoDetect.Visibility visibility) { + this.visibilities.put(accessor, visibility); + return this; + } + + /** + * Specify features to enable. + * @see com.fasterxml.jackson.core.JsonParser.Feature + * @see com.fasterxml.jackson.core.JsonGenerator.Feature + * @see com.fasterxml.jackson.databind.SerializationFeature + * @see com.fasterxml.jackson.databind.DeserializationFeature + * @see com.fasterxml.jackson.databind.MapperFeature + */ + public Jackson2ObjectMapperBuilder featuresToEnable(Object... featuresToEnable) { + for (Object feature : featuresToEnable) { + this.features.put(feature, Boolean.TRUE); + } + return this; + } + + /** + * Specify features to disable. + * @see com.fasterxml.jackson.core.JsonParser.Feature + * @see com.fasterxml.jackson.core.JsonGenerator.Feature + * @see com.fasterxml.jackson.databind.SerializationFeature + * @see com.fasterxml.jackson.databind.DeserializationFeature + * @see com.fasterxml.jackson.databind.MapperFeature + */ + public Jackson2ObjectMapperBuilder featuresToDisable(Object... featuresToDisable) { + for (Object feature : featuresToDisable) { + this.features.put(feature, Boolean.FALSE); + } + return this; + } + + /** + * Specify one or more modules to be registered with the {@link ObjectMapper}. + * Multiple invocations are not additive, the last one defines the modules to + * register. + *

Note: If this is set, no finding of modules is going to happen - not by + * Jackson, and not by Spring either (see {@link #findModulesViaServiceLoader}). + * As a consequence, specifying an empty list here will suppress any kind of + * module detection. + *

Specify either this or {@link #modulesToInstall}, not both. + * @since 4.1.5 + * @see #modules(List) + * @see com.fasterxml.jackson.databind.Module + */ + public Jackson2ObjectMapperBuilder modules(Module... modules) { + return modules(Arrays.asList(modules)); + } + + /** + * Set a complete list of modules to be registered with the {@link ObjectMapper}. + * Multiple invocations are not additive, the last one defines the modules to + * register. + *

Note: If this is set, no finding of modules is going to happen - not by + * Jackson, and not by Spring either (see {@link #findModulesViaServiceLoader}). + * As a consequence, specifying an empty list here will suppress any kind of + * module detection. + *

Specify either this or {@link #modulesToInstall}, not both. + * @see #modules(Module...) + * @see com.fasterxml.jackson.databind.Module + */ + public Jackson2ObjectMapperBuilder modules(List modules) { + this.modules = new LinkedList<>(modules); + this.findModulesViaServiceLoader = false; + this.findWellKnownModules = false; + return this; + } + + /** + * Specify one or more modules to be registered with the {@link ObjectMapper}. + * Multiple invocations are not additive, the last one defines the modules + * to register. + *

Modules specified here will be registered after + * Spring's autodetection of JSR-310 and Joda-Time, or Jackson's + * finding of modules (see {@link #findModulesViaServiceLoader}), + * allowing to eventually override their configuration. + *

Specify either this or {@link #modules}, not both. + * @since 4.1.5 + * @see com.fasterxml.jackson.databind.Module + */ + public Jackson2ObjectMapperBuilder modulesToInstall(Module... modules) { + this.modules = Arrays.asList(modules); + this.findWellKnownModules = true; + return this; + } + + /** + * Specify one or more modules by class to be registered with + * the {@link ObjectMapper}. Multiple invocations are not additive, + * the last one defines the modules to register. + *

Modules specified here will be registered after + * Spring's autodetection of JSR-310 and Joda-Time, or Jackson's + * finding of modules (see {@link #findModulesViaServiceLoader}), + * allowing to eventually override their configuration. + *

Specify either this or {@link #modules}, not both. + * @see #modulesToInstall(Module...) + * @see com.fasterxml.jackson.databind.Module + */ + @SuppressWarnings("unchecked") + public Jackson2ObjectMapperBuilder modulesToInstall(Class... modules) { + this.moduleClasses = modules; + this.findWellKnownModules = true; + return this; + } + + /** + * Set whether to let Jackson find available modules via the JDK ServiceLoader, + * based on META-INF metadata in the classpath. + *

If this mode is not set, Spring's Jackson2ObjectMapperBuilder itself + * will try to find the JSR-310 and Joda-Time support modules on the classpath - + * provided that Java 8 and Joda-Time themselves are available, respectively. + * @see com.fasterxml.jackson.databind.ObjectMapper#findModules() + */ + public Jackson2ObjectMapperBuilder findModulesViaServiceLoader(boolean findModules) { + this.findModulesViaServiceLoader = findModules; + return this; + } + + /** + * Set the ClassLoader to use for loading Jackson extension modules. + */ + public Jackson2ObjectMapperBuilder moduleClassLoader(ClassLoader moduleClassLoader) { + this.moduleClassLoader = moduleClassLoader; + return this; + } + + /** + * Customize the construction of Jackson handlers ({@link JsonSerializer}, {@link JsonDeserializer}, + * {@link KeyDeserializer}, {@code TypeResolverBuilder} and {@code TypeIdResolver}). + * @since 4.1.3 + * @see Jackson2ObjectMapperBuilder#applicationContext(ApplicationContext) + */ + public Jackson2ObjectMapperBuilder handlerInstantiator(HandlerInstantiator handlerInstantiator) { + this.handlerInstantiator = handlerInstantiator; + return this; + } + + /** + * Set the Spring {@link ApplicationContext} in order to autowire Jackson handlers ({@link JsonSerializer}, + * {@link JsonDeserializer}, {@link KeyDeserializer}, {@code TypeResolverBuilder} and {@code TypeIdResolver}). + * @since 4.1.3 + * @see SpringHandlerInstantiator + */ + public Jackson2ObjectMapperBuilder applicationContext(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + return this; + } + + + /** + * Build a new {@link ObjectMapper} instance. + *

Each build operation produces an independent {@link ObjectMapper} instance. + * The builder's settings can get modified, with a subsequent build operation + * then producing a new {@link ObjectMapper} based on the most recent settings. + * @return the newly built ObjectMapper + */ + @SuppressWarnings("unchecked") + public T build() { + ObjectMapper mapper; + if (this.createXmlMapper) { + mapper = (this.defaultUseWrapper != null ? + new XmlObjectMapperInitializer().create(this.defaultUseWrapper, this.factory) : + new XmlObjectMapperInitializer().create(this.factory)); + } + else { + mapper = (this.factory != null ? new ObjectMapper(this.factory) : new ObjectMapper()); + } + configure(mapper); + return (T) mapper; + } + + /** + * Configure an existing {@link ObjectMapper} instance with this builder's + * settings. This can be applied to any number of {@code ObjectMappers}. + * @param objectMapper the ObjectMapper to configure + */ + public void configure(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + + MultiValueMap modulesToRegister = new LinkedMultiValueMap<>(); + if (this.findModulesViaServiceLoader) { + ObjectMapper.findModules(this.moduleClassLoader).forEach(module -> registerModule(module, modulesToRegister)); + } + else if (this.findWellKnownModules) { + registerWellKnownModulesIfAvailable(modulesToRegister); + } + + if (this.modules != null) { + this.modules.forEach(module -> registerModule(module, modulesToRegister)); + } + if (this.moduleClasses != null) { + for (Class moduleClass : this.moduleClasses) { + registerModule(BeanUtils.instantiateClass(moduleClass), modulesToRegister); + } + } + List modules = new ArrayList<>(); + for (List nestedModules : modulesToRegister.values()) { + modules.addAll(nestedModules); + } + objectMapper.registerModules(modules); + + if (this.dateFormat != null) { + objectMapper.setDateFormat(this.dateFormat); + } + if (this.locale != null) { + objectMapper.setLocale(this.locale); + } + if (this.timeZone != null) { + objectMapper.setTimeZone(this.timeZone); + } + + if (this.annotationIntrospector != null) { + objectMapper.setAnnotationIntrospector(this.annotationIntrospector); + } + if (this.propertyNamingStrategy != null) { + objectMapper.setPropertyNamingStrategy(this.propertyNamingStrategy); + } + if (this.defaultTyping != null) { + objectMapper.setDefaultTyping(this.defaultTyping); + } + if (this.serializationInclusion != null) { + objectMapper.setSerializationInclusion(this.serializationInclusion); + } + + if (this.filters != null) { + objectMapper.setFilterProvider(this.filters); + } + + this.mixIns.forEach(objectMapper::addMixIn); + + if (!this.serializers.isEmpty() || !this.deserializers.isEmpty()) { + SimpleModule module = new SimpleModule(); + addSerializers(module); + addDeserializers(module); + objectMapper.registerModule(module); + } + + this.visibilities.forEach(objectMapper::setVisibility); + + customizeDefaultFeatures(objectMapper); + this.features.forEach((feature, enabled) -> configureFeature(objectMapper, feature, enabled)); + + if (this.handlerInstantiator != null) { + objectMapper.setHandlerInstantiator(this.handlerInstantiator); + } + else if (this.applicationContext != null) { + objectMapper.setHandlerInstantiator( + new SpringHandlerInstantiator(this.applicationContext.getAutowireCapableBeanFactory())); + } + } + + private void registerModule(Module module, MultiValueMap modulesToRegister) { + if (module.getTypeId() == null) { + modulesToRegister.add(SimpleModule.class.getName(), module); + } + else { + modulesToRegister.set(module.getTypeId(), module); + } + } + + + // Any change to this method should be also applied to spring-jms and spring-messaging + // MappingJackson2MessageConverter default constructors + private void customizeDefaultFeatures(ObjectMapper objectMapper) { + if (!this.features.containsKey(MapperFeature.DEFAULT_VIEW_INCLUSION)) { + configureFeature(objectMapper, MapperFeature.DEFAULT_VIEW_INCLUSION, false); + } + if (!this.features.containsKey(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)) { + configureFeature(objectMapper, DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + } + } + + @SuppressWarnings("unchecked") + private void addSerializers(SimpleModule module) { + this.serializers.forEach((type, serializer) -> + module.addSerializer((Class) type, (JsonSerializer) serializer)); + } + + @SuppressWarnings("unchecked") + private void addDeserializers(SimpleModule module) { + this.deserializers.forEach((type, deserializer) -> + module.addDeserializer((Class) type, (JsonDeserializer) deserializer)); + } + + private void configureFeature(ObjectMapper objectMapper, Object feature, boolean enabled) { + if (feature instanceof JsonParser.Feature) { + objectMapper.configure((JsonParser.Feature) feature, enabled); + } + else if (feature instanceof JsonGenerator.Feature) { + objectMapper.configure((JsonGenerator.Feature) feature, enabled); + } + else if (feature instanceof SerializationFeature) { + objectMapper.configure((SerializationFeature) feature, enabled); + } + else if (feature instanceof DeserializationFeature) { + objectMapper.configure((DeserializationFeature) feature, enabled); + } + else if (feature instanceof MapperFeature) { + objectMapper.configure((MapperFeature) feature, enabled); + } + else { + throw new FatalBeanException("Unknown feature class: " + feature.getClass().getName()); + } + } + + @SuppressWarnings("unchecked") + private void registerWellKnownModulesIfAvailable(MultiValueMap modulesToRegister) { + try { + Class jdk8ModuleClass = (Class) + ClassUtils.forName("com.fasterxml.jackson.datatype.jdk8.Jdk8Module", this.moduleClassLoader); + Module jdk8Module = BeanUtils.instantiateClass(jdk8ModuleClass); + modulesToRegister.set(jdk8Module.getTypeId(), jdk8Module); + } + catch (ClassNotFoundException ex) { + // jackson-datatype-jdk8 not available + } + + try { + Class javaTimeModuleClass = (Class) + ClassUtils.forName("com.fasterxml.jackson.datatype.jsr310.JavaTimeModule", this.moduleClassLoader); + Module javaTimeModule = BeanUtils.instantiateClass(javaTimeModuleClass); + modulesToRegister.set(javaTimeModule.getTypeId(), javaTimeModule); + } + catch (ClassNotFoundException ex) { + // jackson-datatype-jsr310 not available + } + + // Joda-Time present? + if (ClassUtils.isPresent("org.joda.time.LocalDate", this.moduleClassLoader)) { + try { + Class jodaModuleClass = (Class) + ClassUtils.forName("com.fasterxml.jackson.datatype.joda.JodaModule", this.moduleClassLoader); + Module jodaModule = BeanUtils.instantiateClass(jodaModuleClass); + modulesToRegister.set(jodaModule.getTypeId(), jodaModule); + } + catch (ClassNotFoundException ex) { + // jackson-datatype-joda not available + } + } + + // Kotlin present? + if (KotlinDetector.isKotlinPresent()) { + try { + Class kotlinModuleClass = (Class) + ClassUtils.forName("com.fasterxml.jackson.module.kotlin.KotlinModule", this.moduleClassLoader); + Module kotlinModule = BeanUtils.instantiateClass(kotlinModuleClass); + modulesToRegister.set(kotlinModule.getTypeId(), kotlinModule); + } + catch (ClassNotFoundException ex) { + if (!kotlinWarningLogged) { + kotlinWarningLogged = true; + logger.warn("For Jackson Kotlin classes support please add " + + "\"com.fasterxml.jackson.module:jackson-module-kotlin\" to the classpath"); + } + } + } + } + + + // Convenience factory methods + + /** + * Obtain a {@link Jackson2ObjectMapperBuilder} instance in order to + * build a regular JSON {@link ObjectMapper} instance. + */ + public static Jackson2ObjectMapperBuilder json() { + return new Jackson2ObjectMapperBuilder(); + } + + /** + * Obtain a {@link Jackson2ObjectMapperBuilder} instance in order to + * build an {@link XmlMapper} instance. + */ + public static Jackson2ObjectMapperBuilder xml() { + return new Jackson2ObjectMapperBuilder().createXmlMapper(true); + } + + /** + * Obtain a {@link Jackson2ObjectMapperBuilder} instance in order to + * build a Smile data format {@link ObjectMapper} instance. + * @since 5.0 + */ + public static Jackson2ObjectMapperBuilder smile() { + return new Jackson2ObjectMapperBuilder().factory(new SmileFactoryInitializer().create()); + } + + /** + * Obtain a {@link Jackson2ObjectMapperBuilder} instance in order to + * build a CBOR data format {@link ObjectMapper} instance. + * @since 5.0 + */ + public static Jackson2ObjectMapperBuilder cbor() { + return new Jackson2ObjectMapperBuilder().factory(new CborFactoryInitializer().create()); + } + + + private static class XmlObjectMapperInitializer { + + public ObjectMapper create(@Nullable JsonFactory factory) { + if (factory != null) { + return new XmlMapper((XmlFactory) factory); + } + else { + return new XmlMapper(StaxUtils.createDefensiveInputFactory()); + } + } + + public ObjectMapper create(boolean defaultUseWrapper, @Nullable JsonFactory factory) { + JacksonXmlModule module = new JacksonXmlModule(); + module.setDefaultUseWrapper(defaultUseWrapper); + if (factory != null) { + return new XmlMapper((XmlFactory) factory, module); + } + else { + return new XmlMapper(new XmlFactory(StaxUtils.createDefensiveInputFactory()), module); + } + } + } + + + private static class SmileFactoryInitializer { + + public JsonFactory create() { + return new SmileFactory(); + } + } + + + private static class CborFactoryInitializer { + + public JsonFactory create() { + return new CBORFactory(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/Jackson2ObjectMapperFactoryBean.java b/spring-web/src/main/java/org/springframework/http/converter/json/Jackson2ObjectMapperFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..6925ddaf1acb793d60aaa48e37c87744fbd1eea4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/Jackson2ObjectMapperFactoryBean.java @@ -0,0 +1,489 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; + +import com.fasterxml.jackson.annotation.JsonFilter; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.databind.AnnotationIntrospector; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.KeyDeserializer; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.Module; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.cfg.HandlerInstantiator; +import com.fasterxml.jackson.databind.jsontype.TypeResolverBuilder; +import com.fasterxml.jackson.databind.ser.FilterProvider; +import com.fasterxml.jackson.dataformat.xml.XmlMapper; + +import org.springframework.beans.factory.BeanClassLoaderAware; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.lang.Nullable; + +/** + * A {@link FactoryBean} for creating a Jackson 2.x {@link ObjectMapper} (default) or + * {@link XmlMapper} ({@code createXmlMapper} property set to true) with setters + * to enable or disable Jackson features from within XML configuration. + * + *

It customizes Jackson defaults properties with the following ones: + *

    + *
  • {@link MapperFeature#DEFAULT_VIEW_INCLUSION} is disabled
  • + *
  • {@link DeserializationFeature#FAIL_ON_UNKNOWN_PROPERTIES} is disabled
  • + *
+ * + *

Example usage with + * {@link MappingJackson2HttpMessageConverter}: + * + *

+ * <bean class="org.springframework.http.converter.json.MappingJackson2HttpMessageConverter">
+ *   <property name="objectMapper">
+ *     <bean class="org.springframework.http.converter.json.Jackson2ObjectMapperFactoryBean"
+ *       p:autoDetectFields="false"
+ *       p:autoDetectGettersSetters="false"
+ *       p:annotationIntrospector-ref="jaxbAnnotationIntrospector" />
+ *   </property>
+ * </bean>
+ * 
+ * + *

Example usage with MappingJackson2JsonView: + * + *

+ * <bean class="org.springframework.web.servlet.view.json.MappingJackson2JsonView">
+ *   <property name="objectMapper">
+ *     <bean class="org.springframework.http.converter.json.Jackson2ObjectMapperFactoryBean"
+ *       p:failOnEmptyBeans="false"
+ *       p:indentOutput="true">
+ *       <property name="serializers">
+ *         <array>
+ *           <bean class="org.mycompany.MyCustomSerializer" />
+ *         </array>
+ *       </property>
+ *     </bean>
+ *   </property>
+ * </bean>
+ * 
+ * + *

In case there are no specific setters provided (for some rarely used options), + * you can still use the more general methods {@link #setFeaturesToEnable} and + * {@link #setFeaturesToDisable}. + * + *

+ * <bean class="org.springframework.http.converter.json.Jackson2ObjectMapperFactoryBean">
+ *   <property name="featuresToEnable">
+ *     <array>
+ *       <util:constant static-field="com.fasterxml.jackson.databind.SerializationFeature.WRAP_ROOT_VALUE"/>
+ *       <util:constant static-field="com.fasterxml.jackson.databind.SerializationFeature.CLOSE_CLOSEABLE"/>
+ *     </array>
+ *   </property>
+ *   <property name="featuresToDisable">
+ *     <array>
+ *       <util:constant static-field="com.fasterxml.jackson.databind.MapperFeature.USE_ANNOTATIONS"/>
+ *     </array>
+ *   </property>
+ * </bean>
+ * 
+ * + *

It also automatically registers the following well-known modules if they are + * detected on the classpath: + *

+ * + *

In case you want to configure Jackson's {@link ObjectMapper} with a custom {@link Module}, + * you can register one or more such Modules by class name via {@link #setModulesToInstall}: + * + *

+ * <bean class="org.springframework.http.converter.json.Jackson2ObjectMapperFactoryBean">
+ *   <property name="modulesToInstall" value="myapp.jackson.MySampleModule,myapp.jackson.MyOtherModule"/>
+ * </bean
+ * 
+ * + *

Compatible with Jackson 2.6 and higher, as of Spring 4.3. + * + * @author Dmitry Katsubo + * @author Rossen Stoyanchev + * @author Brian Clozel + * @author Juergen Hoeller + * @author Tadaya Tsuyukubo + * @author Sebastien Deleuze + * @since 3.2 + */ +public class Jackson2ObjectMapperFactoryBean implements FactoryBean, BeanClassLoaderAware, + ApplicationContextAware, InitializingBean { + + private final Jackson2ObjectMapperBuilder builder = new Jackson2ObjectMapperBuilder(); + + @Nullable + private ObjectMapper objectMapper; + + + /** + * Set the {@link ObjectMapper} instance to use. If not set, the {@link ObjectMapper} + * will be created using its default constructor. + */ + public void setObjectMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + } + + /** + * If set to true and no custom {@link ObjectMapper} has been set, a {@link XmlMapper} + * will be created using its default constructor. + * @since 4.1 + */ + public void setCreateXmlMapper(boolean createXmlMapper) { + this.builder.createXmlMapper(createXmlMapper); + } + + /** + * Define the {@link JsonFactory} to be used to create the {@link ObjectMapper} + * instance. + * @since 5.0 + */ + public void setFactory(JsonFactory factory) { + this.builder.factory(factory); + } + + /** + * Define the format for date/time with the given {@link DateFormat}. + *

Note: Setting this property makes the exposed {@link ObjectMapper} + * non-thread-safe, according to Jackson's thread safety rules. + * @see #setSimpleDateFormat(String) + */ + public void setDateFormat(DateFormat dateFormat) { + this.builder.dateFormat(dateFormat); + } + + /** + * Define the date/time format with a {@link SimpleDateFormat}. + *

Note: Setting this property makes the exposed {@link ObjectMapper} + * non-thread-safe, according to Jackson's thread safety rules. + * @see #setDateFormat(DateFormat) + */ + public void setSimpleDateFormat(String format) { + this.builder.simpleDateFormat(format); + } + + /** + * Override the default {@link Locale} to use for formatting. + * Default value used is {@link Locale#getDefault()}. + * @since 4.1.5 + */ + public void setLocale(Locale locale) { + this.builder.locale(locale); + } + + /** + * Override the default {@link TimeZone} to use for formatting. + * Default value used is UTC (NOT local timezone). + * @since 4.1.5 + */ + public void setTimeZone(TimeZone timeZone) { + this.builder.timeZone(timeZone); + } + + /** + * Set an {@link AnnotationIntrospector} for both serialization and deserialization. + */ + public void setAnnotationIntrospector(AnnotationIntrospector annotationIntrospector) { + this.builder.annotationIntrospector(annotationIntrospector); + } + + /** + * Specify a {@link com.fasterxml.jackson.databind.PropertyNamingStrategy} to + * configure the {@link ObjectMapper} with. + * @since 4.0.2 + */ + public void setPropertyNamingStrategy(PropertyNamingStrategy propertyNamingStrategy) { + this.builder.propertyNamingStrategy(propertyNamingStrategy); + } + + /** + * Specify a {@link TypeResolverBuilder} to use for Jackson's default typing. + * @since 4.2.2 + */ + public void setDefaultTyping(TypeResolverBuilder typeResolverBuilder) { + this.builder.defaultTyping(typeResolverBuilder); + } + + /** + * Set a custom inclusion strategy for serialization. + * @see com.fasterxml.jackson.annotation.JsonInclude.Include + */ + public void setSerializationInclusion(JsonInclude.Include serializationInclusion) { + this.builder.serializationInclusion(serializationInclusion); + } + + /** + * Set the global filters to use in order to support {@link JsonFilter @JsonFilter} annotated POJO. + * @since 4.2 + * @see Jackson2ObjectMapperBuilder#filters(FilterProvider) + */ + public void setFilters(FilterProvider filters) { + this.builder.filters(filters); + } + + /** + * Add mix-in annotations to use for augmenting specified class or interface. + * @param mixIns a Map of entries with target classes (or interface) whose annotations + * to effectively override as key and mix-in classes (or interface) whose + * annotations are to be "added" to target's annotations as value. + * @since 4.1.2 + * @see com.fasterxml.jackson.databind.ObjectMapper#addMixInAnnotations(Class, Class) + */ + public void setMixIns(Map, Class> mixIns) { + this.builder.mixIns(mixIns); + } + + /** + * Configure custom serializers. Each serializer is registered for the type + * returned by {@link JsonSerializer#handledType()}, which must not be {@code null}. + * @see #setSerializersByType(Map) + */ + public void setSerializers(JsonSerializer... serializers) { + this.builder.serializers(serializers); + } + + /** + * Configure custom serializers for the given types. + * @see #setSerializers(JsonSerializer...) + */ + public void setSerializersByType(Map, JsonSerializer> serializers) { + this.builder.serializersByType(serializers); + } + + /** + * Configure custom deserializers. Each deserializer is registered for the type + * returned by {@link JsonDeserializer#handledType()}, which must not be {@code null}. + * @since 4.3 + * @see #setDeserializersByType(Map) + */ + public void setDeserializers(JsonDeserializer... deserializers) { + this.builder.deserializers(deserializers); + } + + /** + * Configure custom deserializers for the given types. + */ + public void setDeserializersByType(Map, JsonDeserializer> deserializers) { + this.builder.deserializersByType(deserializers); + } + + /** + * Shortcut for {@link MapperFeature#AUTO_DETECT_FIELDS} option. + */ + public void setAutoDetectFields(boolean autoDetectFields) { + this.builder.autoDetectFields(autoDetectFields); + } + + /** + * Shortcut for {@link MapperFeature#AUTO_DETECT_SETTERS}/ + * {@link MapperFeature#AUTO_DETECT_GETTERS}/{@link MapperFeature#AUTO_DETECT_IS_GETTERS} + * options. + */ + public void setAutoDetectGettersSetters(boolean autoDetectGettersSetters) { + this.builder.autoDetectGettersSetters(autoDetectGettersSetters); + } + + /** + * Shortcut for {@link MapperFeature#DEFAULT_VIEW_INCLUSION} option. + * @since 4.1 + */ + public void setDefaultViewInclusion(boolean defaultViewInclusion) { + this.builder.defaultViewInclusion(defaultViewInclusion); + } + + /** + * Shortcut for {@link DeserializationFeature#FAIL_ON_UNKNOWN_PROPERTIES} option. + * @since 4.1.1 + */ + public void setFailOnUnknownProperties(boolean failOnUnknownProperties) { + this.builder.failOnUnknownProperties(failOnUnknownProperties); + } + + /** + * Shortcut for {@link SerializationFeature#FAIL_ON_EMPTY_BEANS} option. + */ + public void setFailOnEmptyBeans(boolean failOnEmptyBeans) { + this.builder.failOnEmptyBeans(failOnEmptyBeans); + } + + /** + * Shortcut for {@link SerializationFeature#INDENT_OUTPUT} option. + */ + public void setIndentOutput(boolean indentOutput) { + this.builder.indentOutput(indentOutput); + } + + /** + * Define if a wrapper will be used for indexed (List, array) properties or not by + * default (only applies to {@link XmlMapper}). + * @since 4.3 + */ + public void setDefaultUseWrapper(boolean defaultUseWrapper) { + this.builder.defaultUseWrapper(defaultUseWrapper); + } + + /** + * Specify features to enable. + * @see com.fasterxml.jackson.core.JsonParser.Feature + * @see com.fasterxml.jackson.core.JsonGenerator.Feature + * @see com.fasterxml.jackson.databind.SerializationFeature + * @see com.fasterxml.jackson.databind.DeserializationFeature + * @see com.fasterxml.jackson.databind.MapperFeature + */ + public void setFeaturesToEnable(Object... featuresToEnable) { + this.builder.featuresToEnable(featuresToEnable); + } + + /** + * Specify features to disable. + * @see com.fasterxml.jackson.core.JsonParser.Feature + * @see com.fasterxml.jackson.core.JsonGenerator.Feature + * @see com.fasterxml.jackson.databind.SerializationFeature + * @see com.fasterxml.jackson.databind.DeserializationFeature + * @see com.fasterxml.jackson.databind.MapperFeature + */ + public void setFeaturesToDisable(Object... featuresToDisable) { + this.builder.featuresToDisable(featuresToDisable); + } + + /** + * Set a complete list of modules to be registered with the {@link ObjectMapper}. + *

Note: If this is set, no finding of modules is going to happen - not by + * Jackson, and not by Spring either (see {@link #setFindModulesViaServiceLoader}). + * As a consequence, specifying an empty list here will suppress any kind of + * module detection. + *

Specify either this or {@link #setModulesToInstall}, not both. + * @since 4.0 + * @see com.fasterxml.jackson.databind.Module + */ + public void setModules(List modules) { + this.builder.modules(modules); + } + + /** + * Specify one or more modules by class (or class name in XML) + * to be registered with the {@link ObjectMapper}. + *

Modules specified here will be registered after + * Spring's autodetection of JSR-310 and Joda-Time, or Jackson's + * finding of modules (see {@link #setFindModulesViaServiceLoader}), + * allowing to eventually override their configuration. + *

Specify either this or {@link #setModules}, not both. + * @since 4.0.1 + * @see com.fasterxml.jackson.databind.Module + */ + @SuppressWarnings("unchecked") + public void setModulesToInstall(Class... modules) { + this.builder.modulesToInstall(modules); + } + + /** + * Set whether to let Jackson find available modules via the JDK ServiceLoader, + * based on META-INF metadata in the classpath. Requires Jackson 2.2 or higher. + *

If this mode is not set, Spring's Jackson2ObjectMapperFactoryBean itself + * will try to find the JSR-310 and Joda-Time support modules on the classpath - + * provided that Java 8 and Joda-Time themselves are available, respectively. + * @since 4.0.1 + * @see com.fasterxml.jackson.databind.ObjectMapper#findModules() + */ + public void setFindModulesViaServiceLoader(boolean findModules) { + this.builder.findModulesViaServiceLoader(findModules); + } + + @Override + public void setBeanClassLoader(ClassLoader beanClassLoader) { + this.builder.moduleClassLoader(beanClassLoader); + } + + /** + * Customize the construction of Jackson handlers + * ({@link JsonSerializer}, {@link JsonDeserializer}, {@link KeyDeserializer}, + * {@code TypeResolverBuilder} and {@code TypeIdResolver}). + * @since 4.1.3 + * @see Jackson2ObjectMapperFactoryBean#setApplicationContext(ApplicationContext) + */ + public void setHandlerInstantiator(HandlerInstantiator handlerInstantiator) { + this.builder.handlerInstantiator(handlerInstantiator); + } + + /** + * Set the builder {@link ApplicationContext} in order to autowire Jackson handlers + * ({@link JsonSerializer}, {@link JsonDeserializer}, {@link KeyDeserializer}, + * {@code TypeResolverBuilder} and {@code TypeIdResolver}). + * @since 4.1.3 + * @see Jackson2ObjectMapperBuilder#applicationContext(ApplicationContext) + * @see SpringHandlerInstantiator + */ + @Override + public void setApplicationContext(ApplicationContext applicationContext) { + this.builder.applicationContext(applicationContext); + } + + + @Override + public void afterPropertiesSet() { + if (this.objectMapper != null) { + this.builder.configure(this.objectMapper); + } + else { + this.objectMapper = this.builder.build(); + } + } + + /** + * Return the singleton ObjectMapper. + */ + @Override + @Nullable + public ObjectMapper getObject() { + return this.objectMapper; + } + + @Override + public Class getObjectType() { + return (this.objectMapper != null ? this.objectMapper.getClass() : null); + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/JsonbHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/JsonbHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..bdbc7221cfdafd43aef13fb7c76df29a7f63c5f3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/JsonbHttpMessageConverter.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.Reader; +import java.io.Writer; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +import javax.json.bind.Jsonb; +import javax.json.bind.JsonbBuilder; +import javax.json.bind.JsonbConfig; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter} + * that can read and write JSON using the + * JSON Binding API. + * + *

This converter can be used to bind to typed beans or untyped {@code HashMap}s. + * By default, it supports {@code application/json} and {@code application/*+json} with + * {@code UTF-8} character set. + * + * @author Juergen Hoeller + * @since 5.0 + * @see javax.json.bind.Jsonb + * @see javax.json.bind.JsonbBuilder + * @see #setJsonb + */ +public class JsonbHttpMessageConverter extends AbstractJsonHttpMessageConverter { + + private Jsonb jsonb; + + + /** + * Construct a new {@code JsonbHttpMessageConverter} with default configuration. + */ + public JsonbHttpMessageConverter() { + this(JsonbBuilder.create()); + } + + /** + * Construct a new {@code JsonbHttpMessageConverter} with the given configuration. + * @param config the {@code JsonbConfig} for the underlying delegate + */ + public JsonbHttpMessageConverter(JsonbConfig config) { + this.jsonb = JsonbBuilder.create(config); + } + + /** + * Construct a new {@code JsonbHttpMessageConverter} with the given delegate. + * @param jsonb the Jsonb instance to use + */ + public JsonbHttpMessageConverter(Jsonb jsonb) { + Assert.notNull(jsonb, "A Jsonb instance is required"); + this.jsonb = jsonb; + } + + + /** + * Set the {@code Jsonb} instance to use. + * If not set, a default {@code Jsonb} instance will be created. + *

Setting a custom-configured {@code Jsonb} is one way to take further + * control of the JSON serialization process. + * @see #JsonbHttpMessageConverter(Jsonb) + * @see #JsonbHttpMessageConverter(JsonbConfig) + * @see JsonbBuilder + */ + public void setJsonb(Jsonb jsonb) { + Assert.notNull(jsonb, "A Jsonb instance is required"); + this.jsonb = jsonb; + } + + /** + * Return the configured {@code Jsonb} instance for this converter. + */ + public Jsonb getJsonb() { + return this.jsonb; + } + + + @Override + protected Object readInternal(Type resolvedType, Reader reader) throws Exception { + return getJsonb().fromJson(reader, resolvedType); + } + + @Override + protected void writeInternal(Object object, @Nullable Type type, Writer writer) throws Exception { + if (type instanceof ParameterizedType) { + getJsonb().toJson(object, type, writer); + } + else { + getJsonb().toJson(object, writer); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/MappingJackson2HttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/json/MappingJackson2HttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..b3616ea1971df9cfa003407d8682b91ad5df93c4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/MappingJackson2HttpMessageConverter.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter} that can read and + * write JSON using Jackson 2.x's {@link ObjectMapper}. + * + *

This converter can be used to bind to typed beans, or untyped {@code HashMap} instances. + * + *

By default, this converter supports {@code application/json} and {@code application/*+json} + * with {@code UTF-8} character set. This can be overridden by setting the + * {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + *

The default constructor uses the default configuration provided by {@link Jackson2ObjectMapperBuilder}. + * + *

Compatible with Jackson 2.9 and higher, as of Spring 5.0. + * + * @author Arjen Poutsma + * @author Keith Donald + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 3.1.2 + */ +public class MappingJackson2HttpMessageConverter extends AbstractJackson2HttpMessageConverter { + + @Nullable + private String jsonPrefix; + + + /** + * Construct a new {@link MappingJackson2HttpMessageConverter} using default configuration + * provided by {@link Jackson2ObjectMapperBuilder}. + */ + public MappingJackson2HttpMessageConverter() { + this(Jackson2ObjectMapperBuilder.json().build()); + } + + /** + * Construct a new {@link MappingJackson2HttpMessageConverter} with a custom {@link ObjectMapper}. + * You can use {@link Jackson2ObjectMapperBuilder} to build it easily. + * @see Jackson2ObjectMapperBuilder#json() + */ + public MappingJackson2HttpMessageConverter(ObjectMapper objectMapper) { + super(objectMapper, MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); + } + + + /** + * Specify a custom prefix to use for this view's JSON output. + * Default is none. + * @see #setPrefixJson + */ + public void setJsonPrefix(String jsonPrefix) { + this.jsonPrefix = jsonPrefix; + } + + /** + * Indicate whether the JSON output by this view should be prefixed with ")]}', ". Default is false. + *

Prefixing the JSON string in this manner is used to help prevent JSON Hijacking. + * The prefix renders the string syntactically invalid as a script so that it cannot be hijacked. + * This prefix should be stripped before parsing the string as JSON. + * @see #setJsonPrefix + */ + public void setPrefixJson(boolean prefixJson) { + this.jsonPrefix = (prefixJson ? ")]}', " : null); + } + + + @Override + protected void writePrefix(JsonGenerator generator, Object object) throws IOException { + if (this.jsonPrefix != null) { + generator.writeRaw(this.jsonPrefix); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/MappingJacksonInputMessage.java b/spring-web/src/main/java/org/springframework/http/converter/json/MappingJacksonInputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..d882f3b0bce82735ebc6a0606b2cf77c66fc379f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/MappingJacksonInputMessage.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.io.InputStream; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.lang.Nullable; + +/** + * {@link HttpInputMessage} that can eventually stores a Jackson view that will be used + * to deserialize the message. + * + * @author Sebastien Deleuze + * @since 4.2 + */ +public class MappingJacksonInputMessage implements HttpInputMessage { + + private final InputStream body; + + private final HttpHeaders headers; + + @Nullable + private Class deserializationView; + + + public MappingJacksonInputMessage(InputStream body, HttpHeaders headers) { + this.body = body; + this.headers = headers; + } + + public MappingJacksonInputMessage(InputStream body, HttpHeaders headers, Class deserializationView) { + this(body, headers); + this.deserializationView = deserializationView; + } + + + @Override + public InputStream getBody() throws IOException { + return this.body; + } + + @Override + public HttpHeaders getHeaders() { + return this.headers; + } + + public void setDeserializationView(@Nullable Class deserializationView) { + this.deserializationView = deserializationView; + } + + @Nullable + public Class getDeserializationView() { + return this.deserializationView; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/MappingJacksonValue.java b/spring-web/src/main/java/org/springframework/http/converter/json/MappingJacksonValue.java new file mode 100644 index 0000000000000000000000000000000000000000..fad90875f2bf8eae400bf5e515e7abebb2f33304 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/MappingJacksonValue.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import com.fasterxml.jackson.databind.ser.FilterProvider; + +import org.springframework.lang.Nullable; + +/** + * A simple holder for the POJO to serialize via + * {@link MappingJackson2HttpMessageConverter} along with further + * serialization instructions to be passed in to the converter. + * + *

On the server side this wrapper is added with a + * {@code ResponseBodyInterceptor} after content negotiation selects the + * converter to use but before the write. + * + *

On the client side, simply wrap the POJO and pass it in to the + * {@code RestTemplate}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class MappingJacksonValue { + + private Object value; + + @Nullable + private Class serializationView; + + @Nullable + private FilterProvider filters; + + + /** + * Create a new instance wrapping the given POJO to be serialized. + * @param value the Object to be serialized + */ + public MappingJacksonValue(Object value) { + this.value = value; + } + + + /** + * Modify the POJO to serialize. + */ + public void setValue(Object value) { + this.value = value; + } + + /** + * Return the POJO that needs to be serialized. + */ + public Object getValue() { + return this.value; + } + + /** + * Set the serialization view to serialize the POJO with. + * @see com.fasterxml.jackson.databind.ObjectMapper#writerWithView(Class) + * @see com.fasterxml.jackson.annotation.JsonView + */ + public void setSerializationView(@Nullable Class serializationView) { + this.serializationView = serializationView; + } + + /** + * Return the serialization view to use. + * @see com.fasterxml.jackson.databind.ObjectMapper#writerWithView(Class) + * @see com.fasterxml.jackson.annotation.JsonView + */ + @Nullable + public Class getSerializationView() { + return this.serializationView; + } + + /** + * Set the Jackson filter provider to serialize the POJO with. + * @since 4.2 + * @see com.fasterxml.jackson.databind.ObjectMapper#writer(FilterProvider) + * @see com.fasterxml.jackson.annotation.JsonFilter + * @see Jackson2ObjectMapperBuilder#filters(FilterProvider) + */ + public void setFilters(@Nullable FilterProvider filters) { + this.filters = filters; + } + + /** + * Return the Jackson filter provider to use. + * @since 4.2 + * @see com.fasterxml.jackson.databind.ObjectMapper#writer(FilterProvider) + * @see com.fasterxml.jackson.annotation.JsonFilter + */ + @Nullable + public FilterProvider getFilters() { + return this.filters; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/SpringHandlerInstantiator.java b/spring-web/src/main/java/org/springframework/http/converter/json/SpringHandlerInstantiator.java new file mode 100644 index 0000000000000000000000000000000000000000..4d5885d39f18b69e50f90bbb93fe5992a2522475 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/SpringHandlerInstantiator.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import com.fasterxml.jackson.annotation.ObjectIdGenerator; +import com.fasterxml.jackson.annotation.ObjectIdResolver; +import com.fasterxml.jackson.databind.DeserializationConfig; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.KeyDeserializer; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.SerializationConfig; +import com.fasterxml.jackson.databind.cfg.HandlerInstantiator; +import com.fasterxml.jackson.databind.cfg.MapperConfig; +import com.fasterxml.jackson.databind.deser.ValueInstantiator; +import com.fasterxml.jackson.databind.introspect.Annotated; +import com.fasterxml.jackson.databind.jsontype.TypeIdResolver; +import com.fasterxml.jackson.databind.jsontype.TypeResolverBuilder; +import com.fasterxml.jackson.databind.ser.VirtualBeanPropertyWriter; +import com.fasterxml.jackson.databind.util.Converter; + +import org.springframework.beans.factory.config.AutowireCapableBeanFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.util.Assert; + +/** + * Allows for creating Jackson ({@link JsonSerializer}, {@link JsonDeserializer}, + * {@link KeyDeserializer}, {@link TypeResolverBuilder}, {@link TypeIdResolver}) + * beans with autowiring against a Spring {@link ApplicationContext}. + * + *

As of Spring 4.3, this overrides all factory methods in {@link HandlerInstantiator}, + * including non-abstract ones and recently introduced ones from Jackson 2.4 and 2.5: + * for {@link ValueInstantiator}, {@link ObjectIdGenerator}, {@link ObjectIdResolver}, + * {@link PropertyNamingStrategy}, {@link Converter}, {@link VirtualBeanPropertyWriter}. + * + * @author Sebastien Deleuze + * @author Juergen Hoeller + * @since 4.1.3 + * @see Jackson2ObjectMapperBuilder#handlerInstantiator(HandlerInstantiator) + * @see ApplicationContext#getAutowireCapableBeanFactory() + * @see HandlerInstantiator + */ +public class SpringHandlerInstantiator extends HandlerInstantiator { + + private final AutowireCapableBeanFactory beanFactory; + + + /** + * Create a new SpringHandlerInstantiator for the given BeanFactory. + * @param beanFactory the target BeanFactory + */ + public SpringHandlerInstantiator(AutowireCapableBeanFactory beanFactory) { + Assert.notNull(beanFactory, "BeanFactory must not be null"); + this.beanFactory = beanFactory; + } + + + @Override + public JsonDeserializer deserializerInstance(DeserializationConfig config, + Annotated annotated, Class implClass) { + + return (JsonDeserializer) this.beanFactory.createBean(implClass); + } + + @Override + public KeyDeserializer keyDeserializerInstance(DeserializationConfig config, + Annotated annotated, Class implClass) { + + return (KeyDeserializer) this.beanFactory.createBean(implClass); + } + + @Override + public JsonSerializer serializerInstance(SerializationConfig config, + Annotated annotated, Class implClass) { + + return (JsonSerializer) this.beanFactory.createBean(implClass); + } + + @Override + public TypeResolverBuilder typeResolverBuilderInstance(MapperConfig config, + Annotated annotated, Class implClass) { + + return (TypeResolverBuilder) this.beanFactory.createBean(implClass); + } + + @Override + public TypeIdResolver typeIdResolverInstance(MapperConfig config, Annotated annotated, Class implClass) { + return (TypeIdResolver) this.beanFactory.createBean(implClass); + } + + /** @since 4.3 */ + @Override + public ValueInstantiator valueInstantiatorInstance(MapperConfig config, + Annotated annotated, Class implClass) { + + return (ValueInstantiator) this.beanFactory.createBean(implClass); + } + + /** @since 4.3 */ + @Override + public ObjectIdGenerator objectIdGeneratorInstance(MapperConfig config, + Annotated annotated, Class implClass) { + + return (ObjectIdGenerator) this.beanFactory.createBean(implClass); + } + + /** @since 4.3 */ + @Override + public ObjectIdResolver resolverIdGeneratorInstance(MapperConfig config, + Annotated annotated, Class implClass) { + + return (ObjectIdResolver) this.beanFactory.createBean(implClass); + } + + /** @since 4.3 */ + @Override + public PropertyNamingStrategy namingStrategyInstance(MapperConfig config, + Annotated annotated, Class implClass) { + + return (PropertyNamingStrategy) this.beanFactory.createBean(implClass); + } + + /** @since 4.3 */ + @Override + public Converter converterInstance(MapperConfig config, + Annotated annotated, Class implClass) { + + return (Converter) this.beanFactory.createBean(implClass); + } + + /** @since 4.3 */ + @Override + public VirtualBeanPropertyWriter virtualPropertyWriterInstance(MapperConfig config, Class implClass) { + return (VirtualBeanPropertyWriter) this.beanFactory.createBean(implClass); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/json/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/json/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..5290f00a1c2e222b377aee667510a24f7c4d3180 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/json/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides HttpMessageConverter implementations for handling JSON. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter.json; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..82ddae975e23ce74be9649f6cc7afbb7c53bb98c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides an HttpMessageConverter abstraction to convert between Java objects and HTTP input/output messages. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/protobuf/ExtensionRegistryInitializer.java b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ExtensionRegistryInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..9f11533c79645f1c39d5a0050624e1fafdef63c8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ExtensionRegistryInitializer.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.protobuf; + +import com.google.protobuf.ExtensionRegistry; + +/** + * Google Protocol Messages can contain message extensions that can be parsed if + * the appropriate configuration has been registered in the {@code ExtensionRegistry}. + * + *

This interface provides a facility to populate the {@code ExtensionRegistry}. + * + * @author Alex Antonov + * @author Sebastien Deleuze + * @since 4.1 + * @see + * com.google.protobuf.ExtensionRegistry + * @deprecated as of Spring Framework 5.1, use {@link ExtensionRegistry} based constructors instead + */ +@Deprecated +public interface ExtensionRegistryInitializer { + + /** + * Initializes the {@code ExtensionRegistry} with Protocol Message extensions. + * @param registry the registry to populate + */ + void initializeExtensionRegistry(ExtensionRegistry registry); + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..7cfa95211294feb8c30674abb4f2606d9125903f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverter.java @@ -0,0 +1,419 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.protobuf; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.lang.reflect.Method; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Map; + +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.Message; +import com.google.protobuf.TextFormat; +import com.google.protobuf.util.JsonFormat; +import com.googlecode.protobuf.format.FormatFactory; +import com.googlecode.protobuf.format.ProtobufFormatter; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ConcurrentReferenceHashMap; + +import static org.springframework.http.MediaType.APPLICATION_JSON; +import static org.springframework.http.MediaType.APPLICATION_XML; +import static org.springframework.http.MediaType.TEXT_HTML; +import static org.springframework.http.MediaType.TEXT_PLAIN; + +/** + * An {@code HttpMessageConverter} that reads and writes + * {@link com.google.protobuf.Message com.google.protobuf.Messages} using + * Google Protocol Buffers. + * + *

To generate {@code Message} Java classes, you need to install the {@code protoc} binary. + * + *

This converter supports by default {@code "application/x-protobuf"} and {@code "text/plain"} + * with the official {@code "com.google.protobuf:protobuf-java"} library. Other formats can be + * supported with one of the following additional libraries on the classpath: + *

    + *
  • {@code "application/json"}, {@code "application/xml"}, and {@code "text/html"} (write-only) + * with the {@code "com.googlecode.protobuf-java-format:protobuf-java-format"} third-party library + *
  • {@code "application/json"} with the official {@code "com.google.protobuf:protobuf-java-util"} + * for Protobuf 3 (see {@link ProtobufJsonFormatHttpMessageConverter} for a configurable variant) + *
+ * + *

Requires Protobuf 2.6 or higher (and Protobuf Java Format 1.4 or higher for formatting). + * This converter will auto-adapt to Protobuf 3 and its default {@code protobuf-java-util} JSON + * format if the Protobuf 2 based {@code protobuf-java-format} isn't present; however, for more + * explicit JSON setup on Protobuf 3, consider {@link ProtobufJsonFormatHttpMessageConverter}. + * + * @author Alex Antonov + * @author Brian Clozel + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 4.1 + * @see FormatFactory + * @see JsonFormat + * @see ProtobufJsonFormatHttpMessageConverter + */ +public class ProtobufHttpMessageConverter extends AbstractHttpMessageConverter { + + /** + * The default charset used by the converter. + */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + /** + * The media-type for protobuf {@code application/x-protobuf}. + */ + public static final MediaType PROTOBUF = new MediaType("application", "x-protobuf", DEFAULT_CHARSET); + + /** + * The HTTP header containing the protobuf schema. + */ + public static final String X_PROTOBUF_SCHEMA_HEADER = "X-Protobuf-Schema"; + + /** + * The HTTP header containing the protobuf message. + */ + public static final String X_PROTOBUF_MESSAGE_HEADER = "X-Protobuf-Message"; + + + private static final Map, Method> methodCache = new ConcurrentReferenceHashMap<>(); + + final ExtensionRegistry extensionRegistry; + + @Nullable + private final ProtobufFormatSupport protobufFormatSupport; + + + /** + * Construct a new {@code ProtobufHttpMessageConverter}. + */ + public ProtobufHttpMessageConverter() { + this(null, null); + } + + /** + * Construct a new {@code ProtobufHttpMessageConverter} with an + * initializer that allows the registration of message extensions. + * @param registryInitializer an initializer for message extensions + * @deprecated as of Spring Framework 5.1, use {@link #ProtobufHttpMessageConverter(ExtensionRegistry)} instead + */ + @Deprecated + public ProtobufHttpMessageConverter(@Nullable ExtensionRegistryInitializer registryInitializer) { + this(null, null); + if (registryInitializer != null) { + registryInitializer.initializeExtensionRegistry(this.extensionRegistry); + } + } + + /** + * Construct a new {@code ProtobufHttpMessageConverter} with a registry that specifies + * protocol message extensions. + * @param extensionRegistry the registry to populate + */ + public ProtobufHttpMessageConverter(ExtensionRegistry extensionRegistry) { + this(null, extensionRegistry); + } + + ProtobufHttpMessageConverter(@Nullable ProtobufFormatSupport formatSupport, + @Nullable ExtensionRegistry extensionRegistry) { + + if (formatSupport != null) { + this.protobufFormatSupport = formatSupport; + } + else if (ClassUtils.isPresent("com.googlecode.protobuf.format.FormatFactory", getClass().getClassLoader())) { + this.protobufFormatSupport = new ProtobufJavaFormatSupport(); + } + else if (ClassUtils.isPresent("com.google.protobuf.util.JsonFormat", getClass().getClassLoader())) { + this.protobufFormatSupport = new ProtobufJavaUtilSupport(null, null); + } + else { + this.protobufFormatSupport = null; + } + + setSupportedMediaTypes(Arrays.asList(this.protobufFormatSupport != null ? + this.protobufFormatSupport.supportedMediaTypes() : new MediaType[] {PROTOBUF, TEXT_PLAIN})); + + this.extensionRegistry = (extensionRegistry == null ? ExtensionRegistry.newInstance() : extensionRegistry); + } + + + @Override + protected boolean supports(Class clazz) { + return Message.class.isAssignableFrom(clazz); + } + + @Override + protected MediaType getDefaultContentType(Message message) { + return PROTOBUF; + } + + @Override + protected Message readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + MediaType contentType = inputMessage.getHeaders().getContentType(); + if (contentType == null) { + contentType = PROTOBUF; + } + Charset charset = contentType.getCharset(); + if (charset == null) { + charset = DEFAULT_CHARSET; + } + + Message.Builder builder = getMessageBuilder(clazz); + if (PROTOBUF.isCompatibleWith(contentType)) { + builder.mergeFrom(inputMessage.getBody(), this.extensionRegistry); + } + else if (TEXT_PLAIN.isCompatibleWith(contentType)) { + InputStreamReader reader = new InputStreamReader(inputMessage.getBody(), charset); + TextFormat.merge(reader, this.extensionRegistry, builder); + } + else if (this.protobufFormatSupport != null) { + this.protobufFormatSupport.merge( + inputMessage.getBody(), charset, contentType, this.extensionRegistry, builder); + } + return builder.build(); + } + + /** + * Create a new {@code Message.Builder} instance for the given class. + *

This method uses a ConcurrentReferenceHashMap for caching method lookups. + */ + private Message.Builder getMessageBuilder(Class clazz) { + try { + Method method = methodCache.get(clazz); + if (method == null) { + method = clazz.getMethod("newBuilder"); + methodCache.put(clazz, method); + } + return (Message.Builder) method.invoke(clazz); + } + catch (Exception ex) { + throw new HttpMessageConversionException( + "Invalid Protobuf Message type: no invocable newBuilder() method on " + clazz, ex); + } + } + + + @Override + protected boolean canWrite(@Nullable MediaType mediaType) { + return (super.canWrite(mediaType) || + (this.protobufFormatSupport != null && this.protobufFormatSupport.supportsWriteOnly(mediaType))); + } + + @Override + protected void writeInternal(Message message, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + MediaType contentType = outputMessage.getHeaders().getContentType(); + if (contentType == null) { + contentType = getDefaultContentType(message); + Assert.state(contentType != null, "No content type"); + } + Charset charset = contentType.getCharset(); + if (charset == null) { + charset = DEFAULT_CHARSET; + } + + if (PROTOBUF.isCompatibleWith(contentType)) { + setProtoHeader(outputMessage, message); + CodedOutputStream codedOutputStream = CodedOutputStream.newInstance(outputMessage.getBody()); + message.writeTo(codedOutputStream); + codedOutputStream.flush(); + } + else if (TEXT_PLAIN.isCompatibleWith(contentType)) { + OutputStreamWriter outputStreamWriter = new OutputStreamWriter(outputMessage.getBody(), charset); + TextFormat.print(message, outputStreamWriter); + outputStreamWriter.flush(); + outputMessage.getBody().flush(); + } + else if (this.protobufFormatSupport != null) { + this.protobufFormatSupport.print(message, outputMessage.getBody(), contentType, charset); + outputMessage.getBody().flush(); + } + } + + /** + * Set the "X-Protobuf-*" HTTP headers when responding with a message of + * content type "application/x-protobuf" + *

Note: outputMessage.getBody() should not have been called + * before because it writes HTTP headers (making them read only).

+ */ + private void setProtoHeader(HttpOutputMessage response, Message message) { + response.getHeaders().set(X_PROTOBUF_SCHEMA_HEADER, message.getDescriptorForType().getFile().getName()); + response.getHeaders().set(X_PROTOBUF_MESSAGE_HEADER, message.getDescriptorForType().getFullName()); + } + + + /** + * Protobuf format support. + */ + interface ProtobufFormatSupport { + + MediaType[] supportedMediaTypes(); + + boolean supportsWriteOnly(@Nullable MediaType mediaType); + + void merge(InputStream input, Charset charset, MediaType contentType, + ExtensionRegistry extensionRegistry, Message.Builder builder) + throws IOException, HttpMessageConversionException; + + void print(Message message, OutputStream output, MediaType contentType, Charset charset) + throws IOException, HttpMessageConversionException; + } + + + /** + * {@link ProtobufFormatSupport} implementation used when + * {@code com.googlecode.protobuf.format.FormatFactory} is available. + */ + static class ProtobufJavaFormatSupport implements ProtobufFormatSupport { + + private final ProtobufFormatter jsonFormatter; + + private final ProtobufFormatter xmlFormatter; + + private final ProtobufFormatter htmlFormatter; + + public ProtobufJavaFormatSupport() { + FormatFactory formatFactory = new FormatFactory(); + this.jsonFormatter = formatFactory.createFormatter(FormatFactory.Formatter.JSON); + this.xmlFormatter = formatFactory.createFormatter(FormatFactory.Formatter.XML); + this.htmlFormatter = formatFactory.createFormatter(FormatFactory.Formatter.HTML); + } + + @Override + public MediaType[] supportedMediaTypes() { + return new MediaType[] {PROTOBUF, TEXT_PLAIN, APPLICATION_XML, APPLICATION_JSON}; + } + + @Override + public boolean supportsWriteOnly(@Nullable MediaType mediaType) { + return TEXT_HTML.isCompatibleWith(mediaType); + } + + @Override + public void merge(InputStream input, Charset charset, MediaType contentType, + ExtensionRegistry extensionRegistry, Message.Builder builder) + throws IOException, HttpMessageConversionException { + + if (contentType.isCompatibleWith(APPLICATION_JSON)) { + this.jsonFormatter.merge(input, charset, extensionRegistry, builder); + } + else if (contentType.isCompatibleWith(APPLICATION_XML)) { + this.xmlFormatter.merge(input, charset, extensionRegistry, builder); + } + else { + throw new HttpMessageConversionException( + "protobuf-java-format does not support parsing " + contentType); + } + } + + @Override + public void print(Message message, OutputStream output, MediaType contentType, Charset charset) + throws IOException, HttpMessageConversionException { + + if (contentType.isCompatibleWith(APPLICATION_JSON)) { + this.jsonFormatter.print(message, output, charset); + } + else if (contentType.isCompatibleWith(APPLICATION_XML)) { + this.xmlFormatter.print(message, output, charset); + } + else if (contentType.isCompatibleWith(TEXT_HTML)) { + this.htmlFormatter.print(message, output, charset); + } + else { + throw new HttpMessageConversionException( + "protobuf-java-format does not support printing " + contentType); + } + } + } + + + /** + * {@link ProtobufFormatSupport} implementation used when + * {@code com.google.protobuf.util.JsonFormat} is available. + */ + static class ProtobufJavaUtilSupport implements ProtobufFormatSupport { + + private final JsonFormat.Parser parser; + + private final JsonFormat.Printer printer; + + public ProtobufJavaUtilSupport(@Nullable JsonFormat.Parser parser, @Nullable JsonFormat.Printer printer) { + this.parser = (parser != null ? parser : JsonFormat.parser()); + this.printer = (printer != null ? printer : JsonFormat.printer()); + } + + @Override + public MediaType[] supportedMediaTypes() { + return new MediaType[] {PROTOBUF, TEXT_PLAIN, APPLICATION_JSON}; + } + + @Override + public boolean supportsWriteOnly(@Nullable MediaType mediaType) { + return false; + } + + @Override + public void merge(InputStream input, Charset charset, MediaType contentType, + ExtensionRegistry extensionRegistry, Message.Builder builder) + throws IOException, HttpMessageConversionException { + + if (contentType.isCompatibleWith(APPLICATION_JSON)) { + InputStreamReader reader = new InputStreamReader(input, charset); + this.parser.merge(reader, builder); + } + else { + throw new HttpMessageConversionException( + "protobuf-java-util does not support parsing " + contentType); + } + } + + @Override + public void print(Message message, OutputStream output, MediaType contentType, Charset charset) + throws IOException, HttpMessageConversionException { + + if (contentType.isCompatibleWith(APPLICATION_JSON)) { + OutputStreamWriter writer = new OutputStreamWriter(output, charset); + this.printer.appendTo(message, writer); + writer.flush(); + } + else { + throw new HttpMessageConversionException( + "protobuf-java-util does not support printing " + contentType); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufJsonFormatHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufJsonFormatHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..47bc2ad91a5d9df2ae0e93e59b06ba11b18d9b2b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/protobuf/ProtobufJsonFormatHttpMessageConverter.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.protobuf; + +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.util.JsonFormat; + +import org.springframework.lang.Nullable; + +/** + * Subclass of {@link ProtobufHttpMessageConverter} which enforces the use of Protobuf 3 and + * its official library {@code "com.google.protobuf:protobuf-java-util"} for JSON processing. + * + *

Most importantly, this class allows for custom JSON parser and printer configurations + * through the {@link JsonFormat} utility. If no special parser or printer configuration is + * given, default variants will be used instead. + * + *

Requires Protobuf 3.x and {@code "com.google.protobuf:protobuf-java-util"} 3.x, + * with 3.3 or higher recommended. + * + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 5.0 + * @see JsonFormat#parser() + * @see JsonFormat#printer() + * @see #ProtobufJsonFormatHttpMessageConverter(JsonFormat.Parser, JsonFormat.Printer) + */ +public class ProtobufJsonFormatHttpMessageConverter extends ProtobufHttpMessageConverter { + + /** + * Construct a new {@code ProtobufJsonFormatHttpMessageConverter} with default + * {@code JsonFormat.Parser} and {@code JsonFormat.Printer} configuration. + */ + public ProtobufJsonFormatHttpMessageConverter() { + this(null, null, (ExtensionRegistry)null); + } + + /** + * Construct a new {@code ProtobufJsonFormatHttpMessageConverter} with the given + * {@code JsonFormat.Parser} and {@code JsonFormat.Printer} configuration. + * @param parser the JSON parser configuration + * @param printer the JSON printer configuration + */ + public ProtobufJsonFormatHttpMessageConverter( + @Nullable JsonFormat.Parser parser, @Nullable JsonFormat.Printer printer) { + + this(parser, printer, (ExtensionRegistry)null); + } + + /** + * Construct a new {@code ProtobufJsonFormatHttpMessageConverter} with the given + * {@code JsonFormat.Parser} and {@code JsonFormat.Printer} configuration, also + * accepting a registry that specifies protocol message extensions. + * @param parser the JSON parser configuration + * @param printer the JSON printer configuration + * @param extensionRegistry the registry to populate + * @since 5.1 + */ + public ProtobufJsonFormatHttpMessageConverter(@Nullable JsonFormat.Parser parser, + @Nullable JsonFormat.Printer printer, @Nullable ExtensionRegistry extensionRegistry) { + + super(new ProtobufJavaUtilSupport(parser, printer), extensionRegistry); + } + + /** + * Construct a new {@code ProtobufJsonFormatHttpMessageConverter} with the given + * {@code JsonFormat.Parser} and {@code JsonFormat.Printer} configuration, also + * accepting an initializer that allows the registration of message extensions. + * @param parser the JSON parser configuration + * @param printer the JSON printer configuration + * @param registryInitializer an initializer for message extensions + * @deprecated as of 5.1, in favor of + * {@link #ProtobufJsonFormatHttpMessageConverter(JsonFormat.Parser, JsonFormat.Printer, ExtensionRegistry)} + */ + @Deprecated + public ProtobufJsonFormatHttpMessageConverter(@Nullable JsonFormat.Parser parser, + @Nullable JsonFormat.Printer printer, @Nullable ExtensionRegistryInitializer registryInitializer) { + + super(new ProtobufJavaUtilSupport(parser, printer), null); + if (registryInitializer != null) { + registryInitializer.initializeExtensionRegistry(this.extensionRegistry); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/protobuf/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/protobuf/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..313e9089c38263ee21d4a8b2cfedacf78c1ac4e3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/protobuf/package-info.java @@ -0,0 +1,10 @@ +/** + * Provides an HttpMessageConverter implementation for handling + * Google Protocol Buffers. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter.protobuf; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/smile/MappingJackson2SmileHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/smile/MappingJackson2SmileHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..9eead5a3e188c6608774a05c29a6a229ebaabe9b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/smile/MappingJackson2SmileHttpMessageConverter.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.smile; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; + +import org.springframework.http.MediaType; +import org.springframework.http.converter.json.AbstractJackson2HttpMessageConverter; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.util.Assert; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter HttpMessageConverter} + * that can read and write Smile data format ("binary JSON") using + * + * the dedicated Jackson 2.x extension. + * + *

By default, this converter supports {@code "application/x-jackson-smile"} media type. + * This can be overridden by setting the {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + *

The default constructor uses the default configuration provided by {@link Jackson2ObjectMapperBuilder}. + * + *

Compatible with Jackson 2.9 and higher. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public class MappingJackson2SmileHttpMessageConverter extends AbstractJackson2HttpMessageConverter { + + /** + * Construct a new {@code MappingJackson2SmileHttpMessageConverter} using default configuration + * provided by {@code Jackson2ObjectMapperBuilder}. + */ + public MappingJackson2SmileHttpMessageConverter() { + this(Jackson2ObjectMapperBuilder.smile().build()); + } + + /** + * Construct a new {@code MappingJackson2SmileHttpMessageConverter} with a custom {@link ObjectMapper} + * (must be configured with a {@code SmileFactory} instance). + * You can use {@link Jackson2ObjectMapperBuilder} to build it easily. + * @see Jackson2ObjectMapperBuilder#smile() + */ + public MappingJackson2SmileHttpMessageConverter(ObjectMapper objectMapper) { + super(objectMapper, new MediaType("application", "x-jackson-smile")); + Assert.isInstanceOf(SmileFactory.class, objectMapper.getFactory(), "SmileFactory required"); + } + + + /** + * {@inheritDoc} + * The {@code ObjectMapper} must be configured with a {@code SmileFactory} instance. + */ + @Override + public void setObjectMapper(ObjectMapper objectMapper) { + Assert.isInstanceOf(SmileFactory.class, objectMapper.getFactory(), "SmileFactory required"); + super.setObjectMapper(objectMapper); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/smile/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/smile/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..fee55e553be727689b6e69ee6fa2589c358295cb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/smile/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides an HttpMessageConverter for the Smile data format ("binary JSON"). + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter.smile; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/support/AllEncompassingFormHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/support/AllEncompassingFormHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..ee7d13f2e10051e292c5c05bb7fc9659bfe73fff --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/support/AllEncompassingFormHttpMessageConverter.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.support; + +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.json.GsonHttpMessageConverter; +import org.springframework.http.converter.json.JsonbHttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.http.converter.smile.MappingJackson2SmileHttpMessageConverter; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; +import org.springframework.http.converter.xml.MappingJackson2XmlHttpMessageConverter; +import org.springframework.http.converter.xml.SourceHttpMessageConverter; +import org.springframework.util.ClassUtils; + +/** + * Extension of {@link org.springframework.http.converter.FormHttpMessageConverter}, + * adding support for XML and JSON-based parts. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.2 + */ +public class AllEncompassingFormHttpMessageConverter extends FormHttpMessageConverter { + + private static final boolean jaxb2Present; + + private static final boolean jackson2Present; + + private static final boolean jackson2XmlPresent; + + private static final boolean jackson2SmilePresent; + + private static final boolean gsonPresent; + + private static final boolean jsonbPresent; + + static { + ClassLoader classLoader = AllEncompassingFormHttpMessageConverter.class.getClassLoader(); + jaxb2Present = ClassUtils.isPresent("javax.xml.bind.Binder", classLoader); + jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader) && + ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader); + jackson2XmlPresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.xml.XmlMapper", classLoader); + jackson2SmilePresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.smile.SmileFactory", classLoader); + gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", classLoader); + jsonbPresent = ClassUtils.isPresent("javax.json.bind.Jsonb", classLoader); + } + + + public AllEncompassingFormHttpMessageConverter() { + try { + addPartConverter(new SourceHttpMessageConverter<>()); + } + catch (Error err) { + // Ignore when no TransformerFactory implementation is available + } + + if (jaxb2Present && !jackson2XmlPresent) { + addPartConverter(new Jaxb2RootElementHttpMessageConverter()); + } + + if (jackson2Present) { + addPartConverter(new MappingJackson2HttpMessageConverter()); + } + else if (gsonPresent) { + addPartConverter(new GsonHttpMessageConverter()); + } + else if (jsonbPresent) { + addPartConverter(new JsonbHttpMessageConverter()); + } + + if (jackson2XmlPresent) { + addPartConverter(new MappingJackson2XmlHttpMessageConverter()); + } + + if (jackson2SmilePresent) { + addPartConverter(new MappingJackson2SmileHttpMessageConverter()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/support/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..4dbf1c570a43eed56b64220aca9978ed8b61f78d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/support/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides a comprehensive HttpMessageConverter variant for form handling. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/AbstractJaxb2HttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/AbstractJaxb2HttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..811536c316576837229bc5b05e9b322ca682c436 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/AbstractJaxb2HttpMessageConverter.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import javax.xml.bind.JAXBContext; +import javax.xml.bind.JAXBException; +import javax.xml.bind.Marshaller; +import javax.xml.bind.Unmarshaller; + +import org.springframework.http.converter.HttpMessageConversionException; + +/** + * Abstract base class for {@link org.springframework.http.converter.HttpMessageConverter HttpMessageConverters} + * that use JAXB2. Creates {@link JAXBContext} object lazily. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.0 + * @param the converted object type + */ +public abstract class AbstractJaxb2HttpMessageConverter extends AbstractXmlHttpMessageConverter { + + private final ConcurrentMap, JAXBContext> jaxbContexts = new ConcurrentHashMap<>(64); + + + /** + * Create a new {@link Marshaller} for the given class. + * @param clazz the class to create the marshaller for + * @return the {@code Marshaller} + * @throws HttpMessageConversionException in case of JAXB errors + */ + protected final Marshaller createMarshaller(Class clazz) { + try { + JAXBContext jaxbContext = getJaxbContext(clazz); + Marshaller marshaller = jaxbContext.createMarshaller(); + customizeMarshaller(marshaller); + return marshaller; + } + catch (JAXBException ex) { + throw new HttpMessageConversionException( + "Could not create Marshaller for class [" + clazz + "]: " + ex.getMessage(), ex); + } + } + + /** + * Customize the {@link Marshaller} created by this + * message converter before using it to write the object to the output. + * @param marshaller the marshaller to customize + * @since 4.0.3 + * @see #createMarshaller(Class) + */ + protected void customizeMarshaller(Marshaller marshaller) { + } + + /** + * Create a new {@link Unmarshaller} for the given class. + * @param clazz the class to create the unmarshaller for + * @return the {@code Unmarshaller} + * @throws HttpMessageConversionException in case of JAXB errors + */ + protected final Unmarshaller createUnmarshaller(Class clazz) { + try { + JAXBContext jaxbContext = getJaxbContext(clazz); + Unmarshaller unmarshaller = jaxbContext.createUnmarshaller(); + customizeUnmarshaller(unmarshaller); + return unmarshaller; + } + catch (JAXBException ex) { + throw new HttpMessageConversionException( + "Could not create Unmarshaller for class [" + clazz + "]: " + ex.getMessage(), ex); + } + } + + /** + * Customize the {@link Unmarshaller} created by this + * message converter before using it to read the object from the input. + * @param unmarshaller the unmarshaller to customize + * @since 4.0.3 + * @see #createUnmarshaller(Class) + */ + protected void customizeUnmarshaller(Unmarshaller unmarshaller) { + } + + /** + * Return a {@link JAXBContext} for the given class. + * @param clazz the class to return the context for + * @return the {@code JAXBContext} + * @throws HttpMessageConversionException in case of JAXB errors + */ + protected final JAXBContext getJaxbContext(Class clazz) { + return this.jaxbContexts.computeIfAbsent(clazz, key -> { + try { + return JAXBContext.newInstance(clazz); + } + catch (JAXBException ex) { + throw new HttpMessageConversionException( + "Could not create JAXBContext for class [" + clazz + "]: " + ex.getMessage(), ex); + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/AbstractXmlHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/AbstractXmlHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..9da6bcb42c11bc3cceb2f82236f00b6fef165313 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/AbstractXmlHttpMessageConverter.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.io.IOException; + +import javax.xml.transform.Result; +import javax.xml.transform.Source; +import javax.xml.transform.TransformerException; +import javax.xml.transform.TransformerFactory; +import javax.xml.transform.stream.StreamResult; +import javax.xml.transform.stream.StreamSource; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; + +/** + * Abstract base class for {@link org.springframework.http.converter.HttpMessageConverter HttpMessageConverters} + * that convert from/to XML. + * + *

By default, subclasses of this converter support {@code text/xml}, {@code application/xml}, and {@code + * application/*-xml}. This can be overridden by setting the {@link #setSupportedMediaTypes(java.util.List) + * supportedMediaTypes} property. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @param the converted object type + */ +public abstract class AbstractXmlHttpMessageConverter extends AbstractHttpMessageConverter { + + private final TransformerFactory transformerFactory = TransformerFactory.newInstance(); + + + /** + * Protected constructor that sets the {@link #setSupportedMediaTypes(java.util.List) supportedMediaTypes} + * to {@code text/xml} and {@code application/xml}, and {@code application/*-xml}. + */ + protected AbstractXmlHttpMessageConverter() { + super(MediaType.APPLICATION_XML, MediaType.TEXT_XML, new MediaType("application", "*+xml")); + } + + + @Override + public final T readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + try { + return readFromSource(clazz, inputMessage.getHeaders(), new StreamSource(inputMessage.getBody())); + } + catch (IOException | HttpMessageConversionException ex) { + throw ex; + } + catch (Exception ex) { + throw new HttpMessageNotReadableException("Could not unmarshal to [" + clazz + "]: " + ex.getMessage(), + ex, inputMessage); + } + } + + @Override + protected final void writeInternal(T t, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + try { + writeToResult(t, outputMessage.getHeaders(), new StreamResult(outputMessage.getBody())); + } + catch (IOException | HttpMessageConversionException ex) { + throw ex; + } + catch (Exception ex) { + throw new HttpMessageNotWritableException("Could not marshal [" + t + "]: " + ex.getMessage(), ex); + } + } + + /** + * Transforms the given {@code Source} to the {@code Result}. + * @param source the source to transform from + * @param result the result to transform to + * @throws TransformerException in case of transformation errors + */ + protected void transform(Source source, Result result) throws TransformerException { + this.transformerFactory.newTransformer().transform(source, result); + } + + + /** + * Abstract template method called from {@link #read(Class, HttpInputMessage)}. + * @param clazz the type of object to return + * @param headers the HTTP input headers + * @param source the HTTP input body + * @return the converted object + * @throws Exception in case of I/O or conversion errors + */ + protected abstract T readFromSource(Class clazz, HttpHeaders headers, Source source) throws Exception; + + /** + * Abstract template method called from {@link #writeInternal(Object, HttpOutputMessage)}. + * @param t the object to write to the output message + * @param headers the HTTP output headers + * @param result the HTTP output body + * @throws Exception in case of I/O or conversion errors + */ + protected abstract void writeToResult(T t, HttpHeaders headers, Result result) throws Exception; + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2CollectionHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2CollectionHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..ab3850c42bf7e8eb8cd7ad71288758aaa8e220ee --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2CollectionHttpMessageConverter.java @@ -0,0 +1,259 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.io.IOException; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.SortedSet; +import java.util.TreeSet; + +import javax.xml.bind.JAXBException; +import javax.xml.bind.UnmarshalException; +import javax.xml.bind.Unmarshaller; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlType; +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Result; +import javax.xml.transform.Source; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.lang.Nullable; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.xml.StaxUtils; + +/** + * An {@code HttpMessageConverter} that can read XML collections using JAXB2. + * + *

This converter can read {@linkplain Collection collections} that contain classes + * annotated with {@link XmlRootElement} and {@link XmlType}. Note that this converter + * does not support writing. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.2 + * @param the converted object type + */ +@SuppressWarnings("rawtypes") +public class Jaxb2CollectionHttpMessageConverter + extends AbstractJaxb2HttpMessageConverter implements GenericHttpMessageConverter { + + private final XMLInputFactory inputFactory = createXmlInputFactory(); + + + /** + * Always returns {@code false} since Jaxb2CollectionHttpMessageConverter + * required generic type information in order to read a Collection. + */ + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return false; + } + + /** + * {@inheritDoc} + *

Jaxb2CollectionHttpMessageConverter can read a generic + * {@link Collection} where the generic type is a JAXB type annotated with + * {@link XmlRootElement} or {@link XmlType}. + */ + @Override + public boolean canRead(Type type, @Nullable Class contextClass, @Nullable MediaType mediaType) { + if (!(type instanceof ParameterizedType)) { + return false; + } + ParameterizedType parameterizedType = (ParameterizedType) type; + if (!(parameterizedType.getRawType() instanceof Class)) { + return false; + } + Class rawType = (Class) parameterizedType.getRawType(); + if (!(Collection.class.isAssignableFrom(rawType))) { + return false; + } + if (parameterizedType.getActualTypeArguments().length != 1) { + return false; + } + Type typeArgument = parameterizedType.getActualTypeArguments()[0]; + if (!(typeArgument instanceof Class)) { + return false; + } + Class typeArgumentClass = (Class) typeArgument; + return (typeArgumentClass.isAnnotationPresent(XmlRootElement.class) || + typeArgumentClass.isAnnotationPresent(XmlType.class)) && canRead(mediaType); + } + + /** + * Always returns {@code false} since Jaxb2CollectionHttpMessageConverter + * does not convert collections to XML. + */ + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return false; + } + + /** + * Always returns {@code false} since Jaxb2CollectionHttpMessageConverter + * does not convert collections to XML. + */ + @Override + public boolean canWrite(@Nullable Type type, @Nullable Class clazz, @Nullable MediaType mediaType) { + return false; + } + + @Override + protected boolean supports(Class clazz) { + // should not be called, since we override canRead/Write + throw new UnsupportedOperationException(); + } + + @Override + protected T readFromSource(Class clazz, HttpHeaders headers, Source source) throws Exception { + // should not be called, since we return false for canRead(Class) + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings("unchecked") + public T read(Type type, @Nullable Class contextClass, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + ParameterizedType parameterizedType = (ParameterizedType) type; + T result = createCollection((Class) parameterizedType.getRawType()); + Class elementClass = (Class) parameterizedType.getActualTypeArguments()[0]; + + try { + Unmarshaller unmarshaller = createUnmarshaller(elementClass); + XMLStreamReader streamReader = this.inputFactory.createXMLStreamReader(inputMessage.getBody()); + int event = moveToFirstChildOfRootElement(streamReader); + + while (event != XMLStreamReader.END_DOCUMENT) { + if (elementClass.isAnnotationPresent(XmlRootElement.class)) { + result.add(unmarshaller.unmarshal(streamReader)); + } + else if (elementClass.isAnnotationPresent(XmlType.class)) { + result.add(unmarshaller.unmarshal(streamReader, elementClass).getValue()); + } + else { + // should not happen, since we check in canRead(Type) + throw new HttpMessageNotReadableException( + "Cannot unmarshal to [" + elementClass + "]", inputMessage); + } + event = moveToNextElement(streamReader); + } + return result; + } + catch (XMLStreamException ex) { + throw new HttpMessageNotReadableException( + "Failed to read XML stream: " + ex.getMessage(), ex, inputMessage); + } + catch (UnmarshalException ex) { + throw new HttpMessageNotReadableException( + "Could not unmarshal to [" + elementClass + "]: " + ex.getMessage(), ex, inputMessage); + } + catch (JAXBException ex) { + throw new HttpMessageConversionException("Invalid JAXB setup: " + ex.getMessage(), ex); + } + } + + /** + * Create a Collection of the given type, with the given initial capacity + * (if supported by the Collection type). + * @param collectionClass the type of Collection to instantiate + * @return the created Collection instance + */ + @SuppressWarnings("unchecked") + protected T createCollection(Class collectionClass) { + if (!collectionClass.isInterface()) { + try { + return (T) ReflectionUtils.accessibleConstructor(collectionClass).newInstance(); + } + catch (Throwable ex) { + throw new IllegalArgumentException( + "Could not instantiate collection class: " + collectionClass.getName(), ex); + } + } + else if (List.class == collectionClass) { + return (T) new ArrayList(); + } + else if (SortedSet.class == collectionClass) { + return (T) new TreeSet(); + } + else { + return (T) new LinkedHashSet(); + } + } + + private int moveToFirstChildOfRootElement(XMLStreamReader streamReader) throws XMLStreamException { + // root + int event = streamReader.next(); + while (event != XMLStreamReader.START_ELEMENT) { + event = streamReader.next(); + } + + // first child + event = streamReader.next(); + while ((event != XMLStreamReader.START_ELEMENT) && (event != XMLStreamReader.END_DOCUMENT)) { + event = streamReader.next(); + } + return event; + } + + private int moveToNextElement(XMLStreamReader streamReader) throws XMLStreamException { + int event = streamReader.getEventType(); + while (event != XMLStreamReader.START_ELEMENT && event != XMLStreamReader.END_DOCUMENT) { + event = streamReader.next(); + } + return event; + } + + @Override + public void write(T t, @Nullable Type type, @Nullable MediaType contentType, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + throw new UnsupportedOperationException(); + } + + @Override + protected void writeToResult(T t, HttpHeaders headers, Result result) throws Exception { + throw new UnsupportedOperationException(); + } + + /** + * Create an {@code XMLInputFactory} that this converter will use to create + * {@link javax.xml.stream.XMLStreamReader} and {@link javax.xml.stream.XMLEventReader} + * objects. + *

Can be overridden in subclasses, adding further initialization of the factory. + * The resulting factory is cached, so this method will only be called once. + * @see StaxUtils#createDefensiveInputFactory() + */ + protected XMLInputFactory createXmlInputFactory() { + return StaxUtils.createDefensiveInputFactory(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..ae4dbd9b69996ae4f195a19a12f6e2f3dccae65a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java @@ -0,0 +1,203 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.io.StringReader; + +import javax.xml.bind.JAXBElement; +import javax.xml.bind.JAXBException; +import javax.xml.bind.MarshalException; +import javax.xml.bind.Marshaller; +import javax.xml.bind.PropertyException; +import javax.xml.bind.UnmarshalException; +import javax.xml.bind.Unmarshaller; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlType; +import javax.xml.transform.Result; +import javax.xml.transform.Source; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stream.StreamSource; + +import org.xml.sax.EntityResolver; +import org.xml.sax.InputSource; +import org.xml.sax.SAXException; +import org.xml.sax.XMLReader; + +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter + * HttpMessageConverter} that can read and write XML using JAXB2. + * + *

This converter can read classes annotated with {@link XmlRootElement} and + * {@link XmlType}, and write classes annotated with {@link XmlRootElement}, + * or subclasses thereof. + * + *

Note: When using Spring's Marshaller/Unmarshaller abstractions from {@code spring-oxm}, + * you should use the {@link MarshallingHttpMessageConverter} instead. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 3.0 + * @see MarshallingHttpMessageConverter + */ +public class Jaxb2RootElementHttpMessageConverter extends AbstractJaxb2HttpMessageConverter { + + private boolean supportDtd = false; + + private boolean processExternalEntities = false; + + + /** + * Indicate whether DTD parsing should be supported. + *

Default is {@code false} meaning that DTD is disabled. + */ + public void setSupportDtd(boolean supportDtd) { + this.supportDtd = supportDtd; + } + + /** + * Return whether DTD parsing is supported. + */ + public boolean isSupportDtd() { + return this.supportDtd; + } + + /** + * Indicate whether external XML entities are processed when converting to a Source. + *

Default is {@code false}, meaning that external entities are not resolved. + *

Note: setting this option to {@code true} also + * automatically sets {@link #setSupportDtd} to {@code true}. + */ + public void setProcessExternalEntities(boolean processExternalEntities) { + this.processExternalEntities = processExternalEntities; + if (processExternalEntities) { + this.supportDtd = true; + } + } + + /** + * Return whether XML external entities are allowed. + */ + public boolean isProcessExternalEntities() { + return this.processExternalEntities; + } + + + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return (clazz.isAnnotationPresent(XmlRootElement.class) || clazz.isAnnotationPresent(XmlType.class)) && + canRead(mediaType); + } + + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return (AnnotationUtils.findAnnotation(clazz, XmlRootElement.class) != null && canWrite(mediaType)); + } + + @Override + protected boolean supports(Class clazz) { + // should not be called, since we override canRead/Write + throw new UnsupportedOperationException(); + } + + @Override + protected Object readFromSource(Class clazz, HttpHeaders headers, Source source) throws Exception { + try { + source = processSource(source); + Unmarshaller unmarshaller = createUnmarshaller(clazz); + if (clazz.isAnnotationPresent(XmlRootElement.class)) { + return unmarshaller.unmarshal(source); + } + else { + JAXBElement jaxbElement = unmarshaller.unmarshal(source, clazz); + return jaxbElement.getValue(); + } + } + catch (NullPointerException ex) { + if (!isSupportDtd()) { + throw new IllegalStateException("NPE while unmarshalling. " + + "This can happen due to the presence of DTD declarations which are disabled.", ex); + } + throw ex; + } + catch (UnmarshalException ex) { + throw ex; + } + catch (JAXBException ex) { + throw new HttpMessageConversionException("Invalid JAXB setup: " + ex.getMessage(), ex); + } + } + + @SuppressWarnings("deprecation") // on JDK 9 + protected Source processSource(Source source) { + if (source instanceof StreamSource) { + StreamSource streamSource = (StreamSource) source; + InputSource inputSource = new InputSource(streamSource.getInputStream()); + try { + XMLReader xmlReader = org.xml.sax.helpers.XMLReaderFactory.createXMLReader(); + xmlReader.setFeature("http://apache.org/xml/features/disallow-doctype-decl", !isSupportDtd()); + String featureName = "http://xml.org/sax/features/external-general-entities"; + xmlReader.setFeature(featureName, isProcessExternalEntities()); + if (!isProcessExternalEntities()) { + xmlReader.setEntityResolver(NO_OP_ENTITY_RESOLVER); + } + return new SAXSource(xmlReader, inputSource); + } + catch (SAXException ex) { + logger.warn("Processing of external entities could not be disabled", ex); + return source; + } + } + else { + return source; + } + } + + @Override + protected void writeToResult(Object o, HttpHeaders headers, Result result) throws Exception { + try { + Class clazz = ClassUtils.getUserClass(o); + Marshaller marshaller = createMarshaller(clazz); + setCharset(headers.getContentType(), marshaller); + marshaller.marshal(o, result); + } + catch (MarshalException ex) { + throw ex; + } + catch (JAXBException ex) { + throw new HttpMessageConversionException("Invalid JAXB setup: " + ex.getMessage(), ex); + } + } + + private void setCharset(@Nullable MediaType contentType, Marshaller marshaller) throws PropertyException { + if (contentType != null && contentType.getCharset() != null) { + marshaller.setProperty(Marshaller.JAXB_ENCODING, contentType.getCharset().name()); + } + } + + + private static final EntityResolver NO_OP_ENTITY_RESOLVER = + (publicId, systemId) -> new InputSource(new StringReader("")); + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/MappingJackson2XmlHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/MappingJackson2XmlHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..40bfc169fb2b9b9b792f019d183eb3da89d13d3f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/MappingJackson2XmlHttpMessageConverter.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.xml.XmlMapper; + +import org.springframework.http.MediaType; +import org.springframework.http.converter.json.AbstractJackson2HttpMessageConverter; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.util.Assert; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter HttpMessageConverter} + * that can read and write XML using + * Jackson 2.x extension component for reading and writing XML encoded data. + * + *

By default, this converter supports {@code application/xml}, {@code text/xml}, and + * {@code application/*+xml} with {@code UTF-8} character set. This can be overridden by + * setting the {@link #setSupportedMediaTypes supportedMediaTypes} property. + * + *

The default constructor uses the default configuration provided by {@link Jackson2ObjectMapperBuilder}. + * + *

Compatible with Jackson 2.9 and higher, as of Spring 5.0. + * + * @author Sebastien Deleuze + * @since 4.1 + */ +public class MappingJackson2XmlHttpMessageConverter extends AbstractJackson2HttpMessageConverter { + + /** + * Construct a new {@code MappingJackson2XmlHttpMessageConverter} using default configuration + * provided by {@code Jackson2ObjectMapperBuilder}. + */ + public MappingJackson2XmlHttpMessageConverter() { + this(Jackson2ObjectMapperBuilder.xml().build()); + } + + /** + * Construct a new {@code MappingJackson2XmlHttpMessageConverter} with a custom {@link ObjectMapper} + * (must be a {@link XmlMapper} instance). + * You can use {@link Jackson2ObjectMapperBuilder} to build it easily. + * @see Jackson2ObjectMapperBuilder#xml() + */ + public MappingJackson2XmlHttpMessageConverter(ObjectMapper objectMapper) { + super(objectMapper, new MediaType("application", "xml"), + new MediaType("text", "xml"), + new MediaType("application", "*+xml")); + Assert.isInstanceOf(XmlMapper.class, objectMapper, "XmlMapper required"); + } + + + /** + * {@inheritDoc} + * The {@code ObjectMapper} parameter must be a {@link XmlMapper} instance. + */ + @Override + public void setObjectMapper(ObjectMapper objectMapper) { + Assert.isInstanceOf(XmlMapper.class, objectMapper, "XmlMapper required"); + super.setObjectMapper(objectMapper); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..b06ddc8e3105082d3a2cb5e70f70bc0570023965 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import javax.xml.transform.Result; +import javax.xml.transform.Source; + +import org.springframework.beans.TypeMismatchException; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.oxm.Marshaller; +import org.springframework.oxm.Unmarshaller; +import org.springframework.util.Assert; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter HttpMessageConverter} + * that can read and write XML using Spring's {@link Marshaller} and {@link Unmarshaller} abstractions. + * + *

This converter requires a {@code Marshaller} and {@code Unmarshaller} before it can be used. + * These can be injected by the {@linkplain #MarshallingHttpMessageConverter(Marshaller) constructor} + * or {@linkplain #setMarshaller(Marshaller) bean properties}. + * + *

By default, this converter supports {@code text/xml} and {@code application/xml}. This can be + * overridden by setting the {@link #setSupportedMediaTypes(java.util.List) supportedMediaTypes} property. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public class MarshallingHttpMessageConverter extends AbstractXmlHttpMessageConverter { + + @Nullable + private Marshaller marshaller; + + @Nullable + private Unmarshaller unmarshaller; + + + /** + * Construct a new {@code MarshallingHttpMessageConverter} with no {@link Marshaller} or + * {@link Unmarshaller} set. The Marshaller and Unmarshaller must be set after construction + * by invoking {@link #setMarshaller(Marshaller)} and {@link #setUnmarshaller(Unmarshaller)}. + */ + public MarshallingHttpMessageConverter() { + } + + /** + * Construct a new {@code MarshallingMessageConverter} with the given {@link Marshaller} set. + *

If the given {@link Marshaller} also implements the {@link Unmarshaller} interface, + * it is used for both marshalling and unmarshalling. Otherwise, an exception is thrown. + *

Note that all {@code Marshaller} implementations in Spring also implement the + * {@code Unmarshaller} interface, so that you can safely use this constructor. + * @param marshaller object used as marshaller and unmarshaller + */ + public MarshallingHttpMessageConverter(Marshaller marshaller) { + Assert.notNull(marshaller, "Marshaller must not be null"); + this.marshaller = marshaller; + if (marshaller instanceof Unmarshaller) { + this.unmarshaller = (Unmarshaller) marshaller; + } + } + + /** + * Construct a new {@code MarshallingMessageConverter} with the given + * {@code Marshaller} and {@code Unmarshaller}. + * @param marshaller the Marshaller to use + * @param unmarshaller the Unmarshaller to use + */ + public MarshallingHttpMessageConverter(Marshaller marshaller, Unmarshaller unmarshaller) { + Assert.notNull(marshaller, "Marshaller must not be null"); + Assert.notNull(unmarshaller, "Unmarshaller must not be null"); + this.marshaller = marshaller; + this.unmarshaller = unmarshaller; + } + + + /** + * Set the {@link Marshaller} to be used by this message converter. + */ + public void setMarshaller(Marshaller marshaller) { + this.marshaller = marshaller; + } + + /** + * Set the {@link Unmarshaller} to be used by this message converter. + */ + public void setUnmarshaller(Unmarshaller unmarshaller) { + this.unmarshaller = unmarshaller; + } + + + @Override + public boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return (canRead(mediaType) && this.unmarshaller != null && this.unmarshaller.supports(clazz)); + } + + @Override + public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return (canWrite(mediaType) && this.marshaller != null && this.marshaller.supports(clazz)); + } + + @Override + protected boolean supports(Class clazz) { + // should not be called, since we override canRead()/canWrite() + throw new UnsupportedOperationException(); + } + + @Override + protected Object readFromSource(Class clazz, HttpHeaders headers, Source source) throws Exception { + Assert.notNull(this.unmarshaller, "Property 'unmarshaller' is required"); + Object result = this.unmarshaller.unmarshal(source); + if (!clazz.isInstance(result)) { + throw new TypeMismatchException(result, clazz); + } + return result; + } + + @Override + protected void writeToResult(Object o, HttpHeaders headers, Result result) throws Exception { + Assert.notNull(this.marshaller, "Property 'marshaller' is required"); + this.marshaller.marshal(o, result); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/SourceHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/SourceHttpMessageConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..63c70e30c64335944d29760a4ce72c1503a163a7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/SourceHttpMessageConverter.java @@ -0,0 +1,293 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.StringReader; +import java.util.HashSet; +import java.util.Set; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLResolver; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Result; +import javax.xml.transform.Source; +import javax.xml.transform.TransformerException; +import javax.xml.transform.TransformerFactory; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamResult; +import javax.xml.transform.stream.StreamSource; + +import org.w3c.dom.Document; +import org.xml.sax.EntityResolver; +import org.xml.sax.InputSource; +import org.xml.sax.SAXException; +import org.xml.sax.XMLReader; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.AbstractHttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * Implementation of {@link org.springframework.http.converter.HttpMessageConverter} + * that can read and write {@link Source} objects. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.0 + * @param the converted object type + */ +public class SourceHttpMessageConverter extends AbstractHttpMessageConverter { + + private static final EntityResolver NO_OP_ENTITY_RESOLVER = + (publicId, systemId) -> new InputSource(new StringReader("")); + + private static final XMLResolver NO_OP_XML_RESOLVER = + (publicID, systemID, base, ns) -> StreamUtils.emptyInput(); + + private static final Set> SUPPORTED_CLASSES = new HashSet<>(8); + + static { + SUPPORTED_CLASSES.add(DOMSource.class); + SUPPORTED_CLASSES.add(SAXSource.class); + SUPPORTED_CLASSES.add(StAXSource.class); + SUPPORTED_CLASSES.add(StreamSource.class); + SUPPORTED_CLASSES.add(Source.class); + } + + + private final TransformerFactory transformerFactory = TransformerFactory.newInstance(); + + private boolean supportDtd = false; + + private boolean processExternalEntities = false; + + + /** + * Sets the {@link #setSupportedMediaTypes(java.util.List) supportedMediaTypes} + * to {@code text/xml} and {@code application/xml}, and {@code application/*-xml}. + */ + public SourceHttpMessageConverter() { + super(MediaType.APPLICATION_XML, MediaType.TEXT_XML, new MediaType("application", "*+xml")); + } + + + /** + * Indicate whether DTD parsing should be supported. + *

Default is {@code false} meaning that DTD is disabled. + */ + public void setSupportDtd(boolean supportDtd) { + this.supportDtd = supportDtd; + } + + /** + * Return whether DTD parsing is supported. + */ + public boolean isSupportDtd() { + return this.supportDtd; + } + + /** + * Indicate whether external XML entities are processed when converting to a Source. + *

Default is {@code false}, meaning that external entities are not resolved. + *

Note: setting this option to {@code true} also + * automatically sets {@link #setSupportDtd} to {@code true}. + */ + public void setProcessExternalEntities(boolean processExternalEntities) { + this.processExternalEntities = processExternalEntities; + if (processExternalEntities) { + this.supportDtd = true; + } + } + + /** + * Return whether XML external entities are allowed. + */ + public boolean isProcessExternalEntities() { + return this.processExternalEntities; + } + + + @Override + public boolean supports(Class clazz) { + return SUPPORTED_CLASSES.contains(clazz); + } + + @Override + @SuppressWarnings("unchecked") + protected T readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + InputStream body = inputMessage.getBody(); + if (DOMSource.class == clazz) { + return (T) readDOMSource(body, inputMessage); + } + else if (SAXSource.class == clazz) { + return (T) readSAXSource(body, inputMessage); + } + else if (StAXSource.class == clazz) { + return (T) readStAXSource(body, inputMessage); + } + else if (StreamSource.class == clazz || Source.class == clazz) { + return (T) readStreamSource(body); + } + else { + throw new HttpMessageNotReadableException("Could not read class [" + clazz + + "]. Only DOMSource, SAXSource, StAXSource, and StreamSource are supported.", inputMessage); + } + } + + private DOMSource readDOMSource(InputStream body, HttpInputMessage inputMessage) throws IOException { + try { + DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); + documentBuilderFactory.setNamespaceAware(true); + documentBuilderFactory.setFeature( + "http://apache.org/xml/features/disallow-doctype-decl", !isSupportDtd()); + documentBuilderFactory.setFeature( + "http://xml.org/sax/features/external-general-entities", isProcessExternalEntities()); + DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder(); + if (!isProcessExternalEntities()) { + documentBuilder.setEntityResolver(NO_OP_ENTITY_RESOLVER); + } + Document document = documentBuilder.parse(body); + return new DOMSource(document); + } + catch (NullPointerException ex) { + if (!isSupportDtd()) { + throw new HttpMessageNotReadableException("NPE while unmarshalling: This can happen " + + "due to the presence of DTD declarations which are disabled.", ex, inputMessage); + } + throw ex; + } + catch (ParserConfigurationException ex) { + throw new HttpMessageNotReadableException( + "Could not set feature: " + ex.getMessage(), ex, inputMessage); + } + catch (SAXException ex) { + throw new HttpMessageNotReadableException( + "Could not parse document: " + ex.getMessage(), ex, inputMessage); + } + } + + @SuppressWarnings("deprecation") // on JDK 9 + private SAXSource readSAXSource(InputStream body, HttpInputMessage inputMessage) throws IOException { + try { + XMLReader xmlReader = org.xml.sax.helpers.XMLReaderFactory.createXMLReader(); + xmlReader.setFeature("http://apache.org/xml/features/disallow-doctype-decl", !isSupportDtd()); + xmlReader.setFeature("http://xml.org/sax/features/external-general-entities", isProcessExternalEntities()); + if (!isProcessExternalEntities()) { + xmlReader.setEntityResolver(NO_OP_ENTITY_RESOLVER); + } + byte[] bytes = StreamUtils.copyToByteArray(body); + return new SAXSource(xmlReader, new InputSource(new ByteArrayInputStream(bytes))); + } + catch (SAXException ex) { + throw new HttpMessageNotReadableException( + "Could not parse document: " + ex.getMessage(), ex, inputMessage); + } + } + + private Source readStAXSource(InputStream body, HttpInputMessage inputMessage) { + try { + XMLInputFactory inputFactory = XMLInputFactory.newInstance(); + inputFactory.setProperty(XMLInputFactory.SUPPORT_DTD, isSupportDtd()); + inputFactory.setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, isProcessExternalEntities()); + if (!isProcessExternalEntities()) { + inputFactory.setXMLResolver(NO_OP_XML_RESOLVER); + } + XMLStreamReader streamReader = inputFactory.createXMLStreamReader(body); + return new StAXSource(streamReader); + } + catch (XMLStreamException ex) { + throw new HttpMessageNotReadableException( + "Could not parse document: " + ex.getMessage(), ex, inputMessage); + } + } + + private StreamSource readStreamSource(InputStream body) throws IOException { + byte[] bytes = StreamUtils.copyToByteArray(body); + return new StreamSource(new ByteArrayInputStream(bytes)); + } + + @Override + @Nullable + protected Long getContentLength(T t, @Nullable MediaType contentType) { + if (t instanceof DOMSource) { + try { + CountingOutputStream os = new CountingOutputStream(); + transform(t, new StreamResult(os)); + return os.count; + } + catch (TransformerException ex) { + // ignore + } + } + return null; + } + + @Override + protected void writeInternal(T t, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + try { + Result result = new StreamResult(outputMessage.getBody()); + transform(t, result); + } + catch (TransformerException ex) { + throw new HttpMessageNotWritableException("Could not transform [" + t + "] to output message", ex); + } + } + + private void transform(Source source, Result result) throws TransformerException { + this.transformerFactory.newTransformer().transform(source, result); + } + + + private static class CountingOutputStream extends OutputStream { + + long count = 0; + + @Override + public void write(int b) throws IOException { + this.count++; + } + + @Override + public void write(byte[] b) throws IOException { + this.count += b.length; + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + this.count += len; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/package-info.java b/spring-web/src/main/java/org/springframework/http/converter/xml/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..34734d540353433d851087c8a0d8899abb259515 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides HttpMessageConverter implementations for handling XML. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.converter.xml; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/package-info.java b/spring-web/src/main/java/org/springframework/http/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..8c8e6c19f41b238b4a253e1f896a2786246f8c90 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/package-info.java @@ -0,0 +1,10 @@ +/** + * Contains a basic abstraction over client/server-side HTTP. This package contains + * the {@code HttpInputMessage} and {@code HttpOutputMessage} interfaces. + */ +@NonNullApi +@NonNullFields +package org.springframework.http; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/server/DefaultPathContainer.java b/spring-web/src/main/java/org/springframework/http/server/DefaultPathContainer.java new file mode 100644 index 0000000000000000000000000000000000000000..3ed1bb8de20eaf66b6562a4b7ea1248a2fd2a6be --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/DefaultPathContainer.java @@ -0,0 +1,253 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Default implementation of {@link PathContainer}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +final class DefaultPathContainer implements PathContainer { + + private static final MultiValueMap EMPTY_MAP = new LinkedMultiValueMap<>(); + + private static final PathContainer EMPTY_PATH = new DefaultPathContainer("", Collections.emptyList()); + + private static final PathContainer.Separator SEPARATOR = () -> "/"; + + + private final String path; + + private final List elements; + + + private DefaultPathContainer(String path, List elements) { + this.path = path; + this.elements = Collections.unmodifiableList(elements); + } + + + @Override + public String value() { + return this.path; + } + + @Override + public List elements() { + return this.elements; + } + + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + return this.path.equals(((DefaultPathContainer) other).path); + } + + @Override + public int hashCode() { + return this.path.hashCode(); + } + + @Override + public String toString() { + return value(); + } + + + static PathContainer createFromUrlPath(String path) { + if (path.equals("")) { + return EMPTY_PATH; + } + String separator = "/"; + Separator separatorElement = separator.equals(SEPARATOR.value()) ? SEPARATOR : () -> separator; + List elements = new ArrayList<>(); + int begin; + if (path.length() > 0 && path.startsWith(separator)) { + begin = separator.length(); + elements.add(separatorElement); + } + else { + begin = 0; + } + while (begin < path.length()) { + int end = path.indexOf(separator, begin); + String segment = (end != -1 ? path.substring(begin, end) : path.substring(begin)); + if (!segment.equals("")) { + elements.add(parsePathSegment(segment)); + } + if (end == -1) { + break; + } + elements.add(separatorElement); + begin = end + separator.length(); + } + return new DefaultPathContainer(path, elements); + } + + private static PathSegment parsePathSegment(String segment) { + Charset charset = StandardCharsets.UTF_8; + int index = segment.indexOf(';'); + if (index == -1) { + String valueToMatch = StringUtils.uriDecode(segment, charset); + return new DefaultPathSegment(segment, valueToMatch, EMPTY_MAP); + } + else { + String valueToMatch = StringUtils.uriDecode(segment.substring(0, index), charset); + String pathParameterContent = segment.substring(index); + MultiValueMap parameters = parsePathParams(pathParameterContent, charset); + return new DefaultPathSegment(segment, valueToMatch, parameters); + } + } + + private static MultiValueMap parsePathParams(String input, Charset charset) { + MultiValueMap result = new LinkedMultiValueMap<>(); + int begin = 1; + while (begin < input.length()) { + int end = input.indexOf(';', begin); + String param = (end != -1 ? input.substring(begin, end) : input.substring(begin)); + parsePathParamValues(param, charset, result); + if (end == -1) { + break; + } + begin = end + 1; + } + return result; + } + + private static void parsePathParamValues(String input, Charset charset, MultiValueMap output) { + if (StringUtils.hasText(input)) { + int index = input.indexOf('='); + if (index != -1) { + String name = input.substring(0, index); + String value = input.substring(index + 1); + for (String v : StringUtils.commaDelimitedListToStringArray(value)) { + name = StringUtils.uriDecode(name, charset); + if (StringUtils.hasText(name)) { + output.add(name, StringUtils.uriDecode(v, charset)); + } + } + } + else { + String name = StringUtils.uriDecode(input, charset); + if (StringUtils.hasText(name)) { + output.add(input, ""); + } + } + } + } + + static PathContainer subPath(PathContainer container, int fromIndex, int toIndex) { + List elements = container.elements(); + if (fromIndex == 0 && toIndex == elements.size()) { + return container; + } + if (fromIndex == toIndex) { + return EMPTY_PATH; + } + + Assert.isTrue(fromIndex >= 0 && fromIndex < elements.size(), () -> "Invalid fromIndex: " + fromIndex); + Assert.isTrue(toIndex >= 0 && toIndex <= elements.size(), () -> "Invalid toIndex: " + toIndex); + Assert.isTrue(fromIndex < toIndex, () -> "fromIndex: " + fromIndex + " should be < toIndex " + toIndex); + + List subList = elements.subList(fromIndex, toIndex); + String path = subList.stream().map(Element::value).collect(Collectors.joining("")); + return new DefaultPathContainer(path, subList); + } + + + private static class DefaultPathSegment implements PathSegment { + + private final String value; + + private final String valueToMatch; + + private final char[] valueToMatchAsChars; + + private final MultiValueMap parameters; + + public DefaultPathSegment(String value, String valueToMatch, MultiValueMap params) { + Assert.isTrue(!value.contains("/"), () -> "Invalid path segment value: " + value); + this.value = value; + this.valueToMatch = valueToMatch; + this.valueToMatchAsChars = valueToMatch.toCharArray(); + this.parameters = CollectionUtils.unmodifiableMultiValueMap(params); + } + + @Override + public String value() { + return this.value; + } + + @Override + public String valueToMatch() { + return this.valueToMatch; + } + + @Override + public char[] valueToMatchAsChars() { + return this.valueToMatchAsChars; + } + + @Override + public MultiValueMap parameters() { + return this.parameters; + } + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + return this.value.equals(((DefaultPathSegment) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + + public String toString() { + return "[value='" + this.value + "']"; + } + } + +} + diff --git a/spring-web/src/main/java/org/springframework/http/server/DefaultRequestPath.java b/spring-web/src/main/java/org/springframework/http/server/DefaultRequestPath.java new file mode 100644 index 0000000000000000000000000000000000000000..ff5355a76c57238607c698e32ef82a5e8e6d597d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/DefaultRequestPath.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.net.URI; +import java.util.List; + +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * Default implementation of {@link RequestPath}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +class DefaultRequestPath implements RequestPath { + + private final PathContainer fullPath; + + private final PathContainer contextPath; + + private final PathContainer pathWithinApplication; + + + DefaultRequestPath(URI uri, @Nullable String contextPath) { + this.fullPath = PathContainer.parsePath(uri.getRawPath()); + this.contextPath = initContextPath(this.fullPath, contextPath); + this.pathWithinApplication = extractPathWithinApplication(this.fullPath, this.contextPath); + } + + private DefaultRequestPath(RequestPath requestPath, String contextPath) { + this.fullPath = requestPath; + this.contextPath = initContextPath(this.fullPath, contextPath); + this.pathWithinApplication = extractPathWithinApplication(this.fullPath, this.contextPath); + } + + private static PathContainer initContextPath(PathContainer path, @Nullable String contextPath) { + if (!StringUtils.hasText(contextPath) || "/".equals(contextPath)) { + return PathContainer.parsePath(""); + } + + validateContextPath(path.value(), contextPath); + + int length = contextPath.length(); + int counter = 0; + + for (int i=0; i < path.elements().size(); i++) { + PathContainer.Element element = path.elements().get(i); + counter += element.value().length(); + if (length == counter) { + return path.subPath(0, i + 1); + } + } + + // Should not happen.. + throw new IllegalStateException("Failed to initialize contextPath '" + contextPath + "'" + + " for requestPath '" + path.value() + "'"); + } + + private static void validateContextPath(String fullPath, String contextPath) { + int length = contextPath.length(); + if (contextPath.charAt(0) != '/' || contextPath.charAt(length - 1) == '/') { + throw new IllegalArgumentException("Invalid contextPath: '" + contextPath + "': " + + "must start with '/' and not end with '/'"); + } + if (!fullPath.startsWith(contextPath)) { + throw new IllegalArgumentException("Invalid contextPath '" + contextPath + "': " + + "must match the start of requestPath: '" + fullPath + "'"); + } + if (fullPath.length() > length && fullPath.charAt(length) != '/') { + throw new IllegalArgumentException("Invalid contextPath '" + contextPath + "': " + + "must match to full path segments for requestPath: '" + fullPath + "'"); + } + } + + private static PathContainer extractPathWithinApplication(PathContainer fullPath, PathContainer contextPath) { + return fullPath.subPath(contextPath.elements().size()); + } + + + // PathContainer methods.. + + @Override + public String value() { + return this.fullPath.value(); + } + + @Override + public List elements() { + return this.fullPath.elements(); + } + + + // RequestPath methods.. + + @Override + public PathContainer contextPath() { + return this.contextPath; + } + + @Override + public PathContainer pathWithinApplication() { + return this.pathWithinApplication; + } + + @Override + public RequestPath modifyContextPath(String contextPath) { + return new DefaultRequestPath(this, contextPath); + } + + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + DefaultRequestPath otherPath= (DefaultRequestPath) other; + return (this.fullPath.equals(otherPath.fullPath) && + this.contextPath.equals(otherPath.contextPath) && + this.pathWithinApplication.equals(otherPath.pathWithinApplication)); + } + + @Override + public int hashCode() { + int result = this.fullPath.hashCode(); + result = 31 * result + this.contextPath.hashCode(); + result = 31 * result + this.pathWithinApplication.hashCode(); + return result; + } + + @Override + public String toString() { + return this.fullPath.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/PathContainer.java b/spring-web/src/main/java/org/springframework/http/server/PathContainer.java new file mode 100644 index 0000000000000000000000000000000000000000..93599a78c7ad837ca2a0518cba50c4887a1df9c3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/PathContainer.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.util.List; + +import org.springframework.util.MultiValueMap; + +/** + * Structured representation of a URI path whose elements have been pre-parsed + * into a sequence of {@link Separator Separator} and {@link PathSegment + * PathSegment} elements. + * + *

An instance of this class can be created via {@link #parsePath(String)}. + * Each {@link PathSegment PathSegment} exposes its structure decoded + * safely without the risk of encoded reserved characters altering the path or + * segment structure and without path parameters for path matching purposes. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface PathContainer { + + /** + * The original (raw, encoded) path that this instance was parsed from. + */ + String value(); + + /** + * The list of path elements, either {@link Separator} or {@link PathSegment}. + */ + List elements(); + + /** + * Extract a sub-path from the given offset into the elements list. + * @param index the start element index (inclusive) + * @return the sub-path + */ + default PathContainer subPath(int index) { + return subPath(index, elements().size()); + } + + /** + * Extract a sub-path from the given start offset (inclusive) into the + * element list and to the end offset (exclusive). + * @param startIndex the start element index (inclusive) + * @param endIndex the end element index (exclusive) + * @return the sub-path + */ + default PathContainer subPath(int startIndex, int endIndex) { + return DefaultPathContainer.subPath(this, startIndex, endIndex); + } + + + /** + * Parse the path value into a sequence of {@link Separator Separator} and + * {@link PathSegment PathSegment} elements. + * @param path the encoded, raw URL path value to parse + * @return the parsed path + */ + static PathContainer parsePath(String path) { + return DefaultPathContainer.createFromUrlPath(path); + } + + + /** + * Common representation of a path element, e.g. separator or segment. + */ + interface Element { + + /** + * Return the original (raw, encoded) value of this path element. + */ + String value(); + } + + + /** + * Path separator element. + */ + interface Separator extends Element { + } + + + /** + * Path segment element. + */ + interface PathSegment extends Element { + + /** + * Return the path segment value to use for pattern matching purposes. + * By default this is the same as {@link #value()} but may also differ + * in sub-interfaces (e.g. decoded, sanitized, etc.). + */ + String valueToMatch(); + + /** + * The same as {@link #valueToMatch()} but as a {@code char[]}. + */ + char[] valueToMatchAsChars(); + + /** + * Path parameters parsed from the path segment. + */ + MultiValueMap parameters(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/RequestPath.java b/spring-web/src/main/java/org/springframework/http/server/RequestPath.java new file mode 100644 index 0000000000000000000000000000000000000000..309537eca9f463814186abb1b91aa06b2d1680e5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/RequestPath.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server; + +import java.net.URI; + +import org.springframework.lang.Nullable; + +/** + * Represents the complete path for a request. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface RequestPath extends PathContainer { + + /** + * Returns the portion of the URL path that represents the application. + * The context path is always at the beginning of the path and starts but + * does not end with "/". It is shared for URLs of the same application. + *

The context path may come from the underlying runtime API such as + * when deploying as a WAR to a Servlet container or it may be assigned in + * a WebFlux application through the use of + * {@link org.springframework.http.server.reactive.ContextPathCompositeHandler + * ContextPathCompositeHandler}. + */ + PathContainer contextPath(); + + /** + * The portion of the request path after the context path. + */ + PathContainer pathWithinApplication(); + + /** + * Return a new {@code RequestPath} instance with a modified context path. + * The new context path must match 0 or more path segments at the start. + * @param contextPath the new context path + * @return a new {@code RequestPath} instance + */ + RequestPath modifyContextPath(String contextPath); + + + /** + * Create a new {@code RequestPath} with the given parameters. + */ + static RequestPath parse(URI uri, @Nullable String contextPath) { + return new DefaultRequestPath(uri, contextPath); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServerHttpAsyncRequestControl.java b/spring-web/src/main/java/org/springframework/http/server/ServerHttpAsyncRequestControl.java new file mode 100644 index 0000000000000000000000000000000000000000..eea1c325a8e3c887d63cc7958353461d60894561 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServerHttpAsyncRequestControl.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +/** + * A control that can put the processing of an HTTP request in asynchronous mode during + * which the response remains open until explicitly closed. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public interface ServerHttpAsyncRequestControl { + + /** + * Enable asynchronous processing after which the response remains open until a call + * to {@link #complete()} is made or the server times out the request. Once enabled, + * additional calls to this method are ignored. + */ + void start(); + + /** + * A variation on {@link #start()} that allows specifying a timeout value to use to + * use for asynchronous processing. If {@link #complete()} is not called within the + * specified value, the request times out. + */ + void start(long timeout); + + /** + * Return whether asynchronous request processing has been started. + */ + boolean isStarted(); + + /** + * Mark asynchronous request processing as completed. + */ + void complete(); + + /** + * Return whether asynchronous request processing has been completed. + */ + boolean isCompleted(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..c8b1a3a971739741bae465c9589e529f7ff31787 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.net.InetSocketAddress; +import java.security.Principal; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpRequest; +import org.springframework.lang.Nullable; + +/** + * Represents a server-side HTTP request. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.0 + */ +public interface ServerHttpRequest extends HttpRequest, HttpInputMessage { + + /** + * Return a {@link java.security.Principal} instance containing the name of the + * authenticated user. + *

If the user has not been authenticated, the method returns null. + */ + @Nullable + Principal getPrincipal(); + + /** + * Return the address on which the request was received. + */ + InetSocketAddress getLocalAddress(); + + /** + * Return the address of the remote client. + */ + InetSocketAddress getRemoteAddress(); + + /** + * Return a control that allows putting the request in asynchronous mode so the + * response remains open until closed explicitly from the current or another thread. + */ + ServerHttpAsyncRequestControl getAsyncRequestControl(ServerHttpResponse response); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..e47d93c313b9c54b7674047b6575c33e3251cbe2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServerHttpResponse.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.io.Closeable; +import java.io.Flushable; +import java.io.IOException; + +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.HttpStatus; + +/** + * Represents a server-side HTTP response. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public interface ServerHttpResponse extends HttpOutputMessage, Flushable, Closeable { + + /** + * Set the HTTP status code of the response. + * @param status the HTTP status as an HttpStatus enum value + */ + void setStatusCode(HttpStatus status); + + /** + * Ensure that the headers and the content of the response are written out. + *

After the first flush, headers can no longer be changed. + * Only further content writing and content flushing is possible. + */ + @Override + void flush() throws IOException; + + /** + * Close this response, freeing any resources created. + */ + @Override + void close(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpAsyncRequestControl.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpAsyncRequestControl.java new file mode 100644 index 0000000000000000000000000000000000000000..ef080c84a8db8a1b8f625392dbbad0973ff67b66 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpAsyncRequestControl.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * A {@link ServerHttpAsyncRequestControl} to use on Servlet containers (Servlet 3.0+). + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class ServletServerHttpAsyncRequestControl implements ServerHttpAsyncRequestControl, AsyncListener { + + private static final long NO_TIMEOUT_VALUE = Long.MIN_VALUE; + + + private final ServletServerHttpRequest request; + + private final ServletServerHttpResponse response; + + @Nullable + private AsyncContext asyncContext; + + private AtomicBoolean asyncCompleted = new AtomicBoolean(false); + + + /** + * Constructor accepting a request and response pair that are expected to be of type + * {@link ServletServerHttpRequest} and {@link ServletServerHttpResponse} + * respectively. + */ + public ServletServerHttpAsyncRequestControl(ServletServerHttpRequest request, ServletServerHttpResponse response) { + Assert.notNull(request, "request is required"); + Assert.notNull(response, "response is required"); + + Assert.isTrue(request.getServletRequest().isAsyncSupported(), + "Async support must be enabled on a servlet and for all filters involved " + + "in async request processing. This is done in Java code using the Servlet API " + + "or by adding \"true\" to servlet and " + + "filter declarations in web.xml. Also you must use a Servlet 3.0+ container"); + + this.request = request; + this.response = response; + } + + + @Override + public boolean isStarted() { + return (this.asyncContext != null && this.request.getServletRequest().isAsyncStarted()); + } + + @Override + public boolean isCompleted() { + return this.asyncCompleted.get(); + } + + @Override + public void start() { + start(NO_TIMEOUT_VALUE); + } + + @Override + public void start(long timeout) { + Assert.state(!isCompleted(), "Async processing has already completed"); + if (isStarted()) { + return; + } + + HttpServletRequest servletRequest = this.request.getServletRequest(); + HttpServletResponse servletResponse = this.response.getServletResponse(); + + this.asyncContext = servletRequest.startAsync(servletRequest, servletResponse); + this.asyncContext.addListener(this); + + if (timeout != NO_TIMEOUT_VALUE) { + this.asyncContext.setTimeout(timeout); + } + } + + @Override + public void complete() { + if (this.asyncContext != null && isStarted() && !isCompleted()) { + this.asyncContext.complete(); + } + } + + + // --------------------------------------------------------------------- + // Implementation of AsyncListener methods + // --------------------------------------------------------------------- + + @Override + public void onComplete(AsyncEvent event) throws IOException { + this.asyncContext = null; + this.asyncCompleted.set(true); + } + + @Override + public void onStartAsync(AsyncEvent event) throws IOException { + } + + @Override + public void onError(AsyncEvent event) throws IOException { + } + + @Override + public void onTimeout(AsyncEvent event) throws IOException { + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..dd646692efc6057484d2002514a324c0c1856a73 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java @@ -0,0 +1,270 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLEncoder; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.InvalidMediaTypeException; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; + +/** + * {@link ServerHttpRequest} implementation that is based on a {@link HttpServletRequest}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.0 + */ +public class ServletServerHttpRequest implements ServerHttpRequest { + + protected static final String FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"; + + protected static final Charset FORM_CHARSET = StandardCharsets.UTF_8; + + + private final HttpServletRequest servletRequest; + + @Nullable + private URI uri; + + @Nullable + private HttpHeaders headers; + + @Nullable + private ServerHttpAsyncRequestControl asyncRequestControl; + + + /** + * Construct a new instance of the ServletServerHttpRequest based on the + * given {@link HttpServletRequest}. + * @param servletRequest the servlet request + */ + public ServletServerHttpRequest(HttpServletRequest servletRequest) { + Assert.notNull(servletRequest, "HttpServletRequest must not be null"); + this.servletRequest = servletRequest; + } + + + /** + * Returns the {@code HttpServletRequest} this object is based on. + */ + public HttpServletRequest getServletRequest() { + return this.servletRequest; + } + + @Override + @Nullable + public HttpMethod getMethod() { + return HttpMethod.resolve(this.servletRequest.getMethod()); + } + + @Override + public String getMethodValue() { + return this.servletRequest.getMethod(); + } + + @Override + public URI getURI() { + if (this.uri == null) { + String urlString = null; + boolean hasQuery = false; + try { + StringBuffer url = this.servletRequest.getRequestURL(); + String query = this.servletRequest.getQueryString(); + hasQuery = StringUtils.hasText(query); + if (hasQuery) { + url.append('?').append(query); + } + urlString = url.toString(); + this.uri = new URI(urlString); + } + catch (URISyntaxException ex) { + if (!hasQuery) { + throw new IllegalStateException( + "Could not resolve HttpServletRequest as URI: " + urlString, ex); + } + // Maybe a malformed query string... try plain request URL + try { + urlString = this.servletRequest.getRequestURL().toString(); + this.uri = new URI(urlString); + } + catch (URISyntaxException ex2) { + throw new IllegalStateException( + "Could not resolve HttpServletRequest as URI: " + urlString, ex2); + } + } + } + return this.uri; + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + + for (Enumeration names = this.servletRequest.getHeaderNames(); names.hasMoreElements();) { + String headerName = (String) names.nextElement(); + for (Enumeration headerValues = this.servletRequest.getHeaders(headerName); + headerValues.hasMoreElements();) { + String headerValue = (String) headerValues.nextElement(); + this.headers.add(headerName, headerValue); + } + } + + // HttpServletRequest exposes some headers as properties: + // we should include those if not already present + try { + MediaType contentType = this.headers.getContentType(); + if (contentType == null) { + String requestContentType = this.servletRequest.getContentType(); + if (StringUtils.hasLength(requestContentType)) { + contentType = MediaType.parseMediaType(requestContentType); + this.headers.setContentType(contentType); + } + } + if (contentType != null && contentType.getCharset() == null) { + String requestEncoding = this.servletRequest.getCharacterEncoding(); + if (StringUtils.hasLength(requestEncoding)) { + Charset charSet = Charset.forName(requestEncoding); + Map params = new LinkedCaseInsensitiveMap<>(); + params.putAll(contentType.getParameters()); + params.put("charset", charSet.toString()); + MediaType mediaType = new MediaType(contentType.getType(), contentType.getSubtype(), params); + this.headers.setContentType(mediaType); + } + } + } + catch (InvalidMediaTypeException ex) { + // Ignore: simply not exposing an invalid content type in HttpHeaders... + } + + if (this.headers.getContentLength() < 0) { + int requestContentLength = this.servletRequest.getContentLength(); + if (requestContentLength != -1) { + this.headers.setContentLength(requestContentLength); + } + } + } + + return this.headers; + } + + @Override + public Principal getPrincipal() { + return this.servletRequest.getUserPrincipal(); + } + + @Override + public InetSocketAddress getLocalAddress() { + return new InetSocketAddress(this.servletRequest.getLocalName(), this.servletRequest.getLocalPort()); + } + + @Override + public InetSocketAddress getRemoteAddress() { + return new InetSocketAddress(this.servletRequest.getRemoteHost(), this.servletRequest.getRemotePort()); + } + + @Override + public InputStream getBody() throws IOException { + if (isFormPost(this.servletRequest)) { + return getBodyFromServletRequestParameters(this.servletRequest); + } + else { + return this.servletRequest.getInputStream(); + } + } + + @Override + public ServerHttpAsyncRequestControl getAsyncRequestControl(ServerHttpResponse response) { + if (this.asyncRequestControl == null) { + if (!ServletServerHttpResponse.class.isInstance(response)) { + throw new IllegalArgumentException( + "Response must be a ServletServerHttpResponse: " + response.getClass()); + } + ServletServerHttpResponse servletServerResponse = (ServletServerHttpResponse) response; + this.asyncRequestControl = new ServletServerHttpAsyncRequestControl(this, servletServerResponse); + } + return this.asyncRequestControl; + } + + + private static boolean isFormPost(HttpServletRequest request) { + String contentType = request.getContentType(); + return (contentType != null && contentType.contains(FORM_CONTENT_TYPE) && + HttpMethod.POST.matches(request.getMethod())); + } + + /** + * Use {@link javax.servlet.ServletRequest#getParameterMap()} to reconstruct the + * body of a form 'POST' providing a predictable outcome as opposed to reading + * from the body, which can fail if any other code has used the ServletRequest + * to access a parameter, thus causing the input stream to be "consumed". + */ + private static InputStream getBodyFromServletRequestParameters(HttpServletRequest request) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); + Writer writer = new OutputStreamWriter(bos, FORM_CHARSET); + + Map form = request.getParameterMap(); + for (Iterator nameIterator = form.keySet().iterator(); nameIterator.hasNext();) { + String name = nameIterator.next(); + List values = Arrays.asList(form.get(name)); + for (Iterator valueIterator = values.iterator(); valueIterator.hasNext();) { + String value = valueIterator.next(); + writer.write(URLEncoder.encode(name, FORM_CHARSET.name())); + if (value != null) { + writer.write('='); + writer.write(URLEncoder.encode(value, FORM_CHARSET.name())); + if (valueIterator.hasNext()) { + writer.write('&'); + } + } + } + if (nameIterator.hasNext()) { + writer.append('&'); + } + } + writer.flush(); + + return new ByteArrayInputStream(bos.toByteArray()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..d718971e68b08b7e95df0e9d35a5f78e0338dd50 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -0,0 +1,180 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.0 + */ +public class ServletServerHttpResponse implements ServerHttpResponse { + + private final HttpServletResponse servletResponse; + + private final HttpHeaders headers; + + private boolean headersWritten = false; + + private boolean bodyUsed = false; + + + /** + * Construct a new instance of the ServletServerHttpResponse based on the given {@link HttpServletResponse}. + * @param servletResponse the servlet response + */ + public ServletServerHttpResponse(HttpServletResponse servletResponse) { + Assert.notNull(servletResponse, "HttpServletResponse must not be null"); + this.servletResponse = servletResponse; + this.headers = new ServletResponseHttpHeaders(); + } + + + /** + * Return the {@code HttpServletResponse} this object is based on. + */ + public HttpServletResponse getServletResponse() { + return this.servletResponse; + } + + @Override + public void setStatusCode(HttpStatus status) { + Assert.notNull(status, "HttpStatus must not be null"); + this.servletResponse.setStatus(status.value()); + } + + @Override + public HttpHeaders getHeaders() { + return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public OutputStream getBody() throws IOException { + this.bodyUsed = true; + writeHeaders(); + return this.servletResponse.getOutputStream(); + } + + @Override + public void flush() throws IOException { + writeHeaders(); + if (this.bodyUsed) { + this.servletResponse.flushBuffer(); + } + } + + @Override + public void close() { + writeHeaders(); + } + + private void writeHeaders() { + if (!this.headersWritten) { + getHeaders().forEach((headerName, headerValues) -> { + for (String headerValue : headerValues) { + this.servletResponse.addHeader(headerName, headerValue); + } + }); + // HttpServletResponse exposes some headers as properties: we should include those if not already present + if (this.servletResponse.getContentType() == null && this.headers.getContentType() != null) { + this.servletResponse.setContentType(this.headers.getContentType().toString()); + } + if (this.servletResponse.getCharacterEncoding() == null && this.headers.getContentType() != null && + this.headers.getContentType().getCharset() != null) { + this.servletResponse.setCharacterEncoding(this.headers.getContentType().getCharset().name()); + } + this.headersWritten = true; + } + } + + + /** + * Extends HttpHeaders with the ability to look up headers already present in + * the underlying HttpServletResponse. + * + *

The intent is merely to expose what is available through the HttpServletResponse + * i.e. the ability to look up specific header values by name. All other + * map-related operations (e.g. iteration, removal, etc) apply only to values + * added directly through HttpHeaders methods. + * + * @since 4.0.3 + */ + private class ServletResponseHttpHeaders extends HttpHeaders { + + private static final long serialVersionUID = 3410708522401046302L; + + @Override + public boolean containsKey(Object key) { + return (super.containsKey(key) || (get(key) != null)); + } + + @Override + @Nullable + public String getFirst(String headerName) { + String value = servletResponse.getHeader(headerName); + if (value != null) { + return value; + } + else { + return super.getFirst(headerName); + } + } + + @Override + public List get(Object key) { + Assert.isInstanceOf(String.class, key, "Key must be a String-based header name"); + + Collection values1 = servletResponse.getHeaders((String) key); + if (headersWritten) { + return new ArrayList<>(values1); + } + boolean isEmpty1 = CollectionUtils.isEmpty(values1); + + List values2 = super.get(key); + boolean isEmpty2 = CollectionUtils.isEmpty(values2); + + if (isEmpty1 && isEmpty2) { + return null; + } + + List values = new ArrayList<>(); + if (!isEmpty1) { + values.addAll(values1); + } + if (!isEmpty2) { + values.addAll(values2); + } + return values; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/package-info.java b/spring-web/src/main/java/org/springframework/http/server/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..8c7e71cd98ec4afab28201feb8799daa1d8e5205 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/package-info.java @@ -0,0 +1,11 @@ +/** + * Contains an abstraction over server-side HTTP. This package + * contains the {@code ServerHttpRequest} and {@code ServerHttpResponse}, + * as well as a Servlet-based implementation of these interfaces. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.server; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java new file mode 100644 index 0000000000000000000000000000000000000000..b28b6e47a05e7de0ae3184c5f4e5c0769df1dae2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java @@ -0,0 +1,474 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.Operators; + +import org.springframework.core.log.LogDelegateFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Abstract base class for {@code Publisher} implementations that bridge between + * event-listener read APIs and Reactive Streams. + * + *

Specifically a base class for reading from the HTTP request body with + * Servlet 3.1 non-blocking I/O and Undertow XNIO as well as handling incoming + * WebSocket messages with standard Java WebSocket (JSR-356), Jetty, and + * Undertow. + * + * @author Arjen Poutsma + * @author Violeta Georgieva + * @author Rossen Stoyanchev + * @since 5.0 + * @param the type of element signaled + */ +public abstract class AbstractListenerReadPublisher implements Publisher { + + /** + * Special logger for debugging Reactive Streams signals. + * @see LogDelegateFactory#getHiddenLog(Class) + * @see AbstractListenerWriteProcessor#rsWriteLogger + * @see AbstractListenerWriteFlushProcessor#rsWriteFlushLogger + * @see WriteResultPublisher#rsWriteResultLogger + */ + protected static Log rsReadLogger = LogDelegateFactory.getHiddenLog(AbstractListenerReadPublisher.class); + + + private final AtomicReference state = new AtomicReference<>(State.UNSUBSCRIBED); + + private volatile long demand; + + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater DEMAND_FIELD_UPDATER = + AtomicLongFieldUpdater.newUpdater(AbstractListenerReadPublisher.class, "demand"); + + @Nullable + private volatile Subscriber subscriber; + + private volatile boolean completionBeforeDemand; + + @Nullable + private volatile Throwable errorBeforeDemand; + + private final String logPrefix; + + + public AbstractListenerReadPublisher() { + this(""); + } + + /** + * Create an instance with the given log prefix. + * @since 5.1 + */ + public AbstractListenerReadPublisher(String logPrefix) { + this.logPrefix = logPrefix; + } + + + /** + * Return the configured log message prefix. + * @since 5.1 + */ + public String getLogPrefix() { + return this.logPrefix; + } + + + // Publisher implementation... + + @Override + public void subscribe(Subscriber subscriber) { + this.state.get().subscribe(this, subscriber); + } + + + // Async I/O notification methods... + + /** + * Invoked when reading is possible, either in the same thread after a check + * via {@link #checkOnDataAvailable()}, or as a callback from the underlying + * container. + */ + public final void onDataAvailable() { + rsReadLogger.trace(getLogPrefix() + "onDataAvailable"); + this.state.get().onDataAvailable(this); + } + + /** + * Sub-classes can call this method to delegate a contain notification when + * all data has been read. + */ + public void onAllDataRead() { + rsReadLogger.trace(getLogPrefix() + "onAllDataRead"); + this.state.get().onAllDataRead(this); + } + + /** + * Sub-classes can call this to delegate container error notifications. + */ + public final void onError(Throwable ex) { + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + "Connection error: " + ex); + } + this.state.get().onError(this, ex); + } + + + // Read API methods to be implemented or template methods to override... + + /** + * Check if data is available and either call {@link #onDataAvailable()} + * immediately or schedule a notification. + */ + protected abstract void checkOnDataAvailable(); + + /** + * Read once from the input, if possible. + * @return the item that was read; or {@code null} + */ + @Nullable + protected abstract T read() throws IOException; + + /** + * Invoked when reading is paused due to a lack of demand. + *

Note: This method is guaranteed not to compete with + * {@link #checkOnDataAvailable()} so it can be used to safely suspend + * reading, if the underlying API supports it, i.e. without competing with + * an implicit call to resume via {@code checkOnDataAvailable()}. + * @since 5.0.2 + */ + protected abstract void readingPaused(); + + /** + * Invoked after an I/O read error from the underlying server or after a + * cancellation signal from the downstream consumer to allow sub-classes + * to discard any current cached data they might have. + * @since 5.0.11 + */ + protected abstract void discardData(); + + + // Private methods for use in State... + + /** + * Read and publish data one at a time until there is no more data, no more + * demand, or perhaps we completed in the mean time. + * @return {@code true} if there is more demand; {@code false} if there is + * no more demand or we have completed. + */ + private boolean readAndPublish() throws IOException { + long r; + while ((r = this.demand) > 0 && !this.state.get().equals(State.COMPLETED)) { + T data = read(); + if (data != null) { + if (r != Long.MAX_VALUE) { + DEMAND_FIELD_UPDATER.addAndGet(this, -1L); + } + Subscriber subscriber = this.subscriber; + Assert.state(subscriber != null, "No subscriber"); + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + "Publishing data read"); + } + subscriber.onNext(data); + } + else { + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + "No more data to read"); + } + return true; + } + } + return false; + } + + private boolean changeState(State oldState, State newState) { + boolean result = this.state.compareAndSet(oldState, newState); + if (result && rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + oldState + " -> " + newState); + } + return result; + } + + private void changeToDemandState(State oldState) { + if (changeState(oldState, State.DEMAND)) { + // Protect from infinite recursion in Undertow, where we can't check if data + // is available, so all we can do is to try to read. + // Generally, no need to check if we just came out of readAndPublish()... + if (!oldState.equals(State.READING)) { + checkOnDataAvailable(); + } + } + } + + private void handleCompletionOrErrorBeforeDemand() { + State state = this.state.get(); + if (!state.equals(State.UNSUBSCRIBED) && !state.equals(State.SUBSCRIBING)) { + if (this.completionBeforeDemand) { + rsReadLogger.trace(getLogPrefix() + "Completed before demand"); + this.state.get().onAllDataRead(this); + } + Throwable ex = this.errorBeforeDemand; + if (ex != null) { + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + "Completed with error before demand: " + ex); + } + this.state.get().onError(this, ex); + } + } + } + + private Subscription createSubscription() { + return new ReadSubscription(); + } + + + /** + * Subscription that delegates signals to State. + */ + private final class ReadSubscription implements Subscription { + + + @Override + public final void request(long n) { + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + n + " requested"); + } + state.get().request(AbstractListenerReadPublisher.this, n); + } + + @Override + public final void cancel() { + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + "Cancellation"); + } + state.get().cancel(AbstractListenerReadPublisher.this); + } + } + + + /** + * Represents a state for the {@link Publisher} to be in. + *

+	 *        UNSUBSCRIBED
+	 *             |
+	 *             v
+	 *        SUBSCRIBING
+	 *             |
+	 *             v
+	 *    +---- NO_DEMAND ---------------> DEMAND ---+
+	 *    |        ^                         ^       |
+	 *    |        |                         |       |
+	 *    |        +------- READING <--------+       |
+	 *    |                    |                     |
+	 *    |                    v                     |
+	 *    +--------------> COMPLETED <---------------+
+	 * 
+ */ + private enum State { + + UNSUBSCRIBED { + @Override + void subscribe(AbstractListenerReadPublisher publisher, Subscriber subscriber) { + Assert.notNull(publisher, "Publisher must not be null"); + Assert.notNull(subscriber, "Subscriber must not be null"); + if (publisher.changeState(this, SUBSCRIBING)) { + Subscription subscription = publisher.createSubscription(); + publisher.subscriber = subscriber; + subscriber.onSubscribe(subscription); + publisher.changeState(SUBSCRIBING, NO_DEMAND); + publisher.handleCompletionOrErrorBeforeDemand(); + } + else { + throw new IllegalStateException("Failed to transition to SUBSCRIBING, " + + "subscriber: " + subscriber); + } + } + + @Override + void onAllDataRead(AbstractListenerReadPublisher publisher) { + publisher.completionBeforeDemand = true; + publisher.handleCompletionOrErrorBeforeDemand(); + } + + @Override + void onError(AbstractListenerReadPublisher publisher, Throwable ex) { + publisher.errorBeforeDemand = ex; + publisher.handleCompletionOrErrorBeforeDemand(); + } + }, + + /** + * Very brief state where we know we have a Subscriber but must not + * send onComplete and onError until we after onSubscribe. + */ + SUBSCRIBING { + @Override + void request(AbstractListenerReadPublisher publisher, long n) { + if (Operators.validate(n)) { + Operators.addCap(DEMAND_FIELD_UPDATER, publisher, n); + publisher.changeToDemandState(this); + } + } + + @Override + void onAllDataRead(AbstractListenerReadPublisher publisher) { + publisher.completionBeforeDemand = true; + publisher.handleCompletionOrErrorBeforeDemand(); + } + + @Override + void onError(AbstractListenerReadPublisher publisher, Throwable ex) { + publisher.errorBeforeDemand = ex; + publisher.handleCompletionOrErrorBeforeDemand(); + } + }, + + NO_DEMAND { + @Override + void request(AbstractListenerReadPublisher publisher, long n) { + if (Operators.validate(n)) { + Operators.addCap(DEMAND_FIELD_UPDATER, publisher, n); + publisher.changeToDemandState(this); + } + } + }, + + DEMAND { + @Override + void request(AbstractListenerReadPublisher publisher, long n) { + if (Operators.validate(n)) { + Operators.addCap(DEMAND_FIELD_UPDATER, publisher, n); + // Did a concurrent read transition to NO_DEMAND just before us? + publisher.changeToDemandState(NO_DEMAND); + } + } + + @Override + void onDataAvailable(AbstractListenerReadPublisher publisher) { + if (publisher.changeState(this, READING)) { + try { + boolean demandAvailable = publisher.readAndPublish(); + if (demandAvailable) { + publisher.changeToDemandState(READING); + } + else { + publisher.readingPaused(); + if (publisher.changeState(READING, NO_DEMAND)) { + // Demand may have arrived since readAndPublish returned + long r = publisher.demand; + if (r > 0) { + publisher.changeToDemandState(NO_DEMAND); + } + } + } + } + catch (IOException ex) { + publisher.onError(ex); + } + } + // Else, either competing onDataAvailable (request vs container), or concurrent completion + } + }, + + READING { + @Override + void request(AbstractListenerReadPublisher publisher, long n) { + if (Operators.validate(n)) { + Operators.addCap(DEMAND_FIELD_UPDATER, publisher, n); + // Did a concurrent read transition to NO_DEMAND just before us? + publisher.changeToDemandState(NO_DEMAND); + } + } + }, + + COMPLETED { + @Override + void request(AbstractListenerReadPublisher publisher, long n) { + // ignore + } + @Override + void cancel(AbstractListenerReadPublisher publisher) { + // ignore + } + @Override + void onAllDataRead(AbstractListenerReadPublisher publisher) { + // ignore + } + @Override + void onError(AbstractListenerReadPublisher publisher, Throwable t) { + // ignore + } + }; + + void subscribe(AbstractListenerReadPublisher publisher, Subscriber subscriber) { + throw new IllegalStateException(toString()); + } + + void request(AbstractListenerReadPublisher publisher, long n) { + throw new IllegalStateException(toString()); + } + + void cancel(AbstractListenerReadPublisher publisher) { + if (publisher.changeState(this, COMPLETED)) { + publisher.discardData(); + } + else { + publisher.state.get().cancel(publisher); + } + } + + void onDataAvailable(AbstractListenerReadPublisher publisher) { + // ignore + } + + void onAllDataRead(AbstractListenerReadPublisher publisher) { + if (publisher.changeState(this, COMPLETED)) { + Subscriber s = publisher.subscriber; + if (s != null) { + s.onComplete(); + } + } + else { + publisher.state.get().onAllDataRead(publisher); + } + } + + void onError(AbstractListenerReadPublisher publisher, Throwable t) { + if (publisher.changeState(this, COMPLETED)) { + publisher.discardData(); + Subscriber s = publisher.subscriber; + if (s != null) { + s.onError(t); + } + } + else { + publisher.state.get().onError(publisher, t); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..0eed513e27959562616cecc300f56000c59b7f32 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.concurrent.atomic.AtomicBoolean; + +import org.reactivestreams.Processor; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; + +/** + * Abstract base class for listener-based server responses, e.g. Servlet 3.1 + * and Undertow. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public abstract class AbstractListenerServerHttpResponse extends AbstractServerHttpResponse { + + private final AtomicBoolean writeCalled = new AtomicBoolean(); + + + public AbstractListenerServerHttpResponse(DataBufferFactory dataBufferFactory) { + super(dataBufferFactory); + } + + public AbstractListenerServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { + super(dataBufferFactory, headers); + } + + + @Override + protected final Mono writeWithInternal(Publisher body) { + return writeAndFlushWithInternal(Mono.just(body)); + } + + @Override + protected final Mono writeAndFlushWithInternal( + Publisher> body) { + + if (this.writeCalled.compareAndSet(false, true)) { + Processor, Void> processor = createBodyFlushProcessor(); + return Mono.from(subscriber -> { + body.subscribe(processor); + processor.subscribe(subscriber); + }); + } + return Mono.error(new IllegalStateException( + "writeWith() or writeAndFlushWith() has already been called")); + } + + /** + * Abstract template method to create a {@code Processor, Void>} + * that will write the response body with flushes to the underlying output. Called from + * {@link #writeAndFlushWithInternal(Publisher)}. + */ + protected abstract Processor, Void> createBodyFlushProcessor(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteFlushProcessor.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteFlushProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..81c8098d32a763322cae6c7374adce2f4be403df --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteFlushProcessor.java @@ -0,0 +1,439 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Processor; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.log.LogDelegateFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * An alternative to {@link AbstractListenerWriteProcessor} but instead writing + * a {@code Publisher>} with flush boundaries enforces after + * the completion of each nested Publisher. + * + * @author Arjen Poutsma + * @author Violeta Georgieva + * @author Rossen Stoyanchev + * @since 5.0 + * @param the type of element signaled to the {@link Subscriber} + */ +public abstract class AbstractListenerWriteFlushProcessor implements Processor, Void> { + + /** + * Special logger for debugging Reactive Streams signals. + * @see LogDelegateFactory#getHiddenLog(Class) + * @see AbstractListenerReadPublisher#rsReadLogger + * @see AbstractListenerWriteProcessor#rsWriteLogger + * @see WriteResultPublisher#rsWriteResultLogger + */ + protected static final Log rsWriteFlushLogger = + LogDelegateFactory.getHiddenLog(AbstractListenerWriteFlushProcessor.class); + + + private final AtomicReference state = new AtomicReference<>(State.UNSUBSCRIBED); + + @Nullable + private Subscription subscription; + + private volatile boolean subscriberCompleted; + + private final WriteResultPublisher resultPublisher; + + private final String logPrefix; + + + public AbstractListenerWriteFlushProcessor() { + this(""); + } + + /** + * Create an instance with the given log prefix. + * @since 5.1 + */ + public AbstractListenerWriteFlushProcessor(String logPrefix) { + this.logPrefix = logPrefix; + this.resultPublisher = new WriteResultPublisher(logPrefix); + } + + + /** + * Create an instance with the given log prefix. + * @since 5.1 + */ + public String getLogPrefix() { + return this.logPrefix; + } + + + // Subscriber methods and async I/O notification methods... + + @Override + public final void onSubscribe(Subscription subscription) { + this.state.get().onSubscribe(this, subscription); + } + + @Override + public final void onNext(Publisher publisher) { + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + "Received onNext publisher"); + } + this.state.get().onNext(this, publisher); + } + + /** + * Error signal from the upstream, write Publisher. This is also used by + * sub-classes to delegate error notifications from the container. + */ + @Override + public final void onError(Throwable ex) { + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + "Received onError: " + ex); + } + this.state.get().onError(this, ex); + } + + /** + * Completion signal from the upstream, write Publisher. This is also used + * by sub-classes to delegate completion notifications from the container. + */ + @Override + public final void onComplete() { + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + "Received onComplete"); + } + this.state.get().onComplete(this); + } + + /** + * Invoked when flushing is possible, either in the same thread after a check + * via {@link #isWritePossible()}, or as a callback from the underlying + * container. + */ + protected final void onFlushPossible() { + this.state.get().onFlushPossible(this); + } + + /** + * Invoked during an error or completion callback from the underlying + * container to cancel the upstream subscription. + */ + protected void cancel() { + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + "Received request to cancel"); + } + if (this.subscription != null) { + this.subscription.cancel(); + } + } + + + // Publisher implementation for result notifications... + + @Override + public final void subscribe(Subscriber subscriber) { + this.resultPublisher.subscribe(subscriber); + } + + + // Write API methods to be implemented or template methods to override... + + /** + * Create a new processor for the current flush boundary. + */ + protected abstract Processor createWriteProcessor(); + + /** + * Whether writing/flushing is possible. + */ + protected abstract boolean isWritePossible(); + + /** + * Flush the output if ready, or otherwise {@link #isFlushPending()} should + * return true after. + *

This is primarily for the Servlet non-blocking I/O API where flush + * cannot be called without a readyToWrite check. + */ + protected abstract void flush() throws IOException; + + /** + * Whether flushing is pending. + *

This is primarily for the Servlet non-blocking I/O API where flush + * cannot be called without a readyToWrite check. + */ + protected abstract boolean isFlushPending(); + + /** + * Invoked when an error happens while flushing. Sub-classes may choose + * to ignore this if they know the underlying API will provide an error + * notification in a container thread. + *

Defaults to no-op. + */ + protected void flushingFailed(Throwable t) { + } + + + // Private methods for use in State... + + private boolean changeState(State oldState, State newState) { + boolean result = this.state.compareAndSet(oldState, newState); + if (result && rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + oldState + " -> " + newState); + } + return result; + } + + private void flushIfPossible() { + boolean result = isWritePossible(); + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + "isWritePossible[" + result + "]"); + } + if (result) { + onFlushPossible(); + } + } + + + /** + * Represents a state for the {@link Processor} to be in. + * + *

+	 *       UNSUBSCRIBED
+	 *            |
+	 *            v
+	 *        REQUESTED <---> RECEIVED ------+
+	 *            |              |           |
+	 *            |              v           |
+	 *            |           FLUSHING       |
+	 *            |              |           |
+	 *            |              v           |
+	 *            +--------> COMPLETED <-----+
+	 * 
+ */ + private enum State { + + UNSUBSCRIBED { + @Override + public void onSubscribe(AbstractListenerWriteFlushProcessor processor, Subscription subscription) { + Assert.notNull(subscription, "Subscription must not be null"); + if (processor.changeState(this, REQUESTED)) { + processor.subscription = subscription; + subscription.request(1); + } + else { + super.onSubscribe(processor, subscription); + } + } + + @Override + public void onComplete(AbstractListenerWriteFlushProcessor processor) { + // This can happen on (very early) completion notification from container.. + if (processor.changeState(this, COMPLETED)) { + processor.resultPublisher.publishComplete(); + } + else { + processor.state.get().onComplete(processor); + } + } + }, + + REQUESTED { + @Override + public void onNext(AbstractListenerWriteFlushProcessor processor, + Publisher currentPublisher) { + + if (processor.changeState(this, RECEIVED)) { + Processor currentProcessor = processor.createWriteProcessor(); + currentPublisher.subscribe(currentProcessor); + currentProcessor.subscribe(new WriteResultSubscriber(processor)); + } + } + @Override + public void onComplete(AbstractListenerWriteFlushProcessor processor) { + if (processor.changeState(this, COMPLETED)) { + processor.resultPublisher.publishComplete(); + } + else { + processor.state.get().onComplete(processor); + } + } + }, + + RECEIVED { + @Override + public void writeComplete(AbstractListenerWriteFlushProcessor processor) { + try { + processor.flush(); + } + catch (Throwable ex) { + processor.flushingFailed(ex); + return; + } + if (processor.changeState(this, REQUESTED)) { + if (processor.subscriberCompleted) { + handleSubscriberCompleted(processor); + } + else { + Assert.state(processor.subscription != null, "No subscription"); + processor.subscription.request(1); + } + } + } + @Override + public void onComplete(AbstractListenerWriteFlushProcessor processor) { + processor.subscriberCompleted = true; + // A competing write might have completed very quickly + if (processor.state.get().equals(State.REQUESTED)) { + handleSubscriberCompleted(processor); + } + } + + private void handleSubscriberCompleted(AbstractListenerWriteFlushProcessor processor) { + if (processor.isFlushPending()) { + // Ensure the final flush + processor.changeState(State.REQUESTED, State.FLUSHING); + processor.flushIfPossible(); + } + else if (processor.changeState(State.REQUESTED, State.COMPLETED)) { + processor.resultPublisher.publishComplete(); + } + else { + processor.state.get().onComplete(processor); + } + } + }, + + FLUSHING { + @Override + public void onFlushPossible(AbstractListenerWriteFlushProcessor processor) { + try { + processor.flush(); + } + catch (Throwable ex) { + processor.flushingFailed(ex); + return; + } + if (processor.changeState(this, COMPLETED)) { + processor.resultPublisher.publishComplete(); + } + else { + processor.state.get().onComplete(processor); + } + } + @Override + public void onNext(AbstractListenerWriteFlushProcessor proc, Publisher pub) { + // ignore + } + @Override + public void onComplete(AbstractListenerWriteFlushProcessor processor) { + // ignore + } + }, + + COMPLETED { + @Override + public void onNext(AbstractListenerWriteFlushProcessor proc, Publisher pub) { + // ignore + } + @Override + public void onError(AbstractListenerWriteFlushProcessor processor, Throwable t) { + // ignore + } + @Override + public void onComplete(AbstractListenerWriteFlushProcessor processor) { + // ignore + } + }; + + + public void onSubscribe(AbstractListenerWriteFlushProcessor proc, Subscription subscription) { + subscription.cancel(); + } + + public void onNext(AbstractListenerWriteFlushProcessor proc, Publisher pub) { + throw new IllegalStateException(toString()); + } + + public void onError(AbstractListenerWriteFlushProcessor processor, Throwable ex) { + if (processor.changeState(this, COMPLETED)) { + processor.resultPublisher.publishError(ex); + } + else { + processor.state.get().onError(processor, ex); + } + } + + public void onComplete(AbstractListenerWriteFlushProcessor processor) { + throw new IllegalStateException(toString()); + } + + public void writeComplete(AbstractListenerWriteFlushProcessor processor) { + throw new IllegalStateException(toString()); + } + + public void onFlushPossible(AbstractListenerWriteFlushProcessor processor) { + // ignore + } + + + /** + * Subscriber to receive and delegate completion notifications for from + * the current Publisher, i.e. for the current flush boundary. + */ + private static class WriteResultSubscriber implements Subscriber { + + private final AbstractListenerWriteFlushProcessor processor; + + + public WriteResultSubscriber(AbstractListenerWriteFlushProcessor processor) { + this.processor = processor; + } + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + } + + @Override + public void onError(Throwable ex) { + this.processor.cancel(); + this.processor.onError(ex); + } + + @Override + public void onComplete() { + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(this.processor.getLogPrefix() + this.processor.state + " writeComplete"); + } + this.processor.state.get().writeComplete(this.processor); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..3f37a955806845a6fbd78321cc9633acdd6b03bc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java @@ -0,0 +1,462 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Processor; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.log.LogDelegateFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Abstract base class for {@code Processor} implementations that bridge between + * event-listener write APIs and Reactive Streams. + * + *

Specifically a base class for writing to the HTTP response body with + * Servlet 3.1 non-blocking I/O and Undertow XNIO as well for writing WebSocket + * messages through the Java WebSocket API (JSR-356), Jetty, and Undertow. + * + * @author Arjen Poutsma + * @author Violeta Georgieva + * @author Rossen Stoyanchev + * @since 5.0 + * @param the type of element signaled to the {@link Subscriber} + */ +public abstract class AbstractListenerWriteProcessor implements Processor { + + /** + * Special logger for debugging Reactive Streams signals. + * @see LogDelegateFactory#getHiddenLog(Class) + * @see AbstractListenerReadPublisher#rsReadLogger + * @see AbstractListenerWriteFlushProcessor#rsWriteFlushLogger + * @see WriteResultPublisher#rsWriteResultLogger + */ + protected static final Log rsWriteLogger = LogDelegateFactory.getHiddenLog(AbstractListenerWriteProcessor.class); + + + private final AtomicReference state = new AtomicReference<>(State.UNSUBSCRIBED); + + @Nullable + private Subscription subscription; + + @Nullable + private volatile T currentData; + + /* Indicates "onComplete" was received during the (last) write. */ + private volatile boolean subscriberCompleted; + + /** + * Indicates we're waiting for one last isReady-onWritePossible cycle + * after "onComplete" because some Servlet containers expect this to take + * place prior to calling AsyncContext.complete(). + * See https://github.com/eclipse-ee4j/servlet-api/issues/273 + */ + private volatile boolean readyToCompleteAfterLastWrite; + + private final WriteResultPublisher resultPublisher; + + private final String logPrefix; + + + public AbstractListenerWriteProcessor() { + this(""); + } + + /** + * Create an instance with the given log prefix. + * @since 5.1 + */ + public AbstractListenerWriteProcessor(String logPrefix) { + this.logPrefix = logPrefix; + this.resultPublisher = new WriteResultPublisher(logPrefix); + } + + + /** + * Create an instance with the given log prefix. + * @since 5.1 + */ + public String getLogPrefix() { + return this.logPrefix; + } + + + // Subscriber methods and async I/O notification methods... + + @Override + public final void onSubscribe(Subscription subscription) { + this.state.get().onSubscribe(this, subscription); + } + + @Override + public final void onNext(T data) { + if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "Item to write"); + } + this.state.get().onNext(this, data); + } + + /** + * Error signal from the upstream, write Publisher. This is also used by + * sub-classes to delegate error notifications from the container. + */ + @Override + public final void onError(Throwable ex) { + if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "Write source error: " + ex); + } + this.state.get().onError(this, ex); + } + + /** + * Completion signal from the upstream, write Publisher. This is also used + * by sub-classes to delegate completion notifications from the container. + */ + @Override + public final void onComplete() { + if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "No more items to write"); + } + this.state.get().onComplete(this); + } + + /** + * Invoked when writing is possible, either in the same thread after a check + * via {@link #isWritePossible()}, or as a callback from the underlying + * container. + */ + public final void onWritePossible() { + if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "onWritePossible"); + } + this.state.get().onWritePossible(this); + } + + /** + * Invoked during an error or completion callback from the underlying + * container to cancel the upstream subscription. + */ + public void cancel() { + rsWriteLogger.trace(getLogPrefix() + "Cancellation"); + if (this.subscription != null) { + this.subscription.cancel(); + } + } + + // Publisher implementation for result notifications... + + @Override + public final void subscribe(Subscriber subscriber) { + // Technically, cancellation from the result subscriber should be propagated + // to the upstream subscription. In practice, HttpHandler server adapters + // don't have a reason to cancel the result subscription. + this.resultPublisher.subscribe(subscriber); + } + + + // Write API methods to be implemented or template methods to override... + + /** + * Whether the given data item has any content to write. + * If false the item is not written. + */ + protected abstract boolean isDataEmpty(T data); + + /** + * Template method invoked after a data item to write is received via + * {@link Subscriber#onNext(Object)}. The default implementation saves the + * data item for writing once that is possible. + */ + protected void dataReceived(T data) { + T prev = this.currentData; + if (prev != null) { + // This shouldn't happen: + // 1. dataReceived can only be called from REQUESTED state + // 2. currentData is cleared before requesting + discardData(data); + cancel(); + onError(new IllegalStateException("Received new data while current not processed yet.")); + } + this.currentData = data; + } + + /** + * Whether writing is possible. + */ + protected abstract boolean isWritePossible(); + + /** + * Write the given item. + *

Note: Sub-classes are responsible for releasing any + * data buffer associated with the item, once fully written, if pooled + * buffers apply to the underlying container. + * @param data the item to write + * @return whether the current data item was written and another one + * requested ({@code true}), or or otherwise if more writes are required. + */ + protected abstract boolean write(T data) throws IOException; + + /** + * Invoked after the current data has been written and before requesting + * the next item from the upstream, write Publisher. + *

The default implementation is a no-op. + * @deprecated originally introduced for Undertow to stop write notifications + * when no data is available, but deprecated as of as of 5.0.6 since constant + * switching on every requested item causes a significant slowdown. + */ + @Deprecated + protected void writingPaused() { + } + + /** + * Invoked after onComplete or onError notification. + *

The default implementation is a no-op. + */ + protected void writingComplete() { + } + + /** + * Invoked when an I/O error occurs during a write. Sub-classes may choose + * to ignore this if they know the underlying API will provide an error + * notification in a container thread. + *

Defaults to no-op. + */ + protected void writingFailed(Throwable ex) { + } + + /** + * Invoked after any error (either from the upstream write Publisher, or + * from I/O operations to the underlying server) and cancellation + * to discard in-flight data that was in + * the process of being written when the error took place. + * @param data the data to be released + * @since 5.0.11 + */ + protected abstract void discardData(T data); + + + // Private methods for use from State's... + + private boolean changeState(State oldState, State newState) { + boolean result = this.state.compareAndSet(oldState, newState); + if (result && rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + oldState + " -> " + newState); + } + return result; + } + + private void changeStateToReceived(State oldState) { + if (changeState(oldState, State.RECEIVED)) { + writeIfPossible(); + } + } + + private void changeStateToComplete(State oldState) { + if (changeState(oldState, State.COMPLETED)) { + discardCurrentData(); + writingComplete(); + this.resultPublisher.publishComplete(); + } + else { + this.state.get().onComplete(this); + } + } + + private void writeIfPossible() { + boolean result = isWritePossible(); + if (!result && rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "isWritePossible: false"); + } + if (result) { + onWritePossible(); + } + } + + private void discardCurrentData() { + T data = this.currentData; + this.currentData = null; + if (data != null) { + discardData(data); + } + } + + + /** + * Represents a state for the {@link Processor} to be in. + * + *

+	 *        UNSUBSCRIBED
+	 *             |
+	 *             v
+	 *   +--- REQUESTED -------------> RECEIVED ---+
+	 *   |        ^                       ^        |
+	 *   |        |                       |        |
+	 *   |        + ------ WRITING <------+        |
+	 *   |                    |                    |
+	 *   |                    v                    |
+	 *   +--------------> COMPLETED <--------------+
+	 * 
+ */ + private enum State { + + UNSUBSCRIBED { + @Override + public void onSubscribe(AbstractListenerWriteProcessor processor, Subscription subscription) { + Assert.notNull(subscription, "Subscription must not be null"); + if (processor.changeState(this, REQUESTED)) { + processor.subscription = subscription; + subscription.request(1); + } + else { + super.onSubscribe(processor, subscription); + } + } + + @Override + public void onComplete(AbstractListenerWriteProcessor processor) { + // This can happen on (very early) completion notification from container.. + processor.changeStateToComplete(this); + } + }, + + REQUESTED { + @Override + public void onNext(AbstractListenerWriteProcessor processor, T data) { + if (processor.isDataEmpty(data)) { + Assert.state(processor.subscription != null, "No subscription"); + processor.subscription.request(1); + } + else { + processor.dataReceived(data); + processor.changeStateToReceived(this); + } + } + @Override + public void onComplete(AbstractListenerWriteProcessor processor) { + processor.readyToCompleteAfterLastWrite = true; + processor.changeStateToReceived(this); + } + }, + + RECEIVED { + @SuppressWarnings("deprecation") + @Override + public void onWritePossible(AbstractListenerWriteProcessor processor) { + if (processor.readyToCompleteAfterLastWrite) { + processor.changeStateToComplete(RECEIVED); + } + else if (processor.changeState(this, WRITING)) { + T data = processor.currentData; + Assert.state(data != null, "No data"); + try { + if (processor.write(data)) { + if (processor.changeState(WRITING, REQUESTED)) { + processor.currentData = null; + if (processor.subscriberCompleted) { + processor.readyToCompleteAfterLastWrite = true; + processor.changeStateToReceived(REQUESTED); + } + else { + processor.writingPaused(); + Assert.state(processor.subscription != null, "No subscription"); + processor.subscription.request(1); + } + } + } + else { + processor.changeStateToReceived(WRITING); + } + } + catch (IOException ex) { + processor.writingFailed(ex); + } + } + } + + @Override + public void onComplete(AbstractListenerWriteProcessor processor) { + processor.subscriberCompleted = true; + // A competing write might have completed very quickly + if (processor.state.get().equals(State.REQUESTED)) { + processor.changeStateToComplete(State.REQUESTED); + } + } + }, + + WRITING { + @Override + public void onComplete(AbstractListenerWriteProcessor processor) { + processor.subscriberCompleted = true; + // A competing write might have completed very quickly + if (processor.state.get().equals(State.REQUESTED)) { + processor.changeStateToComplete(State.REQUESTED); + } + } + }, + + COMPLETED { + @Override + public void onNext(AbstractListenerWriteProcessor processor, T data) { + // ignore + } + @Override + public void onError(AbstractListenerWriteProcessor processor, Throwable ex) { + // ignore + } + @Override + public void onComplete(AbstractListenerWriteProcessor processor) { + // ignore + } + }; + + public void onSubscribe(AbstractListenerWriteProcessor processor, Subscription subscription) { + subscription.cancel(); + } + + public void onNext(AbstractListenerWriteProcessor processor, T data) { + processor.discardData(data); + processor.cancel(); + processor.onError(new IllegalStateException("Illegal onNext without demand")); + } + + public void onError(AbstractListenerWriteProcessor processor, Throwable ex) { + if (processor.changeState(this, COMPLETED)) { + processor.discardCurrentData(); + processor.writingComplete(); + processor.resultPublisher.publishError(ex); + } + else { + processor.state.get().onError(processor, ex); + } + } + + public void onComplete(AbstractListenerWriteProcessor processor) { + throw new IllegalStateException(toString()); + } + + public void onWritePossible(AbstractListenerWriteProcessor processor) { + // ignore + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..84c2e9597ad5427b2e4cd3473f125cdd8d15a02a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpRequest.java @@ -0,0 +1,220 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URLDecoder; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.commons.logging.Log; + +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpLogging; +import org.springframework.http.server.RequestPath; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * Common base class for {@link ServerHttpRequest} implementations. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public abstract class AbstractServerHttpRequest implements ServerHttpRequest { + + private static final Pattern QUERY_PATTERN = Pattern.compile("([^&=]+)(=?)([^&]+)?"); + + + protected final Log logger = HttpLogging.forLogName(getClass()); + + private final URI uri; + + private final RequestPath path; + + private final HttpHeaders headers; + + @Nullable + private MultiValueMap queryParams; + + @Nullable + private MultiValueMap cookies; + + @Nullable + private SslInfo sslInfo; + + @Nullable + private String id; + + @Nullable + private String logPrefix; + + + /** + * Constructor with the URI and headers for the request. + * @param uri the URI for the request + * @param contextPath the context path for the request + * @param headers the headers for the request + */ + public AbstractServerHttpRequest(URI uri, @Nullable String contextPath, HttpHeaders headers) { + this.uri = uri; + this.path = RequestPath.parse(uri, contextPath); + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + } + + + public String getId() { + if (this.id == null) { + this.id = initId(); + if (this.id == null) { + this.id = ObjectUtils.getIdentityHexString(this); + } + } + return this.id; + } + + /** + * Obtain the request id to use, or {@code null} in which case the Object + * identity of this request instance is used. + * @since 5.1 + */ + @Nullable + protected String initId() { + return null; + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public RequestPath getPath() { + return this.path; + } + + @Override + public HttpHeaders getHeaders() { + return this.headers; + } + + @Override + public MultiValueMap getQueryParams() { + if (this.queryParams == null) { + this.queryParams = CollectionUtils.unmodifiableMultiValueMap(initQueryParams()); + } + return this.queryParams; + } + + /** + * A method for parsing of the query into name-value pairs. The return + * value is turned into an immutable map and cached. + *

Note that this method is invoked lazily on first access to + * {@link #getQueryParams()}. The invocation is not synchronized but the + * parsing is thread-safe nevertheless. + */ + protected MultiValueMap initQueryParams() { + MultiValueMap queryParams = new LinkedMultiValueMap<>(); + String query = getURI().getRawQuery(); + if (query != null) { + Matcher matcher = QUERY_PATTERN.matcher(query); + while (matcher.find()) { + String name = decodeQueryParam(matcher.group(1)); + String eq = matcher.group(2); + String value = matcher.group(3); + value = (value != null ? decodeQueryParam(value) : (StringUtils.hasLength(eq) ? "" : null)); + queryParams.add(name, value); + } + } + return queryParams; + } + + @SuppressWarnings("deprecation") + private String decodeQueryParam(String value) { + try { + return URLDecoder.decode(value, "UTF-8"); + } + catch (UnsupportedEncodingException ex) { + if (logger.isWarnEnabled()) { + logger.warn(getLogPrefix() + "Could not decode query value [" + value + "] as 'UTF-8'. " + + "Falling back on default encoding: " + ex.getMessage()); + } + return URLDecoder.decode(value); + } + } + + @Override + public MultiValueMap getCookies() { + if (this.cookies == null) { + this.cookies = CollectionUtils.unmodifiableMultiValueMap(initCookies()); + } + return this.cookies; + } + + /** + * Obtain the cookies from the underlying "native" request and adapt those to + * an {@link HttpCookie} map. The return value is turned into an immutable + * map and cached. + *

Note that this method is invoked lazily on access to + * {@link #getCookies()}. Sub-classes should synchronize cookie + * initialization if the underlying "native" request does not provide + * thread-safe access to cookie data. + */ + protected abstract MultiValueMap initCookies(); + + @Nullable + @Override + public SslInfo getSslInfo() { + if (this.sslInfo == null) { + this.sslInfo = initSslInfo(); + } + return this.sslInfo; + } + + /** + * Obtain SSL session information from the underlying "native" request. + * @return the session information, or {@code null} if none available + * @since 5.0.2 + */ + @Nullable + protected abstract SslInfo initSslInfo(); + + /** + * Return the underlying server response. + *

Note: This is exposed mainly for internal framework + * use such as WebSocket upgrades in the spring-webflux module. + */ + public abstract T getNativeRequest(); + + /** + * For internal use in logging at the HTTP adapter layer. + * @since 5.1 + */ + String getLogPrefix() { + if (this.logPrefix == null) { + this.logPrefix = "[" + getId() + "] "; + } + return this.logPrefix; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..cd07408b9569f4faf899ac6c4f479b0e637ed779 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -0,0 +1,301 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Base class for {@link ServerHttpResponse} implementations. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @author Brian Clozel + * @since 5.0 + */ +public abstract class AbstractServerHttpResponse implements ServerHttpResponse { + + /** + * COMMITTING -> COMMITTED is the period after doCommit is called but before + * the response status and headers have been applied to the underlying + * response during which time pre-commit actions can still make changes to + * the response status and headers. + */ + private enum State {NEW, COMMITTING, COMMITTED} + + protected final Log logger = HttpLogging.forLogName(getClass()); + + + private final DataBufferFactory dataBufferFactory; + + @Nullable + private Integer statusCode; + + private final HttpHeaders headers; + + private final MultiValueMap cookies; + + private final AtomicReference state = new AtomicReference<>(State.NEW); + + private final List>> commitActions = new ArrayList<>(4); + + + public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory) { + this(dataBufferFactory, new HttpHeaders()); + } + + public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { + Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); + Assert.notNull(headers, "HttpHeaders must not be null"); + this.dataBufferFactory = dataBufferFactory; + this.headers = headers; + this.cookies = new LinkedMultiValueMap<>(); + } + + + @Override + public final DataBufferFactory bufferFactory() { + return this.dataBufferFactory; + } + + @Override + public boolean setStatusCode(@Nullable HttpStatus status) { + if (this.state.get() == State.COMMITTED) { + return false; + } + else { + this.statusCode = (status != null ? status.value() : null); + return true; + } + } + + @Override + @Nullable + public HttpStatus getStatusCode() { + return (this.statusCode != null ? HttpStatus.resolve(this.statusCode) : null); + } + + /** + * Set the HTTP status code of the response. + * @param statusCode the HTTP status as an integer value + * @since 5.0.1 + */ + public void setStatusCodeValue(@Nullable Integer statusCode) { + this.statusCode = statusCode; + } + + /** + * Return the HTTP status code of the response. + * @return the HTTP status as an integer value + * @since 5.0.1 + */ + @Nullable + public Integer getStatusCodeValue() { + return this.statusCode; + } + + @Override + public HttpHeaders getHeaders() { + return (this.state.get() == State.COMMITTED ? + HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public MultiValueMap getCookies() { + return (this.state.get() == State.COMMITTED ? + CollectionUtils.unmodifiableMultiValueMap(this.cookies) : this.cookies); + } + + @Override + public void addCookie(ResponseCookie cookie) { + Assert.notNull(cookie, "ResponseCookie must not be null"); + + if (this.state.get() == State.COMMITTED) { + throw new IllegalStateException("Can't add the cookie " + cookie + + "because the HTTP response has already been committed"); + } + else { + getCookies().add(cookie.getName(), cookie); + } + } + + /** + * Return the underlying server response. + *

Note: This is exposed mainly for internal framework + * use such as WebSocket upgrades in the spring-webflux module. + */ + public abstract T getNativeResponse(); + + + @Override + public void beforeCommit(Supplier> action) { + this.commitActions.add(action); + } + + @Override + public boolean isCommitted() { + return this.state.get() != State.NEW; + } + + @Override + @SuppressWarnings("unchecked") + public final Mono writeWith(Publisher body) { + // Write as Mono if possible as an optimization hint to Reactor Netty + // ChannelSendOperator not necessary for Mono + if (body instanceof Mono) { + return ((Mono) body) + .flatMap(buffer -> { + AtomicReference subscribed = new AtomicReference<>(false); + return doCommit( + () -> { + try { + return writeWithInternal(Mono.fromCallable(() -> buffer) + .doOnSubscribe(s -> subscribed.set(true)) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)); + } + catch (Throwable ex) { + return Mono.error(ex); + } + }) + .doOnError(ex -> DataBufferUtils.release(buffer)) + .doOnCancel(() -> { + if (!subscribed.get()) { + DataBufferUtils.release(buffer); + } + }); + }) + .doOnError(t -> removeContentLength()) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); + } + return new ChannelSendOperator<>(body, inner -> doCommit(() -> writeWithInternal(inner))) + .doOnError(t -> removeContentLength()); + } + + @Override + public final Mono writeAndFlushWith(Publisher> body) { + return new ChannelSendOperator<>(body, inner -> doCommit(() -> writeAndFlushWithInternal(inner))) + .doOnError(t -> removeContentLength()); + } + + private void removeContentLength() { + if (!this.isCommitted()) { + this.getHeaders().remove(HttpHeaders.CONTENT_LENGTH); + } + } + + @Override + public Mono setComplete() { + return !isCommitted() ? doCommit(null) : Mono.empty(); + } + + /** + * A variant of {@link #doCommit(Supplier)} for a response without no body. + * @return a completion publisher + */ + protected Mono doCommit() { + return doCommit(null); + } + + /** + * Apply {@link #beforeCommit(Supplier) beforeCommit} actions, apply the + * response status and headers/cookies, and write the response body. + * @param writeAction the action to write the response body (may be {@code null}) + * @return a completion publisher + */ + protected Mono doCommit(@Nullable Supplier> writeAction) { + if (!this.state.compareAndSet(State.NEW, State.COMMITTING)) { + return Mono.empty(); + } + + this.commitActions.add(() -> + Mono.fromRunnable(() -> { + applyStatusCode(); + applyHeaders(); + applyCookies(); + this.state.set(State.COMMITTED); + })); + + if (writeAction != null) { + this.commitActions.add(writeAction); + } + + List> actions = this.commitActions.stream() + .map(Supplier::get).collect(Collectors.toList()); + + return Flux.concat(actions).then(); + } + + + /** + * Write to the underlying the response. + * @param body the publisher to write with + */ + protected abstract Mono writeWithInternal(Publisher body); + + /** + * Write to the underlying the response, and flush after each {@code Publisher}. + * @param body the publisher to write and flush with + */ + protected abstract Mono writeAndFlushWithInternal(Publisher> body); + + /** + * Write the status code to the underlying response. + * This method is called once only. + */ + protected abstract void applyStatusCode(); + + /** + * Invoked when the response is getting committed allowing sub-classes to + * make apply header values to the underlying response. + *

Note that most sub-classes use an {@link HttpHeaders} instance that + * wraps an adapter to the native response headers such that changes are + * propagated to the underlying response on the go. That means this callback + * is typically not used other than for specialized updates such as setting + * the contentType or characterEncoding fields in a Servlet response. + */ + protected abstract void applyHeaders(); + + /** + * Add cookies from {@link #getHeaders()} to the underlying response. + * This method is called once only. + */ + protected abstract void applyCookies(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..da46443edc2f9edc0bca9935618ed78423ed4d99 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java @@ -0,0 +1,447 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Given a write function that accepts a source {@code Publisher} to write + * with and returns {@code Publisher} for the result, this operator helps + * to defer the invocation of the write function, until we know if the source + * publisher will begin publishing without an error. If the first emission is + * an error, the write function is bypassed, and the error is sent directly + * through the result publisher. Otherwise the write function is invoked. + * + * @author Rossen Stoyanchev + * @author Stephane Maldini + * @since 5.0 + * @param the type of element signaled + */ +public class ChannelSendOperator extends Mono implements Scannable { + + private final Function, Publisher> writeFunction; + + private final Flux source; + + + public ChannelSendOperator(Publisher source, Function, Publisher> writeFunction) { + this.source = Flux.from(source); + this.writeFunction = writeFunction; + } + + + @Override + @Nullable + @SuppressWarnings("rawtypes") + public Object scanUnsafe(Attr key) { + if (key == Attr.PREFETCH) { + return Integer.MAX_VALUE; + } + if (key == Attr.PARENT) { + return this.source; + } + return null; + } + + @Override + public void subscribe(CoreSubscriber actual) { + this.source.subscribe(new WriteBarrier(actual)); + } + + + private enum State { + + /** No emissions from the upstream source yet. */ + NEW, + + /** + * At least one signal of any kind has been received; we're ready to + * call the write function and proceed with actual writing. + */ + FIRST_SIGNAL_RECEIVED, + + /** + * The write subscriber has subscribed and requested; we're going to + * emit the cached signals. + */ + EMITTING_CACHED_SIGNALS, + + /** + * The write subscriber has subscribed, and cached signals have been + * emitted to it; we're ready to switch to a simple pass-through mode + * for all remaining signals. + **/ + READY_TO_WRITE + + } + + + /** + * A barrier inserted between the write source and the write subscriber + * (i.e. the HTTP server adapter) that pre-fetches and waits for the first + * signal before deciding whether to hook in to the write subscriber. + * + *

Acts as: + *

    + *
  • Subscriber to the write source. + *
  • Subscription to the write subscriber. + *
  • Publisher to the write subscriber. + *
+ * + *

Also uses {@link WriteCompletionBarrier} to communicate completion + * and detect cancel signals from the completion subscriber. + */ + private class WriteBarrier implements CoreSubscriber, Subscription, Publisher { + + /* Bridges signals to and from the completionSubscriber */ + private final WriteCompletionBarrier writeCompletionBarrier; + + /* Upstream write source subscription */ + @Nullable + private Subscription subscription; + + /** Cached data item before readyToWrite. */ + @Nullable + private T item; + + /** Cached error signal before readyToWrite. */ + @Nullable + private Throwable error; + + /** Cached onComplete signal before readyToWrite. */ + private boolean completed = false; + + /** Recursive demand while emitting cached signals. */ + private long demandBeforeReadyToWrite; + + /** Current state. */ + private State state = State.NEW; + + /** The actual writeSubscriber from the HTTP server adapter. */ + @Nullable + private Subscriber writeSubscriber; + + + WriteBarrier(CoreSubscriber completionSubscriber) { + this.writeCompletionBarrier = new WriteCompletionBarrier(completionSubscriber, this); + } + + + // Subscriber methods (we're the subscriber to the write source).. + + @Override + public final void onSubscribe(Subscription s) { + if (Operators.validate(this.subscription, s)) { + this.subscription = s; + this.writeCompletionBarrier.connect(); + s.request(1); + } + } + + @Override + public final void onNext(T item) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onNext(item); + return; + } + //FIXME revisit in case of reentrant sync deadlock + synchronized (this) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onNext(item); + } + else if (this.state == State.NEW) { + this.item = item; + this.state = State.FIRST_SIGNAL_RECEIVED; + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); + } + else { + if (this.subscription != null) { + this.subscription.cancel(); + } + this.writeCompletionBarrier.onError(new IllegalStateException("Unexpected item.")); + } + } + } + + private Subscriber requiredWriteSubscriber() { + Assert.state(this.writeSubscriber != null, "No write subscriber"); + return this.writeSubscriber; + } + + @Override + public final void onError(Throwable ex) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onError(ex); + return; + } + synchronized (this) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onError(ex); + } + else if (this.state == State.NEW) { + this.state = State.FIRST_SIGNAL_RECEIVED; + this.writeCompletionBarrier.onError(ex); + } + else { + this.error = ex; + } + } + } + + @Override + public final void onComplete() { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onComplete(); + return; + } + synchronized (this) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onComplete(); + } + else if (this.state == State.NEW) { + this.completed = true; + this.state = State.FIRST_SIGNAL_RECEIVED; + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); + } + else { + this.completed = true; + } + } + } + + @Override + public Context currentContext() { + return this.writeCompletionBarrier.currentContext(); + } + + + // Subscription methods (we're the Subscription to the writeSubscriber).. + + @Override + public void request(long n) { + Subscription s = this.subscription; + if (s == null) { + return; + } + if (this.state == State.READY_TO_WRITE) { + s.request(n); + return; + } + synchronized (this) { + if (this.writeSubscriber != null) { + if (this.state == State.EMITTING_CACHED_SIGNALS) { + this.demandBeforeReadyToWrite = n; + return; + } + try { + this.state = State.EMITTING_CACHED_SIGNALS; + if (emitCachedSignals()) { + return; + } + n = n + this.demandBeforeReadyToWrite - 1; + if (n == 0) { + return; + } + } + finally { + this.state = State.READY_TO_WRITE; + } + } + } + s.request(n); + } + + private boolean emitCachedSignals() { + if (this.error != null) { + try { + requiredWriteSubscriber().onError(this.error); + } + finally { + releaseCachedItem(); + } + return true; + } + T item = this.item; + this.item = null; + if (item != null) { + requiredWriteSubscriber().onNext(item); + } + if (this.completed) { + requiredWriteSubscriber().onComplete(); + return true; + } + return false; + } + + @Override + public void cancel() { + Subscription s = this.subscription; + if (s != null) { + this.subscription = null; + try { + s.cancel(); + } + finally { + releaseCachedItem(); + } + } + } + + private void releaseCachedItem() { + synchronized (this) { + Object item = this.item; + if (item instanceof DataBuffer) { + DataBufferUtils.release((DataBuffer) item); + } + this.item = null; + } + } + + + // Publisher methods (we're the Publisher to the writeSubscriber).. + + @Override + public void subscribe(Subscriber writeSubscriber) { + synchronized (this) { + Assert.state(this.writeSubscriber == null, "Only one write subscriber supported"); + this.writeSubscriber = writeSubscriber; + if (this.error != null || this.completed) { + this.writeSubscriber.onSubscribe(Operators.emptySubscription()); + emitCachedSignals(); + } + else { + this.writeSubscriber.onSubscribe(this); + } + } + } + } + + + /** + * We need an extra barrier between the WriteBarrier itself and the actual + * completion subscriber. + * + *

The completionSubscriber is subscribed initially to the WriteBarrier. + * Later after the first signal is received, we need one more subscriber + * instance (per spec can only subscribe once) to subscribe to the write + * function and switch to delegating completion signals from it. + */ + private class WriteCompletionBarrier implements CoreSubscriber, Subscription { + + /* Downstream write completion subscriber */ + private final CoreSubscriber completionSubscriber; + + private final WriteBarrier writeBarrier; + + @Nullable + private Subscription subscription; + + + public WriteCompletionBarrier(CoreSubscriber subscriber, WriteBarrier writeBarrier) { + this.completionSubscriber = subscriber; + this.writeBarrier = writeBarrier; + } + + + /** + * Connect the underlying completion subscriber to this barrier in order + * to track cancel signals and pass them on to the write barrier. + */ + public void connect() { + this.completionSubscriber.onSubscribe(this); + } + + // Subscriber methods (we're the subscriber to the write function).. + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + } + + @Override + public void onError(Throwable ex) { + try { + this.completionSubscriber.onError(ex); + } + finally { + this.writeBarrier.releaseCachedItem(); + } + } + + @Override + public void onComplete() { + this.completionSubscriber.onComplete(); + } + + @Override + public Context currentContext() { + return this.completionSubscriber.currentContext(); + } + + + @Override + public void request(long n) { + // Ignore: we don't produce data + } + + @Override + public void cancel() { + this.writeBarrier.cancel(); + Subscription subscription = this.subscription; + if (subscription != null) { + subscription.cancel(); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ContextPathCompositeHandler.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ContextPathCompositeHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..42d63aea5258be81acafbcf97c7ec92afddbd4de --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ContextPathCompositeHandler.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.LinkedHashMap; +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.util.Assert; + +/** + * {@code HttpHandler} delegating requests to one of several {@code HttpHandler}'s + * based on simple, prefix-based mappings. + * + *

This is intended as a coarse-grained mechanism for delegating requests to + * one of several applications -- each represented by an {@code HttpHandler}, with + * the application "context path" (the prefix-based mapping) exposed via + * {@link ServerHttpRequest#getPath()}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ContextPathCompositeHandler implements HttpHandler { + + private final Map handlerMap; + + + public ContextPathCompositeHandler(Map handlerMap) { + Assert.notEmpty(handlerMap, "Handler map must not be empty"); + this.handlerMap = initHandlers(handlerMap); + } + + private static Map initHandlers(Map map) { + map.keySet().forEach(ContextPathCompositeHandler::assertValidContextPath); + return new LinkedHashMap<>(map); + } + + private static void assertValidContextPath(String contextPath) { + Assert.hasText(contextPath, "Context path must not be empty"); + if (contextPath.equals("/")) { + return; + } + Assert.isTrue(contextPath.startsWith("/"), "Context path must begin with '/'"); + Assert.isTrue(!contextPath.endsWith("/"), "Context path must not end with '/'"); + } + + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + // Remove underlying context path first (e.g. Servlet container) + String path = request.getPath().pathWithinApplication().value(); + return this.handlerMap.entrySet().stream() + .filter(entry -> path.startsWith(entry.getKey())) + .findFirst() + .map(entry -> { + String contextPath = request.getPath().contextPath().value() + entry.getKey(); + ServerHttpRequest newRequest = request.mutate().contextPath(contextPath).build(); + return entry.getValue().handle(newRequest, response); + }) + .orElseGet(() -> { + response.setStatusCode(HttpStatus.NOT_FOUND); + return response.setComplete(); + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..8ffce598443b65f8d9abdbc71cf284bfeb0c00cd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java @@ -0,0 +1,247 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.LinkedList; +import java.util.function.Consumer; + +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Package-private default implementation of {@link ServerHttpRequest.Builder}. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + */ +class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { + + private URI uri; + + private HttpHeaders httpHeaders; + + private String httpMethodValue; + + private final MultiValueMap cookies; + + @Nullable + private String uriPath; + + @Nullable + private String contextPath; + + @Nullable + private SslInfo sslInfo; + + private Flux body; + + private final ServerHttpRequest originalRequest; + + + public DefaultServerHttpRequestBuilder(ServerHttpRequest original) { + Assert.notNull(original, "ServerHttpRequest is required"); + + this.uri = original.getURI(); + this.httpMethodValue = original.getMethodValue(); + this.body = original.getBody(); + + this.httpHeaders = HttpHeaders.writableHttpHeaders(original.getHeaders()); + + this.cookies = new LinkedMultiValueMap<>(original.getCookies().size()); + copyMultiValueMap(original.getCookies(), this.cookies); + + this.originalRequest = original; + } + + private static void copyMultiValueMap(MultiValueMap source, MultiValueMap target) { + source.forEach((key, value) -> target.put(key, new LinkedList<>(value))); + } + + + @Override + public ServerHttpRequest.Builder method(HttpMethod httpMethod) { + this.httpMethodValue = httpMethod.name(); + return this; + } + + @Override + public ServerHttpRequest.Builder uri(URI uri) { + this.uri = uri; + return this; + } + + @Override + public ServerHttpRequest.Builder path(String path) { + Assert.isTrue(path.startsWith("/"), "The path does not have a leading slash."); + this.uriPath = path; + return this; + } + + @Override + public ServerHttpRequest.Builder contextPath(String contextPath) { + this.contextPath = contextPath; + return this; + } + + @Override + @Deprecated + public ServerHttpRequest.Builder header(String key, String value) { + this.httpHeaders.add(key, value); + return this; + } + + @Override + public ServerHttpRequest.Builder headers(Consumer headersConsumer) { + Assert.notNull(headersConsumer, "'headersConsumer' must not be null"); + headersConsumer.accept(this.httpHeaders); + return this; + } + + @Override + public ServerHttpRequest.Builder sslInfo(SslInfo sslInfo) { + this.sslInfo = sslInfo; + return this; + } + + @Override + public ServerHttpRequest build() { + return new MutatedServerHttpRequest(getUriToUse(), this.contextPath, this.httpHeaders, + this.httpMethodValue, this.cookies, this.sslInfo, this.body, this.originalRequest); + } + + private URI getUriToUse() { + if (this.uriPath == null) { + return this.uri; + } + + StringBuilder uriBuilder = new StringBuilder(); + if (this.uri.getScheme() != null) { + uriBuilder.append(this.uri.getScheme()).append(':'); + } + if (this.uri.getRawUserInfo() != null || this.uri.getHost() != null) { + uriBuilder.append("//"); + if (this.uri.getRawUserInfo() != null) { + uriBuilder.append(this.uri.getRawUserInfo()).append('@'); + } + if (this.uri.getHost() != null) { + uriBuilder.append(this.uri.getHost()); + } + if (this.uri.getPort() != -1) { + uriBuilder.append(':').append(this.uri.getPort()); + } + } + if (StringUtils.hasLength(this.uriPath)) { + uriBuilder.append(this.uriPath); + } + if (this.uri.getRawQuery() != null) { + uriBuilder.append('?').append(this.uri.getRawQuery()); + } + if (this.uri.getRawFragment() != null) { + uriBuilder.append('#').append(this.uri.getRawFragment()); + } + try { + return new URI(uriBuilder.toString()); + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Invalid URI path: \"" + this.uriPath + "\"", ex); + } + } + + + private static class MutatedServerHttpRequest extends AbstractServerHttpRequest { + + private final String methodValue; + + private final MultiValueMap cookies; + + @Nullable + private final InetSocketAddress remoteAddress; + + @Nullable + private final SslInfo sslInfo; + + private final Flux body; + + private final ServerHttpRequest originalRequest; + + + public MutatedServerHttpRequest(URI uri, @Nullable String contextPath, + HttpHeaders headers, String methodValue, MultiValueMap cookies, + @Nullable SslInfo sslInfo, Flux body, ServerHttpRequest originalRequest) { + + super(uri, contextPath, headers); + this.methodValue = methodValue; + this.cookies = cookies; + this.remoteAddress = originalRequest.getRemoteAddress(); + this.sslInfo = sslInfo != null ? sslInfo : originalRequest.getSslInfo(); + this.body = body; + this.originalRequest = originalRequest; + } + + @Override + public String getMethodValue() { + return this.methodValue; + } + + @Override + protected MultiValueMap initCookies() { + return this.cookies; + } + + @Override + @Nullable + public InetSocketAddress getRemoteAddress() { + return this.remoteAddress; + } + + @Override + @Nullable + protected SslInfo initSslInfo() { + return this.sslInfo; + } + + @Override + public Flux getBody() { + return this.body; + } + + @SuppressWarnings("unchecked") + @Override + public T getNativeRequest() { + return (T) this.originalRequest; + } + + @Override + public String getId() { + return this.originalRequest.getId(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultSslInfo.java b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultSslInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..3533ad70c9cef6315a01c2da2dcffe468ceef35a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultSslInfo.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; + +import javax.net.ssl.SSLSession; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default implementation of {@link SslInfo}. + * + * @author Rossen Stoyanchev + * @since 5.0.2 + */ +final class DefaultSslInfo implements SslInfo { + + @Nullable + private final String sessionId; + + @Nullable + private final X509Certificate[] peerCertificates; + + + DefaultSslInfo(@Nullable String sessionId, X509Certificate[] peerCertificates) { + Assert.notNull(peerCertificates, "No SSL certificates"); + this.sessionId = sessionId; + this.peerCertificates = peerCertificates; + } + + DefaultSslInfo(SSLSession session) { + Assert.notNull(session, "SSLSession is required"); + this.sessionId = initSessionId(session); + this.peerCertificates = initCertificates(session); + } + + + @Override + @Nullable + public String getSessionId() { + return this.sessionId; + } + + @Override + @Nullable + public X509Certificate[] getPeerCertificates() { + return this.peerCertificates; + } + + + @Nullable + private static String initSessionId(SSLSession session) { + byte [] bytes = session.getId(); + if (bytes == null) { + return null; + } + + StringBuilder sb = new StringBuilder(); + for (byte b : bytes) { + String digit = Integer.toHexString(b); + if (digit.length() < 2) { + sb.append('0'); + } + if (digit.length() > 2) { + digit = digit.substring(digit.length() - 2); + } + sb.append(digit); + } + return sb.toString(); + } + + @Nullable + private static X509Certificate[] initCertificates(SSLSession session) { + Certificate[] certificates; + try { + certificates = session.getPeerCertificates(); + } + catch (Throwable ex) { + return null; + } + + List result = new ArrayList<>(certificates.length); + for (Certificate certificate : certificates) { + if (certificate instanceof X509Certificate) { + result.add((X509Certificate) certificate); + } + } + return (!result.isEmpty() ? result.toArray(new X509Certificate[0]) : null); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/HttpHandler.java b/spring-web/src/main/java/org/springframework/http/server/reactive/HttpHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..8cb9a48afad44f31054202a42fe003f5f4ebe276 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/HttpHandler.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import reactor.core.publisher.Mono; + +/** + * Lowest level contract for reactive HTTP request handling that serves as a + * common denominator across different runtimes. + * + *

Higher-level, but still generic, building blocks for applications such as + * {@code WebFilter}, {@code WebSession}, {@code ServerWebExchange}, and others + * are available in the {@code org.springframework.web.server} package. + * + *

Application level programming models such as annotated controllers and + * functional handlers are available in the {@code spring-webflux} module. + * + *

Typically an {@link HttpHandler} represents an entire application with + * higher-level programming models bridged via + * {@link org.springframework.web.server.adapter.WebHttpHandlerBuilder}. + * Multiple applications at unique context paths can be plugged in with the + * help of the {@link ContextPathCompositeHandler}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 5.0 + * @see ContextPathCompositeHandler + */ +public interface HttpHandler { + + /** + * Handle the given request and write to the response. + * @param request current request + * @param response current response + * @return indicates completion of request handling + */ + Mono handle(ServerHttpRequest request, ServerHttpResponse response); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/HttpHeadResponseDecorator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/HttpHeadResponseDecorator.java new file mode 100644 index 0000000000000000000000000000000000000000..d387abd329bf08dc61909264089cca03ff65aa6f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/HttpHeadResponseDecorator.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.function.BiFunction; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; + +/** + * {@link ServerHttpResponse} decorator for HTTP HEAD requests. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class HttpHeadResponseDecorator extends ServerHttpResponseDecorator { + + + public HttpHeadResponseDecorator(ServerHttpResponse delegate) { + super(delegate); + } + + + /** + * Apply {@link Flux#reduce(Object, BiFunction) reduce} on the body, count + * the number of bytes produced, release data buffers without writing, and + * set the {@literal Content-Length} header. + */ + @Override + public final Mono writeWith(Publisher body) { + return Flux.from(body) + .reduce(0, (current, buffer) -> { + int next = current + buffer.readableByteCount(); + DataBufferUtils.release(buffer); + return next; + }) + .doOnNext(length -> { + if (length > 0 || getHeaders().getFirst(HttpHeaders.CONTENT_LENGTH) == null) { + getHeaders().setContentLength(length); + } + }) + .then(); + } + + /** + * Invoke {@link #setComplete()} without writing. + *

RFC 7302 allows HTTP HEAD response without content-length and it's not + * something that can be computed on a streaming response. + */ + @Override + public final Mono writeAndFlushWith(Publisher> body) { + // Not feasible to count bytes on potentially streaming response. + // RFC 7302 allows HEAD without content-length. + return setComplete(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..198e549bb65d7bdd5d41e58bfeb1846e40919968 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.eclipse.jetty.http.HttpField; +import org.eclipse.jetty.http.HttpFields; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Jetty HTTP headers. + * + * @author Brian Clozel + * @since 5.1.1 + */ +class JettyHeadersAdapter implements MultiValueMap { + + private final HttpFields headers; + + + JettyHeadersAdapter(HttpFields headers) { + this.headers = headers; + } + + + @Override + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + Iterator iterator = this.headers.iterator(); + iterator.forEachRemaining(field -> { + if (!singleValueMap.containsKey(field.getName())) { + singleValueMap.put(field.getName(), field.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.getFieldNamesCollection().size(); + } + + @Override + public boolean isEmpty() { + return (this.headers.size() == 0); + } + + @Override + public boolean containsKey(Object key) { + return (key instanceof String && this.headers.containsKey((String) key)); + } + + @Override + public boolean containsValue(Object value) { + return (value instanceof String && + this.headers.stream().anyMatch(field -> field.contains((String) value))); + } + + @Nullable + @Override + public List get(Object key) { + if (containsKey(key)) { + return this.headers.getValuesList((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, List value) { + List oldValues = get(key); + this.headers.put(key, value); + return oldValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List oldValues = get(key); + this.headers.remove((String) key); + return oldValues; + } + return null; + } + + @Override + public void putAll(Map> map) { + map.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getFieldNamesCollection(); + } + + @Override + public Collection> values() { + return this.headers.getFieldNamesCollection().stream() + .map(this.headers::getValuesList).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + + @Override + public String toString() { + return HttpHeaders.formatHeaders(this); + } + + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.getFieldNames(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public List getValue() { + return headers.getValuesList(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getValuesList(this.key); + headers.put(this.key, value); + return previousValues; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..e0c538bc47220262312b2a33d22d8e33e5d4e7ae --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; + +import javax.servlet.AsyncContext; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.server.HttpOutput; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.util.Assert; + +/** + * {@link ServletHttpHandlerAdapter} extension that uses Jetty APIs for writing + * to the response with {@link ByteBuffer}. + * + * @author Violeta Georgieva + * @author Brian Clozel + * @since 5.0 + * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer + */ +public class JettyHttpHandlerAdapter extends ServletHttpHandlerAdapter { + + public JettyHttpHandlerAdapter(HttpHandler httpHandler) { + super(httpHandler); + } + + + @Override + protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext context) + throws IOException, URISyntaxException { + + Assert.notNull(getServletPath(), "Servlet path is not initialized"); + return new JettyServerHttpRequest(request, context, getServletPath(), getDataBufferFactory(), getBufferSize()); + } + + @Override + protected ServletServerHttpResponse createResponse(HttpServletResponse response, + AsyncContext context, ServletServerHttpRequest request) throws IOException { + + return new JettyServerHttpResponse( + response, context, getDataBufferFactory(), getBufferSize(), request); + } + + + private static final class JettyServerHttpRequest extends ServletServerHttpRequest { + + JettyServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(createHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + private static HttpHeaders createHeaders(HttpServletRequest request) { + HttpFields fields = ((Request) request).getMetaData().getFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + } + + + private static final class JettyServerHttpResponse extends ServletServerHttpResponse { + + JettyServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, + DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) + throws IOException { + + super(createHeaders(response), response, asyncContext, bufferFactory, bufferSize, request); + } + + private static HttpHeaders createHeaders(HttpServletResponse response) { + HttpFields fields = ((Response) response).getHttpFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + + @Override + protected void applyHeaders() { + HttpServletResponse response = getNativeResponse(); + MediaType contentType = null; + try { + contentType = getHeaders().getContentType(); + } + catch (Exception ex) { + String rawContentType = getHeaders().getFirst(HttpHeaders.CONTENT_TYPE); + response.setContentType(rawContentType); + } + if (response.getContentType() == null && contentType != null) { + response.setContentType(contentType.toString()); + } + Charset charset = (contentType != null ? contentType.getCharset() : null); + if (response.getCharacterEncoding() == null && charset != null) { + response.setCharacterEncoding(charset.name()); + } + long contentLength = getHeaders().getContentLength(); + if (contentLength != -1) { + response.setContentLengthLong(contentLength); + } + } + + @Override + protected int writeToOutputStream(DataBuffer dataBuffer) throws IOException { + ByteBuffer input = dataBuffer.asByteBuffer(); + int len = input.remaining(); + ServletResponse response = getNativeResponse(); + ((HttpOutput) response.getOutputStream()).write(input); + return len; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..c32cb5426fdc4fbdbb6c347e370d6cd1c3317b16 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java @@ -0,0 +1,227 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.netty.handler.codec.http.HttpHeaders; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Netty HTTP headers. + * + * @author Brian Clozel + * @since 5.1.1 + */ +class NettyHeadersAdapter implements MultiValueMap { + + private final HttpHeaders headers; + + + NettyHeadersAdapter(HttpHeaders headers) { + this.headers = headers; + } + + + @Override + @Nullable + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + this.headers.add(key, values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this.headers::add); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.set(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this.headers::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.entries() + .forEach(entry -> { + if (!singleValueMap.containsKey(entry.getKey())) { + singleValueMap.put(entry.getKey(), entry.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.names().size(); + } + + @Override + public boolean isEmpty() { + return this.headers.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return (key instanceof String && this.headers.contains((String) key)); + } + + @Override + public boolean containsValue(Object value) { + return (value instanceof String && + this.headers.entries().stream() + .anyMatch(entry -> value.equals(entry.getValue()))); + } + + @Override + @Nullable + public List get(Object key) { + if (containsKey(key)) { + return this.headers.getAll((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, @Nullable List value) { + List previousValues = this.headers.getAll(key); + this.headers.set(key, value); + return previousValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List previousValues = this.headers.getAll((String) key); + this.headers.remove((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> map) { + map.forEach(this.headers::add); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.names(); + } + + @Override + public Collection> values() { + return this.headers.names().stream() + .map(this.headers::getAll).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + + @Override + public String toString() { + return org.springframework.http.HttpHeaders.formatHeaders(this); + } + + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.names().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public List getValue() { + return headers.getAll(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getAll(this.key); + headers.set(this.key, value); + return previousValues; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorHttpHandlerAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..bc473d6464dceab4ea5f27403966a043e2e4a29d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorHttpHandlerAdapter.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URISyntaxException; +import java.util.function.BiFunction; + +import io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.commons.logging.Log; +import reactor.core.publisher.Mono; +import reactor.netty.http.server.HttpServerRequest; +import reactor.netty.http.server.HttpServerResponse; + +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; + +/** + * Adapt {@link HttpHandler} to the Reactor Netty channel handling function. + * + * @author Stephane Maldini + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ReactorHttpHandlerAdapter implements BiFunction> { + + private static final Log logger = HttpLogging.forLogName(ReactorHttpHandlerAdapter.class); + + + private final HttpHandler httpHandler; + + + public ReactorHttpHandlerAdapter(HttpHandler httpHandler) { + Assert.notNull(httpHandler, "HttpHandler must not be null"); + this.httpHandler = httpHandler; + } + + + @Override + public Mono apply(HttpServerRequest reactorRequest, HttpServerResponse reactorResponse) { + NettyDataBufferFactory bufferFactory = new NettyDataBufferFactory(reactorResponse.alloc()); + try { + ReactorServerHttpRequest request = new ReactorServerHttpRequest(reactorRequest, bufferFactory); + ServerHttpResponse response = new ReactorServerHttpResponse(reactorResponse, bufferFactory); + + if (request.getMethod() == HttpMethod.HEAD) { + response = new HttpHeadResponseDecorator(response); + } + + return this.httpHandler.handle(request, response) + .doOnError(ex -> logger.trace(request.getLogPrefix() + "Failed to complete: " + ex.getMessage())) + .doOnSuccess(aVoid -> logger.trace(request.getLogPrefix() + "Handling completed")); + } + catch (URISyntaxException ex) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to get request URI: " + ex.getMessage()); + } + reactorResponse.status(HttpResponseStatus.BAD_REQUEST); + return Mono.empty(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..cd3a2b34578903a7c3ded35c0c0181eb628ab6e2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java @@ -0,0 +1,189 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; + +import javax.net.ssl.SSLSession; + +import io.netty.channel.Channel; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.ssl.SslHandler; +import reactor.core.publisher.Flux; +import reactor.netty.Connection; +import reactor.netty.http.server.HttpServerRequest; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Adapt {@link ServerHttpRequest} to the Reactor {@link HttpServerRequest}. + * + * @author Stephane Maldini + * @author Rossen Stoyanchev + * @since 5.0 + */ +class ReactorServerHttpRequest extends AbstractServerHttpRequest { + + private final HttpServerRequest request; + + private final NettyDataBufferFactory bufferFactory; + + + public ReactorServerHttpRequest(HttpServerRequest request, NettyDataBufferFactory bufferFactory) + throws URISyntaxException { + + super(initUri(request), "", initHeaders(request)); + Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); + this.request = request; + this.bufferFactory = bufferFactory; + } + + private static URI initUri(HttpServerRequest request) throws URISyntaxException { + Assert.notNull(request, "HttpServerRequest must not be null"); + return new URI(resolveBaseUrl(request).toString() + resolveRequestUri(request)); + } + + private static URI resolveBaseUrl(HttpServerRequest request) throws URISyntaxException { + String scheme = getScheme(request); + String header = request.requestHeaders().get(HttpHeaderNames.HOST); + if (header != null) { + final int portIndex; + if (header.startsWith("[")) { + portIndex = header.indexOf(':', header.indexOf(']')); + } + else { + portIndex = header.indexOf(':'); + } + if (portIndex != -1) { + try { + return new URI(scheme, null, header.substring(0, portIndex), + Integer.parseInt(header.substring(portIndex + 1)), null, null, null); + } + catch (NumberFormatException ex) { + throw new URISyntaxException(header, "Unable to parse port", portIndex); + } + } + else { + return new URI(scheme, header, null, null); + } + } + else { + InetSocketAddress localAddress = request.hostAddress(); + Assert.state(localAddress != null, "No host address available"); + return new URI(scheme, null, localAddress.getHostString(), + localAddress.getPort(), null, null, null); + } + } + + private static String getScheme(HttpServerRequest request) { + return request.scheme(); + } + + private static String resolveRequestUri(HttpServerRequest request) { + String uri = request.uri(); + for (int i = 0; i < uri.length(); i++) { + char c = uri.charAt(i); + if (c == '/' || c == '?' || c == '#') { + break; + } + if (c == ':' && (i + 2 < uri.length())) { + if (uri.charAt(i + 1) == '/' && uri.charAt(i + 2) == '/') { + for (int j = i + 3; j < uri.length(); j++) { + c = uri.charAt(j); + if (c == '/' || c == '?' || c == '#') { + return uri.substring(j); + } + } + return ""; + } + } + } + return uri; + } + + private static HttpHeaders initHeaders(HttpServerRequest channel) { + NettyHeadersAdapter headersMap = new NettyHeadersAdapter(channel.requestHeaders()); + return new HttpHeaders(headersMap); + } + + + @Override + public String getMethodValue() { + return this.request.method().name(); + } + + @Override + protected MultiValueMap initCookies() { + MultiValueMap cookies = new LinkedMultiValueMap<>(); + for (CharSequence name : this.request.cookies().keySet()) { + for (Cookie cookie : this.request.cookies().get(name)) { + HttpCookie httpCookie = new HttpCookie(name.toString(), cookie.value()); + cookies.add(name.toString(), httpCookie); + } + } + return cookies; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return this.request.remoteAddress(); + } + + @Override + @Nullable + protected SslInfo initSslInfo() { + Channel channel = ((Connection) this.request).channel(); + SslHandler sslHandler = channel.pipeline().get(SslHandler.class); + if (sslHandler == null && channel.parent() != null) { // HTTP/2 + sslHandler = channel.parent().pipeline().get(SslHandler.class); + } + if (sslHandler != null) { + SSLSession session = sslHandler.engine().getSession(); + return new DefaultSslInfo(session); + } + return null; + } + + @Override + public Flux getBody() { + return this.request.receive().retain().map(this.bufferFactory::wrap); + } + + @SuppressWarnings("unchecked") + @Override + public T getNativeRequest() { + return (T) this.request; + } + + @Override + @Nullable + protected String initId() { + return this.request instanceof Connection ? + ((Connection) this.request).channel().id().asShortText() : null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..43842bdb28e07bcea7365d5f4278953842c8c534 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.nio.file.Path; +import java.util.List; + +import io.netty.buffer.ByteBuf; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.http.server.HttpServerResponse; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.http.ZeroCopyHttpOutputMessage; +import org.springframework.util.Assert; + +/** + * Adapt {@link ServerHttpResponse} to the {@link HttpServerResponse}. + * + * @author Stephane Maldini + * @author Rossen Stoyanchev + * @since 5.0 + */ +class ReactorServerHttpResponse extends AbstractServerHttpResponse implements ZeroCopyHttpOutputMessage { + + private final HttpServerResponse response; + + + public ReactorServerHttpResponse(HttpServerResponse response, DataBufferFactory bufferFactory) { + super(bufferFactory, new HttpHeaders(new NettyHeadersAdapter(response.responseHeaders()))); + Assert.notNull(response, "HttpServerResponse must not be null"); + this.response = response; + } + + + @SuppressWarnings("unchecked") + @Override + public T getNativeResponse() { + return (T) this.response; + } + + @Override + public HttpStatus getStatusCode() { + HttpStatus httpStatus = super.getStatusCode(); + return (httpStatus != null ? httpStatus : HttpStatus.resolve(this.response.status().code())); + } + + + @Override + protected void applyStatusCode() { + Integer statusCode = getStatusCodeValue(); + if (statusCode != null) { + this.response.status(statusCode); + } + } + + @Override + protected Mono writeWithInternal(Publisher publisher) { + return this.response.send(toByteBufs(publisher)).then(); + } + + @Override + protected Mono writeAndFlushWithInternal(Publisher> publisher) { + return this.response.sendGroups(Flux.from(publisher).map(this::toByteBufs)).then(); + } + + @Override + protected void applyHeaders() { + } + + @Override + protected void applyCookies() { + // Netty Cookie doesn't support sameSite. When this is resolved, we can adapt to it again: + // https://github.com/netty/netty/issues/8161 + for (List cookies : getCookies().values()) { + for (ResponseCookie cookie : cookies) { + this.response.addHeader(HttpHeaders.SET_COOKIE, cookie.toString()); + } + } + } + + @Override + public Mono writeWith(Path file, long position, long count) { + return doCommit(() -> this.response.sendFile(file, position, count).then()); + } + + private Publisher toByteBufs(Publisher dataBuffers) { + return dataBuffers instanceof Mono ? + Mono.from(dataBuffers).map(NettyDataBufferFactory::toByteBuf) : + Flux.from(dataBuffers).map(NettyDataBufferFactory::toByteBuf); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..c71cc0636a996fae3854686d520a30114e4a5efd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java @@ -0,0 +1,190 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.Arrays; +import java.util.function.Consumer; + +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.server.RequestPath; +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * Represents a reactive server-side HTTP request. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Sam Brannen + * @since 5.0 + */ +public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage { + + /** + * Return an id that represents the underlying connection, if available, + * or the request for the purpose of correlating log messages. + * @since 5.1 + * @see org.springframework.web.server.ServerWebExchange#getLogPrefix() + */ + String getId(); + + /** + * Returns a structured representation of the request path including the + * context path + path within application portions, path segments with + * encoded and decoded values, and path parameters. + */ + RequestPath getPath(); + + /** + * Return a read-only map with parsed and decoded query parameter values. + */ + MultiValueMap getQueryParams(); + + /** + * Return a read-only map of cookies sent by the client. + */ + MultiValueMap getCookies(); + + /** + * Return the remote address where this request is connected to, if available. + */ + @Nullable + default InetSocketAddress getRemoteAddress() { + return null; + } + + /** + * Return the SSL session information if the request has been transmitted + * over a secure protocol including SSL certificates, if available. + * @return the session information, or {@code null} if none available + * @since 5.0.2 + */ + @Nullable + default SslInfo getSslInfo() { + return null; + } + + /** + * Return a builder to mutate properties of this request by wrapping it + * with {@link ServerHttpRequestDecorator} and returning either mutated + * values or delegating back to this instance. + */ + default ServerHttpRequest.Builder mutate() { + return new DefaultServerHttpRequestBuilder(this); + } + + + /** + * Builder for mutating an existing {@link ServerHttpRequest}. + */ + interface Builder { + + /** + * Set the HTTP method to return. + */ + Builder method(HttpMethod httpMethod); + + /** + * Set the URI to use with the following conditions: + *

    + *
  • If {@link #path(String) path} is also set, it overrides the path + * of the URI provided here. + *
  • If {@link #contextPath(String) contextPath} is also set, or + * already present, it must match the start of the path of the URI + * provided here. + *
+ */ + Builder uri(URI uri); + + /** + * Set the path to use instead of the {@code "rawPath"} of the URI of + * the request with the following conditions: + *
    + *
  • If {@link #uri(URI) uri} is also set, the path given here + * overrides the path of the given URI. + *
  • If {@link #contextPath(String) contextPath} is also set, or + * already present, it must match the start of the path given here. + *
  • The given value must begin with a slash. + *
+ */ + Builder path(String path); + + /** + * Set the contextPath to use. + *

The given value must be a valid {@link RequestPath#contextPath() + * contextPath} and it must match the start of the path of the URI of + * the request. That means changing the contextPath, implies also + * changing the path via {@link #path(String)}. + */ + Builder contextPath(String contextPath); + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValue the header value + * @deprecated This method will be removed in Spring Framework 5.2 in + * favor of {@link #header(String, String...)}. + */ + @Deprecated + Builder header(String headerName, String headerValue); + + /** + * Set or override the specified header values under the given name. + *

If you need to set a single header value, you may invoke this + * method with an explicit one-element array — for example, + * header("key", new String[] { "value" }) — or you + * may choose to use {@link #headers(Consumer)} for greater control. + * @param headerName the header name + * @param headerValues the header values + * @since 5.1.9 + * @see #headers(Consumer) + */ + default Builder header(String headerName, String... headerValues) { + return headers(httpHeaders -> httpHeaders.put(headerName, Arrays.asList(headerValues))); + } + + /** + * Manipulate request headers. The provided {@code HttpHeaders} contains + * current request headers, so that the {@code Consumer} can + * {@linkplain HttpHeaders#set(String, String) overwrite} or + * {@linkplain HttpHeaders#remove(Object) remove} existing values, or + * use any other {@link HttpHeaders} methods. + * @see #header(String, String...) + */ + Builder headers(Consumer headersConsumer); + + /** + * Set the SSL session information. This may be useful in environments + * where TLS termination is done at the router, but SSL information is + * made available in some other way such as through a header. + * @since 5.0.7 + */ + Builder sslInfo(SslInfo sslInfo); + + /** + * Build a {@link ServerHttpRequest} decorator with the mutated properties. + */ + ServerHttpRequest build(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequestDecorator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequestDecorator.java new file mode 100644 index 0000000000000000000000000000000000000000..f8c26b4271b1f73b3462f4cc16755d13ed54e0a5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequestDecorator.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.InetSocketAddress; +import java.net.URI; + +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.server.RequestPath; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * Wraps another {@link ServerHttpRequest} and delegates all methods to it. + * Sub-classes can override specific methods selectively. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ServerHttpRequestDecorator implements ServerHttpRequest { + + private final ServerHttpRequest delegate; + + + public ServerHttpRequestDecorator(ServerHttpRequest delegate) { + Assert.notNull(delegate, "Delegate is required"); + this.delegate = delegate; + } + + + public ServerHttpRequest getDelegate() { + return this.delegate; + } + + + // ServerHttpRequest delegation methods... + + @Override + public String getId() { + return getDelegate().getId(); + } + + @Override + @Nullable + public HttpMethod getMethod() { + return getDelegate().getMethod(); + } + + @Override + public String getMethodValue() { + return getDelegate().getMethodValue(); + } + + @Override + public URI getURI() { + return getDelegate().getURI(); + } + + @Override + public RequestPath getPath() { + return getDelegate().getPath(); + } + + @Override + public MultiValueMap getQueryParams() { + return getDelegate().getQueryParams(); + } + + @Override + public HttpHeaders getHeaders() { + return getDelegate().getHeaders(); + } + + @Override + public MultiValueMap getCookies() { + return getDelegate().getCookies(); + } + + @Override + @Nullable + public InetSocketAddress getRemoteAddress() { + return getDelegate().getRemoteAddress(); + } + + @Override + @Nullable + public SslInfo getSslInfo() { + return getDelegate().getSslInfo(); + } + + @Override + public Flux getBody() { + return getDelegate().getBody(); + } + + + @Override + public String toString() { + return getClass().getSimpleName() + " [delegate=" + getDelegate() + "]"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..2bbfba900f4e702a059b4227ff7d685fa5fc8432 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpResponse.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import org.springframework.http.HttpStatus; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.ResponseCookie; +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * Represents a reactive server-side HTTP response. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface ServerHttpResponse extends ReactiveHttpOutputMessage { + + /** + * Set the HTTP status code of the response. + * @param status the HTTP status as an {@link HttpStatus} enum value + * @return {@code false} if the status code has not been set because the + * HTTP response is already committed, {@code true} if successfully set. + */ + boolean setStatusCode(@Nullable HttpStatus status); + + /** + * Return the status code set via {@link #setStatusCode}, or if the status + * has not been set, return the default status code from the underlying + * server response. The return value may be {@code null} if the status code + * value is outside the {@link HttpStatus} enum range, or if the underlying + * server response does not have a default value. + */ + @Nullable + HttpStatus getStatusCode(); + + /** + * Return a mutable map with the cookies to send to the server. + */ + MultiValueMap getCookies(); + + /** + * Add the given {@code ResponseCookie}. + * @param cookie the cookie to add + * @throws IllegalStateException if the response has already been committed + */ + void addCookie(ResponseCookie cookie); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpResponseDecorator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpResponseDecorator.java new file mode 100644 index 0000000000000000000000000000000000000000..274cea9676311b04553ae62fe29927c7122efb4c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpResponseDecorator.java @@ -0,0 +1,118 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server.reactive; + +import java.util.function.Supplier; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * Wraps another {@link ServerHttpResponse} and delegates all methods to it. + * Sub-classes can override specific methods selectively. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ServerHttpResponseDecorator implements ServerHttpResponse { + + private final ServerHttpResponse delegate; + + + public ServerHttpResponseDecorator(ServerHttpResponse delegate) { + Assert.notNull(delegate, "Delegate is required"); + this.delegate = delegate; + } + + + public ServerHttpResponse getDelegate() { + return this.delegate; + } + + + // ServerHttpResponse delegation methods... + + @Override + public boolean setStatusCode(@Nullable HttpStatus status) { + return getDelegate().setStatusCode(status); + } + + @Override + public HttpStatus getStatusCode() { + return getDelegate().getStatusCode(); + } + + @Override + public HttpHeaders getHeaders() { + return getDelegate().getHeaders(); + } + + @Override + public MultiValueMap getCookies() { + return getDelegate().getCookies(); + } + + @Override + public void addCookie(ResponseCookie cookie) { + getDelegate().addCookie(cookie); + } + + @Override + public DataBufferFactory bufferFactory() { + return getDelegate().bufferFactory(); + } + + @Override + public void beforeCommit(Supplier> action) { + getDelegate().beforeCommit(action); + } + + @Override + public boolean isCommitted() { + return getDelegate().isCommitted(); + } + + @Override + public Mono writeWith(Publisher body) { + return getDelegate().writeWith(body); + } + + @Override + public Mono writeAndFlushWith(Publisher> body) { + return getDelegate().writeAndFlushWith(body); + } + + @Override + public Mono setComplete() { + return getDelegate().setComplete(); + } + + + @Override + public String toString() { + return getClass().getSimpleName() + " [delegate=" + getDelegate() + "]"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..dd9144d0368d82977b2c9cc8ee621aa5634188a4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java @@ -0,0 +1,336 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.DispatcherType; +import javax.servlet.Servlet; +import javax.servlet.ServletConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRegistration; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Adapt {@link HttpHandler} to an {@link HttpServlet} using Servlet Async support + * and Servlet 3.1 non-blocking I/O. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 5.0 + * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer + */ +@SuppressWarnings("serial") +public class ServletHttpHandlerAdapter implements Servlet { + + private static final Log logger = HttpLogging.forLogName(ServletHttpHandlerAdapter.class); + + private static final int DEFAULT_BUFFER_SIZE = 8192; + + private static final String WRITE_ERROR_ATTRIBUTE_NAME = ServletHttpHandlerAdapter.class.getName() + ".ERROR"; + + + private final HttpHandler httpHandler; + + private int bufferSize = DEFAULT_BUFFER_SIZE; + + @Nullable + private String servletPath; + + private DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(false); + + + public ServletHttpHandlerAdapter(HttpHandler httpHandler) { + Assert.notNull(httpHandler, "HttpHandler must not be null"); + this.httpHandler = httpHandler; + } + + + /** + * Set the size of the input buffer used for reading in bytes. + *

By default this is set to 8192. + */ + public void setBufferSize(int bufferSize) { + Assert.isTrue(bufferSize > 0, "Buffer size must be larger than zero"); + this.bufferSize = bufferSize; + } + + /** + * Return the configured input buffer size. + */ + public int getBufferSize() { + return this.bufferSize; + } + + /** + * Return the Servlet path under which the Servlet is deployed by checking + * the Servlet registration from {@link #init(ServletConfig)}. + * @return the path, or an empty string if the Servlet is deployed without + * a prefix (i.e. "/" or "/*"), or {@code null} if this method is invoked + * before the {@link #init(ServletConfig)} Servlet container callback. + */ + @Nullable + public String getServletPath() { + return this.servletPath; + } + + public void setDataBufferFactory(DataBufferFactory dataBufferFactory) { + Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); + this.dataBufferFactory = dataBufferFactory; + } + + public DataBufferFactory getDataBufferFactory() { + return this.dataBufferFactory; + } + + + // Servlet methods... + + @Override + public void init(ServletConfig config) { + this.servletPath = getServletPath(config); + } + + private String getServletPath(ServletConfig config) { + String name = config.getServletName(); + ServletRegistration registration = config.getServletContext().getServletRegistration(name); + if (registration == null) { + throw new IllegalStateException("ServletRegistration not found for Servlet '" + name + "'"); + } + + Collection mappings = registration.getMappings(); + if (mappings.size() == 1) { + String mapping = mappings.iterator().next(); + if (mapping.equals("/")) { + return ""; + } + if (mapping.endsWith("/*")) { + String path = mapping.substring(0, mapping.length() - 2); + if (!path.isEmpty() && logger.isDebugEnabled()) { + logger.debug("Found servlet mapping prefix '" + path + "' for '" + name + "'"); + } + return path; + } + } + + throw new IllegalArgumentException("Expected a single Servlet mapping: " + + "either the default Servlet mapping (i.e. '/'), " + + "or a path based mapping (e.g. '/*', '/foo/*'). " + + "Actual mappings: " + mappings + " for Servlet '" + name + "'"); + } + + + @Override + public void service(ServletRequest request, ServletResponse response) throws ServletException, IOException { + // Check for existing error attribute first + if (DispatcherType.ASYNC.equals(request.getDispatcherType())) { + Throwable ex = (Throwable) request.getAttribute(WRITE_ERROR_ATTRIBUTE_NAME); + throw new ServletException("Failed to create response content", ex); + } + + // Start async before Read/WriteListener registration + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(-1); + + ServletServerHttpRequest httpRequest; + try { + httpRequest = createRequest(((HttpServletRequest) request), asyncContext); + } + catch (URISyntaxException ex) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to get request URL: " + ex.getMessage()); + } + ((HttpServletResponse) response).setStatus(400); + asyncContext.complete(); + return; + } + + ServerHttpResponse httpResponse = createResponse(((HttpServletResponse) response), asyncContext, httpRequest); + if (httpRequest.getMethod() == HttpMethod.HEAD) { + httpResponse = new HttpHeadResponseDecorator(httpResponse); + } + + AtomicBoolean isCompleted = new AtomicBoolean(); + HandlerResultAsyncListener listener = new HandlerResultAsyncListener(isCompleted, httpRequest); + asyncContext.addListener(listener); + + HandlerResultSubscriber subscriber = new HandlerResultSubscriber(asyncContext, isCompleted, httpRequest); + this.httpHandler.handle(httpRequest, httpResponse).subscribe(subscriber); + } + + protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext context) + throws IOException, URISyntaxException { + + Assert.notNull(this.servletPath, "Servlet path is not initialized"); + return new ServletServerHttpRequest( + request, context, this.servletPath, getDataBufferFactory(), getBufferSize()); + } + + protected ServletServerHttpResponse createResponse(HttpServletResponse response, + AsyncContext context, ServletServerHttpRequest request) throws IOException { + + return new ServletServerHttpResponse(response, context, getDataBufferFactory(), getBufferSize(), request); + } + + @Override + public String getServletInfo() { + return ""; + } + + @Override + @Nullable + public ServletConfig getServletConfig() { + return null; + } + + @Override + public void destroy() { + } + + + /** + * We cannot combine ERROR_LISTENER and HandlerResultSubscriber due to: + * https://issues.jboss.org/browse/WFLY-8515. + */ + private static void runIfAsyncNotComplete(AsyncContext asyncContext, AtomicBoolean isCompleted, Runnable task) { + try { + if (asyncContext.getRequest().isAsyncStarted() && isCompleted.compareAndSet(false, true)) { + task.run(); + } + } + catch (IllegalStateException ex) { + // Ignore: AsyncContext recycled and should not be used + // e.g. TIMEOUT_LISTENER (above) may have completed the AsyncContext + } + } + + + private static class HandlerResultAsyncListener implements AsyncListener { + + private final AtomicBoolean isCompleted; + + private final String logPrefix; + + public HandlerResultAsyncListener(AtomicBoolean isCompleted, ServletServerHttpRequest httpRequest) { + this.isCompleted = isCompleted; + this.logPrefix = httpRequest.getLogPrefix(); + } + + @Override + public void onTimeout(AsyncEvent event) { + logger.debug(this.logPrefix + "Timeout notification"); + AsyncContext context = event.getAsyncContext(); + runIfAsyncNotComplete(context, this.isCompleted, context::complete); + } + + @Override + public void onError(AsyncEvent event) { + Throwable ex = event.getThrowable(); + logger.debug(this.logPrefix + "Error notification: " + (ex != null ? ex : "")); + AsyncContext context = event.getAsyncContext(); + runIfAsyncNotComplete(context, this.isCompleted, context::complete); + } + + @Override + public void onStartAsync(AsyncEvent event) { + // no-op + } + + @Override + public void onComplete(AsyncEvent event) { + // no-op + } + } + + + private class HandlerResultSubscriber implements Subscriber { + + private final AsyncContext asyncContext; + + private final AtomicBoolean isCompleted; + + private final String logPrefix; + + public HandlerResultSubscriber( + AsyncContext asyncContext, AtomicBoolean isCompleted, ServletServerHttpRequest httpRequest) { + + this.asyncContext = asyncContext; + this.isCompleted = isCompleted; + this.logPrefix = httpRequest.getLogPrefix(); + } + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + // no-op + } + + @Override + public void onError(Throwable ex) { + logger.trace(this.logPrefix + "Failed to complete: " + ex.getMessage()); + runIfAsyncNotComplete(this.asyncContext, this.isCompleted, () -> { + if (this.asyncContext.getResponse().isCommitted()) { + logger.trace(this.logPrefix + "Dispatch to container, to raise the error on servlet thread"); + this.asyncContext.getRequest().setAttribute(WRITE_ERROR_ATTRIBUTE_NAME, ex); + this.asyncContext.dispatch(); + } + else { + try { + logger.trace(this.logPrefix + "Setting ServletResponse status to 500 Server Error"); + this.asyncContext.getResponse().resetBuffer(); + ((HttpServletResponse) this.asyncContext.getResponse()).setStatus(500); + } + finally { + this.asyncContext.complete(); + } + } + }); + } + + @Override + public void onComplete() { + logger.trace(this.logPrefix + "Handling completed"); + runIfAsyncNotComplete(this.asyncContext, this.isCompleted, this.asyncContext::complete); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..e29a79a3b38d47265b058dc55dfc0fcf9e46b539 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -0,0 +1,335 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.Charset; +import java.security.cert.X509Certificate; +import java.util.Enumeration; +import java.util.Map; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.logging.Log; +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Adapt {@link ServerHttpRequest} to the Servlet {@link HttpServletRequest}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +class ServletServerHttpRequest extends AbstractServerHttpRequest { + + static final DataBuffer EOF_BUFFER = new DefaultDataBufferFactory().allocateBuffer(0); + + + private final HttpServletRequest request; + + private final RequestBodyPublisher bodyPublisher; + + private final Object cookieLock = new Object(); + + private final DataBufferFactory bufferFactory; + + private final byte[] buffer; + + public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + this(createDefaultHttpHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + public ServletServerHttpRequest(HttpHeaders headers, HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(initUri(request), request.getContextPath() + servletPath, initHeaders(headers, request)); + + Assert.notNull(bufferFactory, "'bufferFactory' must not be null"); + Assert.isTrue(bufferSize > 0, "'bufferSize' must be higher than 0"); + + this.request = request; + this.bufferFactory = bufferFactory; + this.buffer = new byte[bufferSize]; + + asyncContext.addListener(new RequestAsyncListener()); + + // Tomcat expects ReadListener registration on initial thread + ServletInputStream inputStream = request.getInputStream(); + this.bodyPublisher = new RequestBodyPublisher(inputStream); + this.bodyPublisher.registerReadListener(); + } + + + private static HttpHeaders createDefaultHttpHeaders(HttpServletRequest request) { + HttpHeaders headers = new HttpHeaders(); + for (Enumeration names = request.getHeaderNames(); names.hasMoreElements(); ) { + String name = (String) names.nextElement(); + for (Enumeration values = request.getHeaders(name); values.hasMoreElements(); ) { + headers.add(name, (String) values.nextElement()); + } + } + return headers; + } + + private static URI initUri(HttpServletRequest request) throws URISyntaxException { + Assert.notNull(request, "'request' must not be null"); + StringBuffer url = request.getRequestURL(); + String query = request.getQueryString(); + if (StringUtils.hasText(query)) { + url.append('?').append(query); + } + return new URI(url.toString()); + } + + private static HttpHeaders initHeaders(HttpHeaders headers, HttpServletRequest request) { + MediaType contentType = headers.getContentType(); + if (contentType == null) { + String requestContentType = request.getContentType(); + if (StringUtils.hasLength(requestContentType)) { + contentType = MediaType.parseMediaType(requestContentType); + headers.setContentType(contentType); + } + } + if (contentType != null && contentType.getCharset() == null) { + String encoding = request.getCharacterEncoding(); + if (StringUtils.hasLength(encoding)) { + Charset charset = Charset.forName(encoding); + Map params = new LinkedCaseInsensitiveMap<>(); + params.putAll(contentType.getParameters()); + params.put("charset", charset.toString()); + headers.setContentType( + new MediaType(contentType.getType(), contentType.getSubtype(), + params)); + } + } + if (headers.getContentLength() == -1) { + int contentLength = request.getContentLength(); + if (contentLength != -1) { + headers.setContentLength(contentLength); + } + } + return headers; + } + + + @Override + public String getMethodValue() { + return this.request.getMethod(); + } + + @Override + protected MultiValueMap initCookies() { + MultiValueMap httpCookies = new LinkedMultiValueMap<>(); + Cookie[] cookies; + synchronized (this.cookieLock) { + cookies = this.request.getCookies(); + } + if (cookies != null) { + for (Cookie cookie : cookies) { + String name = cookie.getName(); + HttpCookie httpCookie = new HttpCookie(name, cookie.getValue()); + httpCookies.add(name, httpCookie); + } + } + return httpCookies; + } + + @Override + @NonNull + public InetSocketAddress getRemoteAddress() { + return new InetSocketAddress(this.request.getRemoteHost(), this.request.getRemotePort()); + } + + @Override + @Nullable + protected SslInfo initSslInfo() { + X509Certificate[] certificates = getX509Certificates(); + return certificates != null ? new DefaultSslInfo(getSslSessionId(), certificates) : null; + } + + @Nullable + private String getSslSessionId() { + return (String) this.request.getAttribute("javax.servlet.request.ssl_session_id"); + } + + @Nullable + private X509Certificate[] getX509Certificates() { + String name = "javax.servlet.request.X509Certificate"; + return (X509Certificate[]) this.request.getAttribute(name); + } + + @Override + public Flux getBody() { + return Flux.from(this.bodyPublisher); + } + + /** + * Read from the request body InputStream and return a DataBuffer. + * Invoked only when {@link ServletInputStream#isReady()} returns "true". + * @return a DataBuffer with data read, or {@link #EOF_BUFFER} if the input + * stream returned -1, or null if 0 bytes were read. + */ + @Nullable + DataBuffer readFromInputStream() throws IOException { + int read = this.request.getInputStream().read(this.buffer); + logBytesRead(read); + + if (read > 0) { + DataBuffer dataBuffer = this.bufferFactory.allocateBuffer(read); + dataBuffer.write(this.buffer, 0, read); + return dataBuffer; + } + + if (read == -1) { + return EOF_BUFFER; + } + + return null; + } + + protected final void logBytesRead(int read) { + Log rsReadLogger = AbstractListenerReadPublisher.rsReadLogger; + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + "Read " + read + (read != -1 ? " bytes" : "")); + } + } + + @SuppressWarnings("unchecked") + @Override + public T getNativeRequest() { + return (T) this.request; + } + + + private final class RequestAsyncListener implements AsyncListener { + + @Override + public void onStartAsync(AsyncEvent event) { + } + + @Override + public void onTimeout(AsyncEvent event) { + Throwable ex = event.getThrowable(); + ex = ex != null ? ex : new IllegalStateException("Async operation timeout."); + bodyPublisher.onError(ex); + } + + @Override + public void onError(AsyncEvent event) { + bodyPublisher.onError(event.getThrowable()); + } + + @Override + public void onComplete(AsyncEvent event) { + bodyPublisher.onAllDataRead(); + } + } + + + private class RequestBodyPublisher extends AbstractListenerReadPublisher { + + private final ServletInputStream inputStream; + + public RequestBodyPublisher(ServletInputStream inputStream) { + super(ServletServerHttpRequest.this.getLogPrefix()); + this.inputStream = inputStream; + } + + public void registerReadListener() throws IOException { + this.inputStream.setReadListener(new RequestBodyPublisherReadListener()); + } + + @Override + protected void checkOnDataAvailable() { + if (this.inputStream.isReady() && !this.inputStream.isFinished()) { + onDataAvailable(); + } + } + + @Override + @Nullable + protected DataBuffer read() throws IOException { + if (this.inputStream.isReady()) { + DataBuffer dataBuffer = readFromInputStream(); + if (dataBuffer == EOF_BUFFER) { + // No need to wait for container callback... + onAllDataRead(); + dataBuffer = null; + } + return dataBuffer; + } + return null; + } + + @Override + protected void readingPaused() { + // no-op + } + + @Override + protected void discardData() { + // Nothing to discard since we pass data buffers on immediately.. + } + + + private class RequestBodyPublisherReadListener implements ReadListener { + + @Override + public void onDataAvailable() throws IOException { + RequestBodyPublisher.this.onDataAvailable(); + } + + @Override + public void onAllDataRead() throws IOException { + RequestBodyPublisher.this.onAllDataRead(); + } + + @Override + public void onError(Throwable throwable) { + RequestBodyPublisher.this.onError(throwable); + + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..b82233977e2a1b27a88051a4e907d9beef823caa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -0,0 +1,387 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.List; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletResponse; + +import org.reactivestreams.Processor; +import org.reactivestreams.Publisher; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Adapt {@link ServerHttpResponse} to the Servlet {@link HttpServletResponse}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { + + private final HttpServletResponse response; + + private final ServletOutputStream outputStream; + + private final int bufferSize; + + @Nullable + private volatile ResponseBodyFlushProcessor bodyFlushProcessor; + + @Nullable + private volatile ResponseBodyProcessor bodyProcessor; + + private volatile boolean flushOnNext; + + private final ServletServerHttpRequest request; + + public ServletServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, + DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { + + this(new HttpHeaders(), response, asyncContext, bufferFactory, bufferSize, request); + } + + public ServletServerHttpResponse(HttpHeaders headers, HttpServletResponse response, AsyncContext asyncContext, + DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { + + super(bufferFactory, headers); + + Assert.notNull(response, "HttpServletResponse must not be null"); + Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); + Assert.isTrue(bufferSize > 0, "Buffer size must be greater than 0"); + + this.response = response; + this.outputStream = response.getOutputStream(); + this.bufferSize = bufferSize; + this.request = request; + + asyncContext.addListener(new ResponseAsyncListener()); + + // Tomcat expects WriteListener registration on initial thread + response.getOutputStream().setWriteListener(new ResponseBodyWriteListener()); + } + + + @SuppressWarnings("unchecked") + @Override + public T getNativeResponse() { + return (T) this.response; + } + + @Override + public HttpStatus getStatusCode() { + HttpStatus httpStatus = super.getStatusCode(); + return (httpStatus != null ? httpStatus : HttpStatus.resolve(this.response.getStatus())); + } + + @Override + protected void applyStatusCode() { + Integer statusCode = getStatusCodeValue(); + if (statusCode != null) { + this.response.setStatus(statusCode); + } + } + + @Override + protected void applyHeaders() { + getHeaders().forEach((headerName, headerValues) -> { + for (String headerValue : headerValues) { + this.response.addHeader(headerName, headerValue); + } + }); + MediaType contentType = null; + try { + contentType = getHeaders().getContentType(); + } + catch (Exception ex) { + String rawContentType = getHeaders().getFirst(HttpHeaders.CONTENT_TYPE); + this.response.setContentType(rawContentType); + } + if (this.response.getContentType() == null && contentType != null) { + this.response.setContentType(contentType.toString()); + } + Charset charset = (contentType != null ? contentType.getCharset() : null); + if (this.response.getCharacterEncoding() == null && charset != null) { + this.response.setCharacterEncoding(charset.name()); + } + long contentLength = getHeaders().getContentLength(); + if (contentLength != -1) { + this.response.setContentLengthLong(contentLength); + } + } + + @Override + protected void applyCookies() { + + // Servlet Cookie doesn't support same site: + // https://github.com/eclipse-ee4j/servlet-api/issues/175 + + // For Jetty, starting 9.4.21+ we could adapt to HttpCookie: + // https://github.com/eclipse/jetty.project/issues/3040 + + // For Tomcat it seems to be a global option only: + // https://tomcat.apache.org/tomcat-8.5-doc/config/cookie-processor.html + + for (List cookies : getCookies().values()) { + for (ResponseCookie cookie : cookies) { + this.response.addHeader(HttpHeaders.SET_COOKIE, cookie.toString()); + } + } + } + + @Override + protected Processor, Void> createBodyFlushProcessor() { + ResponseBodyFlushProcessor processor = new ResponseBodyFlushProcessor(); + this.bodyFlushProcessor = processor; + return processor; + } + + /** + * Write the DataBuffer to the response body OutputStream. + * Invoked only when {@link ServletOutputStream#isReady()} returns "true" + * and the readable bytes in the DataBuffer is greater than 0. + * @return the number of bytes written + */ + protected int writeToOutputStream(DataBuffer dataBuffer) throws IOException { + ServletOutputStream outputStream = this.outputStream; + InputStream input = dataBuffer.asInputStream(); + int bytesWritten = 0; + byte[] buffer = new byte[this.bufferSize]; + int bytesRead; + while (outputStream.isReady() && (bytesRead = input.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + bytesWritten += bytesRead; + } + return bytesWritten; + } + + private void flush() throws IOException { + ServletOutputStream outputStream = this.outputStream; + if (outputStream.isReady()) { + try { + outputStream.flush(); + this.flushOnNext = false; + } + catch (IOException ex) { + this.flushOnNext = true; + throw ex; + } + } + else { + this.flushOnNext = true; + } + } + + private boolean isWritePossible() { + return this.outputStream.isReady(); + } + + + private final class ResponseAsyncListener implements AsyncListener { + + @Override + public void onStartAsync(AsyncEvent event) {} + + @Override + public void onTimeout(AsyncEvent event) { + Throwable ex = event.getThrowable(); + ex = (ex != null ? ex : new IllegalStateException("Async operation timeout.")); + handleError(ex); + } + + @Override + public void onError(AsyncEvent event) { + handleError(event.getThrowable()); + } + + void handleError(Throwable ex) { + ResponseBodyFlushProcessor flushProcessor = bodyFlushProcessor; + if (flushProcessor != null) { + flushProcessor.cancel(); + flushProcessor.onError(ex); + } + + ResponseBodyProcessor processor = bodyProcessor; + if (processor != null) { + processor.cancel(); + processor.onError(ex); + } + } + + @Override + public void onComplete(AsyncEvent event) { + ResponseBodyFlushProcessor flushProcessor = bodyFlushProcessor; + if (flushProcessor != null) { + flushProcessor.cancel(); + flushProcessor.onComplete(); + } + + ResponseBodyProcessor processor = bodyProcessor; + if (processor != null) { + processor.cancel(); + processor.onComplete(); + } + } + } + + + private class ResponseBodyWriteListener implements WriteListener { + + @Override + public void onWritePossible() throws IOException { + ResponseBodyProcessor processor = bodyProcessor; + if (processor != null) { + processor.onWritePossible(); + } + else { + ResponseBodyFlushProcessor flushProcessor = bodyFlushProcessor; + if (flushProcessor != null) { + flushProcessor.onFlushPossible(); + } + } + } + + @Override + public void onError(Throwable ex) { + ResponseBodyProcessor processor = bodyProcessor; + if (processor != null) { + processor.cancel(); + processor.onError(ex); + } + else { + ResponseBodyFlushProcessor flushProcessor = bodyFlushProcessor; + if (flushProcessor != null) { + flushProcessor.cancel(); + flushProcessor.onError(ex); + } + } + } + } + + + private class ResponseBodyFlushProcessor extends AbstractListenerWriteFlushProcessor { + + public ResponseBodyFlushProcessor() { + super(request.getLogPrefix()); + } + + @Override + protected Processor createWriteProcessor() { + ResponseBodyProcessor processor = new ResponseBodyProcessor(); + bodyProcessor = processor; + return processor; + } + + @Override + protected void flush() throws IOException { + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + "Flush attempt"); + } + ServletServerHttpResponse.this.flush(); + } + + @Override + protected boolean isWritePossible() { + return ServletServerHttpResponse.this.isWritePossible(); + } + + @Override + protected boolean isFlushPending() { + return flushOnNext; + } + } + + + private class ResponseBodyProcessor extends AbstractListenerWriteProcessor { + + + public ResponseBodyProcessor() { + super(request.getLogPrefix()); + } + + @Override + protected boolean isWritePossible() { + return ServletServerHttpResponse.this.isWritePossible(); + } + + @Override + protected boolean isDataEmpty(DataBuffer dataBuffer) { + return dataBuffer.readableByteCount() == 0; + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (ServletServerHttpResponse.this.flushOnNext) { + if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "Flush attempt"); + } + flush(); + } + + boolean ready = ServletServerHttpResponse.this.isWritePossible(); + int remaining = dataBuffer.readableByteCount(); + if (ready && remaining > 0) { + // In case of IOException, onError handling should call discardData(DataBuffer).. + int written = writeToOutputStream(dataBuffer); + if (logger.isTraceEnabled()) { + logger.trace(getLogPrefix() + "Wrote " + written + " of " + remaining + " bytes"); + } + else if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "Wrote " + written + " of " + remaining + " bytes"); + } + if (written == remaining) { + DataBufferUtils.release(dataBuffer); + return true; + } + } + else { + if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "ready: " + ready + ", remaining: " + remaining); + } + } + + return false; + } + + @Override + protected void writingComplete() { + bodyProcessor = null; + } + + @Override + protected void discardData(DataBuffer dataBuffer) { + DataBufferUtils.release(dataBuffer); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/SslInfo.java b/spring-web/src/main/java/org/springframework/http/server/reactive/SslInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..47e4f9a1bf1310953f7bb691fb9886b6bbade64c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/SslInfo.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.security.cert.X509Certificate; + +import org.springframework.lang.Nullable; + +/** + * A holder for SSL session information. + * + * @author Rossen Stoyanchev + * @since 5.0.2 + */ +public interface SslInfo { + + /** + * Return the SSL session id, if any. + */ + @Nullable + String getSessionId(); + + /** + * Return SSL certificates associated with the request, if any. + */ + @Nullable + X509Certificate[] getPeerCertificates(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..10ca124e79b748b7e6f1e7a62746b03b46e1ce70 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java @@ -0,0 +1,250 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.tomcat.util.buf.MessageBytes; +import org.apache.tomcat.util.http.MimeHeaders; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Tomcat HTTP headers. + * + * @author Brian Clozel + * @since 5.1.1 + */ +class TomcatHeadersAdapter implements MultiValueMap { + + private final MimeHeaders headers; + + + TomcatHeadersAdapter(MimeHeaders headers) { + this.headers = headers; + } + + + @Override + public String getFirst(String key) { + return this.headers.getHeader(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.addValue(key).setString(value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.setValue(key).setString(value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.keySet().forEach(key -> singleValueMap.put(key, getFirst(key))); + return singleValueMap; + } + + @Override + public int size() { + Enumeration names = this.headers.names(); + int size = 0; + while (names.hasMoreElements()) { + size++; + names.nextElement(); + } + return size; + } + + @Override + public boolean isEmpty() { + return (this.headers.size() == 0); + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return (this.headers.findHeader((String) key, 0) != -1); + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + MessageBytes needle = MessageBytes.newInstance(); + needle.setString((String) value); + for (int i = 0; i < this.headers.size(); i++) { + if (this.headers.getValue(i).equals(needle)) { + return true; + } + } + } + return false; + } + + @Override + @Nullable + public List get(Object key) { + if (containsKey(key)) { + return Collections.list(this.headers.values((String) key)); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + List previousValues = get(key); + this.headers.removeHeader(key); + value.forEach(v -> this.headers.addValue(key).setString(v)); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + List previousValues = get(key); + this.headers.removeHeader((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> map) { + map.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + Set result = new HashSet<>(8); + Enumeration names = this.headers.names(); + while (names.hasMoreElements()) { + result.add(names.nextElement()); + } + return result; + } + + @Override + public Collection> values() { + return keySet().stream().map(this::get).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + + @Override + public String toString() { + return HttpHeaders.formatHeaders(this); + } + + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.names(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + + private final class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Nullable + @Override + public List getValue() { + return get(this.key); + } + + @Nullable + @Override + public List setValue(List value) { + List previous = getValue(); + headers.removeHeader(this.key); + addAll(this.key, value); + return previous; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..c01a11b5a7ef7f4730ca1add9947f9b0d1b84e3d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java @@ -0,0 +1,241 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; + +import javax.servlet.AsyncContext; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.apache.catalina.connector.CoyoteInputStream; +import org.apache.catalina.connector.CoyoteOutputStream; +import org.apache.catalina.connector.RequestFacade; +import org.apache.catalina.connector.ResponseFacade; +import org.apache.coyote.Request; +import org.apache.coyote.Response; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * {@link ServletHttpHandlerAdapter} extension that uses Tomcat APIs for reading + * from the request and writing to the response with {@link ByteBuffer}. + * + * @author Violeta Georgieva + * @author Brian Clozel + * @since 5.0 + * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer + */ +public class TomcatHttpHandlerAdapter extends ServletHttpHandlerAdapter { + + + public TomcatHttpHandlerAdapter(HttpHandler httpHandler) { + super(httpHandler); + } + + + @Override + protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext asyncContext) + throws IOException, URISyntaxException { + + Assert.notNull(getServletPath(), "Servlet path is not initialized"); + return new TomcatServerHttpRequest( + request, asyncContext, getServletPath(), getDataBufferFactory(), getBufferSize()); + } + + @Override + protected ServletServerHttpResponse createResponse(HttpServletResponse response, + AsyncContext asyncContext, ServletServerHttpRequest request) throws IOException { + + return new TomcatServerHttpResponse( + response, asyncContext, getDataBufferFactory(), getBufferSize(), request); + } + + + private static final class TomcatServerHttpRequest extends ServletServerHttpRequest { + + private static final Field COYOTE_REQUEST_FIELD; + + private final int bufferSize; + + private final DataBufferFactory factory; + + static { + Field field = ReflectionUtils.findField(RequestFacade.class, "request"); + Assert.state(field != null, "Incompatible Tomcat implementation"); + ReflectionUtils.makeAccessible(field); + COYOTE_REQUEST_FIELD = field; + } + + TomcatServerHttpRequest(HttpServletRequest request, AsyncContext context, + String servletPath, DataBufferFactory factory, int bufferSize) + throws IOException, URISyntaxException { + + super(createTomcatHttpHeaders(request), request, context, servletPath, factory, bufferSize); + this.factory = factory; + this.bufferSize = bufferSize; + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletRequest request) { + RequestFacade requestFacade = getRequestFacade(request); + org.apache.catalina.connector.Request connectorRequest = (org.apache.catalina.connector.Request) + ReflectionUtils.getField(COYOTE_REQUEST_FIELD, requestFacade); + Assert.state(connectorRequest != null, "No Tomcat connector request"); + Request tomcatRequest = connectorRequest.getCoyoteRequest(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatRequest.getMimeHeaders()); + return new HttpHeaders(headers); + } + + private static RequestFacade getRequestFacade(HttpServletRequest request) { + if (request instanceof RequestFacade) { + return (RequestFacade) request; + } + else if (request instanceof HttpServletRequestWrapper) { + HttpServletRequestWrapper wrapper = (HttpServletRequestWrapper) request; + HttpServletRequest wrappedRequest = (HttpServletRequest) wrapper.getRequest(); + return getRequestFacade(wrappedRequest); + } + else { + throw new IllegalArgumentException("Cannot convert [" + request.getClass() + + "] to org.apache.catalina.connector.RequestFacade"); + } + } + + @Override + protected DataBuffer readFromInputStream() throws IOException { + boolean release = true; + int capacity = this.bufferSize; + DataBuffer dataBuffer = this.factory.allocateBuffer(capacity); + try { + ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, capacity); + ServletRequest request = getNativeRequest(); + int read = ((CoyoteInputStream) request.getInputStream()).read(byteBuffer); + logBytesRead(read); + if (read > 0) { + dataBuffer.writePosition(read); + release = false; + return dataBuffer; + } + else if (read == -1) { + return EOF_BUFFER; + } + else { + return null; + } + } + finally { + if (release) { + DataBufferUtils.release(dataBuffer); + } + } + } + } + + + private static final class TomcatServerHttpResponse extends ServletServerHttpResponse { + + private static final Field COYOTE_RESPONSE_FIELD; + + static { + Field field = ReflectionUtils.findField(ResponseFacade.class, "response"); + Assert.state(field != null, "Incompatible Tomcat implementation"); + ReflectionUtils.makeAccessible(field); + COYOTE_RESPONSE_FIELD = field; + } + + TomcatServerHttpResponse(HttpServletResponse response, AsyncContext context, + DataBufferFactory factory, int bufferSize, ServletServerHttpRequest request) throws IOException { + + super(createTomcatHttpHeaders(response), response, context, factory, bufferSize, request); + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletResponse response) { + ResponseFacade responseFacade = getResponseFacade(response); + org.apache.catalina.connector.Response connectorResponse = (org.apache.catalina.connector.Response) + ReflectionUtils.getField(COYOTE_RESPONSE_FIELD, responseFacade); + Assert.state(connectorResponse != null, "No Tomcat connector response"); + Response tomcatResponse = connectorResponse.getCoyoteResponse(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatResponse.getMimeHeaders()); + return new HttpHeaders(headers); + } + + private static ResponseFacade getResponseFacade(HttpServletResponse response) { + if (response instanceof ResponseFacade) { + return (ResponseFacade) response; + } + else if (response instanceof HttpServletResponseWrapper) { + HttpServletResponseWrapper wrapper = (HttpServletResponseWrapper) response; + HttpServletResponse wrappedResponse = (HttpServletResponse) wrapper.getResponse(); + return getResponseFacade(wrappedResponse); + } + else { + throw new IllegalArgumentException("Cannot convert [" + response.getClass() + + "] to org.apache.catalina.connector.ResponseFacade"); + } + } + + @Override + protected void applyHeaders() { + HttpServletResponse response = getNativeResponse(); + MediaType contentType = null; + try { + contentType = getHeaders().getContentType(); + } + catch (Exception ex) { + String rawContentType = getHeaders().getFirst(HttpHeaders.CONTENT_TYPE); + response.setContentType(rawContentType); + } + if (response.getContentType() == null && contentType != null) { + response.setContentType(contentType.toString()); + } + getHeaders().remove(HttpHeaders.CONTENT_TYPE); + Charset charset = (contentType != null ? contentType.getCharset() : null); + if (response.getCharacterEncoding() == null && charset != null) { + response.setCharacterEncoding(charset.name()); + } + long contentLength = getHeaders().getContentLength(); + if (contentLength != -1) { + response.setContentLengthLong(contentLength); + } + getHeaders().remove(HttpHeaders.CONTENT_LENGTH); + } + + @Override + protected int writeToOutputStream(DataBuffer dataBuffer) throws IOException { + ByteBuffer input = dataBuffer.asByteBuffer(); + int len = input.remaining(); + ServletResponse response = getNativeResponse(); + ((CoyoteOutputStream) response.getOutputStream()).write(input); + return len; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..bbe1f3df6e0fa0fac6f740e6af79214523bb76cc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.undertow.util.HeaderMap; +import io.undertow.util.HeaderValues; +import io.undertow.util.HttpString; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Undertow HTTP headers. + * + * @author Brian Clozel + * @since 5.1.1 + */ +class UndertowHeadersAdapter implements MultiValueMap { + + private final HeaderMap headers; + + + UndertowHeadersAdapter(HeaderMap headers) { + this.headers = headers; + } + + + @Override + public String getFirst(String key) { + return this.headers.getFirst(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(HttpString.tryFromString(key), value); + } + + @Override + @SuppressWarnings("unchecked") + public void addAll(String key, List values) { + this.headers.addAll(HttpString.tryFromString(key), (List) values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach((key, list) -> this.headers.addAll(HttpString.tryFromString(key), list)); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(HttpString.tryFromString(key), value); + } + + @Override + public void setAll(Map values) { + values.forEach((key, list) -> this.headers.put(HttpString.tryFromString(key), list)); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.forEach(values -> + singleValueMap.put(values.getHeaderName().toString(), values.getFirst())); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return (this.headers.size() == 0); + } + + @Override + public boolean containsKey(Object key) { + return (key instanceof String && this.headers.contains((String) key)); + } + + @Override + public boolean containsValue(Object value) { + return (value instanceof String && + this.headers.getHeaderNames().stream() + .map(this.headers::get) + .anyMatch(values -> values.contains(value))); + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return this.headers.get((String) key); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + HeaderValues previousValues = this.headers.get(key); + this.headers.putAll(HttpString.tryFromString(key), value); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + this.headers.remove((String) key); + } + return null; + } + + @Override + public void putAll(Map> map) { + map.forEach((key, values) -> + this.headers.putAll(HttpString.tryFromString(key), values)); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getHeaderNames().stream() + .map(HttpString::toString) + .collect(Collectors.toSet()); + } + + @Override + public Collection> values() { + return this.headers.getHeaderNames().stream() + .map(this.headers::get) + .collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + + @Override + public String toString() { + return org.springframework.http.HttpHeaders.formatHeaders(this); + } + + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.getHeaderNames().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + + private class HeaderEntry implements Entry> { + + private final HttpString key; + + HeaderEntry(HttpString key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key.toString(); + } + + @Override + public List getValue() { + return headers.get(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.get(this.key); + headers.putAll(this.key, value); + return previousValues; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..49867f2b013d754a52776675b1f52cb753dbbc29 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.net.URISyntaxException; + +import io.undertow.server.HttpServerExchange; +import org.apache.commons.logging.Log; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpLogging; +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; + +/** + * Adapt {@link HttpHandler} to the Undertow {@link io.undertow.server.HttpHandler}. + * + * @author Marek Hawrylczak + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @since 5.0 + */ +public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandler { + + private static final Log logger = HttpLogging.forLogName(UndertowHttpHandlerAdapter.class); + + + private final HttpHandler httpHandler; + + private DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false); + + + public UndertowHttpHandlerAdapter(HttpHandler httpHandler) { + Assert.notNull(httpHandler, "HttpHandler must not be null"); + this.httpHandler = httpHandler; + } + + + public void setDataBufferFactory(DataBufferFactory bufferFactory) { + Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); + this.bufferFactory = bufferFactory; + } + + public DataBufferFactory getDataBufferFactory() { + return this.bufferFactory; + } + + + @Override + public void handleRequest(HttpServerExchange exchange) { + UndertowServerHttpRequest request = null; + try { + request = new UndertowServerHttpRequest(exchange, getDataBufferFactory()); + } + catch (URISyntaxException ex) { + if (logger.isWarnEnabled()) { + logger.debug("Failed to get request URI: " + ex.getMessage()); + } + exchange.setStatusCode(400); + return; + } + ServerHttpResponse response = new UndertowServerHttpResponse(exchange, getDataBufferFactory(), request); + + if (request.getMethod() == HttpMethod.HEAD) { + response = new HttpHeadResponseDecorator(response); + } + + HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(exchange, request); + this.httpHandler.handle(request, response).subscribe(resultSubscriber); + } + + + private class HandlerResultSubscriber implements Subscriber { + + private final HttpServerExchange exchange; + + private final String logPrefix; + + + public HandlerResultSubscriber(HttpServerExchange exchange, UndertowServerHttpRequest request) { + this.exchange = exchange; + this.logPrefix = request.getLogPrefix(); + } + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + // no-op + } + + @Override + public void onError(Throwable ex) { + logger.trace(this.logPrefix + "Failed to complete: " + ex.getMessage()); + if (this.exchange.isResponseStarted()) { + try { + logger.debug(this.logPrefix + "Closing connection"); + this.exchange.getConnection().close(); + } + catch (IOException ex2) { + // ignore + } + } + else { + logger.debug(this.logPrefix + "Setting HttpServerExchange status to 500 Server Error"); + this.exchange.setStatusCode(500); + this.exchange.endExchange(); + } + } + + @Override + public void onComplete() { + logger.trace(this.logPrefix + "Handling completed"); + this.exchange.endExchange(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..d01600eabbfce413effcabbbc376fac34e5712ff --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -0,0 +1,411 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntPredicate; + +import javax.net.ssl.SSLSession; + +import io.undertow.connector.ByteBufferPool; +import io.undertow.connector.PooledByteBuffer; +import io.undertow.server.HttpServerExchange; +import io.undertow.server.handlers.Cookie; +import org.xnio.channels.StreamSourceChannel; +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * Adapt {@link ServerHttpRequest} to the Undertow {@link HttpServerExchange}. + * + * @author Marek Hawrylczak + * @author Rossen Stoyanchev + * @since 5.0 + */ +class UndertowServerHttpRequest extends AbstractServerHttpRequest { + + private final HttpServerExchange exchange; + + private final RequestBodyPublisher body; + + + public UndertowServerHttpRequest(HttpServerExchange exchange, DataBufferFactory bufferFactory) + throws URISyntaxException { + + super(initUri(exchange), "", initHeaders(exchange)); + this.exchange = exchange; + this.body = new RequestBodyPublisher(exchange, bufferFactory); + this.body.registerListeners(exchange); + } + + private static URI initUri(HttpServerExchange exchange) throws URISyntaxException { + Assert.notNull(exchange, "HttpServerExchange is required"); + String requestURL = exchange.getRequestURL(); + String query = exchange.getQueryString(); + String requestUriAndQuery = (StringUtils.hasLength(query) ? requestURL + "?" + query : requestURL); + return new URI(requestUriAndQuery); + } + + private static HttpHeaders initHeaders(HttpServerExchange exchange) { + return new HttpHeaders(new UndertowHeadersAdapter(exchange.getRequestHeaders())); + } + + @Override + public String getMethodValue() { + return this.exchange.getRequestMethod().toString(); + } + + @Override + protected MultiValueMap initCookies() { + MultiValueMap cookies = new LinkedMultiValueMap<>(); + for (String name : this.exchange.getRequestCookies().keySet()) { + Cookie cookie = this.exchange.getRequestCookies().get(name); + HttpCookie httpCookie = new HttpCookie(name, cookie.getValue()); + cookies.add(name, httpCookie); + } + return cookies; + } + + @Override + @Nullable + public InetSocketAddress getRemoteAddress() { + return this.exchange.getSourceAddress(); + } + + @Nullable + @Override + protected SslInfo initSslInfo() { + SSLSession session = this.exchange.getConnection().getSslSession(); + if (session != null) { + return new DefaultSslInfo(session); + } + return null; + } + + @Override + public Flux getBody() { + return Flux.from(this.body); + } + + @SuppressWarnings("unchecked") + @Override + public T getNativeRequest() { + return (T) this.exchange; + } + + @Override + protected String initId() { + return ObjectUtils.getIdentityHexString(this.exchange.getConnection()); + } + + + private class RequestBodyPublisher extends AbstractListenerReadPublisher { + + private final StreamSourceChannel channel; + + private final DataBufferFactory bufferFactory; + + private final ByteBufferPool byteBufferPool; + + public RequestBodyPublisher(HttpServerExchange exchange, DataBufferFactory bufferFactory) { + super(UndertowServerHttpRequest.this.getLogPrefix()); + this.channel = exchange.getRequestChannel(); + this.bufferFactory = bufferFactory; + this.byteBufferPool = exchange.getConnection().getByteBufferPool(); + } + + private void registerListeners(HttpServerExchange exchange) { + exchange.addExchangeCompleteListener((ex, next) -> { + onAllDataRead(); + next.proceed(); + }); + this.channel.getReadSetter().set(c -> onDataAvailable()); + this.channel.getCloseSetter().set(c -> onAllDataRead()); + this.channel.resumeReads(); + } + + @Override + protected void checkOnDataAvailable() { + this.channel.resumeReads(); + // We are allowed to try, it will return null if data is not available + onDataAvailable(); + } + + @Override + protected void readingPaused() { + this.channel.suspendReads(); + } + + @Override + @Nullable + protected DataBuffer read() throws IOException { + PooledByteBuffer pooledByteBuffer = this.byteBufferPool.allocate(); + boolean release = true; + try { + ByteBuffer byteBuffer = pooledByteBuffer.getBuffer(); + int read = this.channel.read(byteBuffer); + + if (rsReadLogger.isTraceEnabled()) { + rsReadLogger.trace(getLogPrefix() + "Read " + read + (read != -1 ? " bytes" : "")); + } + + if (read > 0) { + byteBuffer.flip(); + DataBuffer dataBuffer = this.bufferFactory.wrap(byteBuffer); + release = false; + return new UndertowDataBuffer(dataBuffer, pooledByteBuffer); + } + else if (read == -1) { + onAllDataRead(); + } + return null; + } + finally { + if (release && pooledByteBuffer.isOpen()) { + pooledByteBuffer.close(); + } + } + } + + @Override + protected void discardData() { + // Nothing to discard since we pass data buffers on immediately.. + } + } + + + private static class UndertowDataBuffer implements PooledDataBuffer { + + private final DataBuffer dataBuffer; + + private final PooledByteBuffer pooledByteBuffer; + + private final AtomicInteger refCount; + + public UndertowDataBuffer(DataBuffer dataBuffer, PooledByteBuffer pooledByteBuffer) { + this.dataBuffer = dataBuffer; + this.pooledByteBuffer = pooledByteBuffer; + this.refCount = new AtomicInteger(1); + } + + private UndertowDataBuffer(DataBuffer dataBuffer, PooledByteBuffer pooledByteBuffer, + AtomicInteger refCount) { + this.refCount = refCount; + this.dataBuffer = dataBuffer; + this.pooledByteBuffer = pooledByteBuffer; + } + + @Override + public boolean isAllocated() { + return this.refCount.get() > 0; + } + + @Override + public PooledDataBuffer retain() { + this.refCount.incrementAndGet(); + DataBufferUtils.retain(this.dataBuffer); + return this; + } + + @Override + public boolean release() { + int refCount = this.refCount.decrementAndGet(); + if (refCount == 0) { + try { + return DataBufferUtils.release(this.dataBuffer); + } + finally { + this.pooledByteBuffer.close(); + } + } + return false; + } + + @Override + public DataBufferFactory factory() { + return this.dataBuffer.factory(); + } + + @Override + public int indexOf(IntPredicate predicate, int fromIndex) { + return this.dataBuffer.indexOf(predicate, fromIndex); + } + + @Override + public int lastIndexOf(IntPredicate predicate, int fromIndex) { + return this.dataBuffer.lastIndexOf(predicate, fromIndex); + } + + @Override + public int readableByteCount() { + return this.dataBuffer.readableByteCount(); + } + + @Override + public int writableByteCount() { + return this.dataBuffer.writableByteCount(); + } + + @Override + public int readPosition() { + return this.dataBuffer.readPosition(); + } + + @Override + public DataBuffer readPosition(int readPosition) { + return this.dataBuffer.readPosition(readPosition); + } + + @Override + public int writePosition() { + return this.dataBuffer.writePosition(); + } + + @Override + public DataBuffer writePosition(int writePosition) { + this.dataBuffer.writePosition(writePosition); + return this; + } + + @Override + public int capacity() { + return this.dataBuffer.capacity(); + } + + @Override + public DataBuffer capacity(int newCapacity) { + this.dataBuffer.capacity(newCapacity); + return this; + } + + @Override + public DataBuffer ensureCapacity(int capacity) { + this.dataBuffer.ensureCapacity(capacity); + return this; + } + + @Override + public byte getByte(int index) { + return this.dataBuffer.getByte(index); + } + + @Override + public byte read() { + return this.dataBuffer.read(); + } + + @Override + public DataBuffer read(byte[] destination) { + this.dataBuffer.read(destination); + return this; + } + + @Override + public DataBuffer read(byte[] destination, int offset, int length) { + this.dataBuffer.read(destination, offset, length); + return this; + } + + @Override + public DataBuffer write(byte b) { + this.dataBuffer.write(b); + return this; + } + + @Override + public DataBuffer write(byte[] source) { + this.dataBuffer.write(source); + return this; + } + + @Override + public DataBuffer write(byte[] source, int offset, int length) { + this.dataBuffer.write(source, offset, length); + return this; + } + + @Override + public DataBuffer write(DataBuffer... buffers) { + this.dataBuffer.write(buffers); + return this; + } + + @Override + public DataBuffer write(ByteBuffer... byteBuffers) { + this.dataBuffer.write(byteBuffers); + return this; + } + + @Override + public DataBuffer write(CharSequence charSequence, Charset charset) { + this.dataBuffer.write(charSequence, charset); + return this; + } + + @Override + public DataBuffer slice(int index, int length) { + DataBuffer slice = this.dataBuffer.slice(index, length); + return new UndertowDataBuffer(slice, this.pooledByteBuffer, this.refCount); + } + + @Override + public ByteBuffer asByteBuffer() { + return this.dataBuffer.asByteBuffer(); + } + + @Override + public ByteBuffer asByteBuffer(int index, int length) { + return this.dataBuffer.asByteBuffer(index, length); + } + + @Override + public InputStream asInputStream() { + return this.dataBuffer.asInputStream(); + } + + @Override + public InputStream asInputStream(boolean releaseOnClose) { + return this.dataBuffer.asInputStream(releaseOnClose); + } + + @Override + public OutputStream asOutputStream() { + return this.dataBuffer.asOutputStream(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..d19f839cb2b8e328893b76e08518d9c1c2aeabcb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -0,0 +1,356 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; + +import io.undertow.server.HttpServerExchange; +import io.undertow.server.handlers.Cookie; +import io.undertow.server.handlers.CookieImpl; +import org.reactivestreams.Processor; +import org.reactivestreams.Publisher; +import org.xnio.channels.StreamSinkChannel; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.http.ZeroCopyHttpOutputMessage; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Adapt {@link ServerHttpResponse} to the Undertow {@link HttpServerExchange}. + * + * @author Marek Hawrylczak + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @since 5.0 + */ +class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse implements ZeroCopyHttpOutputMessage { + + private final HttpServerExchange exchange; + + private final UndertowServerHttpRequest request; + + @Nullable + private StreamSinkChannel responseChannel; + + + UndertowServerHttpResponse( + HttpServerExchange exchange, DataBufferFactory bufferFactory, UndertowServerHttpRequest request) { + + super(bufferFactory, createHeaders(exchange)); + Assert.notNull(exchange, "HttpServerExchange must not be null"); + this.exchange = exchange; + this.request = request; + } + + private static HttpHeaders createHeaders(HttpServerExchange exchange) { + UndertowHeadersAdapter headersMap = new UndertowHeadersAdapter(exchange.getResponseHeaders()); + return new HttpHeaders(headersMap); + } + + + @SuppressWarnings("unchecked") + @Override + public T getNativeResponse() { + return (T) this.exchange; + } + + @Override + public HttpStatus getStatusCode() { + HttpStatus httpStatus = super.getStatusCode(); + return (httpStatus != null ? httpStatus : HttpStatus.resolve(this.exchange.getStatusCode())); + } + + + @Override + protected void applyStatusCode() { + Integer statusCode = getStatusCodeValue(); + if (statusCode != null) { + this.exchange.setStatusCode(statusCode); + } + } + + @Override + protected void applyHeaders() { + } + + @Override + protected void applyCookies() { + for (String name : getCookies().keySet()) { + for (ResponseCookie httpCookie : getCookies().get(name)) { + Cookie cookie = new CookieImpl(name, httpCookie.getValue()); + if (!httpCookie.getMaxAge().isNegative()) { + cookie.setMaxAge((int) httpCookie.getMaxAge().getSeconds()); + } + if (httpCookie.getDomain() != null) { + cookie.setDomain(httpCookie.getDomain()); + } + if (httpCookie.getPath() != null) { + cookie.setPath(httpCookie.getPath()); + } + cookie.setSecure(httpCookie.isSecure()); + cookie.setHttpOnly(httpCookie.isHttpOnly()); + cookie.setSameSiteMode(httpCookie.getSameSite()); + this.exchange.getResponseCookies().putIfAbsent(name, cookie); + } + } + } + + @Override + public Mono writeWith(Path file, long position, long count) { + return doCommit(() -> + Mono.create(sink -> { + try { + FileChannel source = FileChannel.open(file, StandardOpenOption.READ); + + TransferBodyListener listener = new TransferBodyListener(source, position, + count, sink); + sink.onDispose(listener::closeSource); + + StreamSinkChannel destination = this.exchange.getResponseChannel(); + destination.getWriteSetter().set(listener::transfer); + + listener.transfer(destination); + } + catch (IOException ex) { + sink.error(ex); + } + })); + } + + @Override + protected Processor, Void> createBodyFlushProcessor() { + return new ResponseBodyFlushProcessor(); + } + + private ResponseBodyProcessor createBodyProcessor() { + if (this.responseChannel == null) { + this.responseChannel = this.exchange.getResponseChannel(); + } + return new ResponseBodyProcessor(this.responseChannel); + } + + + private class ResponseBodyProcessor extends AbstractListenerWriteProcessor { + + private final StreamSinkChannel channel; + + @Nullable + private volatile ByteBuffer byteBuffer; + + /** Keep track of write listener calls, for {@link #writePossible}. */ + private volatile boolean writePossible; + + + public ResponseBodyProcessor(StreamSinkChannel channel) { + super(request.getLogPrefix()); + Assert.notNull(channel, "StreamSinkChannel must not be null"); + this.channel = channel; + this.channel.getWriteSetter().set(c -> { + this.writePossible = true; + onWritePossible(); + }); + this.channel.suspendWrites(); + } + + @Override + protected boolean isWritePossible() { + this.channel.resumeWrites(); + return this.writePossible; + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + ByteBuffer buffer = this.byteBuffer; + if (buffer == null) { + return false; + } + + // Track write listener calls from here on.. + this.writePossible = false; + + // In case of IOException, onError handling should call discardData(DataBuffer).. + int total = buffer.remaining(); + int written = writeByteBuffer(buffer); + + if (logger.isTraceEnabled()) { + logger.trace(getLogPrefix() + "Wrote " + written + " of " + total + " bytes"); + } + else if (rsWriteLogger.isTraceEnabled()) { + rsWriteLogger.trace(getLogPrefix() + "Wrote " + written + " of " + total + " bytes"); + } + if (written != total) { + return false; + } + + // We wrote all, so can still write more.. + this.writePossible = true; + + DataBufferUtils.release(dataBuffer); + this.byteBuffer = null; + return true; + } + + private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException { + int written; + int totalWritten = 0; + do { + written = this.channel.write(byteBuffer); + totalWritten += written; + } + while (byteBuffer.hasRemaining() && written > 0); + return totalWritten; + } + + @Override + protected void dataReceived(DataBuffer dataBuffer) { + super.dataReceived(dataBuffer); + this.byteBuffer = dataBuffer.asByteBuffer(); + } + + @Override + protected boolean isDataEmpty(DataBuffer dataBuffer) { + return (dataBuffer.readableByteCount() == 0); + } + + @Override + protected void writingComplete() { + this.channel.getWriteSetter().set(null); + this.channel.resumeWrites(); + } + + @Override + protected void writingFailed(Throwable ex) { + cancel(); + onError(ex); + } + + @Override + protected void discardData(DataBuffer dataBuffer) { + DataBufferUtils.release(dataBuffer); + } + } + + + private class ResponseBodyFlushProcessor extends AbstractListenerWriteFlushProcessor { + + public ResponseBodyFlushProcessor() { + super(request.getLogPrefix()); + } + + @Override + protected Processor createWriteProcessor() { + return UndertowServerHttpResponse.this.createBodyProcessor(); + } + + @Override + protected void flush() throws IOException { + StreamSinkChannel channel = UndertowServerHttpResponse.this.responseChannel; + if (channel != null) { + if (rsWriteFlushLogger.isTraceEnabled()) { + rsWriteFlushLogger.trace(getLogPrefix() + "flush"); + } + channel.flush(); + } + } + + @Override + protected void flushingFailed(Throwable t) { + cancel(); + onError(t); + } + + @Override + protected boolean isWritePossible() { + StreamSinkChannel channel = UndertowServerHttpResponse.this.responseChannel; + if (channel != null) { + // We can always call flush, just ensure writes are on.. + channel.resumeWrites(); + return true; + } + return false; + } + + @Override + protected boolean isFlushPending() { + return false; + } + } + + + private static class TransferBodyListener { + + private final FileChannel source; + + private final MonoSink sink; + + private long position; + + private long count; + + + public TransferBodyListener(FileChannel source, long position, long count, MonoSink sink) { + this.source = source; + this.sink = sink; + this.position = position; + this.count = count; + } + + public void transfer(StreamSinkChannel destination) { + try { + while (this.count > 0) { + long len = destination.transferFrom(this.source, this.position, this.count); + if (len != 0) { + this.position += len; + this.count -= len; + } + else { + destination.resumeWrites(); + return; + } + } + this.sink.success(); + } + catch (IOException ex) { + this.sink.error(ex); + } + + } + + public void closeSource() { + try { + this.source.close(); + } + catch (IOException ignore) { + } + } + + + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/WriteResultPublisher.java b/spring-web/src/main/java/org/springframework/http/server/reactive/WriteResultPublisher.java new file mode 100644 index 0000000000000000000000000000000000000000..39f6051b4f6170c6f1ec4d538b5fa638e1ecdbc8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/WriteResultPublisher.java @@ -0,0 +1,275 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.commons.logging.Log; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.Operators; + +import org.springframework.core.log.LogDelegateFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Publisher returned from {@link ServerHttpResponse#writeWith(Publisher)}. + * + * @author Arjen Poutsma + * @author Violeta Georgieva + * @author Rossen Stoyanchev + * @since 5.0 + */ +class WriteResultPublisher implements Publisher { + + /** + * Special logger for debugging Reactive Streams signals. + * @see LogDelegateFactory#getHiddenLog(Class) + * @see AbstractListenerReadPublisher#rsReadLogger + * @see AbstractListenerWriteProcessor#rsWriteLogger + * @see AbstractListenerWriteFlushProcessor#rsWriteFlushLogger + */ + private static final Log rsWriteResultLogger = LogDelegateFactory.getHiddenLog(WriteResultPublisher.class); + + + private final AtomicReference state = new AtomicReference<>(State.UNSUBSCRIBED); + + @Nullable + private volatile Subscriber subscriber; + + private volatile boolean completedBeforeSubscribed; + + @Nullable + private volatile Throwable errorBeforeSubscribed; + + private final String logPrefix; + + + public WriteResultPublisher(String logPrefix) { + this.logPrefix = logPrefix; + } + + + @Override + public final void subscribe(Subscriber subscriber) { + if (rsWriteResultLogger.isTraceEnabled()) { + rsWriteResultLogger.trace(this.logPrefix + this.state + " subscribe: " + subscriber); + } + this.state.get().subscribe(this, subscriber); + } + + /** + * Invoke this to delegate a completion signal to the subscriber. + */ + public void publishComplete() { + if (rsWriteResultLogger.isTraceEnabled()) { + rsWriteResultLogger.trace(this.logPrefix + this.state + " publishComplete"); + } + this.state.get().publishComplete(this); + } + + /** + * Invoke this to delegate an error signal to the subscriber. + */ + public void publishError(Throwable t) { + if (rsWriteResultLogger.isTraceEnabled()) { + rsWriteResultLogger.trace(this.logPrefix + this.state + " publishError: " + t); + } + this.state.get().publishError(this, t); + } + + private boolean changeState(State oldState, State newState) { + return this.state.compareAndSet(oldState, newState); + } + + + /** + * Subscription to receive and delegate request and cancel signals from the + * subscriber to this publisher. + */ + private static final class WriteResultSubscription implements Subscription { + + private final WriteResultPublisher publisher; + + public WriteResultSubscription(WriteResultPublisher publisher) { + this.publisher = publisher; + } + + @Override + public final void request(long n) { + if (rsWriteResultLogger.isTraceEnabled()) { + rsWriteResultLogger.trace(this.publisher.logPrefix + state() + " request: " + n); + } + state().request(this.publisher, n); + } + + @Override + public final void cancel() { + if (rsWriteResultLogger.isTraceEnabled()) { + rsWriteResultLogger.trace(this.publisher.logPrefix + state() + " cancel"); + } + state().cancel(this.publisher); + } + + private State state() { + return this.publisher.state.get(); + } + } + + + /** + * Represents a state for the {@link Publisher} to be in. + *

+	 *     UNSUBSCRIBED
+	 *          |
+	 *          v
+	 *     SUBSCRIBING
+	 *          |
+	 *          v
+	 *      SUBSCRIBED
+	 *          |
+	 *          v
+	 *      COMPLETED
+	 * 
+ */ + private enum State { + + UNSUBSCRIBED { + @Override + void subscribe(WriteResultPublisher publisher, Subscriber subscriber) { + Assert.notNull(subscriber, "Subscriber must not be null"); + if (publisher.changeState(this, SUBSCRIBING)) { + Subscription subscription = new WriteResultSubscription(publisher); + publisher.subscriber = subscriber; + subscriber.onSubscribe(subscription); + publisher.changeState(SUBSCRIBING, SUBSCRIBED); + // Now safe to check "beforeSubscribed" flags, they won't change once in NO_DEMAND + if (publisher.completedBeforeSubscribed) { + publisher.publishComplete(); + } + Throwable publisherError = publisher.errorBeforeSubscribed; + if (publisherError != null) { + publisher.publishError(publisherError); + } + } + else { + throw new IllegalStateException(toString()); + } + } + @Override + void publishComplete(WriteResultPublisher publisher) { + publisher.completedBeforeSubscribed = true; + if(State.SUBSCRIBED.equals(publisher.state.get())) { + publisher.state.get().publishComplete(publisher); + } + } + @Override + void publishError(WriteResultPublisher publisher, Throwable ex) { + publisher.errorBeforeSubscribed = ex; + if(State.SUBSCRIBED.equals(publisher.state.get())) { + publisher.state.get().publishError(publisher, ex); + } + } + }, + + SUBSCRIBING { + @Override + void request(WriteResultPublisher publisher, long n) { + Operators.validate(n); + } + @Override + void publishComplete(WriteResultPublisher publisher) { + publisher.completedBeforeSubscribed = true; + if(State.SUBSCRIBED.equals(publisher.state.get())) { + publisher.state.get().publishComplete(publisher); + } + } + @Override + void publishError(WriteResultPublisher publisher, Throwable ex) { + publisher.errorBeforeSubscribed = ex; + if(State.SUBSCRIBED.equals(publisher.state.get())) { + publisher.state.get().publishError(publisher, ex); + } + } + }, + + SUBSCRIBED { + @Override + void request(WriteResultPublisher publisher, long n) { + Operators.validate(n); + } + }, + + COMPLETED { + @Override + void request(WriteResultPublisher publisher, long n) { + // ignore + } + @Override + void cancel(WriteResultPublisher publisher) { + // ignore + } + @Override + void publishComplete(WriteResultPublisher publisher) { + // ignore + } + @Override + void publishError(WriteResultPublisher publisher, Throwable t) { + // ignore + } + }; + + void subscribe(WriteResultPublisher publisher, Subscriber subscriber) { + throw new IllegalStateException(toString()); + } + + void request(WriteResultPublisher publisher, long n) { + throw new IllegalStateException(toString()); + } + + void cancel(WriteResultPublisher publisher) { + if (!publisher.changeState(this, COMPLETED)) { + publisher.state.get().cancel(publisher); + } + } + + void publishComplete(WriteResultPublisher publisher) { + if (publisher.changeState(this, COMPLETED)) { + Subscriber s = publisher.subscriber; + Assert.state(s != null, "No subscriber"); + s.onComplete(); + } + else { + publisher.state.get().publishComplete(publisher); + } + } + + void publishError(WriteResultPublisher publisher, Throwable t) { + if (publisher.changeState(this, COMPLETED)) { + Subscriber s = publisher.subscriber; + Assert.state(s != null, "No subscriber"); + s.onError(t); + } + else { + publisher.state.get().publishError(publisher, t); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/package-info.java b/spring-web/src/main/java/org/springframework/http/server/reactive/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..f2025bc49324d3fb28fbad821e93d7e7925393b0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/package-info.java @@ -0,0 +1,15 @@ +/** + * Abstractions for reactive HTTP server support including a + * {@link org.springframework.http.server.reactive.ServerHttpRequest} and + * {@link org.springframework.http.server.reactive.ServerHttpResponse} along with an + * {@link org.springframework.http.server.reactive.HttpHandler} for processing. + * + *

Also provides implementations adapting to different runtimes + * including Servlet 3.1 containers, Netty + Reactor IO, and Undertow. + */ +@NonNullApi +@NonNullFields +package org.springframework.http.server.reactive; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/remoting/caucho/HessianClientInterceptor.java b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianClientInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..980842410499c7264986764df4258d08ee63a2e4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianClientInterceptor.java @@ -0,0 +1,300 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.caucho; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.UndeclaredThrowableException; +import java.net.ConnectException; +import java.net.MalformedURLException; + +import com.caucho.hessian.HessianException; +import com.caucho.hessian.client.HessianConnectionException; +import com.caucho.hessian.client.HessianConnectionFactory; +import com.caucho.hessian.client.HessianProxyFactory; +import com.caucho.hessian.client.HessianRuntimeException; +import com.caucho.hessian.io.SerializerFactory; +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; + +import org.springframework.lang.Nullable; +import org.springframework.remoting.RemoteAccessException; +import org.springframework.remoting.RemoteConnectFailureException; +import org.springframework.remoting.RemoteLookupFailureException; +import org.springframework.remoting.RemoteProxyFailureException; +import org.springframework.remoting.support.UrlBasedRemoteAccessor; +import org.springframework.util.Assert; + +/** + * {@link org.aopalliance.intercept.MethodInterceptor} for accessing a Hessian service. + * Supports authentication via username and password. + * The service URL must be an HTTP URL exposing a Hessian service. + * + *

Hessian is a slim, binary RPC protocol. + * For information on Hessian, see the + * Hessian website + * Note: As of Spring 4.0, this client requires Hessian 4.0 or above. + * + *

Note: There is no requirement for services accessed with this proxy factory + * to have been exported using Spring's {@link HessianServiceExporter}, as there is + * no special handling involved. As a consequence, you can also access services that + * have been exported using Caucho's {@link com.caucho.hessian.server.HessianServlet}. + * + * @author Juergen Hoeller + * @since 29.09.2003 + * @see #setServiceInterface + * @see #setServiceUrl + * @see #setUsername + * @see #setPassword + * @see HessianServiceExporter + * @see HessianProxyFactoryBean + * @see com.caucho.hessian.client.HessianProxyFactory + * @see com.caucho.hessian.server.HessianServlet + */ +public class HessianClientInterceptor extends UrlBasedRemoteAccessor implements MethodInterceptor { + + private HessianProxyFactory proxyFactory = new HessianProxyFactory(); + + @Nullable + private Object hessianProxy; + + + /** + * Set the HessianProxyFactory instance to use. + * If not specified, a default HessianProxyFactory will be created. + *

Allows to use an externally configured factory instance, + * in particular a custom HessianProxyFactory subclass. + */ + public void setProxyFactory(@Nullable HessianProxyFactory proxyFactory) { + this.proxyFactory = (proxyFactory != null ? proxyFactory : new HessianProxyFactory()); + } + + /** + * Specify the Hessian SerializerFactory to use. + *

This will typically be passed in as an inner bean definition + * of type {@code com.caucho.hessian.io.SerializerFactory}, + * with custom bean property values applied. + */ + public void setSerializerFactory(SerializerFactory serializerFactory) { + this.proxyFactory.setSerializerFactory(serializerFactory); + } + + /** + * Set whether to send the Java collection type for each serialized + * collection. Default is "true". + */ + public void setSendCollectionType(boolean sendCollectionType) { + this.proxyFactory.getSerializerFactory().setSendCollectionType(sendCollectionType); + } + + /** + * Set whether to allow non-serializable types as Hessian arguments + * and return values. Default is "true". + */ + public void setAllowNonSerializable(boolean allowNonSerializable) { + this.proxyFactory.getSerializerFactory().setAllowNonSerializable(allowNonSerializable); + } + + /** + * Set whether overloaded methods should be enabled for remote invocations. + * Default is "false". + * @see com.caucho.hessian.client.HessianProxyFactory#setOverloadEnabled + */ + public void setOverloadEnabled(boolean overloadEnabled) { + this.proxyFactory.setOverloadEnabled(overloadEnabled); + } + + /** + * Set the username that this factory should use to access the remote service. + * Default is none. + *

The username will be sent by Hessian via HTTP Basic Authentication. + * @see com.caucho.hessian.client.HessianProxyFactory#setUser + */ + public void setUsername(String username) { + this.proxyFactory.setUser(username); + } + + /** + * Set the password that this factory should use to access the remote service. + * Default is none. + *

The password will be sent by Hessian via HTTP Basic Authentication. + * @see com.caucho.hessian.client.HessianProxyFactory#setPassword + */ + public void setPassword(String password) { + this.proxyFactory.setPassword(password); + } + + /** + * Set whether Hessian's debug mode should be enabled. + * Default is "false". + * @see com.caucho.hessian.client.HessianProxyFactory#setDebug + */ + public void setDebug(boolean debug) { + this.proxyFactory.setDebug(debug); + } + + /** + * Set whether to use a chunked post for sending a Hessian request. + * @see com.caucho.hessian.client.HessianProxyFactory#setChunkedPost + */ + public void setChunkedPost(boolean chunkedPost) { + this.proxyFactory.setChunkedPost(chunkedPost); + } + + /** + * Specify a custom HessianConnectionFactory to use for the Hessian client. + */ + public void setConnectionFactory(HessianConnectionFactory connectionFactory) { + this.proxyFactory.setConnectionFactory(connectionFactory); + } + + /** + * Set the socket connect timeout to use for the Hessian client. + * @see com.caucho.hessian.client.HessianProxyFactory#setConnectTimeout + */ + public void setConnectTimeout(long timeout) { + this.proxyFactory.setConnectTimeout(timeout); + } + + /** + * Set the timeout to use when waiting for a reply from the Hessian service. + * @see com.caucho.hessian.client.HessianProxyFactory#setReadTimeout + */ + public void setReadTimeout(long timeout) { + this.proxyFactory.setReadTimeout(timeout); + } + + /** + * Set whether version 2 of the Hessian protocol should be used for + * parsing requests and replies. Default is "false". + * @see com.caucho.hessian.client.HessianProxyFactory#setHessian2Request + */ + public void setHessian2(boolean hessian2) { + this.proxyFactory.setHessian2Request(hessian2); + this.proxyFactory.setHessian2Reply(hessian2); + } + + /** + * Set whether version 2 of the Hessian protocol should be used for + * parsing requests. Default is "false". + * @see com.caucho.hessian.client.HessianProxyFactory#setHessian2Request + */ + public void setHessian2Request(boolean hessian2) { + this.proxyFactory.setHessian2Request(hessian2); + } + + /** + * Set whether version 2 of the Hessian protocol should be used for + * parsing replies. Default is "false". + * @see com.caucho.hessian.client.HessianProxyFactory#setHessian2Reply + */ + public void setHessian2Reply(boolean hessian2) { + this.proxyFactory.setHessian2Reply(hessian2); + } + + + @Override + public void afterPropertiesSet() { + super.afterPropertiesSet(); + prepare(); + } + + /** + * Initialize the Hessian proxy for this interceptor. + * @throws RemoteLookupFailureException if the service URL is invalid + */ + public void prepare() throws RemoteLookupFailureException { + try { + this.hessianProxy = createHessianProxy(this.proxyFactory); + } + catch (MalformedURLException ex) { + throw new RemoteLookupFailureException("Service URL [" + getServiceUrl() + "] is invalid", ex); + } + } + + /** + * Create the Hessian proxy that is wrapped by this interceptor. + * @param proxyFactory the proxy factory to use + * @return the Hessian proxy + * @throws MalformedURLException if thrown by the proxy factory + * @see com.caucho.hessian.client.HessianProxyFactory#create + */ + protected Object createHessianProxy(HessianProxyFactory proxyFactory) throws MalformedURLException { + Assert.notNull(getServiceInterface(), "'serviceInterface' is required"); + return proxyFactory.create(getServiceInterface(), getServiceUrl(), getBeanClassLoader()); + } + + + @Override + @Nullable + public Object invoke(MethodInvocation invocation) throws Throwable { + if (this.hessianProxy == null) { + throw new IllegalStateException("HessianClientInterceptor is not properly initialized - " + + "invoke 'prepare' before attempting any operations"); + } + + ClassLoader originalClassLoader = overrideThreadContextClassLoader(); + try { + return invocation.getMethod().invoke(this.hessianProxy, invocation.getArguments()); + } + catch (InvocationTargetException ex) { + Throwable targetEx = ex.getTargetException(); + // Hessian 4.0 check: another layer of InvocationTargetException. + if (targetEx instanceof InvocationTargetException) { + targetEx = ((InvocationTargetException) targetEx).getTargetException(); + } + if (targetEx instanceof HessianConnectionException) { + throw convertHessianAccessException(targetEx); + } + else if (targetEx instanceof HessianException || targetEx instanceof HessianRuntimeException) { + Throwable cause = targetEx.getCause(); + throw convertHessianAccessException(cause != null ? cause : targetEx); + } + else if (targetEx instanceof UndeclaredThrowableException) { + UndeclaredThrowableException utex = (UndeclaredThrowableException) targetEx; + throw convertHessianAccessException(utex.getUndeclaredThrowable()); + } + else { + throw targetEx; + } + } + catch (Throwable ex) { + throw new RemoteProxyFailureException( + "Failed to invoke Hessian proxy for remote service [" + getServiceUrl() + "]", ex); + } + finally { + resetThreadContextClassLoader(originalClassLoader); + } + } + + /** + * Convert the given Hessian access exception to an appropriate + * Spring RemoteAccessException. + * @param ex the exception to convert + * @return the RemoteAccessException to throw + */ + protected RemoteAccessException convertHessianAccessException(Throwable ex) { + if (ex instanceof HessianConnectionException || ex instanceof ConnectException) { + return new RemoteConnectFailureException( + "Cannot connect to Hessian remote service at [" + getServiceUrl() + "]", ex); + } + else { + return new RemoteAccessException( + "Cannot access Hessian remote service at [" + getServiceUrl() + "]", ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/caucho/HessianExporter.java b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..5344d77151873cca588ae7fa94add286ea14eefa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianExporter.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.caucho; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintWriter; + +import com.caucho.hessian.io.AbstractHessianInput; +import com.caucho.hessian.io.AbstractHessianOutput; +import com.caucho.hessian.io.Hessian2Input; +import com.caucho.hessian.io.Hessian2Output; +import com.caucho.hessian.io.HessianDebugInputStream; +import com.caucho.hessian.io.HessianDebugOutputStream; +import com.caucho.hessian.io.HessianInput; +import com.caucho.hessian.io.HessianOutput; +import com.caucho.hessian.io.HessianRemoteResolver; +import com.caucho.hessian.io.SerializerFactory; +import com.caucho.hessian.server.HessianSkeleton; +import org.apache.commons.logging.Log; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.remoting.support.RemoteExporter; +import org.springframework.util.Assert; +import org.springframework.util.CommonsLogWriter; + +/** + * General stream-based protocol exporter for a Hessian endpoint. + * + *

Hessian is a slim, binary RPC protocol. + * For information on Hessian, see the + * Hessian website. + * Note: As of Spring 4.0, this exporter requires Hessian 4.0 or above. + * + * @author Juergen Hoeller + * @since 2.5.1 + * @see #invoke(java.io.InputStream, java.io.OutputStream) + * @see HessianServiceExporter + */ +public class HessianExporter extends RemoteExporter implements InitializingBean { + + /** + * The content type for hessian ({@code application/x-hessian}). + */ + public static final String CONTENT_TYPE_HESSIAN = "application/x-hessian"; + + + private SerializerFactory serializerFactory = new SerializerFactory(); + + @Nullable + private HessianRemoteResolver remoteResolver; + + @Nullable + private Log debugLogger; + + @Nullable + private HessianSkeleton skeleton; + + + /** + * Specify the Hessian SerializerFactory to use. + *

This will typically be passed in as an inner bean definition + * of type {@code com.caucho.hessian.io.SerializerFactory}, + * with custom bean property values applied. + */ + public void setSerializerFactory(@Nullable SerializerFactory serializerFactory) { + this.serializerFactory = (serializerFactory != null ? serializerFactory : new SerializerFactory()); + } + + /** + * Set whether to send the Java collection type for each serialized + * collection. Default is "true". + */ + public void setSendCollectionType(boolean sendCollectionType) { + this.serializerFactory.setSendCollectionType(sendCollectionType); + } + + /** + * Set whether to allow non-serializable types as Hessian arguments + * and return values. Default is "true". + */ + public void setAllowNonSerializable(boolean allowNonSerializable) { + this.serializerFactory.setAllowNonSerializable(allowNonSerializable); + } + + /** + * Specify a custom HessianRemoteResolver to use for resolving remote + * object references. + */ + public void setRemoteResolver(HessianRemoteResolver remoteResolver) { + this.remoteResolver = remoteResolver; + } + + /** + * Set whether Hessian's debug mode should be enabled, logging to + * this exporter's Commons Logging log. Default is "false". + * @see com.caucho.hessian.client.HessianProxyFactory#setDebug + */ + public void setDebug(boolean debug) { + this.debugLogger = (debug ? logger : null); + } + + + @Override + public void afterPropertiesSet() { + prepare(); + } + + /** + * Initialize this exporter. + */ + public void prepare() { + checkService(); + checkServiceInterface(); + this.skeleton = new HessianSkeleton(getProxyForService(), getServiceInterface()); + } + + + /** + * Perform an invocation on the exported object. + * @param inputStream the request stream + * @param outputStream the response stream + * @throws Throwable if invocation failed + */ + public void invoke(InputStream inputStream, OutputStream outputStream) throws Throwable { + Assert.notNull(this.skeleton, "Hessian exporter has not been initialized"); + doInvoke(this.skeleton, inputStream, outputStream); + } + + /** + * Actually invoke the skeleton with the given streams. + * @param skeleton the skeleton to invoke + * @param inputStream the request stream + * @param outputStream the response stream + * @throws Throwable if invocation failed + */ + protected void doInvoke(HessianSkeleton skeleton, InputStream inputStream, OutputStream outputStream) + throws Throwable { + + ClassLoader originalClassLoader = overrideThreadContextClassLoader(); + try { + InputStream isToUse = inputStream; + OutputStream osToUse = outputStream; + + if (this.debugLogger != null && this.debugLogger.isDebugEnabled()) { + try (PrintWriter debugWriter = new PrintWriter(new CommonsLogWriter(this.debugLogger))){ + @SuppressWarnings("resource") + HessianDebugInputStream dis = new HessianDebugInputStream(inputStream, debugWriter); + @SuppressWarnings("resource") + HessianDebugOutputStream dos = new HessianDebugOutputStream(outputStream, debugWriter); + dis.startTop2(); + dos.startTop2(); + isToUse = dis; + osToUse = dos; + } + } + + if (!isToUse.markSupported()) { + isToUse = new BufferedInputStream(isToUse); + isToUse.mark(1); + } + + int code = isToUse.read(); + int major; + int minor; + + AbstractHessianInput in; + AbstractHessianOutput out; + + if (code == 'H') { + // Hessian 2.0 stream + major = isToUse.read(); + minor = isToUse.read(); + if (major != 0x02) { + throw new IOException("Version " + major + '.' + minor + " is not understood"); + } + in = new Hessian2Input(isToUse); + out = new Hessian2Output(osToUse); + in.readCall(); + } + else if (code == 'C') { + // Hessian 2.0 call... for some reason not handled in HessianServlet! + isToUse.reset(); + in = new Hessian2Input(isToUse); + out = new Hessian2Output(osToUse); + in.readCall(); + } + else if (code == 'c') { + // Hessian 1.0 call + major = isToUse.read(); + minor = isToUse.read(); + in = new HessianInput(isToUse); + if (major >= 2) { + out = new Hessian2Output(osToUse); + } + else { + out = new HessianOutput(osToUse); + } + } + else { + throw new IOException("Expected 'H'/'C' (Hessian 2.0) or 'c' (Hessian 1.0) in hessian input at " + code); + } + + in.setSerializerFactory(this.serializerFactory); + out.setSerializerFactory(this.serializerFactory); + if (this.remoteResolver != null) { + in.setRemoteResolver(this.remoteResolver); + } + + try { + skeleton.invoke(in, out); + } + finally { + try { + in.close(); + isToUse.close(); + } + catch (IOException ex) { + // ignore + } + try { + out.close(); + osToUse.close(); + } + catch (IOException ex) { + // ignore + } + } + } + finally { + resetThreadContextClassLoader(originalClassLoader); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/caucho/HessianProxyFactoryBean.java b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianProxyFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..39d2d6aa0a33f50749d06441b36e2e6e7225b18e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianProxyFactoryBean.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.caucho; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.lang.Nullable; + +/** + * {@link FactoryBean} for Hessian proxies. Exposes the proxied service + * for use as a bean reference, using the specified service interface. + * + *

Hessian is a slim, binary RPC protocol. + * For information on Hessian, see the + * Hessian website + * Note: As of Spring 4.0, this proxy factory requires Hessian 4.0 or above. + * + *

The service URL must be an HTTP URL exposing a Hessian service. + * For details, see the {@link HessianClientInterceptor} javadoc. + * + * @author Juergen Hoeller + * @since 13.05.2003 + * @see #setServiceInterface + * @see #setServiceUrl + * @see HessianClientInterceptor + * @see HessianServiceExporter + * @see org.springframework.remoting.httpinvoker.HttpInvokerProxyFactoryBean + * @see org.springframework.remoting.rmi.RmiProxyFactoryBean + */ +public class HessianProxyFactoryBean extends HessianClientInterceptor implements FactoryBean { + + @Nullable + private Object serviceProxy; + + + @Override + public void afterPropertiesSet() { + super.afterPropertiesSet(); + this.serviceProxy = new ProxyFactory(getServiceInterface(), this).getProxy(getBeanClassLoader()); + } + + + @Override + @Nullable + public Object getObject() { + return this.serviceProxy; + } + + @Override + public Class getObjectType() { + return getServiceInterface(); + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/caucho/HessianServiceExporter.java b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianServiceExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..1ac88871f031f2c3e1031be8e887755c2ef00489 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/caucho/HessianServiceExporter.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.caucho; + +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.web.HttpRequestHandler; +import org.springframework.web.HttpRequestMethodNotSupportedException; +import org.springframework.web.util.NestedServletException; + +/** + * Servlet-API-based HTTP request handler that exports the specified service bean + * as Hessian service endpoint, accessible via a Hessian proxy. + * + *

Hessian is a slim, binary RPC protocol. + * For information on Hessian, see the + * Hessian website. + * Note: As of Spring 4.0, this exporter requires Hessian 4.0 or above. + * + *

Hessian services exported with this class can be accessed by + * any Hessian client, as there isn't any special handling involved. + * + * @author Juergen Hoeller + * @since 13.05.2003 + * @see HessianClientInterceptor + * @see HessianProxyFactoryBean + * @see org.springframework.remoting.httpinvoker.HttpInvokerServiceExporter + * @see org.springframework.remoting.rmi.RmiServiceExporter + */ +public class HessianServiceExporter extends HessianExporter implements HttpRequestHandler { + + /** + * Processes the incoming Hessian request and creates a Hessian response. + */ + @Override + public void handleRequest(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + if (!"POST".equals(request.getMethod())) { + throw new HttpRequestMethodNotSupportedException(request.getMethod(), + new String[] {"POST"}, "HessianServiceExporter only supports POST requests"); + } + + response.setContentType(CONTENT_TYPE_HESSIAN); + try { + invoke(request.getInputStream(), response.getOutputStream()); + } + catch (Throwable ex) { + throw new NestedServletException("Hessian skeleton invocation failed", ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/caucho/SimpleHessianServiceExporter.java b/spring-web/src/main/java/org/springframework/remoting/caucho/SimpleHessianServiceExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..cc96b28bfa0cc2f72b53f89d1491f9e5735e22b8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/caucho/SimpleHessianServiceExporter.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.caucho; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; + +import org.springframework.util.FileCopyUtils; + +/** + * HTTP request handler that exports the specified service bean as + * Hessian service endpoint, accessible via a Hessian proxy. + * Designed for Sun's JRE 1.6 HTTP server, implementing the + * {@link com.sun.net.httpserver.HttpHandler} interface. + * + *

Hessian is a slim, binary RPC protocol. + * For information on Hessian, see the + * Hessian website. + * Note: As of Spring 4.0, this exporter requires Hessian 4.0 or above. + * + *

Hessian services exported with this class can be accessed by + * any Hessian client, as there isn't any special handling involved. + * + * @author Juergen Hoeller + * @since 2.5.1 + * @see org.springframework.remoting.caucho.HessianClientInterceptor + * @see org.springframework.remoting.caucho.HessianProxyFactoryBean + * @deprecated as of Spring Framework 5.1, in favor of {@link HessianServiceExporter} + */ +@Deprecated +@org.springframework.lang.UsesSunHttpServer +public class SimpleHessianServiceExporter extends HessianExporter implements HttpHandler { + + /** + * Processes the incoming Hessian request and creates a Hessian response. + */ + @Override + public void handle(HttpExchange exchange) throws IOException { + if (!"POST".equals(exchange.getRequestMethod())) { + exchange.getResponseHeaders().set("Allow", "POST"); + exchange.sendResponseHeaders(405, -1); + return; + } + + ByteArrayOutputStream output = new ByteArrayOutputStream(1024); + try { + invoke(exchange.getRequestBody(), output); + } + catch (Throwable ex) { + exchange.sendResponseHeaders(500, -1); + logger.error("Hessian skeleton invocation failed", ex); + return; + } + + exchange.getResponseHeaders().set("Content-Type", CONTENT_TYPE_HESSIAN); + exchange.sendResponseHeaders(200, output.size()); + FileCopyUtils.copy(output.toByteArray(), exchange.getResponseBody()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/caucho/package-info.java b/spring-web/src/main/java/org/springframework/remoting/caucho/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..30d03c51762f7bb801c0cb94443ebe2deee94132 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/caucho/package-info.java @@ -0,0 +1,15 @@ +/** + * This package provides remoting classes for Caucho's Hessian protocol: + * a proxy factory for accessing Hessian services, and an exporter for + * making beans available to Hessian clients. + * + *

Hessian is a slim, binary RPC protocol over HTTP. + * For information on Hessian, see the + * Hessian website + */ +@NonNullApi +@NonNullFields +package org.springframework.remoting.caucho; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/AbstractHttpInvokerRequestExecutor.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/AbstractHttpInvokerRequestExecutor.java new file mode 100644 index 0000000000000000000000000000000000000000..0953e769529aa47ceed862e37d5250f605594058 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/AbstractHttpInvokerRequestExecutor.java @@ -0,0 +1,302 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.rmi.RemoteException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.BeanClassLoaderAware; +import org.springframework.lang.Nullable; +import org.springframework.remoting.rmi.CodebaseAwareObjectInputStream; +import org.springframework.remoting.support.RemoteInvocation; +import org.springframework.remoting.support.RemoteInvocationResult; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + +/** + * Abstract base implementation of the HttpInvokerRequestExecutor interface. + * + *

Pre-implements serialization of RemoteInvocation objects and + * deserialization of RemoteInvocationResults objects. + * + * @author Juergen Hoeller + * @since 1.1 + * @see #doExecuteRequest + */ +public abstract class AbstractHttpInvokerRequestExecutor implements HttpInvokerRequestExecutor, BeanClassLoaderAware { + + /** + * Default content type: "application/x-java-serialized-object". + */ + public static final String CONTENT_TYPE_SERIALIZED_OBJECT = "application/x-java-serialized-object"; + + private static final int SERIALIZED_INVOCATION_BYTE_ARRAY_INITIAL_SIZE = 1024; + + + protected static final String HTTP_METHOD_POST = "POST"; + + protected static final String HTTP_HEADER_ACCEPT_LANGUAGE = "Accept-Language"; + + protected static final String HTTP_HEADER_ACCEPT_ENCODING = "Accept-Encoding"; + + protected static final String HTTP_HEADER_CONTENT_ENCODING = "Content-Encoding"; + + protected static final String HTTP_HEADER_CONTENT_TYPE = "Content-Type"; + + protected static final String HTTP_HEADER_CONTENT_LENGTH = "Content-Length"; + + protected static final String ENCODING_GZIP = "gzip"; + + + protected final Log logger = LogFactory.getLog(getClass()); + + private String contentType = CONTENT_TYPE_SERIALIZED_OBJECT; + + private boolean acceptGzipEncoding = true; + + @Nullable + private ClassLoader beanClassLoader; + + + /** + * Specify the content type to use for sending HTTP invoker requests. + *

Default is "application/x-java-serialized-object". + */ + public void setContentType(String contentType) { + Assert.notNull(contentType, "'contentType' must not be null"); + this.contentType = contentType; + } + + /** + * Return the content type to use for sending HTTP invoker requests. + */ + public String getContentType() { + return this.contentType; + } + + /** + * Set whether to accept GZIP encoding, that is, whether to + * send the HTTP "Accept-Encoding" header with "gzip" as value. + *

Default is "true". Turn this flag off if you do not want + * GZIP response compression even if enabled on the HTTP server. + */ + public void setAcceptGzipEncoding(boolean acceptGzipEncoding) { + this.acceptGzipEncoding = acceptGzipEncoding; + } + + /** + * Return whether to accept GZIP encoding, that is, whether to + * send the HTTP "Accept-Encoding" header with "gzip" as value. + */ + public boolean isAcceptGzipEncoding() { + return this.acceptGzipEncoding; + } + + @Override + public void setBeanClassLoader(ClassLoader classLoader) { + this.beanClassLoader = classLoader; + } + + /** + * Return the bean ClassLoader that this executor is supposed to use. + */ + @Nullable + protected ClassLoader getBeanClassLoader() { + return this.beanClassLoader; + } + + + @Override + public final RemoteInvocationResult executeRequest( + HttpInvokerClientConfiguration config, RemoteInvocation invocation) throws Exception { + + ByteArrayOutputStream baos = getByteArrayOutputStream(invocation); + if (logger.isDebugEnabled()) { + logger.debug("Sending HTTP invoker request for service at [" + config.getServiceUrl() + + "], with size " + baos.size()); + } + return doExecuteRequest(config, baos); + } + + /** + * Serialize the given RemoteInvocation into a ByteArrayOutputStream. + * @param invocation the RemoteInvocation object + * @return a ByteArrayOutputStream with the serialized RemoteInvocation + * @throws IOException if thrown by I/O methods + */ + protected ByteArrayOutputStream getByteArrayOutputStream(RemoteInvocation invocation) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(SERIALIZED_INVOCATION_BYTE_ARRAY_INITIAL_SIZE); + writeRemoteInvocation(invocation, baos); + return baos; + } + + /** + * Serialize the given RemoteInvocation to the given OutputStream. + *

The default implementation gives {@code decorateOutputStream} a chance + * to decorate the stream first (for example, for custom encryption or compression). + * Creates an {@code ObjectOutputStream} for the final stream and calls + * {@code doWriteRemoteInvocation} to actually write the object. + *

Can be overridden for custom serialization of the invocation. + * @param invocation the RemoteInvocation object + * @param os the OutputStream to write to + * @throws IOException if thrown by I/O methods + * @see #decorateOutputStream + * @see #doWriteRemoteInvocation + */ + protected void writeRemoteInvocation(RemoteInvocation invocation, OutputStream os) throws IOException { + ObjectOutputStream oos = new ObjectOutputStream(decorateOutputStream(os)); + try { + doWriteRemoteInvocation(invocation, oos); + } + finally { + oos.close(); + } + } + + /** + * Return the OutputStream to use for writing remote invocations, + * potentially decorating the given original OutputStream. + *

The default implementation returns the given stream as-is. + * Can be overridden, for example, for custom encryption or compression. + * @param os the original OutputStream + * @return the potentially decorated OutputStream + */ + protected OutputStream decorateOutputStream(OutputStream os) throws IOException { + return os; + } + + /** + * Perform the actual writing of the given invocation object to the + * given ObjectOutputStream. + *

The default implementation simply calls {@code writeObject}. + * Can be overridden for serialization of a custom wrapper object rather + * than the plain invocation, for example an encryption-aware holder. + * @param invocation the RemoteInvocation object + * @param oos the ObjectOutputStream to write to + * @throws IOException if thrown by I/O methods + * @see java.io.ObjectOutputStream#writeObject + */ + protected void doWriteRemoteInvocation(RemoteInvocation invocation, ObjectOutputStream oos) throws IOException { + oos.writeObject(invocation); + } + + + /** + * Execute a request to send the given serialized remote invocation. + *

Implementations will usually call {@code readRemoteInvocationResult} + * to deserialize a returned RemoteInvocationResult object. + * @param config the HTTP invoker configuration that specifies the + * target service + * @param baos the ByteArrayOutputStream that contains the serialized + * RemoteInvocation object + * @return the RemoteInvocationResult object + * @throws IOException if thrown by I/O operations + * @throws ClassNotFoundException if thrown during deserialization + * @throws Exception in case of general errors + * @see #readRemoteInvocationResult(java.io.InputStream, String) + */ + protected abstract RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) + throws Exception; + + /** + * Deserialize a RemoteInvocationResult object from the given InputStream. + *

Gives {@code decorateInputStream} a chance to decorate the stream + * first (for example, for custom encryption or compression). Creates an + * {@code ObjectInputStream} via {@code createObjectInputStream} and + * calls {@code doReadRemoteInvocationResult} to actually read the object. + *

Can be overridden for custom serialization of the invocation. + * @param is the InputStream to read from + * @param codebaseUrl the codebase URL to load classes from if not found locally + * @return the RemoteInvocationResult object + * @throws IOException if thrown by I/O methods + * @throws ClassNotFoundException if thrown during deserialization + * @see #decorateInputStream + * @see #createObjectInputStream + * @see #doReadRemoteInvocationResult + */ + protected RemoteInvocationResult readRemoteInvocationResult(InputStream is, @Nullable String codebaseUrl) + throws IOException, ClassNotFoundException { + + ObjectInputStream ois = createObjectInputStream(decorateInputStream(is), codebaseUrl); + try { + return doReadRemoteInvocationResult(ois); + } + finally { + ois.close(); + } + } + + /** + * Return the InputStream to use for reading remote invocation results, + * potentially decorating the given original InputStream. + *

The default implementation returns the given stream as-is. + * Can be overridden, for example, for custom encryption or compression. + * @param is the original InputStream + * @return the potentially decorated InputStream + */ + protected InputStream decorateInputStream(InputStream is) throws IOException { + return is; + } + + /** + * Create an ObjectInputStream for the given InputStream and codebase. + * The default implementation creates a CodebaseAwareObjectInputStream. + * @param is the InputStream to read from + * @param codebaseUrl the codebase URL to load classes from if not found locally + * (can be {@code null}) + * @return the new ObjectInputStream instance to use + * @throws IOException if creation of the ObjectInputStream failed + * @see org.springframework.remoting.rmi.CodebaseAwareObjectInputStream + */ + protected ObjectInputStream createObjectInputStream(InputStream is, @Nullable String codebaseUrl) throws IOException { + return new CodebaseAwareObjectInputStream(is, getBeanClassLoader(), codebaseUrl); + } + + /** + * Perform the actual reading of an invocation object from the + * given ObjectInputStream. + *

The default implementation simply calls {@code readObject}. + * Can be overridden for deserialization of a custom wrapper object rather + * than the plain invocation, for example an encryption-aware holder. + * @param ois the ObjectInputStream to read from + * @return the RemoteInvocationResult object + * @throws IOException if thrown by I/O methods + * @throws ClassNotFoundException if the class name of a serialized object + * couldn't get resolved + * @see java.io.ObjectOutputStream#writeObject + */ + protected RemoteInvocationResult doReadRemoteInvocationResult(ObjectInputStream ois) + throws IOException, ClassNotFoundException { + + Object obj = ois.readObject(); + if (!(obj instanceof RemoteInvocationResult)) { + throw new RemoteException("Deserialized object needs to be assignable to type [" + + RemoteInvocationResult.class.getName() + "]: " + ClassUtils.getDescriptiveType(obj)); + } + return (RemoteInvocationResult) obj; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpComponentsHttpInvokerRequestExecutor.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpComponentsHttpInvokerRequestExecutor.java new file mode 100644 index 0000000000000000000000000000000000000000..e35adf1ebb600ba5939c5c6103cd2b4d6258873a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpComponentsHttpInvokerRequestExecutor.java @@ -0,0 +1,368 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Locale; +import java.util.zip.GZIPInputStream; + +import org.apache.http.Header; +import org.apache.http.HttpResponse; +import org.apache.http.NoHttpResponseException; +import org.apache.http.StatusLine; +import org.apache.http.client.HttpClient; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.Configurable; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.config.Registry; +import org.apache.http.config.RegistryBuilder; +import org.apache.http.conn.socket.ConnectionSocketFactory; +import org.apache.http.conn.socket.PlainConnectionSocketFactory; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; + +import org.springframework.context.i18n.LocaleContext; +import org.springframework.context.i18n.LocaleContextHolder; +import org.springframework.lang.Nullable; +import org.springframework.remoting.support.RemoteInvocationResult; +import org.springframework.util.Assert; + +/** + * {@link org.springframework.remoting.httpinvoker.HttpInvokerRequestExecutor} implementation that uses + * Apache HttpComponents HttpClient + * to execute POST requests. + * + *

Allows to use a pre-configured {@link org.apache.http.client.HttpClient} + * instance, potentially with authentication, HTTP connection pooling, etc. + * Also designed for easy subclassing, providing specific template methods. + * + *

As of Spring 4.1, this request executor requires Apache HttpComponents 4.3 or higher. + * + * @author Juergen Hoeller + * @author Stephane Nicoll + * @since 3.1 + * @see org.springframework.remoting.httpinvoker.SimpleHttpInvokerRequestExecutor + */ +public class HttpComponentsHttpInvokerRequestExecutor extends AbstractHttpInvokerRequestExecutor { + + private static final int DEFAULT_MAX_TOTAL_CONNECTIONS = 100; + + private static final int DEFAULT_MAX_CONNECTIONS_PER_ROUTE = 5; + + private static final int DEFAULT_READ_TIMEOUT_MILLISECONDS = (60 * 1000); + + + private HttpClient httpClient; + + @Nullable + private RequestConfig requestConfig; + + + /** + * Create a new instance of the HttpComponentsHttpInvokerRequestExecutor with a default + * {@link HttpClient} that uses a default {@code org.apache.http.impl.conn.PoolingClientConnectionManager}. + */ + public HttpComponentsHttpInvokerRequestExecutor() { + this(createDefaultHttpClient(), RequestConfig.custom() + .setSocketTimeout(DEFAULT_READ_TIMEOUT_MILLISECONDS).build()); + } + + /** + * Create a new instance of the HttpComponentsClientHttpRequestFactory + * with the given {@link HttpClient} instance. + * @param httpClient the HttpClient instance to use for this request executor + */ + public HttpComponentsHttpInvokerRequestExecutor(HttpClient httpClient) { + this(httpClient, null); + } + + private HttpComponentsHttpInvokerRequestExecutor(HttpClient httpClient, @Nullable RequestConfig requestConfig) { + this.httpClient = httpClient; + this.requestConfig = requestConfig; + } + + + private static HttpClient createDefaultHttpClient() { + Registry schemeRegistry = RegistryBuilder.create() + .register("http", PlainConnectionSocketFactory.getSocketFactory()) + .register("https", SSLConnectionSocketFactory.getSocketFactory()) + .build(); + + PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager(schemeRegistry); + connectionManager.setMaxTotal(DEFAULT_MAX_TOTAL_CONNECTIONS); + connectionManager.setDefaultMaxPerRoute(DEFAULT_MAX_CONNECTIONS_PER_ROUTE); + + return HttpClientBuilder.create().setConnectionManager(connectionManager).build(); + } + + + /** + * Set the {@link HttpClient} instance to use for this request executor. + */ + public void setHttpClient(HttpClient httpClient) { + this.httpClient = httpClient; + } + + /** + * Return the {@link HttpClient} instance that this request executor uses. + */ + public HttpClient getHttpClient() { + return this.httpClient; + } + + /** + * Set the connection timeout for the underlying HttpClient. + * A timeout value of 0 specifies an infinite timeout. + *

Additional properties can be configured by specifying a + * {@link RequestConfig} instance on a custom {@link HttpClient}. + * @param timeout the timeout value in milliseconds + * @see RequestConfig#getConnectTimeout() + */ + public void setConnectTimeout(int timeout) { + Assert.isTrue(timeout >= 0, "Timeout must be a non-negative value"); + this.requestConfig = cloneRequestConfig().setConnectTimeout(timeout).build(); + } + + /** + * Set the timeout in milliseconds used when requesting a connection from the connection + * manager using the underlying HttpClient. + * A timeout value of 0 specifies an infinite timeout. + *

Additional properties can be configured by specifying a + * {@link RequestConfig} instance on a custom {@link HttpClient}. + * @param connectionRequestTimeout the timeout value to request a connection in milliseconds + * @see RequestConfig#getConnectionRequestTimeout() + */ + public void setConnectionRequestTimeout(int connectionRequestTimeout) { + this.requestConfig = cloneRequestConfig().setConnectionRequestTimeout(connectionRequestTimeout).build(); + } + + /** + * Set the socket read timeout for the underlying HttpClient. + * A timeout value of 0 specifies an infinite timeout. + *

Additional properties can be configured by specifying a + * {@link RequestConfig} instance on a custom {@link HttpClient}. + * @param timeout the timeout value in milliseconds + * @see #DEFAULT_READ_TIMEOUT_MILLISECONDS + * @see RequestConfig#getSocketTimeout() + */ + public void setReadTimeout(int timeout) { + Assert.isTrue(timeout >= 0, "Timeout must be a non-negative value"); + this.requestConfig = cloneRequestConfig().setSocketTimeout(timeout).build(); + } + + private RequestConfig.Builder cloneRequestConfig() { + return (this.requestConfig != null ? RequestConfig.copy(this.requestConfig) : RequestConfig.custom()); + } + + + /** + * Execute the given request through the HttpClient. + *

This method implements the basic processing workflow: + * The actual work happens in this class's template methods. + * @see #createHttpPost + * @see #setRequestBody + * @see #executeHttpPost + * @see #validateResponse + * @see #getResponseBody + */ + @Override + protected RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) + throws IOException, ClassNotFoundException { + + HttpPost postMethod = createHttpPost(config); + setRequestBody(config, postMethod, baos); + try { + HttpResponse response = executeHttpPost(config, getHttpClient(), postMethod); + validateResponse(config, response); + InputStream responseBody = getResponseBody(config, response); + return readRemoteInvocationResult(responseBody, config.getCodebaseUrl()); + } + finally { + postMethod.releaseConnection(); + } + } + + /** + * Create a HttpPost for the given configuration. + *

The default implementation creates a standard HttpPost with + * "application/x-java-serialized-object" as "Content-Type" header. + * @param config the HTTP invoker configuration that specifies the + * target service + * @return the HttpPost instance + * @throws java.io.IOException if thrown by I/O methods + */ + protected HttpPost createHttpPost(HttpInvokerClientConfiguration config) throws IOException { + HttpPost httpPost = new HttpPost(config.getServiceUrl()); + + RequestConfig requestConfig = createRequestConfig(config); + if (requestConfig != null) { + httpPost.setConfig(requestConfig); + } + + LocaleContext localeContext = LocaleContextHolder.getLocaleContext(); + if (localeContext != null) { + Locale locale = localeContext.getLocale(); + if (locale != null) { + httpPost.addHeader(HTTP_HEADER_ACCEPT_LANGUAGE, locale.toLanguageTag()); + } + } + + if (isAcceptGzipEncoding()) { + httpPost.addHeader(HTTP_HEADER_ACCEPT_ENCODING, ENCODING_GZIP); + } + + return httpPost; + } + + /** + * Create a {@link RequestConfig} for the given configuration. Can return {@code null} + * to indicate that no custom request config should be set and the defaults of the + * {@link HttpClient} should be used. + *

The default implementation tries to merge the defaults of the client with the + * local customizations of the instance, if any. + * @param config the HTTP invoker configuration that specifies the + * target service + * @return the RequestConfig to use + */ + @Nullable + protected RequestConfig createRequestConfig(HttpInvokerClientConfiguration config) { + HttpClient client = getHttpClient(); + if (client instanceof Configurable) { + RequestConfig clientRequestConfig = ((Configurable) client).getConfig(); + return mergeRequestConfig(clientRequestConfig); + } + return this.requestConfig; + } + + private RequestConfig mergeRequestConfig(RequestConfig defaultRequestConfig) { + if (this.requestConfig == null) { // nothing to merge + return defaultRequestConfig; + } + + RequestConfig.Builder builder = RequestConfig.copy(defaultRequestConfig); + int connectTimeout = this.requestConfig.getConnectTimeout(); + if (connectTimeout >= 0) { + builder.setConnectTimeout(connectTimeout); + } + int connectionRequestTimeout = this.requestConfig.getConnectionRequestTimeout(); + if (connectionRequestTimeout >= 0) { + builder.setConnectionRequestTimeout(connectionRequestTimeout); + } + int socketTimeout = this.requestConfig.getSocketTimeout(); + if (socketTimeout >= 0) { + builder.setSocketTimeout(socketTimeout); + } + return builder.build(); + } + + /** + * Set the given serialized remote invocation as request body. + *

The default implementation simply sets the serialized invocation as the + * HttpPost's request body. This can be overridden, for example, to write a + * specific encoding and to potentially set appropriate HTTP request headers. + * @param config the HTTP invoker configuration that specifies the target service + * @param httpPost the HttpPost to set the request body on + * @param baos the ByteArrayOutputStream that contains the serialized + * RemoteInvocation object + * @throws java.io.IOException if thrown by I/O methods + */ + protected void setRequestBody( + HttpInvokerClientConfiguration config, HttpPost httpPost, ByteArrayOutputStream baos) + throws IOException { + + ByteArrayEntity entity = new ByteArrayEntity(baos.toByteArray()); + entity.setContentType(getContentType()); + httpPost.setEntity(entity); + } + + /** + * Execute the given HttpPost instance. + * @param config the HTTP invoker configuration that specifies the target service + * @param httpClient the HttpClient to execute on + * @param httpPost the HttpPost to execute + * @return the resulting HttpResponse + * @throws java.io.IOException if thrown by I/O methods + */ + protected HttpResponse executeHttpPost( + HttpInvokerClientConfiguration config, HttpClient httpClient, HttpPost httpPost) + throws IOException { + + return httpClient.execute(httpPost); + } + + /** + * Validate the given response as contained in the HttpPost object, + * throwing an exception if it does not correspond to a successful HTTP response. + *

Default implementation rejects any HTTP status code beyond 2xx, to avoid + * parsing the response body and trying to deserialize from a corrupted stream. + * @param config the HTTP invoker configuration that specifies the target service + * @param response the resulting HttpResponse to validate + * @throws java.io.IOException if validation failed + */ + protected void validateResponse(HttpInvokerClientConfiguration config, HttpResponse response) + throws IOException { + + StatusLine status = response.getStatusLine(); + if (status.getStatusCode() >= 300) { + throw new NoHttpResponseException( + "Did not receive successful HTTP response: status code = " + status.getStatusCode() + + ", status message = [" + status.getReasonPhrase() + "]"); + } + } + + /** + * Extract the response body from the given executed remote invocation request. + *

The default implementation simply fetches the HttpPost's response body stream. + * If the response is recognized as GZIP response, the InputStream will get wrapped + * in a GZIPInputStream. + * @param config the HTTP invoker configuration that specifies the target service + * @param httpResponse the resulting HttpResponse to read the response body from + * @return an InputStream for the response body + * @throws java.io.IOException if thrown by I/O methods + * @see #isGzipResponse + * @see java.util.zip.GZIPInputStream + */ + protected InputStream getResponseBody(HttpInvokerClientConfiguration config, HttpResponse httpResponse) + throws IOException { + + if (isGzipResponse(httpResponse)) { + return new GZIPInputStream(httpResponse.getEntity().getContent()); + } + else { + return httpResponse.getEntity().getContent(); + } + } + + /** + * Determine whether the given response indicates a GZIP response. + *

The default implementation checks whether the HTTP "Content-Encoding" + * header contains "gzip" (in any casing). + * @param httpResponse the resulting HttpResponse to check + * @return whether the given response indicates a GZIP response + */ + protected boolean isGzipResponse(HttpResponse httpResponse) { + Header encodingHeader = httpResponse.getFirstHeader(HTTP_HEADER_CONTENT_ENCODING); + return (encodingHeader != null && encodingHeader.getValue() != null && + encodingHeader.getValue().toLowerCase().contains(ENCODING_GZIP)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerClientConfiguration.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerClientConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..c60793749f4b7ebdcae4b52368365889c7e3a54d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerClientConfiguration.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import org.springframework.lang.Nullable; + +/** + * Configuration interface for executing HTTP invoker requests. + * + * @author Juergen Hoeller + * @since 1.1 + * @see HttpInvokerRequestExecutor + * @see HttpInvokerClientInterceptor + */ +public interface HttpInvokerClientConfiguration { + + /** + * Return the HTTP URL of the target service. + */ + String getServiceUrl(); + + /** + * Return the codebase URL to download classes from if not found locally. + * Can consist of multiple URLs, separated by spaces. + * @return the codebase URL, or {@code null} if none + * @see java.rmi.server.RMIClassLoader + */ + @Nullable + String getCodebaseUrl(); + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerClientInterceptor.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerClientInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..7ca54983d5dd1fa96cb14bfc97af54bb4c4b9787 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerClientInterceptor.java @@ -0,0 +1,238 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.IOException; +import java.io.InvalidClassException; +import java.net.ConnectException; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; + +import org.springframework.aop.support.AopUtils; +import org.springframework.lang.Nullable; +import org.springframework.remoting.RemoteAccessException; +import org.springframework.remoting.RemoteConnectFailureException; +import org.springframework.remoting.RemoteInvocationFailureException; +import org.springframework.remoting.support.RemoteInvocation; +import org.springframework.remoting.support.RemoteInvocationBasedAccessor; +import org.springframework.remoting.support.RemoteInvocationResult; + +/** + * {@link org.aopalliance.intercept.MethodInterceptor} for accessing an + * HTTP invoker service. The service URL must be an HTTP URL exposing + * an HTTP invoker service. + * + *

Serializes remote invocation objects and deserializes remote invocation + * result objects. Uses Java serialization just like RMI, but provides the + * same ease of setup as Caucho's HTTP-based Hessian protocol. + * + *

HTTP invoker is a very extensible and customizable protocol. + * It supports the RemoteInvocationFactory mechanism, like RMI invoker, + * allowing to include additional invocation attributes (for example, + * a security context). Furthermore, it allows to customize request + * execution via the {@link HttpInvokerRequestExecutor} strategy. + * + *

Can use the JDK's {@link java.rmi.server.RMIClassLoader} to load classes + * from a given {@link #setCodebaseUrl codebase}, performing on-demand dynamic + * code download from a remote location. The codebase can consist of multiple + * URLs, separated by spaces. Note that RMIClassLoader requires a SecurityManager + * to be set, analogous to when using dynamic class download with standard RMI! + * (See the RMI documentation for details.) + * + *

WARNING: Be aware of vulnerabilities due to unsafe Java deserialization: + * Manipulated input streams could lead to unwanted code execution on the server + * during the deserialization step. As a consequence, do not expose HTTP invoker + * endpoints to untrusted clients but rather just between your own services. + * In general, we strongly recommend any other message format (e.g. JSON) instead. + * + * @author Juergen Hoeller + * @since 1.1 + * @see #setServiceUrl + * @see #setCodebaseUrl + * @see #setRemoteInvocationFactory + * @see #setHttpInvokerRequestExecutor + * @see HttpInvokerServiceExporter + * @see HttpInvokerProxyFactoryBean + * @see java.rmi.server.RMIClassLoader + */ +public class HttpInvokerClientInterceptor extends RemoteInvocationBasedAccessor + implements MethodInterceptor, HttpInvokerClientConfiguration { + + @Nullable + private String codebaseUrl; + + @Nullable + private HttpInvokerRequestExecutor httpInvokerRequestExecutor; + + + /** + * Set the codebase URL to download classes from if not found locally. + * Can consists of multiple URLs, separated by spaces. + *

Follows RMI's codebase conventions for dynamic class download. + * In contrast to RMI, where the server determines the URL for class download + * (via the "java.rmi.server.codebase" system property), it's the client + * that determines the codebase URL here. The server will usually be the + * same as for the service URL, just pointing to a different path there. + * @see #setServiceUrl + * @see org.springframework.remoting.rmi.CodebaseAwareObjectInputStream + * @see java.rmi.server.RMIClassLoader + */ + public void setCodebaseUrl(@Nullable String codebaseUrl) { + this.codebaseUrl = codebaseUrl; + } + + /** + * Return the codebase URL to download classes from if not found locally. + */ + @Override + @Nullable + public String getCodebaseUrl() { + return this.codebaseUrl; + } + + /** + * Set the HttpInvokerRequestExecutor implementation to use for executing + * remote invocations. + *

Default is {@link SimpleHttpInvokerRequestExecutor}. Alternatively, + * consider using {@link HttpComponentsHttpInvokerRequestExecutor} for more + * sophisticated needs. + * @see SimpleHttpInvokerRequestExecutor + * @see HttpComponentsHttpInvokerRequestExecutor + */ + public void setHttpInvokerRequestExecutor(HttpInvokerRequestExecutor httpInvokerRequestExecutor) { + this.httpInvokerRequestExecutor = httpInvokerRequestExecutor; + } + + /** + * Return the HttpInvokerRequestExecutor used by this remote accessor. + *

Creates a default SimpleHttpInvokerRequestExecutor if no executor + * has been initialized already. + */ + public HttpInvokerRequestExecutor getHttpInvokerRequestExecutor() { + if (this.httpInvokerRequestExecutor == null) { + SimpleHttpInvokerRequestExecutor executor = new SimpleHttpInvokerRequestExecutor(); + executor.setBeanClassLoader(getBeanClassLoader()); + this.httpInvokerRequestExecutor = executor; + } + return this.httpInvokerRequestExecutor; + } + + @Override + public void afterPropertiesSet() { + super.afterPropertiesSet(); + + // Eagerly initialize the default HttpInvokerRequestExecutor, if needed. + getHttpInvokerRequestExecutor(); + } + + + @Override + public Object invoke(MethodInvocation methodInvocation) throws Throwable { + if (AopUtils.isToStringMethod(methodInvocation.getMethod())) { + return "HTTP invoker proxy for service URL [" + getServiceUrl() + "]"; + } + + RemoteInvocation invocation = createRemoteInvocation(methodInvocation); + RemoteInvocationResult result; + + try { + result = executeRequest(invocation, methodInvocation); + } + catch (Throwable ex) { + RemoteAccessException rae = convertHttpInvokerAccessException(ex); + throw (rae != null ? rae : ex); + } + + try { + return recreateRemoteInvocationResult(result); + } + catch (Throwable ex) { + if (result.hasInvocationTargetException()) { + throw ex; + } + else { + throw new RemoteInvocationFailureException("Invocation of method [" + methodInvocation.getMethod() + + "] failed in HTTP invoker remote service at [" + getServiceUrl() + "]", ex); + } + } + } + + /** + * Execute the given remote invocation via the {@link HttpInvokerRequestExecutor}. + *

This implementation delegates to {@link #executeRequest(RemoteInvocation)}. + * Can be overridden to react to the specific original MethodInvocation. + * @param invocation the RemoteInvocation to execute + * @param originalInvocation the original MethodInvocation (can e.g. be cast + * to the ProxyMethodInvocation interface for accessing user attributes) + * @return the RemoteInvocationResult object + * @throws Exception in case of errors + */ + protected RemoteInvocationResult executeRequest( + RemoteInvocation invocation, MethodInvocation originalInvocation) throws Exception { + + return executeRequest(invocation); + } + + /** + * Execute the given remote invocation via the {@link HttpInvokerRequestExecutor}. + *

Can be overridden in subclasses to pass a different configuration object + * to the executor. Alternatively, add further configuration properties in a + * subclass of this accessor: By default, the accessor passed itself as + * configuration object to the executor. + * @param invocation the RemoteInvocation to execute + * @return the RemoteInvocationResult object + * @throws IOException if thrown by I/O operations + * @throws ClassNotFoundException if thrown during deserialization + * @throws Exception in case of general errors + * @see #getHttpInvokerRequestExecutor + * @see HttpInvokerClientConfiguration + */ + protected RemoteInvocationResult executeRequest(RemoteInvocation invocation) throws Exception { + return getHttpInvokerRequestExecutor().executeRequest(this, invocation); + } + + /** + * Convert the given HTTP invoker access exception to an appropriate + * Spring {@link RemoteAccessException}. + * @param ex the exception to convert + * @return the RemoteAccessException to throw, or {@code null} to have the + * original exception propagated to the caller + */ + @Nullable + protected RemoteAccessException convertHttpInvokerAccessException(Throwable ex) { + if (ex instanceof ConnectException) { + return new RemoteConnectFailureException( + "Could not connect to HTTP invoker remote service at [" + getServiceUrl() + "]", ex); + } + + if (ex instanceof ClassNotFoundException || ex instanceof NoClassDefFoundError || + ex instanceof InvalidClassException) { + return new RemoteAccessException( + "Could not deserialize result from HTTP invoker remote service [" + getServiceUrl() + "]", ex); + } + + if (ex instanceof Exception) { + return new RemoteAccessException( + "Could not access HTTP invoker remote service at [" + getServiceUrl() + "]", ex); + } + + // For any other Throwable, e.g. OutOfMemoryError: let it get propagated as-is. + return null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerProxyFactoryBean.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerProxyFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..554d74487a629a7364613f81fdc6ccfacb99453c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerProxyFactoryBean.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link FactoryBean} for HTTP invoker proxies. Exposes the proxied service + * for use as a bean reference, using the specified service interface. + * + *

The service URL must be an HTTP URL exposing an HTTP invoker service. + * Optionally, a codebase URL can be specified for on-demand dynamic code download + * from a remote location. For details, see HttpInvokerClientInterceptor docs. + * + *

Serializes remote invocation objects and deserializes remote invocation + * result objects. Uses Java serialization just like RMI, but provides the + * same ease of setup as Caucho's HTTP-based Hessian protocol. + * + *

HTTP invoker is the recommended protocol for Java-to-Java remoting. + * It is more powerful and more extensible than Hessian, at the expense of + * being tied to Java. Nevertheless, it is as easy to set up as Hessian, + * which is its main advantage compared to RMI. + * + *

WARNING: Be aware of vulnerabilities due to unsafe Java deserialization: + * Manipulated input streams could lead to unwanted code execution on the server + * during the deserialization step. As a consequence, do not expose HTTP invoker + * endpoints to untrusted clients but rather just between your own services. + * In general, we strongly recommend any other message format (e.g. JSON) instead. + * + * @author Juergen Hoeller + * @since 1.1 + * @see #setServiceInterface + * @see #setServiceUrl + * @see #setCodebaseUrl + * @see HttpInvokerClientInterceptor + * @see HttpInvokerServiceExporter + * @see org.springframework.remoting.rmi.RmiProxyFactoryBean + * @see org.springframework.remoting.caucho.HessianProxyFactoryBean + */ +public class HttpInvokerProxyFactoryBean extends HttpInvokerClientInterceptor implements FactoryBean { + + @Nullable + private Object serviceProxy; + + + @Override + public void afterPropertiesSet() { + super.afterPropertiesSet(); + Class ifc = getServiceInterface(); + Assert.notNull(ifc, "Property 'serviceInterface' is required"); + this.serviceProxy = new ProxyFactory(ifc, this).getProxy(getBeanClassLoader()); + } + + + @Override + @Nullable + public Object getObject() { + return this.serviceProxy; + } + + @Override + public Class getObjectType() { + return getServiceInterface(); + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerRequestExecutor.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerRequestExecutor.java new file mode 100644 index 0000000000000000000000000000000000000000..963993bdb56662c67ea2563c30f16cbf91af99a7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerRequestExecutor.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.IOException; + +import org.springframework.remoting.support.RemoteInvocation; +import org.springframework.remoting.support.RemoteInvocationResult; + +/** + * Strategy interface for actual execution of an HTTP invoker request. + * Used by HttpInvokerClientInterceptor and its subclass + * HttpInvokerProxyFactoryBean. + * + *

Two implementations are provided out of the box: + *

    + *
  • {@code SimpleHttpInvokerRequestExecutor}: + * Uses JDK facilities to execute POST requests, without support + * for HTTP authentication or advanced configuration options. + *
  • {@code HttpComponentsHttpInvokerRequestExecutor}: + * Uses Apache's Commons HttpClient to execute POST requests, + * allowing to use a preconfigured HttpClient instance + * (potentially with authentication, HTTP connection pooling, etc). + *
+ * + * @author Juergen Hoeller + * @since 1.1 + * @see HttpInvokerClientInterceptor#setHttpInvokerRequestExecutor + */ +@FunctionalInterface +public interface HttpInvokerRequestExecutor { + + /** + * Execute a request to send the given remote invocation. + * @param config the HTTP invoker configuration that specifies the + * target service + * @param invocation the RemoteInvocation to execute + * @return the RemoteInvocationResult object + * @throws IOException if thrown by I/O operations + * @throws ClassNotFoundException if thrown during deserialization + * @throws Exception in case of general errors + */ + RemoteInvocationResult executeRequest(HttpInvokerClientConfiguration config, RemoteInvocation invocation) + throws Exception; + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerServiceExporter.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerServiceExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..22ddb8fd0456a7e023f53a8e8da05e3004bd34cf --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/HttpInvokerServiceExporter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.remoting.rmi.RemoteInvocationSerializingExporter; +import org.springframework.remoting.support.RemoteInvocation; +import org.springframework.remoting.support.RemoteInvocationResult; +import org.springframework.web.HttpRequestHandler; +import org.springframework.web.util.NestedServletException; + +/** + * Servlet-API-based HTTP request handler that exports the specified service bean + * as HTTP invoker service endpoint, accessible via an HTTP invoker proxy. + * + *

Deserializes remote invocation objects and serializes remote invocation + * result objects. Uses Java serialization just like RMI, but provides the + * same ease of setup as Caucho's HTTP-based Hessian protocol. + * + *

HTTP invoker is the recommended protocol for Java-to-Java remoting. + * It is more powerful and more extensible than Hessian, at the expense of + * being tied to Java. Nevertheless, it is as easy to set up as Hessian, + * which is its main advantage compared to RMI. + * + *

WARNING: Be aware of vulnerabilities due to unsafe Java deserialization: + * Manipulated input streams could lead to unwanted code execution on the server + * during the deserialization step. As a consequence, do not expose HTTP invoker + * endpoints to untrusted clients but rather just between your own services. + * In general, we strongly recommend any other message format (e.g. JSON) instead. + * + * @author Juergen Hoeller + * @since 1.1 + * @see HttpInvokerClientInterceptor + * @see HttpInvokerProxyFactoryBean + * @see org.springframework.remoting.rmi.RmiServiceExporter + * @see org.springframework.remoting.caucho.HessianServiceExporter + */ +public class HttpInvokerServiceExporter extends RemoteInvocationSerializingExporter implements HttpRequestHandler { + + /** + * Reads a remote invocation from the request, executes it, + * and writes the remote invocation result to the response. + * @see #readRemoteInvocation(HttpServletRequest) + * @see #invokeAndCreateResult(org.springframework.remoting.support.RemoteInvocation, Object) + * @see #writeRemoteInvocationResult(HttpServletRequest, HttpServletResponse, RemoteInvocationResult) + */ + @Override + public void handleRequest(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + try { + RemoteInvocation invocation = readRemoteInvocation(request); + RemoteInvocationResult result = invokeAndCreateResult(invocation, getProxy()); + writeRemoteInvocationResult(request, response, result); + } + catch (ClassNotFoundException ex) { + throw new NestedServletException("Class not found during deserialization", ex); + } + } + + /** + * Read a RemoteInvocation from the given HTTP request. + *

Delegates to {@link #readRemoteInvocation(HttpServletRequest, InputStream)} with + * the {@link HttpServletRequest#getInputStream() servlet request's input stream}. + * @param request current HTTP request + * @return the RemoteInvocation object + * @throws IOException in case of I/O failure + * @throws ClassNotFoundException if thrown by deserialization + */ + protected RemoteInvocation readRemoteInvocation(HttpServletRequest request) + throws IOException, ClassNotFoundException { + + return readRemoteInvocation(request, request.getInputStream()); + } + + /** + * Deserialize a RemoteInvocation object from the given InputStream. + *

Gives {@link #decorateInputStream} a chance to decorate the stream + * first (for example, for custom encryption or compression). Creates a + * {@link org.springframework.remoting.rmi.CodebaseAwareObjectInputStream} + * and calls {@link #doReadRemoteInvocation} to actually read the object. + *

Can be overridden for custom serialization of the invocation. + * @param request current HTTP request + * @param is the InputStream to read from + * @return the RemoteInvocation object + * @throws IOException in case of I/O failure + * @throws ClassNotFoundException if thrown during deserialization + */ + protected RemoteInvocation readRemoteInvocation(HttpServletRequest request, InputStream is) + throws IOException, ClassNotFoundException { + + ObjectInputStream ois = createObjectInputStream(decorateInputStream(request, is)); + try { + return doReadRemoteInvocation(ois); + } + finally { + ois.close(); + } + } + + /** + * Return the InputStream to use for reading remote invocations, + * potentially decorating the given original InputStream. + *

The default implementation returns the given stream as-is. + * Can be overridden, for example, for custom encryption or compression. + * @param request current HTTP request + * @param is the original InputStream + * @return the potentially decorated InputStream + * @throws IOException in case of I/O failure + */ + protected InputStream decorateInputStream(HttpServletRequest request, InputStream is) throws IOException { + return is; + } + + /** + * Write the given RemoteInvocationResult to the given HTTP response. + * @param request current HTTP request + * @param response current HTTP response + * @param result the RemoteInvocationResult object + * @throws IOException in case of I/O failure + */ + protected void writeRemoteInvocationResult( + HttpServletRequest request, HttpServletResponse response, RemoteInvocationResult result) + throws IOException { + + response.setContentType(getContentType()); + writeRemoteInvocationResult(request, response, result, response.getOutputStream()); + } + + /** + * Serialize the given RemoteInvocation to the given OutputStream. + *

The default implementation gives {@link #decorateOutputStream} a chance + * to decorate the stream first (for example, for custom encryption or compression). + * Creates an {@link java.io.ObjectOutputStream} for the final stream and calls + * {@link #doWriteRemoteInvocationResult} to actually write the object. + *

Can be overridden for custom serialization of the invocation. + * @param request current HTTP request + * @param response current HTTP response + * @param result the RemoteInvocationResult object + * @param os the OutputStream to write to + * @throws IOException in case of I/O failure + * @see #decorateOutputStream + * @see #doWriteRemoteInvocationResult + */ + protected void writeRemoteInvocationResult( + HttpServletRequest request, HttpServletResponse response, RemoteInvocationResult result, OutputStream os) + throws IOException { + + ObjectOutputStream oos = + createObjectOutputStream(new FlushGuardedOutputStream(decorateOutputStream(request, response, os))); + try { + doWriteRemoteInvocationResult(result, oos); + } + finally { + oos.close(); + } + } + + /** + * Return the OutputStream to use for writing remote invocation results, + * potentially decorating the given original OutputStream. + *

The default implementation returns the given stream as-is. + * Can be overridden, for example, for custom encryption or compression. + * @param request current HTTP request + * @param response current HTTP response + * @param os the original OutputStream + * @return the potentially decorated OutputStream + * @throws IOException in case of I/O failure + */ + protected OutputStream decorateOutputStream( + HttpServletRequest request, HttpServletResponse response, OutputStream os) throws IOException { + + return os; + } + + + /** + * Decorate an {@code OutputStream} to guard against {@code flush()} calls, + * which are turned into no-ops. + *

Because {@link ObjectOutputStream#close()} will in fact flush/drain + * the underlying stream twice, this {@link FilterOutputStream} will + * guard against individual flush calls. Multiple flush calls can lead + * to performance issues, since writes aren't gathered as they should be. + * @see SPR-14040 + */ + private static class FlushGuardedOutputStream extends FilterOutputStream { + + public FlushGuardedOutputStream(OutputStream out) { + super(out); + } + + @Override + public void flush() throws IOException { + // Do nothing on flush + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/SimpleHttpInvokerRequestExecutor.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/SimpleHttpInvokerRequestExecutor.java new file mode 100644 index 0000000000000000000000000000000000000000..bcaa28e5ff9773782e5b47ace39ee5d97edd9bb3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/SimpleHttpInvokerRequestExecutor.java @@ -0,0 +1,231 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLConnection; +import java.util.Locale; +import java.util.zip.GZIPInputStream; + +import org.springframework.context.i18n.LocaleContext; +import org.springframework.context.i18n.LocaleContextHolder; +import org.springframework.remoting.support.RemoteInvocationResult; + +/** + * {@link org.springframework.remoting.httpinvoker.HttpInvokerRequestExecutor} implementation + * that uses standard Java facilities to execute POST requests, without support for HTTP + * authentication or advanced configuration options. + * + *

Designed for easy subclassing, customizing specific template methods. However, + * consider {@code HttpComponentsHttpInvokerRequestExecutor} for more sophisticated needs: + * The standard {@link HttpURLConnection} class is rather limited in its capabilities. + * + * @author Juergen Hoeller + * @since 1.1 + * @see java.net.HttpURLConnection + */ +public class SimpleHttpInvokerRequestExecutor extends AbstractHttpInvokerRequestExecutor { + + private int connectTimeout = -1; + + private int readTimeout = -1; + + + /** + * Set the underlying URLConnection's connect timeout (in milliseconds). + * A timeout value of 0 specifies an infinite timeout. + *

Default is the system's default timeout. + * @see URLConnection#setConnectTimeout(int) + */ + public void setConnectTimeout(int connectTimeout) { + this.connectTimeout = connectTimeout; + } + + /** + * Set the underlying URLConnection's read timeout (in milliseconds). + * A timeout value of 0 specifies an infinite timeout. + *

Default is the system's default timeout. + * @see URLConnection#setReadTimeout(int) + */ + public void setReadTimeout(int readTimeout) { + this.readTimeout = readTimeout; + } + + + /** + * Execute the given request through a standard {@link HttpURLConnection}. + *

This method implements the basic processing workflow: + * The actual work happens in this class's template methods. + * @see #openConnection + * @see #prepareConnection + * @see #writeRequestBody + * @see #validateResponse + * @see #readResponseBody + */ + @Override + protected RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) + throws IOException, ClassNotFoundException { + + HttpURLConnection con = openConnection(config); + prepareConnection(con, baos.size()); + writeRequestBody(config, con, baos); + validateResponse(config, con); + InputStream responseBody = readResponseBody(config, con); + + return readRemoteInvocationResult(responseBody, config.getCodebaseUrl()); + } + + /** + * Open an {@link HttpURLConnection} for the given remote invocation request. + * @param config the HTTP invoker configuration that specifies the + * target service + * @return the HttpURLConnection for the given request + * @throws IOException if thrown by I/O methods + * @see java.net.URL#openConnection() + */ + protected HttpURLConnection openConnection(HttpInvokerClientConfiguration config) throws IOException { + URLConnection con = new URL(config.getServiceUrl()).openConnection(); + if (!(con instanceof HttpURLConnection)) { + throw new IOException( + "Service URL [" + config.getServiceUrl() + "] does not resolve to an HTTP connection"); + } + return (HttpURLConnection) con; + } + + /** + * Prepare the given HTTP connection. + *

The default implementation specifies POST as method, + * "application/x-java-serialized-object" as "Content-Type" header, + * and the given content length as "Content-Length" header. + * @param connection the HTTP connection to prepare + * @param contentLength the length of the content to send + * @throws IOException if thrown by HttpURLConnection methods + * @see java.net.HttpURLConnection#setRequestMethod + * @see java.net.HttpURLConnection#setRequestProperty + */ + protected void prepareConnection(HttpURLConnection connection, int contentLength) throws IOException { + if (this.connectTimeout >= 0) { + connection.setConnectTimeout(this.connectTimeout); + } + if (this.readTimeout >= 0) { + connection.setReadTimeout(this.readTimeout); + } + + connection.setDoOutput(true); + connection.setRequestMethod(HTTP_METHOD_POST); + connection.setRequestProperty(HTTP_HEADER_CONTENT_TYPE, getContentType()); + connection.setRequestProperty(HTTP_HEADER_CONTENT_LENGTH, Integer.toString(contentLength)); + + LocaleContext localeContext = LocaleContextHolder.getLocaleContext(); + if (localeContext != null) { + Locale locale = localeContext.getLocale(); + if (locale != null) { + connection.setRequestProperty(HTTP_HEADER_ACCEPT_LANGUAGE, locale.toLanguageTag()); + } + } + + if (isAcceptGzipEncoding()) { + connection.setRequestProperty(HTTP_HEADER_ACCEPT_ENCODING, ENCODING_GZIP); + } + } + + /** + * Set the given serialized remote invocation as request body. + *

The default implementation simply write the serialized invocation to the + * HttpURLConnection's OutputStream. This can be overridden, for example, to write + * a specific encoding and potentially set appropriate HTTP request headers. + * @param config the HTTP invoker configuration that specifies the target service + * @param con the HttpURLConnection to write the request body to + * @param baos the ByteArrayOutputStream that contains the serialized + * RemoteInvocation object + * @throws IOException if thrown by I/O methods + * @see java.net.HttpURLConnection#getOutputStream() + * @see java.net.HttpURLConnection#setRequestProperty + */ + protected void writeRequestBody( + HttpInvokerClientConfiguration config, HttpURLConnection con, ByteArrayOutputStream baos) + throws IOException { + + baos.writeTo(con.getOutputStream()); + } + + /** + * Validate the given response as contained in the {@link HttpURLConnection} object, + * throwing an exception if it does not correspond to a successful HTTP response. + *

Default implementation rejects any HTTP status code beyond 2xx, to avoid + * parsing the response body and trying to deserialize from a corrupted stream. + * @param config the HTTP invoker configuration that specifies the target service + * @param con the HttpURLConnection to validate + * @throws IOException if validation failed + * @see java.net.HttpURLConnection#getResponseCode() + */ + protected void validateResponse(HttpInvokerClientConfiguration config, HttpURLConnection con) + throws IOException { + + if (con.getResponseCode() >= 300) { + throw new IOException( + "Did not receive successful HTTP response: status code = " + con.getResponseCode() + + ", status message = [" + con.getResponseMessage() + "]"); + } + } + + /** + * Extract the response body from the given executed remote invocation + * request. + *

The default implementation simply reads the serialized invocation + * from the HttpURLConnection's InputStream. If the response is recognized + * as GZIP response, the InputStream will get wrapped in a GZIPInputStream. + * @param config the HTTP invoker configuration that specifies the target service + * @param con the HttpURLConnection to read the response body from + * @return an InputStream for the response body + * @throws IOException if thrown by I/O methods + * @see #isGzipResponse + * @see java.util.zip.GZIPInputStream + * @see java.net.HttpURLConnection#getInputStream() + * @see java.net.HttpURLConnection#getHeaderField(int) + * @see java.net.HttpURLConnection#getHeaderFieldKey(int) + */ + protected InputStream readResponseBody(HttpInvokerClientConfiguration config, HttpURLConnection con) + throws IOException { + + if (isGzipResponse(con)) { + // GZIP response found - need to unzip. + return new GZIPInputStream(con.getInputStream()); + } + else { + // Plain response found. + return con.getInputStream(); + } + } + + /** + * Determine whether the given response is a GZIP response. + *

Default implementation checks whether the HTTP "Content-Encoding" + * header contains "gzip" (in any casing). + * @param con the HttpURLConnection to check + */ + protected boolean isGzipResponse(HttpURLConnection con) { + String encodingHeader = con.getHeaderField(HTTP_HEADER_CONTENT_ENCODING); + return (encodingHeader != null && encodingHeader.toLowerCase().contains(ENCODING_GZIP)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/SimpleHttpInvokerServiceExporter.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/SimpleHttpInvokerServiceExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..d5396d75b98f61e1e90e06d30f011c4d65cf632a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/SimpleHttpInvokerServiceExporter.java @@ -0,0 +1,183 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; + +import org.springframework.remoting.rmi.RemoteInvocationSerializingExporter; +import org.springframework.remoting.support.RemoteInvocation; +import org.springframework.remoting.support.RemoteInvocationResult; + +/** + * HTTP request handler that exports the specified service bean as + * HTTP invoker service endpoint, accessible via an HTTP invoker proxy. + * Designed for Sun's JRE 1.6 HTTP server, implementing the + * {@link com.sun.net.httpserver.HttpHandler} interface. + * + *

Deserializes remote invocation objects and serializes remote invocation + * result objects. Uses Java serialization just like RMI, but provides the + * same ease of setup as Caucho's HTTP-based Hessian protocol. + * + *

HTTP invoker is the recommended protocol for Java-to-Java remoting. + * It is more powerful and more extensible than Hessian, at the expense of + * being tied to Java. Nevertheless, it is as easy to set up as Hessian, + * which is its main advantage compared to RMI. + * + *

WARNING: Be aware of vulnerabilities due to unsafe Java deserialization: + * Manipulated input streams could lead to unwanted code execution on the server + * during the deserialization step. As a consequence, do not expose HTTP invoker + * endpoints to untrusted clients but rather just between your own services. + * In general, we strongly recommend any other message format (e.g. JSON) instead. + * + * @author Juergen Hoeller + * @since 2.5.1 + * @see org.springframework.remoting.httpinvoker.HttpInvokerClientInterceptor + * @see org.springframework.remoting.httpinvoker.HttpInvokerProxyFactoryBean + * @deprecated as of Spring Framework 5.1, in favor of {@link HttpInvokerServiceExporter} + */ +@Deprecated +@org.springframework.lang.UsesSunHttpServer +public class SimpleHttpInvokerServiceExporter extends RemoteInvocationSerializingExporter implements HttpHandler { + + /** + * Reads a remote invocation from the request, executes it, + * and writes the remote invocation result to the response. + * @see #readRemoteInvocation(HttpExchange) + * @see #invokeAndCreateResult(RemoteInvocation, Object) + * @see #writeRemoteInvocationResult(HttpExchange, RemoteInvocationResult) + */ + @Override + public void handle(HttpExchange exchange) throws IOException { + try { + RemoteInvocation invocation = readRemoteInvocation(exchange); + RemoteInvocationResult result = invokeAndCreateResult(invocation, getProxy()); + writeRemoteInvocationResult(exchange, result); + exchange.close(); + } + catch (ClassNotFoundException ex) { + exchange.sendResponseHeaders(500, -1); + logger.error("Class not found during deserialization", ex); + } + } + + /** + * Read a RemoteInvocation from the given HTTP request. + *

Delegates to {@link #readRemoteInvocation(HttpExchange, InputStream)} + * with the {@link HttpExchange#getRequestBody()} request's input stream}. + * @param exchange current HTTP request/response + * @return the RemoteInvocation object + * @throws java.io.IOException in case of I/O failure + * @throws ClassNotFoundException if thrown by deserialization + */ + protected RemoteInvocation readRemoteInvocation(HttpExchange exchange) + throws IOException, ClassNotFoundException { + + return readRemoteInvocation(exchange, exchange.getRequestBody()); + } + + /** + * Deserialize a RemoteInvocation object from the given InputStream. + *

Gives {@link #decorateInputStream} a chance to decorate the stream + * first (for example, for custom encryption or compression). Creates a + * {@link org.springframework.remoting.rmi.CodebaseAwareObjectInputStream} + * and calls {@link #doReadRemoteInvocation} to actually read the object. + *

Can be overridden for custom serialization of the invocation. + * @param exchange current HTTP request/response + * @param is the InputStream to read from + * @return the RemoteInvocation object + * @throws java.io.IOException in case of I/O failure + * @throws ClassNotFoundException if thrown during deserialization + */ + protected RemoteInvocation readRemoteInvocation(HttpExchange exchange, InputStream is) + throws IOException, ClassNotFoundException { + + ObjectInputStream ois = createObjectInputStream(decorateInputStream(exchange, is)); + return doReadRemoteInvocation(ois); + } + + /** + * Return the InputStream to use for reading remote invocations, + * potentially decorating the given original InputStream. + *

The default implementation returns the given stream as-is. + * Can be overridden, for example, for custom encryption or compression. + * @param exchange current HTTP request/response + * @param is the original InputStream + * @return the potentially decorated InputStream + * @throws java.io.IOException in case of I/O failure + */ + protected InputStream decorateInputStream(HttpExchange exchange, InputStream is) throws IOException { + return is; + } + + /** + * Write the given RemoteInvocationResult to the given HTTP response. + * @param exchange current HTTP request/response + * @param result the RemoteInvocationResult object + * @throws java.io.IOException in case of I/O failure + */ + protected void writeRemoteInvocationResult(HttpExchange exchange, RemoteInvocationResult result) + throws IOException { + + exchange.getResponseHeaders().set("Content-Type", getContentType()); + exchange.sendResponseHeaders(200, 0); + writeRemoteInvocationResult(exchange, result, exchange.getResponseBody()); + } + + /** + * Serialize the given RemoteInvocation to the given OutputStream. + *

The default implementation gives {@link #decorateOutputStream} a chance + * to decorate the stream first (for example, for custom encryption or compression). + * Creates an {@link java.io.ObjectOutputStream} for the final stream and calls + * {@link #doWriteRemoteInvocationResult} to actually write the object. + *

Can be overridden for custom serialization of the invocation. + * @param exchange current HTTP request/response + * @param result the RemoteInvocationResult object + * @param os the OutputStream to write to + * @throws java.io.IOException in case of I/O failure + * @see #decorateOutputStream + * @see #doWriteRemoteInvocationResult + */ + protected void writeRemoteInvocationResult( + HttpExchange exchange, RemoteInvocationResult result, OutputStream os) throws IOException { + + ObjectOutputStream oos = createObjectOutputStream(decorateOutputStream(exchange, os)); + doWriteRemoteInvocationResult(result, oos); + oos.flush(); + } + + /** + * Return the OutputStream to use for writing remote invocation results, + * potentially decorating the given original OutputStream. + *

The default implementation returns the given stream as-is. + * Can be overridden, for example, for custom encryption or compression. + * @param exchange current HTTP request/response + * @param os the original OutputStream + * @return the potentially decorated OutputStream + * @throws java.io.IOException in case of I/O failure + */ + protected OutputStream decorateOutputStream(HttpExchange exchange, OutputStream os) throws IOException { + return os; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/httpinvoker/package-info.java b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..e0dfd03aa651e7c90ed72722324be449d6310174 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/httpinvoker/package-info.java @@ -0,0 +1,16 @@ +/** + * Remoting classes for transparent Java-to-Java remoting via HTTP invokers. + * Uses Java serialization just like RMI, but provides the same ease of setup + * as Caucho's HTTP-based Hessian protocol. + * + *

HTTP invoker is the recommended protocol for Java-to-Java remoting. + * It is more powerful and more extensible than Hessian, at the expense of + * being tied to Java. Nevertheless, it is as easy to set up as Hessian, + * which is its main advantage compared to RMI. + */ +@NonNullApi +@NonNullFields +package org.springframework.remoting.httpinvoker; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/AbstractJaxWsServiceExporter.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/AbstractJaxWsServiceExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..610651bef2060d0fec7b5db43df329ccfcfb1e9d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/AbstractJaxWsServiceExporter.java @@ -0,0 +1,214 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; + +import javax.jws.WebService; +import javax.xml.ws.Endpoint; +import javax.xml.ws.WebServiceFeature; +import javax.xml.ws.WebServiceProvider; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.CannotLoadBeanClassException; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Abstract exporter for JAX-WS services, autodetecting annotated service beans + * (through the JAX-WS {@link javax.jws.WebService} annotation). + * + *

Subclasses need to implement the {@link #publishEndpoint} template methods + * for actual endpoint exposure. + * + * @author Juergen Hoeller + * @since 2.5.5 + * @see javax.jws.WebService + * @see javax.xml.ws.Endpoint + * @see SimpleJaxWsServiceExporter + */ +public abstract class AbstractJaxWsServiceExporter implements BeanFactoryAware, InitializingBean, DisposableBean { + + @Nullable + private Map endpointProperties; + + @Nullable + private Executor executor; + + @Nullable + private String bindingType; + + @Nullable + private WebServiceFeature[] endpointFeatures; + + @Nullable + private ListableBeanFactory beanFactory; + + private final Set publishedEndpoints = new LinkedHashSet<>(); + + + /** + * Set the property bag for the endpoint, including properties such as + * "javax.xml.ws.wsdl.service" or "javax.xml.ws.wsdl.port". + * @see javax.xml.ws.Endpoint#setProperties + * @see javax.xml.ws.Endpoint#WSDL_SERVICE + * @see javax.xml.ws.Endpoint#WSDL_PORT + */ + public void setEndpointProperties(Map endpointProperties) { + this.endpointProperties = endpointProperties; + } + + /** + * Set the JDK concurrent executor to use for dispatching incoming requests + * to exported service instances. + * @see javax.xml.ws.Endpoint#setExecutor + */ + public void setExecutor(Executor executor) { + this.executor = executor; + } + + /** + * Specify the binding type to use, overriding the value of + * the JAX-WS {@link javax.xml.ws.BindingType} annotation. + */ + public void setBindingType(String bindingType) { + this.bindingType = bindingType; + } + + /** + * Specify WebServiceFeature objects (e.g. as inner bean definitions) + * to apply to JAX-WS endpoint creation. + * @since 4.0 + */ + public void setEndpointFeatures(WebServiceFeature... endpointFeatures) { + this.endpointFeatures = endpointFeatures; + } + + /** + * Obtains all web service beans and publishes them as JAX-WS endpoints. + */ + @Override + public void setBeanFactory(BeanFactory beanFactory) { + if (!(beanFactory instanceof ListableBeanFactory)) { + throw new IllegalStateException(getClass().getSimpleName() + " requires a ListableBeanFactory"); + } + this.beanFactory = (ListableBeanFactory) beanFactory; + } + + + /** + * Immediately publish all endpoints when fully configured. + * @see #publishEndpoints() + */ + @Override + public void afterPropertiesSet() throws Exception { + publishEndpoints(); + } + + /** + * Publish all {@link javax.jws.WebService} annotated beans in the + * containing BeanFactory. + * @see #publishEndpoint + */ + public void publishEndpoints() { + Assert.state(this.beanFactory != null, "No BeanFactory set"); + + Set beanNames = new LinkedHashSet<>(this.beanFactory.getBeanDefinitionCount()); + Collections.addAll(beanNames, this.beanFactory.getBeanDefinitionNames()); + if (this.beanFactory instanceof ConfigurableBeanFactory) { + Collections.addAll(beanNames, ((ConfigurableBeanFactory) this.beanFactory).getSingletonNames()); + } + + for (String beanName : beanNames) { + try { + Class type = this.beanFactory.getType(beanName); + if (type != null && !type.isInterface()) { + WebService wsAnnotation = type.getAnnotation(WebService.class); + WebServiceProvider wsProviderAnnotation = type.getAnnotation(WebServiceProvider.class); + if (wsAnnotation != null || wsProviderAnnotation != null) { + Endpoint endpoint = createEndpoint(this.beanFactory.getBean(beanName)); + if (this.endpointProperties != null) { + endpoint.setProperties(this.endpointProperties); + } + if (this.executor != null) { + endpoint.setExecutor(this.executor); + } + if (wsAnnotation != null) { + publishEndpoint(endpoint, wsAnnotation); + } + else { + publishEndpoint(endpoint, wsProviderAnnotation); + } + this.publishedEndpoints.add(endpoint); + } + } + } + catch (CannotLoadBeanClassException ex) { + // ignore beans where the class is not resolvable + } + } + } + + /** + * Create the actual Endpoint instance. + * @param bean the service object to wrap + * @return the Endpoint instance + * @see Endpoint#create(Object) + * @see Endpoint#create(String, Object) + */ + protected Endpoint createEndpoint(Object bean) { + return (this.endpointFeatures != null ? + Endpoint.create(this.bindingType, bean, this.endpointFeatures) : + Endpoint.create(this.bindingType, bean)); + } + + + /** + * Actually publish the given endpoint. To be implemented by subclasses. + * @param endpoint the JAX-WS Endpoint object + * @param annotation the service bean's WebService annotation + */ + protected abstract void publishEndpoint(Endpoint endpoint, WebService annotation); + + /** + * Actually publish the given provider endpoint. To be implemented by subclasses. + * @param endpoint the JAX-WS Provider Endpoint object + * @param annotation the service bean's WebServiceProvider annotation + */ + protected abstract void publishEndpoint(Endpoint endpoint, WebServiceProvider annotation); + + + /** + * Stops all published endpoints, taking the web services offline. + */ + @Override + public void destroy() { + for (Endpoint endpoint : this.publishedEndpoints) { + endpoint.stop(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsPortClientInterceptor.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsPortClientInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..e8ff14889aaee9e35931b51a2b49a2580e1ded1e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsPortClientInterceptor.java @@ -0,0 +1,561 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; + +import javax.jws.WebService; +import javax.xml.namespace.QName; +import javax.xml.ws.BindingProvider; +import javax.xml.ws.ProtocolException; +import javax.xml.ws.Service; +import javax.xml.ws.WebServiceException; +import javax.xml.ws.WebServiceFeature; +import javax.xml.ws.soap.SOAPFaultException; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; + +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.factory.BeanClassLoaderAware; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.remoting.RemoteAccessException; +import org.springframework.remoting.RemoteConnectFailureException; +import org.springframework.remoting.RemoteLookupFailureException; +import org.springframework.remoting.RemoteProxyFailureException; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; + +/** + * {@link org.aopalliance.intercept.MethodInterceptor} for accessing a + * specific port of a JAX-WS service. + * + *

Uses either {@link LocalJaxWsServiceFactory}'s facilities underneath, + * or takes an explicit reference to an existing JAX-WS Service instance + * (e.g. obtained via {@link org.springframework.jndi.JndiObjectFactoryBean}). + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setPortName + * @see #setServiceInterface + * @see javax.xml.ws.Service#getPort + * @see org.springframework.remoting.RemoteAccessException + * @see org.springframework.jndi.JndiObjectFactoryBean + */ +public class JaxWsPortClientInterceptor extends LocalJaxWsServiceFactory + implements MethodInterceptor, BeanClassLoaderAware, InitializingBean { + + @Nullable + private Service jaxWsService; + + @Nullable + private String portName; + + @Nullable + private String username; + + @Nullable + private String password; + + @Nullable + private String endpointAddress; + + private boolean maintainSession; + + private boolean useSoapAction; + + @Nullable + private String soapActionUri; + + @Nullable + private Map customProperties; + + @Nullable + private WebServiceFeature[] portFeatures; + + @Nullable + private Class serviceInterface; + + private boolean lookupServiceOnStartup = true; + + @Nullable + private ClassLoader beanClassLoader = ClassUtils.getDefaultClassLoader(); + + @Nullable + private QName portQName; + + @Nullable + private Object portStub; + + private final Object preparationMonitor = new Object(); + + + /** + * Set a reference to an existing JAX-WS Service instance, + * for example obtained via {@link org.springframework.jndi.JndiObjectFactoryBean}. + * If not set, {@link LocalJaxWsServiceFactory}'s properties have to be specified. + * @see #setWsdlDocumentUrl + * @see #setNamespaceUri + * @see #setServiceName + * @see org.springframework.jndi.JndiObjectFactoryBean + */ + public void setJaxWsService(@Nullable Service jaxWsService) { + this.jaxWsService = jaxWsService; + } + + /** + * Return a reference to an existing JAX-WS Service instance, if any. + */ + @Nullable + public Service getJaxWsService() { + return this.jaxWsService; + } + + /** + * Set the name of the port. + * Corresponds to the "wsdl:port" name. + */ + public void setPortName(@Nullable String portName) { + this.portName = portName; + } + + /** + * Return the name of the port. + */ + @Nullable + public String getPortName() { + return this.portName; + } + + /** + * Set the username to specify on the stub. + * @see javax.xml.ws.BindingProvider#USERNAME_PROPERTY + */ + public void setUsername(@Nullable String username) { + this.username = username; + } + + /** + * Return the username to specify on the stub. + */ + @Nullable + public String getUsername() { + return this.username; + } + + /** + * Set the password to specify on the stub. + * @see javax.xml.ws.BindingProvider#PASSWORD_PROPERTY + */ + public void setPassword(@Nullable String password) { + this.password = password; + } + + /** + * Return the password to specify on the stub. + */ + @Nullable + public String getPassword() { + return this.password; + } + + /** + * Set the endpoint address to specify on the stub. + * @see javax.xml.ws.BindingProvider#ENDPOINT_ADDRESS_PROPERTY + */ + public void setEndpointAddress(@Nullable String endpointAddress) { + this.endpointAddress = endpointAddress; + } + + /** + * Return the endpoint address to specify on the stub. + */ + @Nullable + public String getEndpointAddress() { + return this.endpointAddress; + } + + /** + * Set the "session.maintain" flag to specify on the stub. + * @see javax.xml.ws.BindingProvider#SESSION_MAINTAIN_PROPERTY + */ + public void setMaintainSession(boolean maintainSession) { + this.maintainSession = maintainSession; + } + + /** + * Return the "session.maintain" flag to specify on the stub. + */ + public boolean isMaintainSession() { + return this.maintainSession; + } + + /** + * Set the "soapaction.use" flag to specify on the stub. + * @see javax.xml.ws.BindingProvider#SOAPACTION_USE_PROPERTY + */ + public void setUseSoapAction(boolean useSoapAction) { + this.useSoapAction = useSoapAction; + } + + /** + * Return the "soapaction.use" flag to specify on the stub. + */ + public boolean isUseSoapAction() { + return this.useSoapAction; + } + + /** + * Set the SOAP action URI to specify on the stub. + * @see javax.xml.ws.BindingProvider#SOAPACTION_URI_PROPERTY + */ + public void setSoapActionUri(@Nullable String soapActionUri) { + this.soapActionUri = soapActionUri; + } + + /** + * Return the SOAP action URI to specify on the stub. + */ + @Nullable + public String getSoapActionUri() { + return this.soapActionUri; + } + + /** + * Set custom properties to be set on the stub. + *

Can be populated with a String "value" (parsed via PropertiesEditor) + * or a "props" element in XML bean definitions. + * @see javax.xml.ws.BindingProvider#getRequestContext() + */ + public void setCustomProperties(Map customProperties) { + this.customProperties = customProperties; + } + + /** + * Allow Map access to the custom properties to be set on the stub, + * with the option to add or override specific entries. + *

Useful for specifying entries directly, for example via + * "customProperties[myKey]". This is particularly useful for + * adding or overriding entries in child bean definitions. + */ + public Map getCustomProperties() { + if (this.customProperties == null) { + this.customProperties = new HashMap<>(); + } + return this.customProperties; + } + + /** + * Add a custom property to this JAX-WS BindingProvider. + * @param name the name of the attribute to expose + * @param value the attribute value to expose + * @see javax.xml.ws.BindingProvider#getRequestContext() + */ + public void addCustomProperty(String name, Object value) { + getCustomProperties().put(name, value); + } + + /** + * Specify WebServiceFeature objects (e.g. as inner bean definitions) + * to apply to JAX-WS port stub creation. + * @since 4.0 + * @see Service#getPort(Class, javax.xml.ws.WebServiceFeature...) + * @see #setServiceFeatures + */ + public void setPortFeatures(WebServiceFeature... features) { + this.portFeatures = features; + } + + /** + * Set the interface of the service that this factory should create a proxy for. + */ + public void setServiceInterface(@Nullable Class serviceInterface) { + if (serviceInterface != null) { + Assert.isTrue(serviceInterface.isInterface(), "'serviceInterface' must be an interface"); + } + this.serviceInterface = serviceInterface; + } + + /** + * Return the interface of the service that this factory should create a proxy for. + */ + @Nullable + public Class getServiceInterface() { + return this.serviceInterface; + } + + /** + * Set whether to look up the JAX-WS service on startup. + *

Default is "true". Turn this flag off to allow for late start + * of the target server. In this case, the JAX-WS service will be + * lazily fetched on first access. + */ + public void setLookupServiceOnStartup(boolean lookupServiceOnStartup) { + this.lookupServiceOnStartup = lookupServiceOnStartup; + } + + /** + * Set the bean ClassLoader to use for this interceptor: primarily for + * building a client proxy in the {@link JaxWsPortProxyFactoryBean} subclass. + */ + @Override + public void setBeanClassLoader(@Nullable ClassLoader classLoader) { + this.beanClassLoader = classLoader; + } + + /** + * Return the bean ClassLoader to use for this interceptor. + */ + @Nullable + protected ClassLoader getBeanClassLoader() { + return this.beanClassLoader; + } + + + @Override + public void afterPropertiesSet() { + if (this.lookupServiceOnStartup) { + prepare(); + } + } + + /** + * Initialize the JAX-WS port for this interceptor. + */ + public void prepare() { + Class ifc = getServiceInterface(); + Assert.notNull(ifc, "Property 'serviceInterface' is required"); + + WebService ann = ifc.getAnnotation(WebService.class); + if (ann != null) { + applyDefaultsFromAnnotation(ann); + } + + Service serviceToUse = getJaxWsService(); + if (serviceToUse == null) { + serviceToUse = createJaxWsService(); + } + + this.portQName = getQName(getPortName() != null ? getPortName() : ifc.getName()); + Object stub = getPortStub(serviceToUse, (getPortName() != null ? this.portQName : null)); + preparePortStub(stub); + this.portStub = stub; + } + + /** + * Initialize this client interceptor's properties from the given WebService annotation, + * if necessary and possible (i.e. if "wsdlDocumentUrl", "namespaceUri", "serviceName" + * and "portName" haven't been set but corresponding values are declared at the + * annotation level of the specified service interface). + * @param ann the WebService annotation found on the specified service interface + */ + protected void applyDefaultsFromAnnotation(WebService ann) { + if (getWsdlDocumentUrl() == null) { + String wsdl = ann.wsdlLocation(); + if (StringUtils.hasText(wsdl)) { + try { + setWsdlDocumentUrl(new URL(wsdl)); + } + catch (MalformedURLException ex) { + throw new IllegalStateException( + "Encountered invalid @Service wsdlLocation value [" + wsdl + "]", ex); + } + } + } + if (getNamespaceUri() == null) { + String ns = ann.targetNamespace(); + if (StringUtils.hasText(ns)) { + setNamespaceUri(ns); + } + } + if (getServiceName() == null) { + String sn = ann.serviceName(); + if (StringUtils.hasText(sn)) { + setServiceName(sn); + } + } + if (getPortName() == null) { + String pn = ann.portName(); + if (StringUtils.hasText(pn)) { + setPortName(pn); + } + } + } + + /** + * Return whether this client interceptor has already been prepared, + * i.e. has already looked up the JAX-WS service and port. + */ + protected boolean isPrepared() { + synchronized (this.preparationMonitor) { + return (this.portStub != null); + } + } + + /** + * Return the prepared QName for the port. + * @see #setPortName + * @see #getQName + */ + @Nullable + protected final QName getPortQName() { + return this.portQName; + } + + /** + * Obtain the port stub from the given JAX-WS Service. + * @param service the Service object to obtain the port from + * @param portQName the name of the desired port, if specified + * @return the corresponding port object as returned from + * {@code Service.getPort(...)} + */ + protected Object getPortStub(Service service, @Nullable QName portQName) { + if (this.portFeatures != null) { + return (portQName != null ? service.getPort(portQName, getServiceInterface(), this.portFeatures) : + service.getPort(getServiceInterface(), this.portFeatures)); + } + else { + return (portQName != null ? service.getPort(portQName, getServiceInterface()) : + service.getPort(getServiceInterface())); + } + } + + /** + * Prepare the given JAX-WS port stub, applying properties to it. + * Called by {@link #prepare}. + * @param stub the current JAX-WS port stub + * @see #setUsername + * @see #setPassword + * @see #setEndpointAddress + * @see #setMaintainSession + * @see #setCustomProperties + */ + protected void preparePortStub(Object stub) { + Map stubProperties = new HashMap<>(); + String username = getUsername(); + if (username != null) { + stubProperties.put(BindingProvider.USERNAME_PROPERTY, username); + } + String password = getPassword(); + if (password != null) { + stubProperties.put(BindingProvider.PASSWORD_PROPERTY, password); + } + String endpointAddress = getEndpointAddress(); + if (endpointAddress != null) { + stubProperties.put(BindingProvider.ENDPOINT_ADDRESS_PROPERTY, endpointAddress); + } + if (isMaintainSession()) { + stubProperties.put(BindingProvider.SESSION_MAINTAIN_PROPERTY, Boolean.TRUE); + } + if (isUseSoapAction()) { + stubProperties.put(BindingProvider.SOAPACTION_USE_PROPERTY, Boolean.TRUE); + } + String soapActionUri = getSoapActionUri(); + if (soapActionUri != null) { + stubProperties.put(BindingProvider.SOAPACTION_URI_PROPERTY, soapActionUri); + } + stubProperties.putAll(getCustomProperties()); + if (!stubProperties.isEmpty()) { + if (!(stub instanceof BindingProvider)) { + throw new RemoteLookupFailureException("Port stub of class [" + stub.getClass().getName() + + "] is not a customizable JAX-WS stub: it does not implement interface [javax.xml.ws.BindingProvider]"); + } + ((BindingProvider) stub).getRequestContext().putAll(stubProperties); + } + } + + /** + * Return the underlying JAX-WS port stub that this interceptor delegates to + * for each method invocation on the proxy. + */ + @Nullable + protected Object getPortStub() { + return this.portStub; + } + + + @Override + @Nullable + public Object invoke(MethodInvocation invocation) throws Throwable { + if (AopUtils.isToStringMethod(invocation.getMethod())) { + return "JAX-WS proxy for port [" + getPortName() + "] of service [" + getServiceName() + "]"; + } + // Lazily prepare service and stub if necessary. + synchronized (this.preparationMonitor) { + if (!isPrepared()) { + prepare(); + } + } + return doInvoke(invocation); + } + + /** + * Perform a JAX-WS service invocation based on the given method invocation. + * @param invocation the AOP method invocation + * @return the invocation result, if any + * @throws Throwable in case of invocation failure + * @see #getPortStub() + * @see #doInvoke(org.aopalliance.intercept.MethodInvocation, Object) + */ + @Nullable + protected Object doInvoke(MethodInvocation invocation) throws Throwable { + try { + return doInvoke(invocation, getPortStub()); + } + catch (SOAPFaultException ex) { + throw new JaxWsSoapFaultException(ex); + } + catch (ProtocolException ex) { + throw new RemoteConnectFailureException( + "Could not connect to remote service [" + getEndpointAddress() + "]", ex); + } + catch (WebServiceException ex) { + throw new RemoteAccessException( + "Could not access remote service at [" + getEndpointAddress() + "]", ex); + } + } + + /** + * Perform a JAX-WS service invocation on the given port stub. + * @param invocation the AOP method invocation + * @param portStub the RMI port stub to invoke + * @return the invocation result, if any + * @throws Throwable in case of invocation failure + * @see #getPortStub() + */ + @Nullable + protected Object doInvoke(MethodInvocation invocation, @Nullable Object portStub) throws Throwable { + Method method = invocation.getMethod(); + try { + return method.invoke(portStub, invocation.getArguments()); + } + catch (InvocationTargetException ex) { + throw ex.getTargetException(); + } + catch (Throwable ex) { + throw new RemoteProxyFailureException("Invocation of stub method failed: " + method, ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsPortProxyFactoryBean.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsPortProxyFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..3788d0e0e7a069e3dcdf818238cf2e38129ac9d0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsPortProxyFactoryBean.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import javax.xml.ws.BindingProvider; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link org.springframework.beans.factory.FactoryBean} for a specific port of a + * JAX-WS service. Exposes a proxy for the port, to be used for bean references. + * Inherits configuration properties from {@link JaxWsPortClientInterceptor}. + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setServiceInterface + * @see LocalJaxWsServiceFactoryBean + */ +public class JaxWsPortProxyFactoryBean extends JaxWsPortClientInterceptor implements FactoryBean { + + @Nullable + private Object serviceProxy; + + + @Override + public void afterPropertiesSet() { + super.afterPropertiesSet(); + + Class ifc = getServiceInterface(); + Assert.notNull(ifc, "Property 'serviceInterface' is required"); + + // Build a proxy that also exposes the JAX-WS BindingProvider interface. + ProxyFactory pf = new ProxyFactory(); + pf.addInterface(ifc); + pf.addInterface(BindingProvider.class); + pf.addAdvice(this); + this.serviceProxy = pf.getProxy(getBeanClassLoader()); + } + + + @Override + @Nullable + public Object getObject() { + return this.serviceProxy; + } + + @Override + public Class getObjectType() { + return getServiceInterface(); + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsSoapFaultException.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsSoapFaultException.java new file mode 100644 index 0000000000000000000000000000000000000000..0205f56e68e381d22d817fa91b2cbed008f77495 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/JaxWsSoapFaultException.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import javax.xml.namespace.QName; +import javax.xml.soap.SOAPFault; +import javax.xml.ws.soap.SOAPFaultException; + +import org.springframework.remoting.soap.SoapFaultException; + +/** + * Spring SoapFaultException adapter for the JAX-WS + * {@link javax.xml.ws.soap.SOAPFaultException} class. + * + * @author Juergen Hoeller + * @since 2.5 + */ +@SuppressWarnings("serial") +public class JaxWsSoapFaultException extends SoapFaultException { + + /** + * Constructor for JaxWsSoapFaultException. + * @param original the original JAX-WS SOAPFaultException to wrap + */ + public JaxWsSoapFaultException(SOAPFaultException original) { + super(original.getMessage(), original); + } + + /** + * Return the wrapped JAX-WS SOAPFault. + */ + public final SOAPFault getFault() { + return ((SOAPFaultException) getCause()).getFault(); + } + + + @Override + public String getFaultCode() { + return getFault().getFaultCode(); + } + + @Override + public QName getFaultCodeAsQName() { + return getFault().getFaultCodeAsQName(); + } + + @Override + public String getFaultString() { + return getFault().getFaultString(); + } + + @Override + public String getFaultActor() { + return getFault().getFaultActor(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/LocalJaxWsServiceFactory.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/LocalJaxWsServiceFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..8abc4f1b0298ac06b9048784d6911e5010fcd3e4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/LocalJaxWsServiceFactory.java @@ -0,0 +1,192 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import java.io.IOException; +import java.net.URL; +import java.util.concurrent.Executor; + +import javax.xml.namespace.QName; +import javax.xml.ws.Service; +import javax.xml.ws.WebServiceFeature; +import javax.xml.ws.handler.HandlerResolver; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Factory for locally defined JAX-WS {@link javax.xml.ws.Service} references. + * Uses the JAX-WS {@link javax.xml.ws.Service#create} factory API underneath. + * + *

Serves as base class for {@link LocalJaxWsServiceFactoryBean} as well as + * {@link JaxWsPortClientInterceptor} and {@link JaxWsPortProxyFactoryBean}. + * + * @author Juergen Hoeller + * @since 2.5 + * @see javax.xml.ws.Service + * @see LocalJaxWsServiceFactoryBean + * @see JaxWsPortClientInterceptor + * @see JaxWsPortProxyFactoryBean + */ +public class LocalJaxWsServiceFactory { + + @Nullable + private URL wsdlDocumentUrl; + + @Nullable + private String namespaceUri; + + @Nullable + private String serviceName; + + @Nullable + private WebServiceFeature[] serviceFeatures; + + @Nullable + private Executor executor; + + @Nullable + private HandlerResolver handlerResolver; + + + /** + * Set the URL of the WSDL document that describes the service. + * @see #setWsdlDocumentResource(Resource) + */ + public void setWsdlDocumentUrl(@Nullable URL wsdlDocumentUrl) { + this.wsdlDocumentUrl = wsdlDocumentUrl; + } + + /** + * Set the WSDL document URL as a {@link Resource}. + * @since 3.2 + */ + public void setWsdlDocumentResource(Resource wsdlDocumentResource) throws IOException { + Assert.notNull(wsdlDocumentResource, "WSDL Resource must not be null"); + this.wsdlDocumentUrl = wsdlDocumentResource.getURL(); + } + + /** + * Return the URL of the WSDL document that describes the service. + */ + @Nullable + public URL getWsdlDocumentUrl() { + return this.wsdlDocumentUrl; + } + + /** + * Set the namespace URI of the service. + * Corresponds to the WSDL "targetNamespace". + */ + public void setNamespaceUri(@Nullable String namespaceUri) { + this.namespaceUri = (namespaceUri != null ? namespaceUri.trim() : null); + } + + /** + * Return the namespace URI of the service. + */ + @Nullable + public String getNamespaceUri() { + return this.namespaceUri; + } + + /** + * Set the name of the service to look up. + * Corresponds to the "wsdl:service" name. + */ + public void setServiceName(@Nullable String serviceName) { + this.serviceName = serviceName; + } + + /** + * Return the name of the service. + */ + @Nullable + public String getServiceName() { + return this.serviceName; + } + + /** + * Specify WebServiceFeature objects (e.g. as inner bean definitions) + * to apply to JAX-WS service creation. + * @since 4.0 + * @see Service#create(QName, WebServiceFeature...) + */ + public void setServiceFeatures(WebServiceFeature... serviceFeatures) { + this.serviceFeatures = serviceFeatures; + } + + /** + * Set the JDK concurrent executor to use for asynchronous executions + * that require callbacks. + * @see javax.xml.ws.Service#setExecutor + */ + public void setExecutor(Executor executor) { + this.executor = executor; + } + + /** + * Set the JAX-WS HandlerResolver to use for all proxies and dispatchers + * created through this factory. + * @see javax.xml.ws.Service#setHandlerResolver + */ + public void setHandlerResolver(HandlerResolver handlerResolver) { + this.handlerResolver = handlerResolver; + } + + + /** + * Create a JAX-WS Service according to the parameters of this factory. + * @see #setServiceName + * @see #setWsdlDocumentUrl + */ + public Service createJaxWsService() { + Assert.notNull(this.serviceName, "No service name specified"); + Service service; + + if (this.serviceFeatures != null) { + service = (this.wsdlDocumentUrl != null ? + Service.create(this.wsdlDocumentUrl, getQName(this.serviceName), this.serviceFeatures) : + Service.create(getQName(this.serviceName), this.serviceFeatures)); + } + else { + service = (this.wsdlDocumentUrl != null ? + Service.create(this.wsdlDocumentUrl, getQName(this.serviceName)) : + Service.create(getQName(this.serviceName))); + } + + if (this.executor != null) { + service.setExecutor(this.executor); + } + if (this.handlerResolver != null) { + service.setHandlerResolver(this.handlerResolver); + } + + return service; + } + + /** + * Return a QName for the given name, relative to the namespace URI + * of this factory, if given. + * @see #setNamespaceUri + */ + protected QName getQName(String name) { + return (getNamespaceUri() != null ? new QName(getNamespaceUri(), name) : new QName(name)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/LocalJaxWsServiceFactoryBean.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/LocalJaxWsServiceFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..d512f7de6f5163749d07179a366609f16e027782 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/LocalJaxWsServiceFactoryBean.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import javax.xml.ws.Service; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; + +/** + * {@link org.springframework.beans.factory.FactoryBean} for locally + * defined JAX-WS Service references. + * Uses {@link LocalJaxWsServiceFactory}'s facilities underneath. + * + *

Alternatively, JAX-WS Service references can be looked up + * in the JNDI environment of the Java EE container. + * + * @author Juergen Hoeller + * @since 2.5 + * @see javax.xml.ws.Service + * @see org.springframework.jndi.JndiObjectFactoryBean + * @see JaxWsPortProxyFactoryBean + */ +public class LocalJaxWsServiceFactoryBean extends LocalJaxWsServiceFactory + implements FactoryBean, InitializingBean { + + @Nullable + private Service service; + + + @Override + public void afterPropertiesSet() { + this.service = createJaxWsService(); + } + + @Override + @Nullable + public Service getObject() { + return this.service; + } + + @Override + public Class getObjectType() { + return (this.service != null ? this.service.getClass() : Service.class); + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/SimpleHttpServerJaxWsServiceExporter.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/SimpleHttpServerJaxWsServiceExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..54382b30d6e74397414dfc735337531253de2023 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/SimpleHttpServerJaxWsServiceExporter.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import java.net.InetSocketAddress; +import java.util.List; + +import javax.jws.WebService; +import javax.xml.ws.Endpoint; +import javax.xml.ws.WebServiceProvider; + +import com.sun.net.httpserver.Authenticator; +import com.sun.net.httpserver.Filter; +import com.sun.net.httpserver.HttpContext; +import com.sun.net.httpserver.HttpServer; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Simple exporter for JAX-WS services, autodetecting annotated service beans + * (through the JAX-WS {@link javax.jws.WebService} annotation) and exporting + * them through the HTTP server included in Sun's JDK 1.6. The full address + * for each service will consist of the server's base address with the + * service name appended (e.g. "http://localhost:8080/OrderService"). + * + *

Note that this exporter will only work on Sun's JDK 1.6 or higher, as well + * as on JDKs that ship Sun's entire class library as included in the Sun JDK. + * For a portable JAX-WS exporter, have a look at {@link SimpleJaxWsServiceExporter}. + * + * @author Juergen Hoeller + * @since 2.5.5 + * @see javax.jws.WebService + * @see javax.xml.ws.Endpoint#publish(Object) + * @see SimpleJaxWsServiceExporter + * @deprecated as of Spring Framework 5.1, in favor of {@link SimpleJaxWsServiceExporter} + */ +@Deprecated +@org.springframework.lang.UsesSunHttpServer +public class SimpleHttpServerJaxWsServiceExporter extends AbstractJaxWsServiceExporter { + + protected final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private HttpServer server; + + private int port = 8080; + + @Nullable + private String hostname; + + private int backlog = -1; + + private int shutdownDelay = 0; + + private String basePath = "/"; + + @Nullable + private List filters; + + @Nullable + private Authenticator authenticator; + + private boolean localServer = false; + + + /** + * Specify an existing HTTP server to register the web service contexts + * with. This will typically be a server managed by the general Spring + * {@link org.springframework.remoting.support.SimpleHttpServerFactoryBean}. + *

Alternatively, configure a local HTTP server through the + * {@link #setPort "port"}, {@link #setHostname "hostname"} and + * {@link #setBacklog "backlog"} properties (or rely on the defaults there). + */ + public void setServer(HttpServer server) { + this.server = server; + } + + /** + * Specify the HTTP server's port. Default is 8080. + *

Only applicable for a locally configured HTTP server. + * Ignored when the {@link #setServer "server"} property has been specified. + */ + public void setPort(int port) { + this.port = port; + } + + /** + * Specify the HTTP server's hostname to bind to. Default is localhost; + * can be overridden with a specific network address to bind to. + *

Only applicable for a locally configured HTTP server. + * Ignored when the {@link #setServer "server"} property has been specified. + */ + public void setHostname(String hostname) { + this.hostname = hostname; + } + + /** + * Specify the HTTP server's TCP backlog. Default is -1, + * indicating the system's default value. + *

Only applicable for a locally configured HTTP server. + * Ignored when the {@link #setServer "server"} property has been specified. + */ + public void setBacklog(int backlog) { + this.backlog = backlog; + } + + /** + * Specify the number of seconds to wait until HTTP exchanges have + * completed when shutting down the HTTP server. Default is 0. + *

Only applicable for a locally configured HTTP server. + * Ignored when the {@link #setServer "server"} property has been specified. + */ + public void setShutdownDelay(int shutdownDelay) { + this.shutdownDelay = shutdownDelay; + } + + /** + * Set the base path for context publication. Default is "/". + *

For each context publication path, the service name will be + * appended to this base address. E.g. service name "OrderService" + * -> "/OrderService". + * @see javax.xml.ws.Endpoint#publish(Object) + * @see javax.jws.WebService#serviceName() + */ + public void setBasePath(String basePath) { + this.basePath = basePath; + } + + /** + * Register common {@link com.sun.net.httpserver.Filter Filters} to be + * applied to all detected {@link javax.jws.WebService} annotated beans. + */ + public void setFilters(List filters) { + this.filters = filters; + } + + /** + * Register a common {@link com.sun.net.httpserver.Authenticator} to be + * applied to all detected {@link javax.jws.WebService} annotated beans. + */ + public void setAuthenticator(Authenticator authenticator) { + this.authenticator = authenticator; + } + + + @Override + public void afterPropertiesSet() throws Exception { + if (this.server == null) { + InetSocketAddress address = (this.hostname != null ? + new InetSocketAddress(this.hostname, this.port) : new InetSocketAddress(this.port)); + HttpServer server = HttpServer.create(address, this.backlog); + if (logger.isInfoEnabled()) { + logger.info("Starting HttpServer at address " + address); + } + server.start(); + this.server = server; + this.localServer = true; + } + super.afterPropertiesSet(); + } + + @Override + protected void publishEndpoint(Endpoint endpoint, WebService annotation) { + endpoint.publish(buildHttpContext(endpoint, annotation.serviceName())); + } + + @Override + protected void publishEndpoint(Endpoint endpoint, WebServiceProvider annotation) { + endpoint.publish(buildHttpContext(endpoint, annotation.serviceName())); + } + + /** + * Build the HttpContext for the given endpoint. + * @param endpoint the JAX-WS Provider Endpoint object + * @param serviceName the given service name + * @return the fully populated HttpContext + */ + protected HttpContext buildHttpContext(Endpoint endpoint, String serviceName) { + Assert.state(this.server != null, "No HttpServer available"); + String fullPath = calculateEndpointPath(endpoint, serviceName); + HttpContext httpContext = this.server.createContext(fullPath); + if (this.filters != null) { + httpContext.getFilters().addAll(this.filters); + } + if (this.authenticator != null) { + httpContext.setAuthenticator(this.authenticator); + } + return httpContext; + } + + /** + * Calculate the full endpoint path for the given endpoint. + * @param endpoint the JAX-WS Provider Endpoint object + * @param serviceName the given service name + * @return the full endpoint path + */ + protected String calculateEndpointPath(Endpoint endpoint, String serviceName) { + return this.basePath + serviceName; + } + + + @Override + public void destroy() { + super.destroy(); + if (this.server != null && this.localServer) { + logger.info("Stopping HttpServer"); + this.server.stop(this.shutdownDelay); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/SimpleJaxWsServiceExporter.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/SimpleJaxWsServiceExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..06786578f9c0e55a80ef21b2f38a27a1d6dab515 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/SimpleJaxWsServiceExporter.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import javax.jws.WebService; +import javax.xml.ws.Endpoint; +import javax.xml.ws.WebServiceProvider; + +/** + * Simple exporter for JAX-WS services, autodetecting annotated service beans + * (through the JAX-WS {@link javax.jws.WebService} annotation) and exporting + * them with a configured base address (by default "http://localhost:8080/") + * using the JAX-WS provider's built-in publication support. The full address + * for each service will consist of the base address with the service name + * appended (e.g. "http://localhost:8080/OrderService"). + * + *

Note that this exporter will only work if the JAX-WS runtime actually + * supports publishing with an address argument, i.e. if the JAX-WS runtime + * ships an internal HTTP server. + * + * @author Juergen Hoeller + * @since 2.5 + * @see javax.jws.WebService + * @see javax.xml.ws.Endpoint#publish(String) + */ +public class SimpleJaxWsServiceExporter extends AbstractJaxWsServiceExporter { + + /** + * The default base address. + */ + public static final String DEFAULT_BASE_ADDRESS = "http://localhost:8080/"; + + private String baseAddress = DEFAULT_BASE_ADDRESS; + + + /** + * Set the base address for exported services. + * Default is "http://localhost:8080/". + *

For each actual publication address, the service name will be + * appended to this base address. E.g. service name "OrderService" + * -> "http://localhost:8080/OrderService". + * @see javax.xml.ws.Endpoint#publish(String) + * @see javax.jws.WebService#serviceName() + */ + public void setBaseAddress(String baseAddress) { + this.baseAddress = baseAddress; + } + + + @Override + protected void publishEndpoint(Endpoint endpoint, WebService annotation) { + endpoint.publish(calculateEndpointAddress(endpoint, annotation.serviceName())); + } + + @Override + protected void publishEndpoint(Endpoint endpoint, WebServiceProvider annotation) { + endpoint.publish(calculateEndpointAddress(endpoint, annotation.serviceName())); + } + + /** + * Calculate the full endpoint address for the given endpoint. + * @param endpoint the JAX-WS Provider Endpoint object + * @param serviceName the given service name + * @return the full endpoint address + */ + protected String calculateEndpointAddress(Endpoint endpoint, String serviceName) { + String fullAddress = this.baseAddress + serviceName; + if (endpoint.getClass().getName().startsWith("weblogic.")) { + // Workaround for WebLogic 10.3 + fullAddress = fullAddress + "/"; + } + return fullAddress; + } + +} diff --git a/spring-web/src/main/java/org/springframework/remoting/jaxws/package-info.java b/spring-web/src/main/java/org/springframework/remoting/jaxws/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..6dc1423ab67d049010db69441ad81b5f074cac72 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/remoting/jaxws/package-info.java @@ -0,0 +1,11 @@ +/** + * Remoting classes for Web Services via JAX-WS (the successor of JAX-RPC), + * as included in Java 6 and Java EE 5. This package provides proxy + * factories for accessing JAX-WS services and ports. + */ +@NonNullApi +@NonNullFields +package org.springframework.remoting.jaxws; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/HttpMediaTypeException.java b/spring-web/src/main/java/org/springframework/web/HttpMediaTypeException.java new file mode 100644 index 0000000000000000000000000000000000000000..91c572147f8f9effbffc1b27c8492b13bbde504f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/HttpMediaTypeException.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import java.util.Collections; +import java.util.List; + +import javax.servlet.ServletException; + +import org.springframework.http.MediaType; + +/** + * Abstract base for exceptions related to media types. Adds a list of supported {@link MediaType MediaTypes}. + * + * @author Arjen Poutsma + * @since 3.0 + */ +@SuppressWarnings("serial") +public abstract class HttpMediaTypeException extends ServletException { + + private final List supportedMediaTypes; + + + /** + * Create a new HttpMediaTypeException. + * @param message the exception message + */ + protected HttpMediaTypeException(String message) { + super(message); + this.supportedMediaTypes = Collections.emptyList(); + } + + /** + * Create a new HttpMediaTypeException with a list of supported media types. + * @param supportedMediaTypes the list of supported media types + */ + protected HttpMediaTypeException(String message, List supportedMediaTypes) { + super(message); + this.supportedMediaTypes = Collections.unmodifiableList(supportedMediaTypes); + } + + + /** + * Return the list of supported media types. + */ + public List getSupportedMediaTypes() { + return this.supportedMediaTypes; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/HttpMediaTypeNotAcceptableException.java b/spring-web/src/main/java/org/springframework/web/HttpMediaTypeNotAcceptableException.java new file mode 100644 index 0000000000000000000000000000000000000000..36df3ae33239e2eb27ec82337db175ffa1b9fd57 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/HttpMediaTypeNotAcceptableException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import java.util.List; + +import org.springframework.http.MediaType; + +/** + * Exception thrown when the request handler cannot generate a response that is acceptable by the client. + * + * @author Arjen Poutsma + * @since 3.0 + */ +@SuppressWarnings("serial") +public class HttpMediaTypeNotAcceptableException extends HttpMediaTypeException { + + /** + * Create a new HttpMediaTypeNotAcceptableException. + * @param message the exception message + */ + public HttpMediaTypeNotAcceptableException(String message) { + super(message); + } + + /** + * Create a new HttpMediaTypeNotSupportedException. + * @param supportedMediaTypes the list of supported media types + */ + public HttpMediaTypeNotAcceptableException(List supportedMediaTypes) { + super("Could not find acceptable representation", supportedMediaTypes); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/HttpMediaTypeNotSupportedException.java b/spring-web/src/main/java/org/springframework/web/HttpMediaTypeNotSupportedException.java new file mode 100644 index 0000000000000000000000000000000000000000..13585aa8c9858a3db2af63340fcc32b3f2ea552b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/HttpMediaTypeNotSupportedException.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import java.util.List; + +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * Exception thrown when a client POSTs, PUTs, or PATCHes content of a type + * not supported by request handler. + * + * @author Arjen Poutsma + * @since 3.0 + */ +@SuppressWarnings("serial") +public class HttpMediaTypeNotSupportedException extends HttpMediaTypeException { + + @Nullable + private final MediaType contentType; + + + /** + * Create a new HttpMediaTypeNotSupportedException. + * @param message the exception message + */ + public HttpMediaTypeNotSupportedException(String message) { + super(message); + this.contentType = null; + } + + /** + * Create a new HttpMediaTypeNotSupportedException. + * @param contentType the unsupported content type + * @param supportedMediaTypes the list of supported media types + */ + public HttpMediaTypeNotSupportedException(@Nullable MediaType contentType, List supportedMediaTypes) { + this(contentType, supportedMediaTypes, "Content type '" + + (contentType != null ? contentType : "") + "' not supported"); + } + + /** + * Create a new HttpMediaTypeNotSupportedException. + * @param contentType the unsupported content type + * @param supportedMediaTypes the list of supported media types + * @param msg the detail message + */ + public HttpMediaTypeNotSupportedException(@Nullable MediaType contentType, + List supportedMediaTypes, String msg) { + + super(msg, supportedMediaTypes); + this.contentType = contentType; + } + + + /** + * Return the HTTP request content type method that caused the failure. + */ + @Nullable + public MediaType getContentType() { + return this.contentType; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/HttpRequestHandler.java b/spring-web/src/main/java/org/springframework/web/HttpRequestHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..187d9dd8c544e3d51e53d3fe46bc91d653d44e2c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/HttpRequestHandler.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Plain handler interface for components that process HTTP requests, + * analogous to a Servlet. Only declares {@link javax.servlet.ServletException} + * and {@link java.io.IOException}, to allow for usage within any + * {@link javax.servlet.http.HttpServlet}. This interface is essentially the + * direct equivalent of an HttpServlet, reduced to a central handle method. + * + *

The easiest way to expose an HttpRequestHandler bean in Spring style + * is to define it in Spring's root web application context and define + * an {@link org.springframework.web.context.support.HttpRequestHandlerServlet} + * in {@code web.xml}, pointing to the target HttpRequestHandler bean + * through its {@code servlet-name} which needs to match the target bean name. + * + *

Supported as a handler type within Spring's + * {@link org.springframework.web.servlet.DispatcherServlet}, being able + * to interact with the dispatcher's advanced mapping and interception + * facilities. This is the recommended way of exposing an HttpRequestHandler, + * while keeping the handler implementations free of direct dependencies + * on a DispatcherServlet environment. + * + *

Typically implemented to generate binary responses directly, + * with no separate view resource involved. This differentiates it from a + * {@link org.springframework.web.servlet.mvc.Controller} within Spring's Web MVC + * framework. The lack of a {@link org.springframework.web.servlet.ModelAndView} + * return value gives a clearer signature to callers other than the + * DispatcherServlet, indicating that there will never be a view to render. + * + *

As of Spring 2.0, Spring's HTTP-based remote exporters, such as + * {@link org.springframework.remoting.httpinvoker.HttpInvokerServiceExporter} + * and {@link org.springframework.remoting.caucho.HessianServiceExporter}, + * implement this interface rather than the more extensive Controller interface, + * for minimal dependencies on Spring-specific web infrastructure. + * + *

Note that HttpRequestHandlers may optionally implement the + * {@link org.springframework.web.servlet.mvc.LastModified} interface, + * just like Controllers can, provided that they run within Spring's + * DispatcherServlet. However, this is usually not necessary, since + * HttpRequestHandlers typically only support POST requests to begin with. + * Alternatively, a handler may implement the "If-Modified-Since" HTTP + * header processing manually within its {@code handle} method. + * + * @author Juergen Hoeller + * @since 2.0 + * @see org.springframework.web.context.support.HttpRequestHandlerServlet + * @see org.springframework.web.servlet.DispatcherServlet + * @see org.springframework.web.servlet.ModelAndView + * @see org.springframework.web.servlet.mvc.Controller + * @see org.springframework.web.servlet.mvc.LastModified + * @see org.springframework.web.servlet.mvc.HttpRequestHandlerAdapter + * @see org.springframework.remoting.httpinvoker.HttpInvokerServiceExporter + * @see org.springframework.remoting.caucho.HessianServiceExporter + */ +@FunctionalInterface +public interface HttpRequestHandler { + + /** + * Process the given request, generating a response. + * @param request current HTTP request + * @param response current HTTP response + * @throws ServletException in case of general errors + * @throws IOException in case of I/O errors + */ + void handleRequest(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/HttpRequestMethodNotSupportedException.java b/spring-web/src/main/java/org/springframework/web/HttpRequestMethodNotSupportedException.java new file mode 100644 index 0000000000000000000000000000000000000000..1d76e991a5785c351b0bd984971d2083184941d4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/HttpRequestMethodNotSupportedException.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import java.util.Collection; +import java.util.EnumSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; + +import javax.servlet.ServletException; + +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * Exception thrown when a request handler does not support a + * specific request method. + * + * @author Juergen Hoeller + * @since 2.0 + */ +@SuppressWarnings("serial") +public class HttpRequestMethodNotSupportedException extends ServletException { + + private final String method; + + @Nullable + private final String[] supportedMethods; + + + /** + * Create a new HttpRequestMethodNotSupportedException. + * @param method the unsupported HTTP request method + */ + public HttpRequestMethodNotSupportedException(String method) { + this(method, (String[]) null); + } + + /** + * Create a new HttpRequestMethodNotSupportedException. + * @param method the unsupported HTTP request method + * @param msg the detail message + */ + public HttpRequestMethodNotSupportedException(String method, String msg) { + this(method, null, msg); + } + + /** + * Create a new HttpRequestMethodNotSupportedException. + * @param method the unsupported HTTP request method + * @param supportedMethods the actually supported HTTP methods (may be {@code null}) + */ + public HttpRequestMethodNotSupportedException(String method, @Nullable Collection supportedMethods) { + this(method, (supportedMethods != null ? StringUtils.toStringArray(supportedMethods) : null)); + } + + /** + * Create a new HttpRequestMethodNotSupportedException. + * @param method the unsupported HTTP request method + * @param supportedMethods the actually supported HTTP methods (may be {@code null}) + */ + public HttpRequestMethodNotSupportedException(String method, @Nullable String[] supportedMethods) { + this(method, supportedMethods, "Request method '" + method + "' not supported"); + } + + /** + * Create a new HttpRequestMethodNotSupportedException. + * @param method the unsupported HTTP request method + * @param supportedMethods the actually supported HTTP methods + * @param msg the detail message + */ + public HttpRequestMethodNotSupportedException(String method, @Nullable String[] supportedMethods, String msg) { + super(msg); + this.method = method; + this.supportedMethods = supportedMethods; + } + + + /** + * Return the HTTP request method that caused the failure. + */ + public String getMethod() { + return this.method; + } + + /** + * Return the actually supported HTTP methods, or {@code null} if not known. + */ + @Nullable + public String[] getSupportedMethods() { + return this.supportedMethods; + } + + /** + * Return the actually supported HTTP methods as {@link HttpMethod} instances, + * or {@code null} if not known. + * @since 3.2 + */ + @Nullable + public Set getSupportedHttpMethods() { + if (this.supportedMethods == null) { + return null; + } + List supportedMethods = new LinkedList<>(); + for (String value : this.supportedMethods) { + HttpMethod resolved = HttpMethod.resolve(value); + if (resolved != null) { + supportedMethods.add(resolved); + } + } + return EnumSet.copyOf(supportedMethods); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/HttpSessionRequiredException.java b/spring-web/src/main/java/org/springframework/web/HttpSessionRequiredException.java new file mode 100644 index 0000000000000000000000000000000000000000..83a00b0de5a86c336dd903edab8ca4befd1e91f9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/HttpSessionRequiredException.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import javax.servlet.ServletException; + +import org.springframework.lang.Nullable; + +/** + * Exception thrown when an HTTP request handler requires a pre-existing session. + * + * @author Juergen Hoeller + * @since 2.0 + */ +@SuppressWarnings("serial") +public class HttpSessionRequiredException extends ServletException { + + @Nullable + private final String expectedAttribute; + + + /** + * Create a new HttpSessionRequiredException. + * @param msg the detail message + */ + public HttpSessionRequiredException(String msg) { + super(msg); + this.expectedAttribute = null; + } + + /** + * Create a new HttpSessionRequiredException. + * @param msg the detail message + * @param expectedAttribute the name of the expected session attribute + * @since 4.3 + */ + public HttpSessionRequiredException(String msg, String expectedAttribute) { + super(msg); + this.expectedAttribute = expectedAttribute; + } + + + /** + * Return the name of the expected session attribute, if any. + * @since 4.3 + */ + @Nullable + public String getExpectedAttribute() { + return this.expectedAttribute; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/SpringServletContainerInitializer.java b/spring-web/src/main/java/org/springframework/web/SpringServletContainerInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..ee3b4509e31558d4714cf999748a5940d8626b35 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/SpringServletContainerInitializer.java @@ -0,0 +1,176 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import java.lang.reflect.Modifier; +import java.util.LinkedList; +import java.util.List; +import java.util.ServiceLoader; +import java.util.Set; + +import javax.servlet.ServletContainerInitializer; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.annotation.HandlesTypes; + +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.lang.Nullable; +import org.springframework.util.ReflectionUtils; + +/** + * Servlet 3.0 {@link ServletContainerInitializer} designed to support code-based + * configuration of the servlet container using Spring's {@link WebApplicationInitializer} + * SPI as opposed to (or possibly in combination with) the traditional + * {@code web.xml}-based approach. + * + *

Mechanism of Operation

+ * This class will be loaded and instantiated and have its {@link #onStartup} + * method invoked by any Servlet 3.0-compliant container during container startup assuming + * that the {@code spring-web} module JAR is present on the classpath. This occurs through + * the JAR Services API {@link ServiceLoader#load(Class)} method detecting the + * {@code spring-web} module's {@code META-INF/services/javax.servlet.ServletContainerInitializer} + * service provider configuration file. See the + * + * JAR Services API documentation as well as section 8.2.4 of the Servlet 3.0 + * Final Draft specification for complete details. + * + *

In combination with {@code web.xml}

+ * A web application can choose to limit the amount of classpath scanning the Servlet + * container does at startup either through the {@code metadata-complete} attribute in + * {@code web.xml}, which controls scanning for Servlet annotations or through an + * {@code } element also in {@code web.xml}, which controls which + * web fragments (i.e. jars) are allowed to perform a {@code ServletContainerInitializer} + * scan. When using this feature, the {@link SpringServletContainerInitializer} + * can be enabled by adding "spring_web" to the list of named web fragments in + * {@code web.xml} as follows: + * + *
+ * <absolute-ordering>
+ *   <name>some_web_fragment</name>
+ *   <name>spring_web</name>
+ * </absolute-ordering>
+ * 
+ * + *

Relationship to Spring's {@code WebApplicationInitializer}

+ * Spring's {@code WebApplicationInitializer} SPI consists of just one method: + * {@link WebApplicationInitializer#onStartup(ServletContext)}. The signature is intentionally + * quite similar to {@link ServletContainerInitializer#onStartup(Set, ServletContext)}: + * simply put, {@code SpringServletContainerInitializer} is responsible for instantiating + * and delegating the {@code ServletContext} to any user-defined + * {@code WebApplicationInitializer} implementations. It is then the responsibility of + * each {@code WebApplicationInitializer} to do the actual work of initializing the + * {@code ServletContext}. The exact process of delegation is described in detail in the + * {@link #onStartup onStartup} documentation below. + * + *

General Notes

+ * In general, this class should be viewed as supporting infrastructure for + * the more important and user-facing {@code WebApplicationInitializer} SPI. Taking + * advantage of this container initializer is also completely optional: while + * it is true that this initializer will be loaded and invoked under all Servlet 3.0+ + * runtimes, it remains the user's choice whether to make any + * {@code WebApplicationInitializer} implementations available on the classpath. If no + * {@code WebApplicationInitializer} types are detected, this container initializer will + * have no effect. + * + *

Note that use of this container initializer and of {@code WebApplicationInitializer} + * is not in any way "tied" to Spring MVC other than the fact that the types are shipped + * in the {@code spring-web} module JAR. Rather, they can be considered general-purpose + * in their ability to facilitate convenient code-based configuration of the + * {@code ServletContext}. In other words, any servlet, listener, or filter may be + * registered within a {@code WebApplicationInitializer}, not just Spring MVC-specific + * components. + * + *

This class is neither designed for extension nor intended to be extended. + * It should be considered an internal type, with {@code WebApplicationInitializer} + * being the public-facing SPI. + * + *

See Also

+ * See {@link WebApplicationInitializer} Javadoc for examples and detailed usage + * recommendations.

+ * + * @author Chris Beams + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 3.1 + * @see #onStartup(Set, ServletContext) + * @see WebApplicationInitializer + */ +@HandlesTypes(WebApplicationInitializer.class) +public class SpringServletContainerInitializer implements ServletContainerInitializer { + + /** + * Delegate the {@code ServletContext} to any {@link WebApplicationInitializer} + * implementations present on the application classpath. + *

Because this class declares @{@code HandlesTypes(WebApplicationInitializer.class)}, + * Servlet 3.0+ containers will automatically scan the classpath for implementations + * of Spring's {@code WebApplicationInitializer} interface and provide the set of all + * such types to the {@code webAppInitializerClasses} parameter of this method. + *

If no {@code WebApplicationInitializer} implementations are found on the classpath, + * this method is effectively a no-op. An INFO-level log message will be issued notifying + * the user that the {@code ServletContainerInitializer} has indeed been invoked but that + * no {@code WebApplicationInitializer} implementations were found. + *

Assuming that one or more {@code WebApplicationInitializer} types are detected, + * they will be instantiated (and sorted if the @{@link + * org.springframework.core.annotation.Order @Order} annotation is present or + * the {@link org.springframework.core.Ordered Ordered} interface has been + * implemented). Then the {@link WebApplicationInitializer#onStartup(ServletContext)} + * method will be invoked on each instance, delegating the {@code ServletContext} such + * that each instance may register and configure servlets such as Spring's + * {@code DispatcherServlet}, listeners such as Spring's {@code ContextLoaderListener}, + * or any other Servlet API componentry such as filters. + * @param webAppInitializerClasses all implementations of + * {@link WebApplicationInitializer} found on the application classpath + * @param servletContext the servlet context to be initialized + * @see WebApplicationInitializer#onStartup(ServletContext) + * @see AnnotationAwareOrderComparator + */ + @Override + public void onStartup(@Nullable Set> webAppInitializerClasses, ServletContext servletContext) + throws ServletException { + + List initializers = new LinkedList<>(); + + if (webAppInitializerClasses != null) { + for (Class waiClass : webAppInitializerClasses) { + // Be defensive: Some servlet containers provide us with invalid classes, + // no matter what @HandlesTypes says... + if (!waiClass.isInterface() && !Modifier.isAbstract(waiClass.getModifiers()) && + WebApplicationInitializer.class.isAssignableFrom(waiClass)) { + try { + initializers.add((WebApplicationInitializer) + ReflectionUtils.accessibleConstructor(waiClass).newInstance()); + } + catch (Throwable ex) { + throw new ServletException("Failed to instantiate WebApplicationInitializer class", ex); + } + } + } + } + + if (initializers.isEmpty()) { + servletContext.log("No Spring WebApplicationInitializer types detected on classpath"); + return; + } + + servletContext.log(initializers.size() + " Spring WebApplicationInitializers detected on classpath"); + AnnotationAwareOrderComparator.sort(initializers); + for (WebApplicationInitializer initializer : initializers) { + initializer.onStartup(servletContext); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/WebApplicationInitializer.java b/spring-web/src/main/java/org/springframework/web/WebApplicationInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..337f53291e7a568bc113388755cbbb54c7293740 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/WebApplicationInitializer.java @@ -0,0 +1,189 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web; + +import javax.servlet.ServletContext; +import javax.servlet.ServletException; + +/** + * Interface to be implemented in Servlet 3.0+ environments in order to configure the + * {@link ServletContext} programmatically -- as opposed to (or possibly in conjunction + * with) the traditional {@code web.xml}-based approach. + * + *

Implementations of this SPI will be detected automatically by {@link + * SpringServletContainerInitializer}, which itself is bootstrapped automatically + * by any Servlet 3.0 container. See {@linkplain SpringServletContainerInitializer its + * Javadoc} for details on this bootstrapping mechanism. + * + *

Example

+ *

The traditional, XML-based approach

+ * Most Spring users building a web application will need to register Spring's {@code + * DispatcherServlet}. For reference, in WEB-INF/web.xml, this would typically be done as + * follows: + *
+ * <servlet>
+ *   <servlet-name>dispatcher</servlet-name>
+ *   <servlet-class>
+ *     org.springframework.web.servlet.DispatcherServlet
+ *   </servlet-class>
+ *   <init-param>
+ *     <param-name>contextConfigLocation</param-name>
+ *     <param-value>/WEB-INF/spring/dispatcher-config.xml</param-value>
+ *   </init-param>
+ *   <load-on-startup>1</load-on-startup>
+ * </servlet>
+ *
+ * <servlet-mapping>
+ *   <servlet-name>dispatcher</servlet-name>
+ *   <url-pattern>/</url-pattern>
+ * </servlet-mapping>
+ * + *

The code-based approach with {@code WebApplicationInitializer}

+ * Here is the equivalent {@code DispatcherServlet} registration logic, + * {@code WebApplicationInitializer}-style: + *
+ * public class MyWebAppInitializer implements WebApplicationInitializer {
+ *
+ *    @Override
+ *    public void onStartup(ServletContext container) {
+ *      XmlWebApplicationContext appContext = new XmlWebApplicationContext();
+ *      appContext.setConfigLocation("/WEB-INF/spring/dispatcher-config.xml");
+ *
+ *      ServletRegistration.Dynamic dispatcher =
+ *        container.addServlet("dispatcher", new DispatcherServlet(appContext));
+ *      dispatcher.setLoadOnStartup(1);
+ *      dispatcher.addMapping("/");
+ *    }
+ *
+ * }
+ * + * As an alternative to the above, you can also extend from {@link + * org.springframework.web.servlet.support.AbstractDispatcherServletInitializer}. + * + * As you can see, thanks to Servlet 3.0's new {@link ServletContext#addServlet} method + * we're actually registering an instance of the {@code DispatcherServlet}, and + * this means that the {@code DispatcherServlet} can now be treated like any other object + * -- receiving constructor injection of its application context in this case. + * + *

This style is both simpler and more concise. There is no concern for dealing with + * init-params, etc, just normal JavaBean-style properties and constructor arguments. You + * are free to create and work with your Spring application contexts as necessary before + * injecting them into the {@code DispatcherServlet}. + * + *

Most major Spring Web components have been updated to support this style of + * registration. You'll find that {@code DispatcherServlet}, {@code FrameworkServlet}, + * {@code ContextLoaderListener} and {@code DelegatingFilterProxy} all now support + * constructor arguments. Even if a component (e.g. non-Spring, other third party) has not + * been specifically updated for use within {@code WebApplicationInitializers}, they still + * may be used in any case. The Servlet 3.0 {@code ServletContext} API allows for setting + * init-params, context-params, etc programmatically. + * + *

A 100% code-based approach to configuration

+ * In the example above, {@code WEB-INF/web.xml} was successfully replaced with code in + * the form of a {@code WebApplicationInitializer}, but the actual + * {@code dispatcher-config.xml} Spring configuration remained XML-based. + * {@code WebApplicationInitializer} is a perfect fit for use with Spring's code-based + * {@code @Configuration} classes. See @{@link + * org.springframework.context.annotation.Configuration Configuration} Javadoc for + * complete details, but the following example demonstrates refactoring to use Spring's + * {@link org.springframework.web.context.support.AnnotationConfigWebApplicationContext + * AnnotationConfigWebApplicationContext} in lieu of {@code XmlWebApplicationContext}, and + * user-defined {@code @Configuration} classes {@code AppConfig} and + * {@code DispatcherConfig} instead of Spring XML files. This example also goes a bit + * beyond those above to demonstrate typical configuration of the 'root' application + * context and registration of the {@code ContextLoaderListener}: + *
+ * public class MyWebAppInitializer implements WebApplicationInitializer {
+ *
+ *    @Override
+ *    public void onStartup(ServletContext container) {
+ *      // Create the 'root' Spring application context
+ *      AnnotationConfigWebApplicationContext rootContext =
+ *        new AnnotationConfigWebApplicationContext();
+ *      rootContext.register(AppConfig.class);
+ *
+ *      // Manage the lifecycle of the root application context
+ *      container.addListener(new ContextLoaderListener(rootContext));
+ *
+ *      // Create the dispatcher servlet's Spring application context
+ *      AnnotationConfigWebApplicationContext dispatcherContext =
+ *        new AnnotationConfigWebApplicationContext();
+ *      dispatcherContext.register(DispatcherConfig.class);
+ *
+ *      // Register and map the dispatcher servlet
+ *      ServletRegistration.Dynamic dispatcher =
+ *        container.addServlet("dispatcher", new DispatcherServlet(dispatcherContext));
+ *      dispatcher.setLoadOnStartup(1);
+ *      dispatcher.addMapping("/");
+ *    }
+ *
+ * }
+ * + * As an alternative to the above, you can also extend from {@link + * org.springframework.web.servlet.support.AbstractAnnotationConfigDispatcherServletInitializer}. + * + * Remember that {@code WebApplicationInitializer} implementations are detected + * automatically -- so you are free to package them within your application as you + * see fit. + * + *

Ordering {@code WebApplicationInitializer} execution

+ * {@code WebApplicationInitializer} implementations may optionally be annotated at the + * class level with Spring's @{@link org.springframework.core.annotation.Order Order} + * annotation or may implement Spring's {@link org.springframework.core.Ordered Ordered} + * interface. If so, the initializers will be ordered prior to invocation. This provides + * a mechanism for users to ensure the order in which servlet container initialization + * occurs. Use of this feature is expected to be rare, as typical applications will likely + * centralize all container initialization within a single {@code WebApplicationInitializer}. + * + *

Caveats

+ * + *

web.xml versioning

+ *

{@code WEB-INF/web.xml} and {@code WebApplicationInitializer} use are not mutually + * exclusive; for example, web.xml can register one servlet, and a {@code + * WebApplicationInitializer} can register another. An initializer can even + * modify registrations performed in {@code web.xml} through methods such as + * {@link ServletContext#getServletRegistration(String)}. However, if + * {@code WEB-INF/web.xml} is present in the application, its {@code version} attribute + * must be set to "3.0" or greater, otherwise {@code ServletContainerInitializer} + * bootstrapping will be ignored by the servlet container. + * + *

Mapping to '/' under Tomcat

+ *

Apache Tomcat maps its internal {@code DefaultServlet} to "/", and on Tomcat versions + * <= 7.0.14, this servlet mapping cannot be overridden programmatically. + * 7.0.15 fixes this issue. Overriding the "/" servlet mapping has also been tested + * successfully under GlassFish 3.1.

+ * + * @author Chris Beams + * @since 3.1 + * @see SpringServletContainerInitializer + * @see org.springframework.web.context.AbstractContextLoaderInitializer + * @see org.springframework.web.servlet.support.AbstractDispatcherServletInitializer + * @see org.springframework.web.servlet.support.AbstractAnnotationConfigDispatcherServletInitializer + */ +public interface WebApplicationInitializer { + + /** + * Configure the given {@link ServletContext} with any servlets, filters, listeners + * context-params and attributes necessary for initializing this web application. See + * examples {@linkplain WebApplicationInitializer above}. + * @param servletContext the {@code ServletContext} to initialize + * @throws ServletException if any call against the given {@code ServletContext} + * throws a {@code ServletException} + */ + void onStartup(ServletContext servletContext) throws ServletException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/AbstractMappingContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/AbstractMappingContentNegotiationStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..5413a34a5a0d95cba01d087051b361f29509a1ea --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/AbstractMappingContentNegotiationStrategy.java @@ -0,0 +1,165 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Base class for {@code ContentNegotiationStrategy} implementations with the + * steps to resolve a request to media types. + * + *

First a key (e.g. "json", "pdf") must be extracted from the request (e.g. + * file extension, query param). The key must then be resolved to media type(s) + * through the base class {@link MappingMediaTypeFileExtensionResolver} which + * stores such mappings. + * + *

The method {@link #handleNoMatch} allow sub-classes to plug in additional + * ways of looking up media types (e.g. through the Java Activation framework, + * or {@link javax.servlet.ServletContext#getMimeType}. Media types resolved + * via base classes are then added to the base class + * {@link MappingMediaTypeFileExtensionResolver}, i.e. cached for new lookups. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public abstract class AbstractMappingContentNegotiationStrategy extends MappingMediaTypeFileExtensionResolver + implements ContentNegotiationStrategy { + + protected final Log logger = LogFactory.getLog(getClass()); + + private boolean useRegisteredExtensionsOnly = false; + + private boolean ignoreUnknownExtensions = false; + + + /** + * Create an instance with the given map of file extensions and media types. + */ + public AbstractMappingContentNegotiationStrategy(@Nullable Map mediaTypes) { + super(mediaTypes); + } + + + /** + * Whether to only use the registered mappings to look up file extensions, + * or also to use dynamic resolution (e.g. via {@link MediaTypeFactory}. + *

By default this is set to {@code false}. + */ + public void setUseRegisteredExtensionsOnly(boolean useRegisteredExtensionsOnly) { + this.useRegisteredExtensionsOnly = useRegisteredExtensionsOnly; + } + + public boolean isUseRegisteredExtensionsOnly() { + return this.useRegisteredExtensionsOnly; + } + + /** + * Whether to ignore requests with unknown file extension. Setting this to + * {@code false} results in {@code HttpMediaTypeNotAcceptableException}. + *

By default this is set to {@literal false} but is overridden in + * {@link PathExtensionContentNegotiationStrategy} to {@literal true}. + */ + public void setIgnoreUnknownExtensions(boolean ignoreUnknownExtensions) { + this.ignoreUnknownExtensions = ignoreUnknownExtensions; + } + + public boolean isIgnoreUnknownExtensions() { + return this.ignoreUnknownExtensions; + } + + + @Override + public List resolveMediaTypes(NativeWebRequest webRequest) + throws HttpMediaTypeNotAcceptableException { + + return resolveMediaTypeKey(webRequest, getMediaTypeKey(webRequest)); + } + + /** + * An alternative to {@link #resolveMediaTypes(NativeWebRequest)} that accepts + * an already extracted key. + * @since 3.2.16 + */ + public List resolveMediaTypeKey(NativeWebRequest webRequest, @Nullable String key) + throws HttpMediaTypeNotAcceptableException { + + if (StringUtils.hasText(key)) { + MediaType mediaType = lookupMediaType(key); + if (mediaType != null) { + handleMatch(key, mediaType); + return Collections.singletonList(mediaType); + } + mediaType = handleNoMatch(webRequest, key); + if (mediaType != null) { + addMapping(key, mediaType); + return Collections.singletonList(mediaType); + } + } + return MEDIA_TYPE_ALL_LIST; + } + + + /** + * Extract a key from the request to use to look up media types. + * @return the lookup key, or {@code null} if none + */ + @Nullable + protected abstract String getMediaTypeKey(NativeWebRequest request); + + /** + * Override to provide handling when a key is successfully resolved via + * {@link #lookupMediaType}. + */ + protected void handleMatch(String key, MediaType mediaType) { + } + + /** + * Override to provide handling when a key is not resolved via. + * {@link #lookupMediaType}. Sub-classes can take further steps to + * determine the media type(s). If a MediaType is returned from + * this method it will be added to the cache in the base class. + */ + @Nullable + protected MediaType handleNoMatch(NativeWebRequest request, String key) + throws HttpMediaTypeNotAcceptableException { + + if (!isUseRegisteredExtensionsOnly()) { + Optional mediaType = MediaTypeFactory.getMediaType("file." + key); + if (mediaType.isPresent()) { + return mediaType.get(); + } + } + if (isIgnoreUnknownExtensions()) { + return null; + } + throw new HttpMediaTypeNotAcceptableException(getAllMediaTypes()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationManager.java b/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationManager.java new file mode 100644 index 0000000000000000000000000000000000000000..8ad0110deb5695608f6a39f67113533521062940 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationManager.java @@ -0,0 +1,162 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Central class to determine requested {@linkplain MediaType media types} + * for a request. This is done by delegating to a list of configured + * {@code ContentNegotiationStrategy} instances. + * + *

Also provides methods to look up file extensions for a media type. + * This is done by delegating to the list of configured + * {@code MediaTypeFileExtensionResolver} instances. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.2 + */ +public class ContentNegotiationManager implements ContentNegotiationStrategy, MediaTypeFileExtensionResolver { + + private final List strategies = new ArrayList<>(); + + private final Set resolvers = new LinkedHashSet<>(); + + + /** + * Create an instance with the given list of + * {@code ContentNegotiationStrategy} strategies each of which may also be + * an instance of {@code MediaTypeFileExtensionResolver}. + * @param strategies the strategies to use + */ + public ContentNegotiationManager(ContentNegotiationStrategy... strategies) { + this(Arrays.asList(strategies)); + } + + /** + * A collection-based alternative to + * {@link #ContentNegotiationManager(ContentNegotiationStrategy...)}. + * @param strategies the strategies to use + * @since 3.2.2 + */ + public ContentNegotiationManager(Collection strategies) { + Assert.notEmpty(strategies, "At least one ContentNegotiationStrategy is expected"); + this.strategies.addAll(strategies); + for (ContentNegotiationStrategy strategy : this.strategies) { + if (strategy instanceof MediaTypeFileExtensionResolver) { + this.resolvers.add((MediaTypeFileExtensionResolver) strategy); + } + } + } + + /** + * Create a default instance with a {@link HeaderContentNegotiationStrategy}. + */ + public ContentNegotiationManager() { + this(new HeaderContentNegotiationStrategy()); + } + + + /** + * Return the configured content negotiation strategies. + * @since 3.2.16 + */ + public List getStrategies() { + return this.strategies; + } + + /** + * Find a {@code ContentNegotiationStrategy} of the given type. + * @param strategyType the strategy type + * @return the first matching strategy, or {@code null} if none + * @since 4.3 + */ + @SuppressWarnings("unchecked") + @Nullable + public T getStrategy(Class strategyType) { + for (ContentNegotiationStrategy strategy : getStrategies()) { + if (strategyType.isInstance(strategy)) { + return (T) strategy; + } + } + return null; + } + + /** + * Register more {@code MediaTypeFileExtensionResolver} instances in addition + * to those detected at construction. + * @param resolvers the resolvers to add + */ + public void addFileExtensionResolvers(MediaTypeFileExtensionResolver... resolvers) { + Collections.addAll(this.resolvers, resolvers); + } + + @Override + public List resolveMediaTypes(NativeWebRequest request) throws HttpMediaTypeNotAcceptableException { + for (ContentNegotiationStrategy strategy : this.strategies) { + List mediaTypes = strategy.resolveMediaTypes(request); + if (mediaTypes.equals(MEDIA_TYPE_ALL_LIST)) { + continue; + } + return mediaTypes; + } + return MEDIA_TYPE_ALL_LIST; + } + + @Override + public List resolveFileExtensions(MediaType mediaType) { + Set result = new LinkedHashSet<>(); + for (MediaTypeFileExtensionResolver resolver : this.resolvers) { + result.addAll(resolver.resolveFileExtensions(mediaType)); + } + return new ArrayList<>(result); + } + + /** + * {@inheritDoc} + *

At startup this method returns extensions explicitly registered with + * either {@link PathExtensionContentNegotiationStrategy} or + * {@link ParameterContentNegotiationStrategy}. At runtime if there is a + * "path extension" strategy and its + * {@link PathExtensionContentNegotiationStrategy#setUseRegisteredExtensionsOnly(boolean) + * useRegisteredExtensionsOnly} property is set to "false", the list of extensions may + * increase as file extensions are resolved via + * {@link org.springframework.http.MediaTypeFactory} and cached. + */ + @Override + public List getAllFileExtensions() { + Set result = new LinkedHashSet<>(); + for (MediaTypeFileExtensionResolver resolver : this.resolvers) { + result.addAll(resolver.getAllFileExtensions()); + } + return new ArrayList<>(result); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationManagerFactoryBean.java b/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationManagerFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..f49cff6395279ec646ac075b64974dbda15d55b5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationManagerFactoryBean.java @@ -0,0 +1,373 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; + +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.context.ServletContextAware; + +/** + * Factory to create a {@code ContentNegotiationManager} and configure it with + * {@link ContentNegotiationStrategy} instances. + * + *

This factory offers properties that in turn result in configuring the + * underlying strategies. The table below shows the property names, their + * default settings, as well as the strategies that they help to configure: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Property SetterDefault ValueUnderlying StrategyEnabled Or Not
{@link #setFavorPathExtension favorPathExtension}true{@link PathExtensionContentNegotiationStrategy}Enabled
{@link #setFavorParameter favorParameter}false{@link ParameterContentNegotiationStrategy}Off
{@link #setIgnoreAcceptHeader ignoreAcceptHeader}false{@link HeaderContentNegotiationStrategy}Enabled
{@link #setDefaultContentType defaultContentType}null{@link FixedContentNegotiationStrategy}Off
{@link #setDefaultContentTypeStrategy defaultContentTypeStrategy}null{@link ContentNegotiationStrategy}Off
+ * + *

As of 5.0 you can set the exact strategies to use via + * {@link #setStrategies(List)}. + * + *

Note: if you must use URL-based content type resolution, + * the use of a query parameter is simpler and preferable to the use of a path + * extension since the latter can cause issues with URI variables, path + * parameters, and URI decoding. Consider setting {@link #setFavorPathExtension} + * to {@literal false} or otherwise set the strategies to use explicitly via + * {@link #setStrategies(List)}. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 3.2 + */ +public class ContentNegotiationManagerFactoryBean + implements FactoryBean, ServletContextAware, InitializingBean { + + @Nullable + private List strategies; + + + private boolean favorPathExtension = true; + + private boolean favorParameter = false; + + private boolean ignoreAcceptHeader = false; + + private Map mediaTypes = new HashMap<>(); + + private boolean ignoreUnknownPathExtensions = true; + + @Nullable + private Boolean useRegisteredExtensionsOnly; + + private String parameterName = "format"; + + @Nullable + private ContentNegotiationStrategy defaultNegotiationStrategy; + + @Nullable + private ContentNegotiationManager contentNegotiationManager; + + @Nullable + private ServletContext servletContext; + + + /** + * Set the exact list of strategies to use. + *

Note: use of this method is mutually exclusive with + * use of all other setters in this class which customize a default, fixed + * set of strategies. See class level doc for more details. + * @param strategies the strategies to use + * @since 5.0 + */ + public void setStrategies(@Nullable List strategies) { + this.strategies = (strategies != null ? new ArrayList<>(strategies) : null); + } + + /** + * Whether the path extension in the URL path should be used to determine + * the requested media type. + *

By default this is set to {@code true} in which case a request + * for {@code /hotels.pdf} will be interpreted as a request for + * {@code "application/pdf"} regardless of the 'Accept' header. + */ + public void setFavorPathExtension(boolean favorPathExtension) { + this.favorPathExtension = favorPathExtension; + } + + /** + * Add a mapping from a key, extracted from a path extension or a query + * parameter, to a MediaType. This is required in order for the parameter + * strategy to work. Any extensions explicitly registered here are also + * whitelisted for the purpose of Reflected File Download attack detection + * (see Spring Framework reference documentation for more details on RFD + * attack protection). + *

The path extension strategy will also try to use + * {@link ServletContext#getMimeType} and + * {@link org.springframework.http.MediaTypeFactory} to resolve path extensions. + * @param mediaTypes media type mappings + * @see #addMediaType(String, MediaType) + * @see #addMediaTypes(Map) + */ + public void setMediaTypes(Properties mediaTypes) { + if (!CollectionUtils.isEmpty(mediaTypes)) { + mediaTypes.forEach((key, value) -> { + String extension = ((String) key).toLowerCase(Locale.ENGLISH); + MediaType mediaType = MediaType.valueOf((String) value); + this.mediaTypes.put(extension, mediaType); + }); + } + } + + /** + * An alternative to {@link #setMediaTypes} for use in Java code. + * @see #setMediaTypes + * @see #addMediaTypes + */ + public void addMediaType(String fileExtension, MediaType mediaType) { + this.mediaTypes.put(fileExtension, mediaType); + } + + /** + * An alternative to {@link #setMediaTypes} for use in Java code. + * @see #setMediaTypes + * @see #addMediaType + */ + public void addMediaTypes(@Nullable Map mediaTypes) { + if (mediaTypes != null) { + this.mediaTypes.putAll(mediaTypes); + } + } + + /** + * Whether to ignore requests with path extension that cannot be resolved + * to any media type. Setting this to {@code false} will result in an + * {@code HttpMediaTypeNotAcceptableException} if there is no match. + *

By default this is set to {@code true}. + */ + public void setIgnoreUnknownPathExtensions(boolean ignore) { + this.ignoreUnknownPathExtensions = ignore; + } + + /** + * Indicate whether to use the Java Activation Framework as a fallback option + * to map from file extensions to media types. + * @deprecated as of 5.0, in favor of {@link #setUseRegisteredExtensionsOnly(boolean)}, which + * has reverse behavior. + */ + @Deprecated + public void setUseJaf(boolean useJaf) { + setUseRegisteredExtensionsOnly(!useJaf); + } + + /** + * When {@link #setFavorPathExtension favorPathExtension} or + * {@link #setFavorParameter(boolean)} is set, this property determines + * whether to use only registered {@code MediaType} mappings or to allow + * dynamic resolution, e.g. via {@link MediaTypeFactory}. + *

By default this is not set in which case dynamic resolution is on. + */ + public void setUseRegisteredExtensionsOnly(boolean useRegisteredExtensionsOnly) { + this.useRegisteredExtensionsOnly = useRegisteredExtensionsOnly; + } + + private boolean useRegisteredExtensionsOnly() { + return (this.useRegisteredExtensionsOnly != null && this.useRegisteredExtensionsOnly); + } + + /** + * Whether a request parameter ("format" by default) should be used to + * determine the requested media type. For this option to work you must + * register {@link #setMediaTypes media type mappings}. + *

By default this is set to {@code false}. + * @see #setParameterName + */ + public void setFavorParameter(boolean favorParameter) { + this.favorParameter = favorParameter; + } + + /** + * Set the query parameter name to use when {@link #setFavorParameter} is on. + *

The default parameter name is {@code "format"}. + */ + public void setParameterName(String parameterName) { + Assert.notNull(parameterName, "parameterName is required"); + this.parameterName = parameterName; + } + + /** + * Whether to disable checking the 'Accept' request header. + *

By default this value is set to {@code false}. + */ + public void setIgnoreAcceptHeader(boolean ignoreAcceptHeader) { + this.ignoreAcceptHeader = ignoreAcceptHeader; + } + + /** + * Set the default content type to use when no content type is requested. + *

By default this is not set. + * @see #setDefaultContentTypeStrategy + */ + public void setDefaultContentType(MediaType contentType) { + this.defaultNegotiationStrategy = new FixedContentNegotiationStrategy(contentType); + } + + /** + * Set the default content types to use when no content type is requested. + *

By default this is not set. + * @since 5.0 + * @see #setDefaultContentTypeStrategy + */ + public void setDefaultContentTypes(List contentTypes) { + this.defaultNegotiationStrategy = new FixedContentNegotiationStrategy(contentTypes); + } + + /** + * Set a custom {@link ContentNegotiationStrategy} to use to determine + * the content type to use when no content type is requested. + *

By default this is not set. + * @since 4.1.2 + * @see #setDefaultContentType + */ + public void setDefaultContentTypeStrategy(ContentNegotiationStrategy strategy) { + this.defaultNegotiationStrategy = strategy; + } + + /** + * Invoked by Spring to inject the ServletContext. + */ + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + + + @Override + public void afterPropertiesSet() { + build(); + } + + /** + * Actually build the {@link ContentNegotiationManager}. + * @since 5.0 + */ + public ContentNegotiationManager build() { + List strategies = new ArrayList<>(); + + if (this.strategies != null) { + strategies.addAll(this.strategies); + } + else { + if (this.favorPathExtension) { + PathExtensionContentNegotiationStrategy strategy; + if (this.servletContext != null && !useRegisteredExtensionsOnly()) { + strategy = new ServletPathExtensionContentNegotiationStrategy(this.servletContext, this.mediaTypes); + } + else { + strategy = new PathExtensionContentNegotiationStrategy(this.mediaTypes); + } + strategy.setIgnoreUnknownExtensions(this.ignoreUnknownPathExtensions); + if (this.useRegisteredExtensionsOnly != null) { + strategy.setUseRegisteredExtensionsOnly(this.useRegisteredExtensionsOnly); + } + strategies.add(strategy); + } + + if (this.favorParameter) { + ParameterContentNegotiationStrategy strategy = new ParameterContentNegotiationStrategy(this.mediaTypes); + strategy.setParameterName(this.parameterName); + if (this.useRegisteredExtensionsOnly != null) { + strategy.setUseRegisteredExtensionsOnly(this.useRegisteredExtensionsOnly); + } + else { + strategy.setUseRegisteredExtensionsOnly(true); // backwards compatibility + } + strategies.add(strategy); + } + + if (!this.ignoreAcceptHeader) { + strategies.add(new HeaderContentNegotiationStrategy()); + } + + if (this.defaultNegotiationStrategy != null) { + strategies.add(this.defaultNegotiationStrategy); + } + } + + this.contentNegotiationManager = new ContentNegotiationManager(strategies); + return this.contentNegotiationManager; + } + + + @Override + @Nullable + public ContentNegotiationManager getObject() { + return this.contentNegotiationManager; + } + + @Override + public Class getObjectType() { + return ContentNegotiationManager.class; + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..0d0664b25de9c05ac52fc7b908f994be60ad306b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/ContentNegotiationStrategy.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Collections; +import java.util.List; + +import org.springframework.http.MediaType; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * A strategy for resolving the requested media types for a request. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +@FunctionalInterface +public interface ContentNegotiationStrategy { + + /** + * A singleton list with {@link MediaType#ALL} that is returned from + * {@link #resolveMediaTypes} when no specific media types are requested. + * @since 5.0.5 + */ + List MEDIA_TYPE_ALL_LIST = Collections.singletonList(MediaType.ALL); + + + /** + * Resolve the given request to a list of media types. The returned list is + * ordered by specificity first and by quality parameter second. + * @param webRequest the current request + * @return the requested media types, or {@link #MEDIA_TYPE_ALL_LIST} if none + * were requested. + * @throws HttpMediaTypeNotAcceptableException if the requested media + * types cannot be parsed + */ + List resolveMediaTypes(NativeWebRequest webRequest) + throws HttpMediaTypeNotAcceptableException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/FixedContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/FixedContentNegotiationStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..3b2073f7152e549f560af3a1af15b3a93b618732 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/FixedContentNegotiationStrategy.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Collections; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.MediaType; +import org.springframework.util.Assert; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * A {@code ContentNegotiationStrategy} that returns a fixed content type. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class FixedContentNegotiationStrategy implements ContentNegotiationStrategy { + + private static final Log logger = LogFactory.getLog(FixedContentNegotiationStrategy.class); + + private final List contentTypes; + + + /** + * Constructor with a single default {@code MediaType}. + */ + public FixedContentNegotiationStrategy(MediaType contentType) { + this(Collections.singletonList(contentType)); + } + + /** + * Constructor with an ordered List of default {@code MediaType}'s to return + * for use in applications that support a variety of content types. + *

Consider appending {@link MediaType#ALL} at the end if destinations + * are present which do not support any of the other default media types. + * @since 5.0 + */ + public FixedContentNegotiationStrategy(List contentTypes) { + Assert.notNull(contentTypes, "'contentTypes' must not be null"); + this.contentTypes = Collections.unmodifiableList(contentTypes); + } + + + /** + * Return the configured list of media types. + */ + public List getContentTypes() { + return this.contentTypes; + } + + + @Override + public List resolveMediaTypes(NativeWebRequest request) { + return this.contentTypes; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/HeaderContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/HeaderContentNegotiationStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..08880a61e97ec2b15fd078b0c1f8865b77a38b85 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/HeaderContentNegotiationStrategy.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Arrays; +import java.util.List; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.InvalidMediaTypeException; +import org.springframework.http.MediaType; +import org.springframework.util.CollectionUtils; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * A {@code ContentNegotiationStrategy} that checks the 'Accept' request header. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.2 + */ +public class HeaderContentNegotiationStrategy implements ContentNegotiationStrategy { + + /** + * {@inheritDoc} + * @throws HttpMediaTypeNotAcceptableException if the 'Accept' header cannot be parsed + */ + @Override + public List resolveMediaTypes(NativeWebRequest request) + throws HttpMediaTypeNotAcceptableException { + + String[] headerValueArray = request.getHeaderValues(HttpHeaders.ACCEPT); + if (headerValueArray == null) { + return MEDIA_TYPE_ALL_LIST; + } + + List headerValues = Arrays.asList(headerValueArray); + try { + List mediaTypes = MediaType.parseMediaTypes(headerValues); + MediaType.sortBySpecificityAndQuality(mediaTypes); + return !CollectionUtils.isEmpty(mediaTypes) ? mediaTypes : MEDIA_TYPE_ALL_LIST; + } + catch (InvalidMediaTypeException ex) { + throw new HttpMediaTypeNotAcceptableException( + "Could not parse 'Accept' header " + headerValues + ": " + ex.getMessage()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/MappingMediaTypeFileExtensionResolver.java b/spring-web/src/main/java/org/springframework/web/accept/MappingMediaTypeFileExtensionResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..806c4b0320cf189243b7e6d649589f0efe6673a9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/MappingMediaTypeFileExtensionResolver.java @@ -0,0 +1,114 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * An implementation of {@code MediaTypeFileExtensionResolver} that maintains + * lookups between file extensions and MediaTypes in both directions. + * + *

Initially created with a map of file extensions and media types. + * Subsequently subclasses can use {@link #addMapping} to add more mappings. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.2 + */ +public class MappingMediaTypeFileExtensionResolver implements MediaTypeFileExtensionResolver { + + private final ConcurrentMap mediaTypes = new ConcurrentHashMap<>(64); + + private final ConcurrentMap> fileExtensions = new ConcurrentHashMap<>(64); + + private final List allFileExtensions = new CopyOnWriteArrayList<>(); + + + /** + * Create an instance with the given map of file extensions and media types. + */ + public MappingMediaTypeFileExtensionResolver(@Nullable Map mediaTypes) { + if (mediaTypes != null) { + List allFileExtensions = new ArrayList<>(); + mediaTypes.forEach((extension, mediaType) -> { + String lowerCaseExtension = extension.toLowerCase(Locale.ENGLISH); + this.mediaTypes.put(lowerCaseExtension, mediaType); + addFileExtension(mediaType, lowerCaseExtension); + allFileExtensions.add(lowerCaseExtension); + }); + this.allFileExtensions.addAll(allFileExtensions); + } + } + + + public Map getMediaTypes() { + return this.mediaTypes; + } + + protected List getAllMediaTypes() { + return new ArrayList<>(this.mediaTypes.values()); + } + + /** + * Map an extension to a MediaType. Ignore if extension already mapped. + */ + protected void addMapping(String extension, MediaType mediaType) { + MediaType previous = this.mediaTypes.putIfAbsent(extension, mediaType); + if (previous == null) { + addFileExtension(mediaType, extension); + this.allFileExtensions.add(extension); + } + } + + private void addFileExtension(MediaType mediaType, String extension) { + List newList = new CopyOnWriteArrayList<>(); + List oldList = this.fileExtensions.putIfAbsent(mediaType, newList); + (oldList != null ? oldList : newList).add(extension); + } + + + @Override + public List resolveFileExtensions(MediaType mediaType) { + List fileExtensions = this.fileExtensions.get(mediaType); + return (fileExtensions != null ? fileExtensions : Collections.emptyList()); + } + + @Override + public List getAllFileExtensions() { + return Collections.unmodifiableList(this.allFileExtensions); + } + + /** + * Use this method for a reverse lookup from extension to MediaType. + * @return a MediaType for the key, or {@code null} if none found + */ + @Nullable + protected MediaType lookupMediaType(String extension) { + return this.mediaTypes.get(extension.toLowerCase(Locale.ENGLISH)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/MediaTypeFileExtensionResolver.java b/spring-web/src/main/java/org/springframework/web/accept/MediaTypeFileExtensionResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..a4b3b20b17f8c4707cbcdf1e68bddcbb52c5e9ec --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/MediaTypeFileExtensionResolver.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.List; + +import org.springframework.http.MediaType; + +/** + * Strategy to resolve {@link MediaType} to a list of file extensions. + * For example resolve "application/json" to "json". + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public interface MediaTypeFileExtensionResolver { + + /** + * Resolve the given media type to a list of path extensions. + * @param mediaType the media type to resolve + * @return a list of extensions or an empty list (never {@code null}) + */ + List resolveFileExtensions(MediaType mediaType); + + /** + * Return all registered file extensions. + * @return a list of extensions or an empty list (never {@code null}) + */ + List getAllFileExtensions(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/ParameterContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/ParameterContentNegotiationStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..ea5b881804190144e125f5f51a6d014da3a61a3e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/ParameterContentNegotiationStrategy.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Map; + +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Strategy that resolves the requested content type from a query parameter. + * The default query parameter name is {@literal "format"}. + * + *

You can register static mappings between keys (i.e. the expected value of + * the query parameter) and MediaType's via {@link #addMapping(String, MediaType)}. + * As of 5.0 this strategy also supports dynamic lookups of keys via + * {@link org.springframework.http.MediaTypeFactory#getMediaType}. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class ParameterContentNegotiationStrategy extends AbstractMappingContentNegotiationStrategy { + + private String parameterName = "format"; + + + /** + * Create an instance with the given map of file extensions and media types. + */ + public ParameterContentNegotiationStrategy(Map mediaTypes) { + super(mediaTypes); + } + + + /** + * Set the name of the parameter to use to determine requested media types. + *

By default this is set to {@code "format"}. + */ + public void setParameterName(String parameterName) { + Assert.notNull(parameterName, "'parameterName' is required"); + this.parameterName = parameterName; + } + + public String getParameterName() { + return this.parameterName; + } + + + @Override + @Nullable + protected String getMediaTypeKey(NativeWebRequest request) { + return request.getParameter(getParameterName()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..084c785ae4cd95d8534df340500e0fea034b9012 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategy.java @@ -0,0 +1,124 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Locale; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.util.UriUtils; +import org.springframework.web.util.UrlPathHelper; + +/** + * A {@code ContentNegotiationStrategy} that resolves the file extension in the + * request path to a key to be used to look up a media type. + * + *

If the file extension is not found in the explicit registrations provided + * to the constructor, the {@link MediaTypeFactory} is used as a fallback + * mechanism. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class PathExtensionContentNegotiationStrategy extends AbstractMappingContentNegotiationStrategy { + + private UrlPathHelper urlPathHelper = new UrlPathHelper(); + + + /** + * Create an instance without any mappings to start with. Mappings may be added + * later on if any extensions are resolved through the Java Activation framework. + */ + public PathExtensionContentNegotiationStrategy() { + this(null); + } + + /** + * Create an instance with the given map of file extensions and media types. + */ + public PathExtensionContentNegotiationStrategy(@Nullable Map mediaTypes) { + super(mediaTypes); + setUseRegisteredExtensionsOnly(false); + setIgnoreUnknownExtensions(true); + this.urlPathHelper.setUrlDecode(false); + } + + + /** + * Configure a {@code UrlPathHelper} to use in {@link #getMediaTypeKey} + * in order to derive the lookup path for a target request URL path. + * @since 4.2.8 + */ + public void setUrlPathHelper(UrlPathHelper urlPathHelper) { + this.urlPathHelper = urlPathHelper; + } + + /** + * Indicate whether to use the Java Activation Framework as a fallback option + * to map from file extensions to media types. + * @deprecated as of 5.0, in favor of {@link #setUseRegisteredExtensionsOnly(boolean)}. + */ + @Deprecated + public void setUseJaf(boolean useJaf) { + setUseRegisteredExtensionsOnly(!useJaf); + } + + @Override + @Nullable + protected String getMediaTypeKey(NativeWebRequest webRequest) { + HttpServletRequest request = webRequest.getNativeRequest(HttpServletRequest.class); + if (request == null) { + return null; + } + String path = this.urlPathHelper.getLookupPathForRequest(request); + String extension = UriUtils.extractFileExtension(path); + return (StringUtils.hasText(extension) ? extension.toLowerCase(Locale.ENGLISH) : null); + } + + /** + * A public method exposing the knowledge of the path extension strategy to + * resolve file extensions to a {@link MediaType} in this case for a given + * {@link Resource}. The method first looks up any explicitly registered + * file extensions first and then falls back on {@link MediaTypeFactory} if available. + * @param resource the resource to look up + * @return the MediaType for the extension, or {@code null} if none found + * @since 4.3 + */ + @Nullable + public MediaType getMediaTypeForResource(Resource resource) { + Assert.notNull(resource, "Resource must not be null"); + MediaType mediaType = null; + String filename = resource.getFilename(); + String extension = StringUtils.getFilenameExtension(filename); + if (extension != null) { + mediaType = lookupMediaType(extension); + } + if (mediaType == null) { + mediaType = MediaTypeFactory.getMediaType(filename).orElse(null); + } + return mediaType; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/ServletPathExtensionContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/ServletPathExtensionContentNegotiationStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..f47af3ca2f82b8f2e7284a4129f9a544251a3d0f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/ServletPathExtensionContentNegotiationStrategy.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Map; + +import javax.servlet.ServletContext; + +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Extends {@code PathExtensionContentNegotiationStrategy} that also uses + * {@link ServletContext#getMimeType(String)} to resolve file extensions. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class ServletPathExtensionContentNegotiationStrategy extends PathExtensionContentNegotiationStrategy { + + private final ServletContext servletContext; + + + /** + * Create an instance without any mappings to start with. Mappings may be + * added later when extensions are resolved through + * {@link ServletContext#getMimeType(String)} or via + * {@link org.springframework.http.MediaTypeFactory}. + */ + public ServletPathExtensionContentNegotiationStrategy(ServletContext context) { + this(context, null); + } + + /** + * Create an instance with the given extension-to-MediaType lookup. + */ + public ServletPathExtensionContentNegotiationStrategy( + ServletContext servletContext, @Nullable Map mediaTypes) { + + super(mediaTypes); + Assert.notNull(servletContext, "ServletContext is required"); + this.servletContext = servletContext; + } + + + /** + * Resolve file extension via {@link ServletContext#getMimeType(String)} + * and also delegate to base class for a potential + * {@link org.springframework.http.MediaTypeFactory} lookup. + */ + @Override + @Nullable + protected MediaType handleNoMatch(NativeWebRequest webRequest, String extension) + throws HttpMediaTypeNotAcceptableException { + + MediaType mediaType = null; + String mimeType = this.servletContext.getMimeType("file." + extension); + if (StringUtils.hasText(mimeType)) { + mediaType = MediaType.parseMediaType(mimeType); + } + if (mediaType == null || MediaType.APPLICATION_OCTET_STREAM.equals(mediaType)) { + MediaType superMediaType = super.handleNoMatch(webRequest, extension); + if (superMediaType != null) { + mediaType = superMediaType; + } + } + return mediaType; + } + + /** + * Extends the base class + * {@link PathExtensionContentNegotiationStrategy#getMediaTypeForResource} + * with the ability to also look up through the ServletContext. + * @param resource the resource to look up + * @return the MediaType for the extension, or {@code null} if none found + * @since 4.3 + */ + @Override + public MediaType getMediaTypeForResource(Resource resource) { + MediaType mediaType = null; + String mimeType = this.servletContext.getMimeType(resource.getFilename()); + if (StringUtils.hasText(mimeType)) { + mediaType = MediaType.parseMediaType(mimeType); + } + if (mediaType == null || MediaType.APPLICATION_OCTET_STREAM.equals(mediaType)) { + MediaType superMediaType = super.getMediaTypeForResource(resource); + if (superMediaType != null) { + mediaType = superMediaType; + } + } + return mediaType; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/accept/package-info.java b/spring-web/src/main/java/org/springframework/web/accept/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..928dca73b775371e3c86c2608e3bb617f9524893 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/accept/package-info.java @@ -0,0 +1,20 @@ +/** + * This package contains classes used to determine the requested the media types in a request. + * + *

{@link org.springframework.web.accept.ContentNegotiationStrategy} is the main + * abstraction for determining requested {@linkplain org.springframework.http.MediaType media types} + * with implementations based on + * {@linkplain org.springframework.web.accept.PathExtensionContentNegotiationStrategy path extensions}, a + * {@linkplain org.springframework.web.accept.ParameterContentNegotiationStrategy a request parameter}, the + * {@linkplain org.springframework.web.accept.HeaderContentNegotiationStrategy 'Accept' header}, or a + * {@linkplain org.springframework.web.accept.FixedContentNegotiationStrategy default content type}. + * + *

{@link org.springframework.web.accept.ContentNegotiationManager} is used to delegate to one + * ore more of the above strategies in a specific order. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.accept; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/bind/EscapedErrors.java b/spring-web/src/main/java/org/springframework/web/bind/EscapedErrors.java new file mode 100644 index 0000000000000000000000000000000000000000..a2cb47e4787aff0a450433b921a3c1b959575e57 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/EscapedErrors.java @@ -0,0 +1,250 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.validation.Errors; +import org.springframework.validation.FieldError; +import org.springframework.validation.ObjectError; +import org.springframework.web.util.HtmlUtils; + +/** + * Errors wrapper that adds automatic HTML escaping to the wrapped instance, + * for convenient usage in HTML views. Can be retrieved easily via + * RequestContext's {@code getErrors} method. + * + *

Note that BindTag does not use this class to avoid unnecessary + * creation of ObjectError instances. It just escapes the messages and values + * that get copied into the respective BindStatus instance. + * + * @author Juergen Hoeller + * @since 01.03.2003 + * @see org.springframework.web.servlet.support.RequestContext#getErrors + * @see org.springframework.web.servlet.tags.BindTag + */ +public class EscapedErrors implements Errors { + + private final Errors source; + + + /** + * Create a new EscapedErrors instance for the given source instance. + */ + public EscapedErrors(Errors source) { + Assert.notNull(source, "Errors source must not be null"); + this.source = source; + } + + public Errors getSource() { + return this.source; + } + + + @Override + public String getObjectName() { + return this.source.getObjectName(); + } + + @Override + public void setNestedPath(String nestedPath) { + this.source.setNestedPath(nestedPath); + } + + @Override + public String getNestedPath() { + return this.source.getNestedPath(); + } + + @Override + public void pushNestedPath(String subPath) { + this.source.pushNestedPath(subPath); + } + + @Override + public void popNestedPath() throws IllegalStateException { + this.source.popNestedPath(); + } + + + @Override + public void reject(String errorCode) { + this.source.reject(errorCode); + } + + @Override + public void reject(String errorCode, String defaultMessage) { + this.source.reject(errorCode, defaultMessage); + } + + @Override + public void reject(String errorCode, @Nullable Object[] errorArgs, @Nullable String defaultMessage) { + this.source.reject(errorCode, errorArgs, defaultMessage); + } + + @Override + public void rejectValue(@Nullable String field, String errorCode) { + this.source.rejectValue(field, errorCode); + } + + @Override + public void rejectValue(@Nullable String field, String errorCode, String defaultMessage) { + this.source.rejectValue(field, errorCode, defaultMessage); + } + + @Override + public void rejectValue(@Nullable String field, String errorCode, @Nullable Object[] errorArgs, + @Nullable String defaultMessage) { + + this.source.rejectValue(field, errorCode, errorArgs, defaultMessage); + } + + @Override + public void addAllErrors(Errors errors) { + this.source.addAllErrors(errors); + } + + + @Override + public boolean hasErrors() { + return this.source.hasErrors(); + } + + @Override + public int getErrorCount() { + return this.source.getErrorCount(); + } + + @Override + public List getAllErrors() { + return escapeObjectErrors(this.source.getAllErrors()); + } + + @Override + public boolean hasGlobalErrors() { + return this.source.hasGlobalErrors(); + } + + @Override + public int getGlobalErrorCount() { + return this.source.getGlobalErrorCount(); + } + + @Override + public List getGlobalErrors() { + return escapeObjectErrors(this.source.getGlobalErrors()); + } + + @Override + @Nullable + public ObjectError getGlobalError() { + return escapeObjectError(this.source.getGlobalError()); + } + + @Override + public boolean hasFieldErrors() { + return this.source.hasFieldErrors(); + } + + @Override + public int getFieldErrorCount() { + return this.source.getFieldErrorCount(); + } + + @Override + public List getFieldErrors() { + return this.source.getFieldErrors(); + } + + @Override + @Nullable + public FieldError getFieldError() { + return this.source.getFieldError(); + } + + @Override + public boolean hasFieldErrors(String field) { + return this.source.hasFieldErrors(field); + } + + @Override + public int getFieldErrorCount(String field) { + return this.source.getFieldErrorCount(field); + } + + @Override + public List getFieldErrors(String field) { + return escapeObjectErrors(this.source.getFieldErrors(field)); + } + + @Override + @Nullable + public FieldError getFieldError(String field) { + return escapeObjectError(this.source.getFieldError(field)); + } + + @Override + @Nullable + public Object getFieldValue(String field) { + Object value = this.source.getFieldValue(field); + return (value instanceof String ? HtmlUtils.htmlEscape((String) value) : value); + } + + @Override + @Nullable + public Class getFieldType(String field) { + return this.source.getFieldType(field); + } + + @SuppressWarnings("unchecked") + @Nullable + private T escapeObjectError(@Nullable T source) { + if (source == null) { + return null; + } + String defaultMessage = source.getDefaultMessage(); + if (defaultMessage != null) { + defaultMessage = HtmlUtils.htmlEscape(defaultMessage); + } + if (source instanceof FieldError) { + FieldError fieldError = (FieldError) source; + Object value = fieldError.getRejectedValue(); + if (value instanceof String) { + value = HtmlUtils.htmlEscape((String) value); + } + return (T) new FieldError( + fieldError.getObjectName(), fieldError.getField(), value, fieldError.isBindingFailure(), + fieldError.getCodes(), fieldError.getArguments(), defaultMessage); + } + else { + return (T) new ObjectError( + source.getObjectName(), source.getCodes(), source.getArguments(), defaultMessage); + } + } + + private List escapeObjectErrors(List source) { + List escaped = new ArrayList<>(source.size()); + for (T objectError : source) { + escaped.add(escapeObjectError(objectError)); + } + return escaped; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/MethodArgumentNotValidException.java b/spring-web/src/main/java/org/springframework/web/bind/MethodArgumentNotValidException.java new file mode 100644 index 0000000000000000000000000000000000000000..954bf6d85b8f16328a28929975d7a6ee54c1c709 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/MethodArgumentNotValidException.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.springframework.core.MethodParameter; +import org.springframework.validation.BindingResult; +import org.springframework.validation.ObjectError; + +/** + * Exception to be thrown when validation on an argument annotated with {@code @Valid} fails. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +@SuppressWarnings("serial") +public class MethodArgumentNotValidException extends Exception { + + private final MethodParameter parameter; + + private final BindingResult bindingResult; + + + /** + * Constructor for {@link MethodArgumentNotValidException}. + * @param parameter the parameter that failed validation + * @param bindingResult the results of the validation + */ + public MethodArgumentNotValidException(MethodParameter parameter, BindingResult bindingResult) { + this.parameter = parameter; + this.bindingResult = bindingResult; + } + + /** + * Return the method parameter that failed validation. + */ + public MethodParameter getParameter() { + return this.parameter; + } + + /** + * Return the results of the failed validation. + */ + public BindingResult getBindingResult() { + return this.bindingResult; + } + + + @Override + public String getMessage() { + StringBuilder sb = new StringBuilder("Validation failed for argument [") + .append(this.parameter.getParameterIndex()).append("] in ") + .append(this.parameter.getExecutable().toGenericString()); + if (this.bindingResult.getErrorCount() > 1) { + sb.append(" with ").append(this.bindingResult.getErrorCount()).append(" errors"); + } + sb.append(": "); + for (ObjectError error : this.bindingResult.getAllErrors()) { + sb.append("[").append(error).append("] "); + } + return sb.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/MissingMatrixVariableException.java b/spring-web/src/main/java/org/springframework/web/bind/MissingMatrixVariableException.java new file mode 100644 index 0000000000000000000000000000000000000000..3074a5c317a4d12b4882474787eee0e7234801c4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/MissingMatrixVariableException.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.springframework.core.MethodParameter; + +/** + * {@link ServletRequestBindingException} subclass that indicates that a matrix + * variable expected in the method parameters of an {@code @RequestMapping} + * method is not present among the matrix variables extracted from the URL. + * + * @author Juergen Hoeller + * @since 5.1 + * @see MissingPathVariableException + */ +@SuppressWarnings("serial") +public class MissingMatrixVariableException extends ServletRequestBindingException { + + private final String variableName; + + private final MethodParameter parameter; + + + /** + * Constructor for MissingMatrixVariableException. + * @param variableName the name of the missing matrix variable + * @param parameter the method parameter + */ + public MissingMatrixVariableException(String variableName, MethodParameter parameter) { + super(""); + this.variableName = variableName; + this.parameter = parameter; + } + + + @Override + public String getMessage() { + return "Missing matrix variable '" + this.variableName + + "' for method parameter of type " + this.parameter.getNestedParameterType().getSimpleName(); + } + + /** + * Return the expected name of the matrix variable. + */ + public final String getVariableName() { + return this.variableName; + } + + /** + * Return the method parameter bound to the matrix variable. + */ + public final MethodParameter getParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/MissingPathVariableException.java b/spring-web/src/main/java/org/springframework/web/bind/MissingPathVariableException.java new file mode 100644 index 0000000000000000000000000000000000000000..191a27d4715ad245837ad13e37f0902bf5349846 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/MissingPathVariableException.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.springframework.core.MethodParameter; + +/** + * {@link ServletRequestBindingException} subclass that indicates that a path + * variable expected in the method parameters of an {@code @RequestMapping} + * method is not present among the URI variables extracted from the URL. + * Typically that means the URI template does not match the path variable name + * declared on the method parameter. + * + * @author Rossen Stoyanchev + * @since 4.2 + * @see MissingMatrixVariableException + */ +@SuppressWarnings("serial") +public class MissingPathVariableException extends ServletRequestBindingException { + + private final String variableName; + + private final MethodParameter parameter; + + + /** + * Constructor for MissingPathVariableException. + * @param variableName the name of the missing path variable + * @param parameter the method parameter + */ + public MissingPathVariableException(String variableName, MethodParameter parameter) { + super(""); + this.variableName = variableName; + this.parameter = parameter; + } + + + @Override + public String getMessage() { + return "Missing URI template variable '" + this.variableName + + "' for method parameter of type " + this.parameter.getNestedParameterType().getSimpleName(); + } + + /** + * Return the expected name of the path variable. + */ + public final String getVariableName() { + return this.variableName; + } + + /** + * Return the method parameter bound to the path variable. + */ + public final MethodParameter getParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/MissingRequestCookieException.java b/spring-web/src/main/java/org/springframework/web/bind/MissingRequestCookieException.java new file mode 100644 index 0000000000000000000000000000000000000000..7a7ea209127a9a8b92d9dfc769a8bc1d2472dc1b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/MissingRequestCookieException.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.springframework.core.MethodParameter; + +/** + * {@link ServletRequestBindingException} subclass that indicates + * that a request cookie expected in the method parameters of an + * {@code @RequestMapping} method is not present. + * + * @author Juergen Hoeller + * @since 5.1 + * @see MissingRequestHeaderException + */ +@SuppressWarnings("serial") +public class MissingRequestCookieException extends ServletRequestBindingException { + + private final String cookieName; + + private final MethodParameter parameter; + + + /** + * Constructor for MissingRequestCookieException. + * @param cookieName the name of the missing request cookie + * @param parameter the method parameter + */ + public MissingRequestCookieException(String cookieName, MethodParameter parameter) { + super(""); + this.cookieName = cookieName; + this.parameter = parameter; + } + + + @Override + public String getMessage() { + return "Missing cookie '" + this.cookieName + + "' for method parameter of type " + this.parameter.getNestedParameterType().getSimpleName(); + } + + /** + * Return the expected name of the request cookie. + */ + public final String getCookieName() { + return this.cookieName; + } + + /** + * Return the method parameter bound to the request cookie. + */ + public final MethodParameter getParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/MissingRequestHeaderException.java b/spring-web/src/main/java/org/springframework/web/bind/MissingRequestHeaderException.java new file mode 100644 index 0000000000000000000000000000000000000000..eb3bf660ea64b4acae5e2f20be8c56d838ccde13 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/MissingRequestHeaderException.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.springframework.core.MethodParameter; + +/** + * {@link ServletRequestBindingException} subclass that indicates + * that a request header expected in the method parameters of an + * {@code @RequestMapping} method is not present. + * + * @author Juergen Hoeller + * @since 5.1 + * @see MissingRequestCookieException + */ +@SuppressWarnings("serial") +public class MissingRequestHeaderException extends ServletRequestBindingException { + + private final String headerName; + + private final MethodParameter parameter; + + + /** + * Constructor for MissingRequestHeaderException. + * @param headerName the name of the missing request header + * @param parameter the method parameter + */ + public MissingRequestHeaderException(String headerName, MethodParameter parameter) { + super(""); + this.headerName = headerName; + this.parameter = parameter; + } + + + @Override + public String getMessage() { + return "Missing request header '" + this.headerName + + "' for method parameter of type " + this.parameter.getNestedParameterType().getSimpleName(); + } + + /** + * Return the expected name of the request header. + */ + public final String getHeaderName() { + return this.headerName; + } + + /** + * Return the method parameter bound to the request header. + */ + public final MethodParameter getParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/MissingServletRequestParameterException.java b/spring-web/src/main/java/org/springframework/web/bind/MissingServletRequestParameterException.java new file mode 100644 index 0000000000000000000000000000000000000000..67feef6bfa848da3087ec4ce3ac32ee047be6197 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/MissingServletRequestParameterException.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +/** + * {@link ServletRequestBindingException} subclass that indicates a missing parameter. + * + * @author Juergen Hoeller + * @since 2.0.2 + */ +@SuppressWarnings("serial") +public class MissingServletRequestParameterException extends ServletRequestBindingException { + + private final String parameterName; + + private final String parameterType; + + + /** + * Constructor for MissingServletRequestParameterException. + * @param parameterName the name of the missing parameter + * @param parameterType the expected type of the missing parameter + */ + public MissingServletRequestParameterException(String parameterName, String parameterType) { + super(""); + this.parameterName = parameterName; + this.parameterType = parameterType; + } + + + @Override + public String getMessage() { + return "Required " + this.parameterType + " parameter '" + this.parameterName + "' is not present"; + } + + /** + * Return the name of the offending parameter. + */ + public final String getParameterName() { + return this.parameterName; + } + + /** + * Return the expected type of the offending parameter. + */ + public final String getParameterType() { + return this.parameterType; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/ServletRequestBindingException.java b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestBindingException.java new file mode 100644 index 0000000000000000000000000000000000000000..fc7b7a6c0b463ff02f28ee0c587e80be994f7183 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestBindingException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.springframework.web.util.NestedServletException; + +/** + * Fatal binding exception, thrown when we want to + * treat binding exceptions as unrecoverable. + * + *

Extends ServletException for convenient throwing in any Servlet resource + * (such as a Filter), and NestedServletException for proper root cause handling + * (as the plain ServletException doesn't expose its root cause at all). + * + * @author Rod Johnson + * @author Juergen Hoeller + */ +@SuppressWarnings("serial") +public class ServletRequestBindingException extends NestedServletException { + + /** + * Constructor for ServletRequestBindingException. + * @param msg the detail message + */ + public ServletRequestBindingException(String msg) { + super(msg); + } + + /** + * Constructor for ServletRequestBindingException. + * @param msg the detail message + * @param cause the root cause + */ + public ServletRequestBindingException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/ServletRequestDataBinder.java b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestDataBinder.java new file mode 100644 index 0000000000000000000000000000000000000000..7a861575a1a82f1da923742f38f40fb69e3663f1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestDataBinder.java @@ -0,0 +1,134 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import javax.servlet.ServletRequest; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.lang.Nullable; +import org.springframework.validation.BindException; +import org.springframework.web.multipart.MultipartRequest; +import org.springframework.web.util.WebUtils; + +/** + * Special {@link org.springframework.validation.DataBinder} to perform data binding + * from servlet request parameters to JavaBeans, including support for multipart files. + * + *

See the DataBinder/WebDataBinder superclasses for customization options, + * which include specifying allowed/required fields, and registering custom + * property editors. + * + *

Can also be used for manual data binding in custom web controllers: + * for example, in a plain Controller implementation or in a MultiActionController + * handler method. Simply instantiate a ServletRequestDataBinder for each binding + * process, and invoke {@code bind} with the current ServletRequest as argument: + * + *

+ * MyBean myBean = new MyBean();
+ * // apply binder to custom target object
+ * ServletRequestDataBinder binder = new ServletRequestDataBinder(myBean);
+ * // register custom editors, if desired
+ * binder.registerCustomEditor(...);
+ * // trigger actual binding of request parameters
+ * binder.bind(request);
+ * // optionally evaluate binding errors
+ * Errors errors = binder.getErrors();
+ * ...
+ * + * @author Rod Johnson + * @author Juergen Hoeller + * @see #bind(javax.servlet.ServletRequest) + * @see #registerCustomEditor + * @see #setAllowedFields + * @see #setRequiredFields + * @see #setFieldMarkerPrefix + */ +public class ServletRequestDataBinder extends WebDataBinder { + + /** + * Create a new ServletRequestDataBinder instance, with default object name. + * @param target the target object to bind onto (or {@code null} + * if the binder is just used to convert a plain parameter value) + * @see #DEFAULT_OBJECT_NAME + */ + public ServletRequestDataBinder(@Nullable Object target) { + super(target); + } + + /** + * Create a new ServletRequestDataBinder instance. + * @param target the target object to bind onto (or {@code null} + * if the binder is just used to convert a plain parameter value) + * @param objectName the name of the target object + */ + public ServletRequestDataBinder(@Nullable Object target, String objectName) { + super(target, objectName); + } + + + /** + * Bind the parameters of the given request to this binder's target, + * also binding multipart files in case of a multipart request. + *

This call can create field errors, representing basic binding + * errors like a required field (code "required"), or type mismatch + * between value and bean property (code "typeMismatch"). + *

Multipart files are bound via their parameter name, just like normal + * HTTP parameters: i.e. "uploadedFile" to an "uploadedFile" bean property, + * invoking a "setUploadedFile" setter method. + *

The type of the target property for a multipart file can be MultipartFile, + * byte[], or String. The latter two receive the contents of the uploaded file; + * all metadata like original file name, content type, etc are lost in those cases. + * @param request request with parameters to bind (can be multipart) + * @see org.springframework.web.multipart.MultipartHttpServletRequest + * @see org.springframework.web.multipart.MultipartFile + * @see #bind(org.springframework.beans.PropertyValues) + */ + public void bind(ServletRequest request) { + MutablePropertyValues mpvs = new ServletRequestParameterPropertyValues(request); + MultipartRequest multipartRequest = WebUtils.getNativeRequest(request, MultipartRequest.class); + if (multipartRequest != null) { + bindMultipart(multipartRequest.getMultiFileMap(), mpvs); + } + addBindValues(mpvs, request); + doBind(mpvs); + } + + /** + * Extension point that subclasses can use to add extra bind values for a + * request. Invoked before {@link #doBind(MutablePropertyValues)}. + * The default implementation is empty. + * @param mpvs the property values that will be used for data binding + * @param request the current request + */ + protected void addBindValues(MutablePropertyValues mpvs, ServletRequest request) { + } + + /** + * Treats errors as fatal. + *

Use this method only if it's an error if the input isn't valid. + * This might be appropriate if all input is from dropdowns, for example. + * @throws ServletRequestBindingException subclass of ServletException on any binding problem + */ + public void closeNoCatch() throws ServletRequestBindingException { + if (getBindingResult().hasErrors()) { + throw new ServletRequestBindingException( + "Errors binding onto object '" + getBindingResult().getObjectName() + "'", + new BindException(getBindingResult())); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/ServletRequestParameterPropertyValues.java b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestParameterPropertyValues.java new file mode 100644 index 0000000000000000000000000000000000000000..6d1b17de2c28a34bc3b4b42118e9b2888c53d5c9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestParameterPropertyValues.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import javax.servlet.ServletRequest; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.lang.Nullable; +import org.springframework.web.util.WebUtils; + +/** + * PropertyValues implementation created from parameters in a ServletRequest. + * Can look for all property values beginning with a certain prefix and + * prefix separator (default is "_"). + * + *

For example, with a prefix of "spring", "spring_param1" and + * "spring_param2" result in a Map with "param1" and "param2" as keys. + * + *

This class is not immutable to be able to efficiently remove property + * values that should be ignored for binding. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @see org.springframework.web.util.WebUtils#getParametersStartingWith + */ +@SuppressWarnings("serial") +public class ServletRequestParameterPropertyValues extends MutablePropertyValues { + + /** Default prefix separator. */ + public static final String DEFAULT_PREFIX_SEPARATOR = "_"; + + + /** + * Create new ServletRequestPropertyValues using no prefix + * (and hence, no prefix separator). + * @param request the HTTP request + */ + public ServletRequestParameterPropertyValues(ServletRequest request) { + this(request, null, null); + } + + /** + * Create new ServletRequestPropertyValues using the given prefix and + * the default prefix separator (the underscore character "_"). + * @param request the HTTP request + * @param prefix the prefix for parameters (the full prefix will + * consist of this plus the separator) + * @see #DEFAULT_PREFIX_SEPARATOR + */ + public ServletRequestParameterPropertyValues(ServletRequest request, @Nullable String prefix) { + this(request, prefix, DEFAULT_PREFIX_SEPARATOR); + } + + /** + * Create new ServletRequestPropertyValues supplying both prefix and + * prefix separator. + * @param request the HTTP request + * @param prefix the prefix for parameters (the full prefix will + * consist of this plus the separator) + * @param prefixSeparator separator delimiting prefix (e.g. "spring") + * and the rest of the parameter name ("param1", "param2") + */ + public ServletRequestParameterPropertyValues( + ServletRequest request, @Nullable String prefix, @Nullable String prefixSeparator) { + + super(WebUtils.getParametersStartingWith( + request, (prefix != null ? prefix + prefixSeparator : null))); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/ServletRequestUtils.java b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..9537758b7930e8974450edb6a48e1a98de4aa0a2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/ServletRequestUtils.java @@ -0,0 +1,717 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import javax.servlet.ServletRequest; + +import org.springframework.lang.Nullable; + +/** + * Parameter extraction methods, for an approach distinct from data binding, + * in which parameters of specific types are required. + * + *

This approach is very useful for simple submissions, where binding + * request parameters to a command object would be overkill. + * + * @author Juergen Hoeller + * @author Keith Donald + * @since 2.0 + */ +public abstract class ServletRequestUtils { + + private static final IntParser INT_PARSER = new IntParser(); + + private static final LongParser LONG_PARSER = new LongParser(); + + private static final FloatParser FLOAT_PARSER = new FloatParser(); + + private static final DoubleParser DOUBLE_PARSER = new DoubleParser(); + + private static final BooleanParser BOOLEAN_PARSER = new BooleanParser(); + + private static final StringParser STRING_PARSER = new StringParser(); + + + /** + * Get an Integer parameter, or {@code null} if not present. + * Throws an exception if it the parameter value isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @return the Integer value, or {@code null} if not present + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + @Nullable + public static Integer getIntParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + if (request.getParameter(name) == null) { + return null; + } + return getRequiredIntParameter(request, name); + } + + /** + * Get an int parameter, with a fallback value. Never throws an exception. + * Can pass a distinguished value as default to enable checks of whether it was supplied. + * @param request current HTTP request + * @param name the name of the parameter + * @param defaultVal the default value to use as fallback + */ + public static int getIntParameter(ServletRequest request, String name, int defaultVal) { + if (request.getParameter(name) == null) { + return defaultVal; + } + try { + return getRequiredIntParameter(request, name); + } + catch (ServletRequestBindingException ex) { + return defaultVal; + } + } + + /** + * Get an array of int parameters, return an empty array if not found. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + */ + public static int[] getIntParameters(ServletRequest request, String name) { + try { + return getRequiredIntParameters(request, name); + } + catch (ServletRequestBindingException ex) { + return new int[0]; + } + } + + /** + * Get an int parameter, throwing an exception if it isn't found or isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static int getRequiredIntParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + return INT_PARSER.parseInt(name, request.getParameter(name)); + } + + /** + * Get an array of int parameters, throwing an exception if not found or one is not a number.. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static int[] getRequiredIntParameters(ServletRequest request, String name) + throws ServletRequestBindingException { + + return INT_PARSER.parseInts(name, request.getParameterValues(name)); + } + + + /** + * Get a Long parameter, or {@code null} if not present. + * Throws an exception if it the parameter value isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @return the Long value, or {@code null} if not present + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + @Nullable + public static Long getLongParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + if (request.getParameter(name) == null) { + return null; + } + return getRequiredLongParameter(request, name); + } + + /** + * Get a long parameter, with a fallback value. Never throws an exception. + * Can pass a distinguished value as default to enable checks of whether it was supplied. + * @param request current HTTP request + * @param name the name of the parameter + * @param defaultVal the default value to use as fallback + */ + public static long getLongParameter(ServletRequest request, String name, long defaultVal) { + if (request.getParameter(name) == null) { + return defaultVal; + } + try { + return getRequiredLongParameter(request, name); + } + catch (ServletRequestBindingException ex) { + return defaultVal; + } + } + + /** + * Get an array of long parameters, return an empty array if not found. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + */ + public static long[] getLongParameters(ServletRequest request, String name) { + try { + return getRequiredLongParameters(request, name); + } + catch (ServletRequestBindingException ex) { + return new long[0]; + } + } + + /** + * Get a long parameter, throwing an exception if it isn't found or isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static long getRequiredLongParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + return LONG_PARSER.parseLong(name, request.getParameter(name)); + } + + /** + * Get an array of long parameters, throwing an exception if not found or one is not a number. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static long[] getRequiredLongParameters(ServletRequest request, String name) + throws ServletRequestBindingException { + + return LONG_PARSER.parseLongs(name, request.getParameterValues(name)); + } + + + /** + * Get a Float parameter, or {@code null} if not present. + * Throws an exception if it the parameter value isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @return the Float value, or {@code null} if not present + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + @Nullable + public static Float getFloatParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + if (request.getParameter(name) == null) { + return null; + } + return getRequiredFloatParameter(request, name); + } + + /** + * Get a float parameter, with a fallback value. Never throws an exception. + * Can pass a distinguished value as default to enable checks of whether it was supplied. + * @param request current HTTP request + * @param name the name of the parameter + * @param defaultVal the default value to use as fallback + */ + public static float getFloatParameter(ServletRequest request, String name, float defaultVal) { + if (request.getParameter(name) == null) { + return defaultVal; + } + try { + return getRequiredFloatParameter(request, name); + } + catch (ServletRequestBindingException ex) { + return defaultVal; + } + } + + /** + * Get an array of float parameters, return an empty array if not found. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + */ + public static float[] getFloatParameters(ServletRequest request, String name) { + try { + return getRequiredFloatParameters(request, name); + } + catch (ServletRequestBindingException ex) { + return new float[0]; + } + } + + /** + * Get a float parameter, throwing an exception if it isn't found or isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static float getRequiredFloatParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + return FLOAT_PARSER.parseFloat(name, request.getParameter(name)); + } + + /** + * Get an array of float parameters, throwing an exception if not found or one is not a number. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static float[] getRequiredFloatParameters(ServletRequest request, String name) + throws ServletRequestBindingException { + + return FLOAT_PARSER.parseFloats(name, request.getParameterValues(name)); + } + + + /** + * Get a Double parameter, or {@code null} if not present. + * Throws an exception if it the parameter value isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @return the Double value, or {@code null} if not present + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + @Nullable + public static Double getDoubleParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + if (request.getParameter(name) == null) { + return null; + } + return getRequiredDoubleParameter(request, name); + } + + /** + * Get a double parameter, with a fallback value. Never throws an exception. + * Can pass a distinguished value as default to enable checks of whether it was supplied. + * @param request current HTTP request + * @param name the name of the parameter + * @param defaultVal the default value to use as fallback + */ + public static double getDoubleParameter(ServletRequest request, String name, double defaultVal) { + if (request.getParameter(name) == null) { + return defaultVal; + } + try { + return getRequiredDoubleParameter(request, name); + } + catch (ServletRequestBindingException ex) { + return defaultVal; + } + } + + /** + * Get an array of double parameters, return an empty array if not found. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + */ + public static double[] getDoubleParameters(ServletRequest request, String name) { + try { + return getRequiredDoubleParameters(request, name); + } + catch (ServletRequestBindingException ex) { + return new double[0]; + } + } + + /** + * Get a double parameter, throwing an exception if it isn't found or isn't a number. + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static double getRequiredDoubleParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + return DOUBLE_PARSER.parseDouble(name, request.getParameter(name)); + } + + /** + * Get an array of double parameters, throwing an exception if not found or one is not a number. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static double[] getRequiredDoubleParameters(ServletRequest request, String name) + throws ServletRequestBindingException { + + return DOUBLE_PARSER.parseDoubles(name, request.getParameterValues(name)); + } + + + /** + * Get a Boolean parameter, or {@code null} if not present. + * Throws an exception if it the parameter value isn't a boolean. + *

Accepts "true", "on", "yes" (any case) and "1" as values for true; + * treats every other non-empty value as false (i.e. parses leniently). + * @param request current HTTP request + * @param name the name of the parameter + * @return the Boolean value, or {@code null} if not present + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + @Nullable + public static Boolean getBooleanParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + if (request.getParameter(name) == null) { + return null; + } + return (getRequiredBooleanParameter(request, name)); + } + + /** + * Get a boolean parameter, with a fallback value. Never throws an exception. + * Can pass a distinguished value as default to enable checks of whether it was supplied. + *

Accepts "true", "on", "yes" (any case) and "1" as values for true; + * treats every other non-empty value as false (i.e. parses leniently). + * @param request current HTTP request + * @param name the name of the parameter + * @param defaultVal the default value to use as fallback + */ + public static boolean getBooleanParameter(ServletRequest request, String name, boolean defaultVal) { + if (request.getParameter(name) == null) { + return defaultVal; + } + try { + return getRequiredBooleanParameter(request, name); + } + catch (ServletRequestBindingException ex) { + return defaultVal; + } + } + + /** + * Get an array of boolean parameters, return an empty array if not found. + *

Accepts "true", "on", "yes" (any case) and "1" as values for true; + * treats every other non-empty value as false (i.e. parses leniently). + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + */ + public static boolean[] getBooleanParameters(ServletRequest request, String name) { + try { + return getRequiredBooleanParameters(request, name); + } + catch (ServletRequestBindingException ex) { + return new boolean[0]; + } + } + + /** + * Get a boolean parameter, throwing an exception if it isn't found + * or isn't a boolean. + *

Accepts "true", "on", "yes" (any case) and "1" as values for true; + * treats every other non-empty value as false (i.e. parses leniently). + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static boolean getRequiredBooleanParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + return BOOLEAN_PARSER.parseBoolean(name, request.getParameter(name)); + } + + /** + * Get an array of boolean parameters, throwing an exception if not found + * or one isn't a boolean. + *

Accepts "true", "on", "yes" (any case) and "1" as values for true; + * treats every other non-empty value as false (i.e. parses leniently). + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static boolean[] getRequiredBooleanParameters(ServletRequest request, String name) + throws ServletRequestBindingException { + + return BOOLEAN_PARSER.parseBooleans(name, request.getParameterValues(name)); + } + + + /** + * Get a String parameter, or {@code null} if not present. + * @param request current HTTP request + * @param name the name of the parameter + * @return the String value, or {@code null} if not present + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + @Nullable + public static String getStringParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + if (request.getParameter(name) == null) { + return null; + } + return getRequiredStringParameter(request, name); + } + + /** + * Get a String parameter, with a fallback value. Never throws an exception. + * Can pass a distinguished value to default to enable checks of whether it was supplied. + * @param request current HTTP request + * @param name the name of the parameter + * @param defaultVal the default value to use as fallback + */ + public static String getStringParameter(ServletRequest request, String name, String defaultVal) { + String val = request.getParameter(name); + return (val != null ? val : defaultVal); + } + + /** + * Get an array of String parameters, return an empty array if not found. + * @param request current HTTP request + * @param name the name of the parameter with multiple possible values + */ + public static String[] getStringParameters(ServletRequest request, String name) { + try { + return getRequiredStringParameters(request, name); + } + catch (ServletRequestBindingException ex) { + return new String[0]; + } + } + + /** + * Get a String parameter, throwing an exception if it isn't found. + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static String getRequiredStringParameter(ServletRequest request, String name) + throws ServletRequestBindingException { + + return STRING_PARSER.validateRequiredString(name, request.getParameter(name)); + } + + /** + * Get an array of String parameters, throwing an exception if not found. + * @param request current HTTP request + * @param name the name of the parameter + * @throws ServletRequestBindingException a subclass of ServletException, + * so it doesn't need to be caught + */ + public static String[] getRequiredStringParameters(ServletRequest request, String name) + throws ServletRequestBindingException { + + return STRING_PARSER.validateRequiredStrings(name, request.getParameterValues(name)); + } + + + private abstract static class ParameterParser { + + protected final T parse(String name, String parameter) throws ServletRequestBindingException { + validateRequiredParameter(name, parameter); + try { + return doParse(parameter); + } + catch (NumberFormatException ex) { + throw new ServletRequestBindingException( + "Required " + getType() + " parameter '" + name + "' with value of '" + + parameter + "' is not a valid number", ex); + } + } + + protected final void validateRequiredParameter(String name, @Nullable Object parameter) + throws ServletRequestBindingException { + + if (parameter == null) { + throw new MissingServletRequestParameterException(name, getType()); + } + } + + protected abstract String getType(); + + protected abstract T doParse(String parameter) throws NumberFormatException; + } + + + private static class IntParser extends ParameterParser { + + @Override + protected String getType() { + return "int"; + } + + @Override + protected Integer doParse(String s) throws NumberFormatException { + return Integer.valueOf(s); + } + + public int parseInt(String name, String parameter) throws ServletRequestBindingException { + return parse(name, parameter); + } + + public int[] parseInts(String name, String[] values) throws ServletRequestBindingException { + validateRequiredParameter(name, values); + int[] parameters = new int[values.length]; + for (int i = 0; i < values.length; i++) { + parameters[i] = parseInt(name, values[i]); + } + return parameters; + } + } + + + private static class LongParser extends ParameterParser { + + @Override + protected String getType() { + return "long"; + } + + @Override + protected Long doParse(String parameter) throws NumberFormatException { + return Long.valueOf(parameter); + } + + public long parseLong(String name, String parameter) throws ServletRequestBindingException { + return parse(name, parameter); + } + + public long[] parseLongs(String name, String[] values) throws ServletRequestBindingException { + validateRequiredParameter(name, values); + long[] parameters = new long[values.length]; + for (int i = 0; i < values.length; i++) { + parameters[i] = parseLong(name, values[i]); + } + return parameters; + } + } + + + private static class FloatParser extends ParameterParser { + + @Override + protected String getType() { + return "float"; + } + + @Override + protected Float doParse(String parameter) throws NumberFormatException { + return Float.valueOf(parameter); + } + + public float parseFloat(String name, String parameter) throws ServletRequestBindingException { + return parse(name, parameter); + } + + public float[] parseFloats(String name, String[] values) throws ServletRequestBindingException { + validateRequiredParameter(name, values); + float[] parameters = new float[values.length]; + for (int i = 0; i < values.length; i++) { + parameters[i] = parseFloat(name, values[i]); + } + return parameters; + } + } + + + private static class DoubleParser extends ParameterParser { + + @Override + protected String getType() { + return "double"; + } + + @Override + protected Double doParse(String parameter) throws NumberFormatException { + return Double.valueOf(parameter); + } + + public double parseDouble(String name, String parameter) throws ServletRequestBindingException { + return parse(name, parameter); + } + + public double[] parseDoubles(String name, String[] values) throws ServletRequestBindingException { + validateRequiredParameter(name, values); + double[] parameters = new double[values.length]; + for (int i = 0; i < values.length; i++) { + parameters[i] = parseDouble(name, values[i]); + } + return parameters; + } + } + + + private static class BooleanParser extends ParameterParser { + + @Override + protected String getType() { + return "boolean"; + } + + @Override + protected Boolean doParse(String parameter) throws NumberFormatException { + return (parameter.equalsIgnoreCase("true") || parameter.equalsIgnoreCase("on") || + parameter.equalsIgnoreCase("yes") || parameter.equals("1")); + } + + public boolean parseBoolean(String name, String parameter) throws ServletRequestBindingException { + return parse(name, parameter); + } + + public boolean[] parseBooleans(String name, String[] values) throws ServletRequestBindingException { + validateRequiredParameter(name, values); + boolean[] parameters = new boolean[values.length]; + for (int i = 0; i < values.length; i++) { + parameters[i] = parseBoolean(name, values[i]); + } + return parameters; + } + } + + + private static class StringParser extends ParameterParser { + + @Override + protected String getType() { + return "string"; + } + + @Override + protected String doParse(String parameter) throws NumberFormatException { + return parameter; + } + + public String validateRequiredString(String name, String value) throws ServletRequestBindingException { + validateRequiredParameter(name, value); + return value; + } + + public String[] validateRequiredStrings(String name, String[] values) throws ServletRequestBindingException { + validateRequiredParameter(name, values); + for (String value : values) { + validateRequiredParameter(name, value); + } + return values; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/UnsatisfiedServletRequestParameterException.java b/spring-web/src/main/java/org/springframework/web/bind/UnsatisfiedServletRequestParameterException.java new file mode 100644 index 0000000000000000000000000000000000000000..5418a06accdf1373da821a73c3bcff9a74c8073b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/UnsatisfiedServletRequestParameterException.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * {@link ServletRequestBindingException} subclass that indicates an unsatisfied + * parameter condition, as typically expressed using an {@code @RequestMapping} + * annotation at the {@code @Controller} type level. + * + * @author Juergen Hoeller + * @since 3.0 + * @see org.springframework.web.bind.annotation.RequestMapping#params() + */ +@SuppressWarnings("serial") +public class UnsatisfiedServletRequestParameterException extends ServletRequestBindingException { + + private final List paramConditions; + + private final Map actualParams; + + + /** + * Create a new UnsatisfiedServletRequestParameterException. + * @param paramConditions the parameter conditions that have been violated + * @param actualParams the actual parameter Map associated with the ServletRequest + */ + public UnsatisfiedServletRequestParameterException(String[] paramConditions, Map actualParams) { + super(""); + this.paramConditions = Arrays.asList(paramConditions); + this.actualParams = actualParams; + } + + /** + * Create a new UnsatisfiedServletRequestParameterException. + * @param paramConditions all sets of parameter conditions that have been violated + * @param actualParams the actual parameter Map associated with the ServletRequest + * @since 4.2 + */ + public UnsatisfiedServletRequestParameterException(List paramConditions, + Map actualParams) { + + super(""); + Assert.notEmpty(paramConditions, "Parameter conditions must not be empty"); + this.paramConditions = paramConditions; + this.actualParams = actualParams; + } + + + @Override + public String getMessage() { + StringBuilder sb = new StringBuilder("Parameter conditions "); + int i = 0; + for (String[] conditions : this.paramConditions) { + if (i > 0) { + sb.append(" OR "); + } + sb.append("\""); + sb.append(StringUtils.arrayToDelimitedString(conditions, ", ")); + sb.append("\""); + i++; + } + sb.append(" not met for actual request parameters: "); + sb.append(requestParameterMapToString(this.actualParams)); + return sb.toString(); + } + + /** + * Return the parameter conditions that have been violated or the first group + * in case of multiple groups. + * @see org.springframework.web.bind.annotation.RequestMapping#params() + */ + public final String[] getParamConditions() { + return this.paramConditions.get(0); + } + + /** + * Return all parameter condition groups that have been violated. + * @since 4.2 + * @see org.springframework.web.bind.annotation.RequestMapping#params() + */ + public final List getParamConditionGroups() { + return this.paramConditions; + } + + /** + * Return the actual parameter Map associated with the ServletRequest. + * @see javax.servlet.ServletRequest#getParameterMap() + */ + public final Map getActualParams() { + return this.actualParams; + } + + + private static String requestParameterMapToString(Map actualParams) { + StringBuilder result = new StringBuilder(); + for (Iterator> it = actualParams.entrySet().iterator(); it.hasNext();) { + Map.Entry entry = it.next(); + result.append(entry.getKey()).append('=').append(ObjectUtils.nullSafeToString(entry.getValue())); + if (it.hasNext()) { + result.append(", "); + } + } + return result.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/WebDataBinder.java b/spring-web/src/main/java/org/springframework/web/bind/WebDataBinder.java new file mode 100644 index 0000000000000000000000000000000000000000..c2893de0fea4c192d07eb98026657754b76d518a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/WebDataBinder.java @@ -0,0 +1,331 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import java.lang.reflect.Array; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.beans.PropertyValue; +import org.springframework.core.CollectionFactory; +import org.springframework.lang.Nullable; +import org.springframework.validation.DataBinder; +import org.springframework.web.multipart.MultipartFile; + +/** + * Special {@link DataBinder} for data binding from web request parameters + * to JavaBean objects. Designed for web environments, but not dependent on + * the Servlet API; serves as base class for more specific DataBinder variants, + * such as {@link org.springframework.web.bind.ServletRequestDataBinder}. + * + *

Includes support for field markers which address a common problem with + * HTML checkboxes and select options: detecting that a field was part of + * the form, but did not generate a request parameter because it was empty. + * A field marker allows to detect that state and reset the corresponding + * bean property accordingly. Default values, for parameters that are otherwise + * not present, can specify a value for the field other then empty. + * + * @author Juergen Hoeller + * @author Scott Andrews + * @author Brian Clozel + * @since 1.2 + * @see #registerCustomEditor + * @see #setAllowedFields + * @see #setRequiredFields + * @see #setFieldMarkerPrefix + * @see #setFieldDefaultPrefix + * @see ServletRequestDataBinder + */ +public class WebDataBinder extends DataBinder { + + /** + * Default prefix that field marker parameters start with, followed by the field + * name: e.g. "_subscribeToNewsletter" for a field "subscribeToNewsletter". + *

Such a marker parameter indicates that the field was visible, that is, + * existed in the form that caused the submission. If no corresponding field + * value parameter was found, the field will be reset. The value of the field + * marker parameter does not matter in this case; an arbitrary value can be used. + * This is particularly useful for HTML checkboxes and select options. + * @see #setFieldMarkerPrefix + */ + public static final String DEFAULT_FIELD_MARKER_PREFIX = "_"; + + /** + * Default prefix that field default parameters start with, followed by the field + * name: e.g. "!subscribeToNewsletter" for a field "subscribeToNewsletter". + *

Default parameters differ from field markers in that they provide a default + * value instead of an empty value. + * @see #setFieldDefaultPrefix + */ + public static final String DEFAULT_FIELD_DEFAULT_PREFIX = "!"; + + @Nullable + private String fieldMarkerPrefix = DEFAULT_FIELD_MARKER_PREFIX; + + @Nullable + private String fieldDefaultPrefix = DEFAULT_FIELD_DEFAULT_PREFIX; + + private boolean bindEmptyMultipartFiles = true; + + + /** + * Create a new WebDataBinder instance, with default object name. + * @param target the target object to bind onto (or {@code null} + * if the binder is just used to convert a plain parameter value) + * @see #DEFAULT_OBJECT_NAME + */ + public WebDataBinder(@Nullable Object target) { + super(target); + } + + /** + * Create a new WebDataBinder instance. + * @param target the target object to bind onto (or {@code null} + * if the binder is just used to convert a plain parameter value) + * @param objectName the name of the target object + */ + public WebDataBinder(@Nullable Object target, String objectName) { + super(target, objectName); + } + + + /** + * Specify a prefix that can be used for parameters that mark potentially + * empty fields, having "prefix + field" as name. Such a marker parameter is + * checked by existence: You can send any value for it, for example "visible". + * This is particularly useful for HTML checkboxes and select options. + *

Default is "_", for "_FIELD" parameters (e.g. "_subscribeToNewsletter"). + * Set this to null if you want to turn off the empty field check completely. + *

HTML checkboxes only send a value when they're checked, so it is not + * possible to detect that a formerly checked box has just been unchecked, + * at least not with standard HTML means. + *

One way to address this is to look for a checkbox parameter value if + * you know that the checkbox has been visible in the form, resetting the + * checkbox if no value found. In Spring web MVC, this typically happens + * in a custom {@code onBind} implementation. + *

This auto-reset mechanism addresses this deficiency, provided + * that a marker parameter is sent for each checkbox field, like + * "_subscribeToNewsletter" for a "subscribeToNewsletter" field. + * As the marker parameter is sent in any case, the data binder can + * detect an empty field and automatically reset its value. + * @see #DEFAULT_FIELD_MARKER_PREFIX + */ + public void setFieldMarkerPrefix(@Nullable String fieldMarkerPrefix) { + this.fieldMarkerPrefix = fieldMarkerPrefix; + } + + /** + * Return the prefix for parameters that mark potentially empty fields. + */ + @Nullable + public String getFieldMarkerPrefix() { + return this.fieldMarkerPrefix; + } + + /** + * Specify a prefix that can be used for parameters that indicate default + * value fields, having "prefix + field" as name. The value of the default + * field is used when the field is not provided. + *

Default is "!", for "!FIELD" parameters (e.g. "!subscribeToNewsletter"). + * Set this to null if you want to turn off the field defaults completely. + *

HTML checkboxes only send a value when they're checked, so it is not + * possible to detect that a formerly checked box has just been unchecked, + * at least not with standard HTML means. A default field is especially + * useful when a checkbox represents a non-boolean value. + *

The presence of a default parameter preempts the behavior of a field + * marker for the given field. + * @see #DEFAULT_FIELD_DEFAULT_PREFIX + */ + public void setFieldDefaultPrefix(@Nullable String fieldDefaultPrefix) { + this.fieldDefaultPrefix = fieldDefaultPrefix; + } + + /** + * Return the prefix for parameters that mark default fields. + */ + @Nullable + public String getFieldDefaultPrefix() { + return this.fieldDefaultPrefix; + } + + /** + * Set whether to bind empty MultipartFile parameters. Default is "true". + *

Turn this off if you want to keep an already bound MultipartFile + * when the user resubmits the form without choosing a different file. + * Else, the already bound MultipartFile will be replaced by an empty + * MultipartFile holder. + * @see org.springframework.web.multipart.MultipartFile + */ + public void setBindEmptyMultipartFiles(boolean bindEmptyMultipartFiles) { + this.bindEmptyMultipartFiles = bindEmptyMultipartFiles; + } + + /** + * Return whether to bind empty MultipartFile parameters. + */ + public boolean isBindEmptyMultipartFiles() { + return this.bindEmptyMultipartFiles; + } + + + /** + * This implementation performs a field default and marker check + * before delegating to the superclass binding process. + * @see #checkFieldDefaults + * @see #checkFieldMarkers + */ + @Override + protected void doBind(MutablePropertyValues mpvs) { + checkFieldDefaults(mpvs); + checkFieldMarkers(mpvs); + super.doBind(mpvs); + } + + /** + * Check the given property values for field defaults, + * i.e. for fields that start with the field default prefix. + *

The existence of a field defaults indicates that the specified + * value should be used if the field is otherwise not present. + * @param mpvs the property values to be bound (can be modified) + * @see #getFieldDefaultPrefix + */ + protected void checkFieldDefaults(MutablePropertyValues mpvs) { + String fieldDefaultPrefix = getFieldDefaultPrefix(); + if (fieldDefaultPrefix != null) { + PropertyValue[] pvArray = mpvs.getPropertyValues(); + for (PropertyValue pv : pvArray) { + if (pv.getName().startsWith(fieldDefaultPrefix)) { + String field = pv.getName().substring(fieldDefaultPrefix.length()); + if (getPropertyAccessor().isWritableProperty(field) && !mpvs.contains(field)) { + mpvs.add(field, pv.getValue()); + } + mpvs.removePropertyValue(pv); + } + } + } + } + + /** + * Check the given property values for field markers, + * i.e. for fields that start with the field marker prefix. + *

The existence of a field marker indicates that the specified + * field existed in the form. If the property values do not contain + * a corresponding field value, the field will be considered as empty + * and will be reset appropriately. + * @param mpvs the property values to be bound (can be modified) + * @see #getFieldMarkerPrefix + * @see #getEmptyValue(String, Class) + */ + protected void checkFieldMarkers(MutablePropertyValues mpvs) { + String fieldMarkerPrefix = getFieldMarkerPrefix(); + if (fieldMarkerPrefix != null) { + PropertyValue[] pvArray = mpvs.getPropertyValues(); + for (PropertyValue pv : pvArray) { + if (pv.getName().startsWith(fieldMarkerPrefix)) { + String field = pv.getName().substring(fieldMarkerPrefix.length()); + if (getPropertyAccessor().isWritableProperty(field) && !mpvs.contains(field)) { + Class fieldType = getPropertyAccessor().getPropertyType(field); + mpvs.add(field, getEmptyValue(field, fieldType)); + } + mpvs.removePropertyValue(pv); + } + } + } + } + + /** + * Determine an empty value for the specified field. + *

The default implementation delegates to {@link #getEmptyValue(Class)} + * if the field type is known, otherwise falls back to {@code null}. + * @param field the name of the field + * @param fieldType the type of the field + * @return the empty value (for most fields: {@code null}) + */ + @Nullable + protected Object getEmptyValue(String field, @Nullable Class fieldType) { + return (fieldType != null ? getEmptyValue(fieldType) : null); + } + + /** + * Determine an empty value for the specified field. + *

The default implementation returns: + *

    + *
  • {@code Boolean.FALSE} for boolean fields + *
  • an empty array for array types + *
  • Collection implementations for Collection types + *
  • Map implementations for Map types + *
  • else, {@code null} is used as default + *
+ * @param fieldType the type of the field + * @return the empty value (for most fields: {@code null}) + * @since 5.0 + */ + @Nullable + public Object getEmptyValue(Class fieldType) { + try { + if (boolean.class == fieldType || Boolean.class == fieldType) { + // Special handling of boolean property. + return Boolean.FALSE; + } + else if (fieldType.isArray()) { + // Special handling of array property. + return Array.newInstance(fieldType.getComponentType(), 0); + } + else if (Collection.class.isAssignableFrom(fieldType)) { + return CollectionFactory.createCollection(fieldType, 0); + } + else if (Map.class.isAssignableFrom(fieldType)) { + return CollectionFactory.createMap(fieldType, 0); + } + } + catch (IllegalArgumentException ex) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to create default value - falling back to null: " + ex.getMessage()); + } + } + // Default value: null. + return null; + } + + + /** + * Bind all multipart files contained in the given request, if any + * (in case of a multipart request). To be called by subclasses. + *

Multipart files will only be added to the property values if they + * are not empty or if we're configured to bind empty multipart files too. + * @param multipartFiles a Map of field name String to MultipartFile object + * @param mpvs the property values to be bound (can be modified) + * @see org.springframework.web.multipart.MultipartFile + * @see #setBindEmptyMultipartFiles + */ + protected void bindMultipart(Map> multipartFiles, MutablePropertyValues mpvs) { + multipartFiles.forEach((key, values) -> { + if (values.size() == 1) { + MultipartFile value = values.get(0); + if (isBindEmptyMultipartFiles() || !value.isEmpty()) { + mpvs.add(key, value); + } + } + else { + mpvs.add(key, values); + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/ControllerAdvice.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/ControllerAdvice.java new file mode 100644 index 0000000000000000000000000000000000000000..e850b23f0164cf55ee3b23b95277b328528618a8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/ControllerAdvice.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.stereotype.Component; + +/** + * Specialization of {@link Component @Component} for classes that declare + * {@link ExceptionHandler @ExceptionHandler}, {@link InitBinder @InitBinder}, or + * {@link ModelAttribute @ModelAttribute} methods to be shared across + * multiple {@code @Controller} classes. + * + *

Classes annotated with {@code @ControllerAdvice} can be declared explicitly + * as Spring beans or auto-detected via classpath scanning. All such beans are + * sorted based on {@link org.springframework.core.annotation.Order @Order} + * semantics and applied in that order at runtime. For handling exceptions, an + * {@code @ExceptionHandler} will be picked on the first advice with a matching + * exception handler method. For model attributes and {@code InitBinder} + * initialization, {@code @ModelAttribute} and {@code @InitBinder} methods will + * also follow {@code @ControllerAdvice} order. + * + *

Note: For {@code @ExceptionHandler} methods, a root exception match will be + * preferred to just matching a cause of the current exception, among the handler + * methods of a particular advice bean. However, a cause match on a higher-priority + * advice will still be preferred over any match (whether root or cause level) + * on a lower-priority advice bean. As a consequence, please declare your primary + * root exception mappings on a prioritized advice bean with a corresponding order. + * + *

By default the methods in an {@code @ControllerAdvice} apply globally to + * all controllers. Use selectors such as {@link #annotations}, + * {@link #basePackageClasses}, and {@link #basePackages} (or its alias + * {@link #value}) to define a more narrow subset of targeted controllers. + * If multiple selectors are declared, {@code OR} logic is applied, meaning selected + * controllers should match at least one selector. Note that selector checks + * are performed at runtime, so adding many selectors may negatively impact + * performance and add complexity. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @author Sam Brannen + * @since 3.2 + * @see org.springframework.stereotype.Controller + * @see RestControllerAdvice + */ +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Component +public @interface ControllerAdvice { + + /** + * Alias for the {@link #basePackages} attribute. + *

Allows for more concise annotation declarations e.g.: + * {@code @ControllerAdvice("org.my.pkg")} is equivalent to + * {@code @ControllerAdvice(basePackages="org.my.pkg")}. + * @since 4.0 + * @see #basePackages() + */ + @AliasFor("basePackages") + String[] value() default {}; + + /** + * Array of base packages. + *

Controllers that belong to those base packages or sub-packages thereof + * will be included, e.g.: {@code @ControllerAdvice(basePackages="org.my.pkg")} + * or {@code @ControllerAdvice(basePackages={"org.my.pkg", "org.my.other.pkg"})}. + *

{@link #value} is an alias for this attribute, simply allowing for + * more concise use of the annotation. + *

Also consider using {@link #basePackageClasses()} as a type-safe + * alternative to String-based package names. + * @since 4.0 + */ + @AliasFor("value") + String[] basePackages() default {}; + + /** + * Type-safe alternative to {@link #value()} for specifying the packages + * to select Controllers to be assisted by the {@code @ControllerAdvice} + * annotated class. + *

Consider creating a special no-op marker class or interface in each package + * that serves no purpose other than being referenced by this attribute. + * @since 4.0 + */ + Class[] basePackageClasses() default {}; + + /** + * Array of classes. + *

Controllers that are assignable to at least one of the given types + * will be assisted by the {@code @ControllerAdvice} annotated class. + * @since 4.0 + */ + Class[] assignableTypes() default {}; + + /** + * Array of annotations. + *

Controllers that are annotated with this/one of those annotation(s) + * will be assisted by the {@code @ControllerAdvice} annotated class. + *

Consider creating a special annotation or use a predefined one, + * like {@link RestController @RestController}. + * @since 4.0 + */ + Class[] annotations() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/CookieValue.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/CookieValue.java new file mode 100644 index 0000000000000000000000000000000000000000..d49ffddaf21563c5e8b9c42e68ba2f7392a8c1a4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/CookieValue.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation which indicates that a method parameter should be bound to an HTTP cookie. + * + *

The method parameter may be declared as type {@link javax.servlet.http.Cookie} + * or as cookie value type (String, int, etc.). + * + * @author Juergen Hoeller + * @author Sam Brannen + * @since 3.0 + * @see RequestMapping + * @see RequestParam + * @see RequestHeader + * @see org.springframework.web.bind.annotation.RequestMapping + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface CookieValue { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the cookie to bind to. + * @since 4.2 + */ + @AliasFor("value") + String name() default ""; + + /** + * Whether the cookie is required. + *

Defaults to {@code true}, leading to an exception being thrown + * if the cookie is missing in the request. Switch this to + * {@code false} if you prefer a {@code null} value if the cookie is + * not present in the request. + *

Alternatively, provide a {@link #defaultValue}, which implicitly + * sets this flag to {@code false}. + */ + boolean required() default true; + + /** + * The default value to use as a fallback. + *

Supplying a default value implicitly sets {@link #required} to + * {@code false}. + */ + String defaultValue() default ValueConstants.DEFAULT_NONE; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java new file mode 100644 index 0000000000000000000000000000000000000000..a7c7e7474c9c0ef221e0c9cac86edda674f6e90b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.web.cors.CorsConfiguration; + +/** + * Annotation for permitting cross-origin requests on specific handler classes + * and/or handler methods. Processed if an appropriate {@code HandlerMapping} + * is configured. + * + *

Both Spring Web MVC and Spring WebFlux support this annotation through the + * {@code RequestMappingHandlerMapping} in their respective modules. The values + * from each type and method level pair of annotations are added to a + * {@link CorsConfiguration} and then default values are applied via + * {@link CorsConfiguration#applyPermitDefaultValues()}. + * + *

The rules for combining global and local configuration are generally + * additive -- e.g. all global and all local origins. For those attributes + * where only a single value can be accepted such as {@code allowCredentials} + * and {@code maxAge}, the local overrides the global value. + * See {@link CorsConfiguration#combine(CorsConfiguration)} for more details. + * + * @author Russell Allen + * @author Sebastien Deleuze + * @author Sam Brannen + * @since 4.2 + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface CrossOrigin { + + /** @deprecated as of Spring 5.0, in favor of {@link CorsConfiguration#applyPermitDefaultValues} */ + @Deprecated + String[] DEFAULT_ORIGINS = {"*"}; + + /** @deprecated as of Spring 5.0, in favor of {@link CorsConfiguration#applyPermitDefaultValues} */ + @Deprecated + String[] DEFAULT_ALLOWED_HEADERS = {"*"}; + + /** @deprecated as of Spring 5.0, in favor of {@link CorsConfiguration#applyPermitDefaultValues} */ + @Deprecated + boolean DEFAULT_ALLOW_CREDENTIALS = false; + + /** @deprecated as of Spring 5.0, in favor of {@link CorsConfiguration#applyPermitDefaultValues} */ + @Deprecated + long DEFAULT_MAX_AGE = 1800; + + + /** + * Alias for {@link #origins}. + */ + @AliasFor("origins") + String[] value() default {}; + + /** + * The list of allowed origins that be specific origins, e.g. + * {@code "https://domain1.com"}, or {@code "*"} for all origins. + *

A matched origin is listed in the {@code Access-Control-Allow-Origin} + * response header of preflight actual CORS requests. + *

By default all origins are allowed. + *

Note: CORS checks use values from "Forwarded" + * (RFC 7239), + * "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, + * if present, in order to reflect the client-originated address. + * Consider using the {@code ForwardedHeaderFilter} in order to choose from a + * central place whether to extract and use, or to discard such headers. + * See the Spring Framework reference for more on this filter. + * @see #value + */ + @AliasFor("value") + String[] origins() default {}; + + /** + * The list of request headers that are permitted in actual requests, + * possibly {@code "*"} to allow all headers. + *

Allowed headers are listed in the {@code Access-Control-Allow-Headers} + * response header of preflight requests. + *

A header name is not required to be listed if it is one of: + * {@code Cache-Control}, {@code Content-Language}, {@code Expires}, + * {@code Last-Modified}, or {@code Pragma} as per the CORS spec. + *

By default all requested headers are allowed. + */ + String[] allowedHeaders() default {}; + + /** + * The List of response headers that the user-agent will allow the client + * to access on an actual response, other than "simple" headers, i.e. + * {@code Cache-Control}, {@code Content-Language}, {@code Content-Type}, + * {@code Expires}, {@code Last-Modified}, or {@code Pragma}, + *

Exposed headers are listed in the {@code Access-Control-Expose-Headers} + * response header of actual CORS requests. + *

The special value {@code "*"} allows all headers to be exposed for + * non-credentialed requests. + *

By default no headers are listed as exposed. + */ + String[] exposedHeaders() default {}; + + /** + * The list of supported HTTP request methods. + *

By default the supported methods are the same as the ones to which a + * controller method is mapped. + */ + RequestMethod[] methods() default {}; + + /** + * Whether the browser should send credentials, such as cookies along with + * cross domain requests, to the annotated endpoint. The configured value is + * set on the {@code Access-Control-Allow-Credentials} response header of + * preflight requests. + *

NOTE: Be aware that this option establishes a high + * level of trust with the configured domains and also increases the surface + * attack of the web application by exposing sensitive user-specific + * information such as cookies and CSRF tokens. + *

By default this is not set in which case the + * {@code Access-Control-Allow-Credentials} header is also not set and + * credentials are therefore not allowed. + */ + String allowCredentials() default ""; + + /** + * The maximum age (in seconds) of the cache duration for preflight responses. + *

This property controls the value of the {@code Access-Control-Max-Age} + * response header of preflight requests. + *

Setting this to a reasonable value can reduce the number of preflight + * request/response interactions required by the browser. + * A negative value means undefined. + *

By default this is set to {@code 1800} seconds (30 minutes). + */ + long maxAge() default -1; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/DeleteMapping.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/DeleteMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..1b9150ccf95b21e26fdf8b5f5152472d17388f0f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/DeleteMapping.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation for mapping HTTP {@code DELETE} requests onto specific handler + * methods. + * + *

Specifically, {@code @DeleteMapping} is a composed annotation that + * acts as a shortcut for {@code @RequestMapping(method = RequestMethod.DELETE)}. + * + * @author Sam Brannen + * @since 4.3 + * @see GetMapping + * @see PostMapping + * @see PutMapping + * @see PatchMapping + * @see RequestMapping + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@RequestMapping(method = RequestMethod.DELETE) +public @interface DeleteMapping { + + /** + * Alias for {@link RequestMapping#name}. + */ + @AliasFor(annotation = RequestMapping.class) + String name() default ""; + + /** + * Alias for {@link RequestMapping#value}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] value() default {}; + + /** + * Alias for {@link RequestMapping#path}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] path() default {}; + + /** + * Alias for {@link RequestMapping#params}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] params() default {}; + + /** + * Alias for {@link RequestMapping#headers}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] headers() default {}; + + /** + * Alias for {@link RequestMapping#consumes}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] consumes() default {}; + + /** + * Alias for {@link RequestMapping#produces}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] produces() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/ExceptionHandler.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/ExceptionHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..6188a45d3d5866638e0410ba2fd38ce566dce5e4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/ExceptionHandler.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation for handling exceptions in specific handler classes and/or + * handler methods. + * + *

Handler methods which are annotated with this annotation are allowed to + * have very flexible signatures. They may have parameters of the following + * types, in arbitrary order: + *

    + *
  • An exception argument: declared as a general Exception or as a more + * specific exception. This also serves as a mapping hint if the annotation + * itself does not narrow the exception types through its {@link #value()}. + *
  • Request and/or response objects (typically from the Servlet API). + * You may choose any specific request/response type, e.g. + * {@link javax.servlet.ServletRequest} / {@link javax.servlet.http.HttpServletRequest}. + *
  • Session object: typically {@link javax.servlet.http.HttpSession}. + * An argument of this type will enforce the presence of a corresponding session. + * As a consequence, such an argument will never be {@code null}. + * Note that session access may not be thread-safe, in particular in a + * Servlet environment: Consider switching the + * {@link org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter#setSynchronizeOnSession + * "synchronizeOnSession"} flag to "true" if multiple requests are allowed to + * access a session concurrently. + *
  • {@link org.springframework.web.context.request.WebRequest} or + * {@link org.springframework.web.context.request.NativeWebRequest}. + * Allows for generic request parameter access as well as request/session + * attribute access, without ties to the native Servlet API. + *
  • {@link java.util.Locale} for the current request locale + * (determined by the most specific locale resolver available, + * i.e. the configured {@link org.springframework.web.servlet.LocaleResolver} + * in a Servlet environment). + *
  • {@link java.io.InputStream} / {@link java.io.Reader} for access + * to the request's content. This will be the raw InputStream/Reader as + * exposed by the Servlet API. + *
  • {@link java.io.OutputStream} / {@link java.io.Writer} for generating + * the response's content. This will be the raw OutputStream/Writer as + * exposed by the Servlet API. + *
  • {@link org.springframework.ui.Model} as an alternative to returning + * a model map from the handler method. Note that the provided model is not + * pre-populated with regular model attributes and therefore always empty, + * as a convenience for preparing the model for an exception-specific view. + *
+ * + *

The following return types are supported for handler methods: + *

    + *
  • A {@code ModelAndView} object (from Servlet MVC). + *
  • A {@link org.springframework.ui.Model} object, with the view name implicitly + * determined through a {@link org.springframework.web.servlet.RequestToViewNameTranslator}. + *
  • A {@link java.util.Map} object for exposing a model, + * with the view name implicitly determined through a + * {@link org.springframework.web.servlet.RequestToViewNameTranslator}. + *
  • A {@link org.springframework.web.servlet.View} object. + *
  • A {@link String} value which is interpreted as view name. + *
  • {@link ResponseBody @ResponseBody} annotated methods (Servlet-only) + * to set the response content. The return value will be converted to the + * response stream using + * {@linkplain org.springframework.http.converter.HttpMessageConverter message converters}. + *
  • An {@link org.springframework.http.HttpEntity HttpEntity<?>} or + * {@link org.springframework.http.ResponseEntity ResponseEntity<?>} object + * (Servlet-only) to set response headers and content. The ResponseEntity body + * will be converted and written to the response stream using + * {@linkplain org.springframework.http.converter.HttpMessageConverter message converters}. + *
  • {@code void} if the method handles the response itself (by + * writing the response content directly, declaring an argument of type + * {@link javax.servlet.ServletResponse} / {@link javax.servlet.http.HttpServletResponse} + * for that purpose) or if the view name is supposed to be implicitly determined + * through a {@link org.springframework.web.servlet.RequestToViewNameTranslator} + * (not declaring a response argument in the handler method signature). + *
+ * + *

You may combine the {@code ExceptionHandler} annotation with + * {@link ResponseStatus @ResponseStatus} for a specific HTTP error status. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @see org.springframework.web.context.request.WebRequest + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface ExceptionHandler { + + /** + * Exceptions handled by the annotated method. If empty, will default to any + * exceptions listed in the method argument list. + */ + Class[] value() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/GetMapping.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/GetMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..bceb66e419fd8dc1323de92a8446516451cda139 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/GetMapping.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation for mapping HTTP {@code GET} requests onto specific handler + * methods. + * + *

Specifically, {@code @GetMapping} is a composed annotation that + * acts as a shortcut for {@code @RequestMapping(method = RequestMethod.GET)}. + * + * + * @author Sam Brannen + * @since 4.3 + * @see PostMapping + * @see PutMapping + * @see DeleteMapping + * @see PatchMapping + * @see RequestMapping + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@RequestMapping(method = RequestMethod.GET) +public @interface GetMapping { + + /** + * Alias for {@link RequestMapping#name}. + */ + @AliasFor(annotation = RequestMapping.class) + String name() default ""; + + /** + * Alias for {@link RequestMapping#value}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] value() default {}; + + /** + * Alias for {@link RequestMapping#path}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] path() default {}; + + /** + * Alias for {@link RequestMapping#params}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] params() default {}; + + /** + * Alias for {@link RequestMapping#headers}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] headers() default {}; + + /** + * Alias for {@link RequestMapping#consumes}. + * @since 4.3.5 + */ + @AliasFor(annotation = RequestMapping.class) + String[] consumes() default {}; + + /** + * Alias for {@link RequestMapping#produces}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] produces() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/InitBinder.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/InitBinder.java new file mode 100644 index 0000000000000000000000000000000000000000..5fc5d6bcc2797eedf243533b3698e4028f9e3ba4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/InitBinder.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation that identifies methods which initialize the + * {@link org.springframework.web.bind.WebDataBinder} which + * will be used for populating command and form object arguments + * of annotated handler methods. + * + *

Such init-binder methods support all arguments that {@link RequestMapping} + * supports, except for command/form objects and corresponding validation result + * objects. Init-binder methods must not have a return value; they are usually + * declared as {@code void}. + * + *

Typical arguments are {@link org.springframework.web.bind.WebDataBinder} + * in combination with {@link org.springframework.web.context.request.WebRequest} + * or {@link java.util.Locale}, allowing to register context-specific editors. + * + * @author Juergen Hoeller + * @since 2.5 + * @see org.springframework.web.bind.WebDataBinder + * @see org.springframework.web.context.request.WebRequest + */ +@Target({ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface InitBinder { + + /** + * The names of command/form attributes and/or request parameters + * that this init-binder method is supposed to apply to. + *

Default is to apply to all command/form attributes and all request parameters + * processed by the annotated handler class. Specifying model attribute names or + * request parameter names here restricts the init-binder method to those specific + * attributes/parameters, with different init-binder methods typically applying to + * different groups of attributes or parameters. + */ + String[] value() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/Mapping.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/Mapping.java new file mode 100644 index 0000000000000000000000000000000000000000..beda6d7bb63d6990c43d0bcf00ef415c61c8df3d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/Mapping.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2009 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Meta annotation that indicates a web mapping annotation. + * + * @author Juergen Hoeller + * @since 3.0 + * @see RequestMapping + */ +@Target({ElementType.ANNOTATION_TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface Mapping { + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/MatrixVariable.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/MatrixVariable.java new file mode 100644 index 0000000000000000000000000000000000000000..12ae0c02e005182d7d7b95054f6f9c09803e4d6d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/MatrixVariable.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation which indicates that a method parameter should be bound to a + * name-value pair within a path segment. Supported for {@link RequestMapping} + * annotated handler methods. + * + *

If the method parameter type is {@link java.util.Map} and a matrix variable + * name is specified, then the matrix variable value is converted to a + * {@link java.util.Map} assuming an appropriate conversion strategy is available. + * + *

If the method parameter is {@link java.util.Map Map<String, String>} or + * {@link org.springframework.util.MultiValueMap MultiValueMap<String, String>} + * and a variable name is not specified, then the map is populated with all + * matrix variable names and values. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @since 3.2 + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface MatrixVariable { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the matrix variable. + * @since 4.2 + * @see #value + */ + @AliasFor("value") + String name() default ""; + + /** + * The name of the URI path variable where the matrix variable is located, + * if necessary for disambiguation (e.g. a matrix variable with the same + * name present in more than one path segment). + */ + String pathVar() default ValueConstants.DEFAULT_NONE; + + /** + * Whether the matrix variable is required. + *

Default is {@code true}, leading to an exception being thrown in + * case the variable is missing in the request. Switch this to {@code false} + * if you prefer a {@code null} if the variable is missing. + *

Alternatively, provide a {@link #defaultValue}, which implicitly sets + * this flag to {@code false}. + */ + boolean required() default true; + + /** + * The default value to use as a fallback. + *

Supplying a default value implicitly sets {@link #required} to + * {@code false}. + */ + String defaultValue() default ValueConstants.DEFAULT_NONE; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/ModelAttribute.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/ModelAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..717a6d0106ef65829bfa98a5fbc1de6379e695d4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/ModelAttribute.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.ui.Model; + +/** + * Annotation that binds a method parameter or method return value + * to a named model attribute, exposed to a web view. Supported + * for controller classes with {@link RequestMapping @RequestMapping} + * methods. + * + *

Can be used to expose command objects to a web view, using + * specific attribute names, through annotating corresponding + * parameters of an {@link RequestMapping @RequestMapping} method. + * + *

Can also be used to expose reference data to a web view + * through annotating accessor methods in a controller class with + * {@link RequestMapping @RequestMapping} methods. Such accessor + * methods are allowed to have any arguments that + * {@link RequestMapping @RequestMapping} methods support, returning + * the model attribute value to expose. + * + *

Note however that reference data and all other model content is + * not available to web views when request processing results in an + * {@code Exception} since the exception could be raised at any time + * making the content of the model unreliable. For this reason + * {@link ExceptionHandler @ExceptionHandler} methods do not provide + * access to a {@link Model} argument. + * + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 2.5 + */ +@Target({ElementType.PARAMETER, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface ModelAttribute { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the model attribute to bind to. + *

The default model attribute name is inferred from the declared + * attribute type (i.e. the method parameter type or method return type), + * based on the non-qualified class name: + * e.g. "orderAddress" for class "mypackage.OrderAddress", + * or "orderAddressList" for "List<mypackage.OrderAddress>". + * @since 4.3 + */ + @AliasFor("value") + String name() default ""; + + /** + * Allows declaring data binding disabled directly on an {@code @ModelAttribute} + * method parameter or on the attribute returned from an {@code @ModelAttribute} + * method, both of which would prevent data binding for that attribute. + *

By default this is set to {@code true} in which case data binding applies. + * Set this to {@code false} to disable data binding. + * @since 4.3 + */ + boolean binding() default true; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/PatchMapping.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/PatchMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..72fd7111919f52940213b14936b441af6851dc24 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/PatchMapping.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation for mapping HTTP {@code PATCH} requests onto specific handler + * methods. + * + *

Specifically, {@code @PatchMapping} is a composed annotation that + * acts as a shortcut for {@code @RequestMapping(method = RequestMethod.PATCH)}. + * + * @author Sam Brannen + * @since 4.3 + * @see GetMapping + * @see PostMapping + * @see PutMapping + * @see DeleteMapping + * @see RequestMapping + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@RequestMapping(method = RequestMethod.PATCH) +public @interface PatchMapping { + + /** + * Alias for {@link RequestMapping#name}. + */ + @AliasFor(annotation = RequestMapping.class) + String name() default ""; + + /** + * Alias for {@link RequestMapping#value}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] value() default {}; + + /** + * Alias for {@link RequestMapping#path}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] path() default {}; + + /** + * Alias for {@link RequestMapping#params}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] params() default {}; + + /** + * Alias for {@link RequestMapping#headers}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] headers() default {}; + + /** + * Alias for {@link RequestMapping#consumes}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] consumes() default {}; + + /** + * Alias for {@link RequestMapping#produces}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] produces() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/PathVariable.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/PathVariable.java new file mode 100644 index 0000000000000000000000000000000000000000..cee9fe2485c254f6fd234612dd5b99435b1a59ce --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/PathVariable.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation which indicates that a method parameter should be bound to a URI template + * variable. Supported for {@link RequestMapping} annotated handler methods. + * + *

If the method parameter is {@link java.util.Map Map<String, String>} + * then the map is populated with all path variable names and values. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @see RequestMapping + * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface PathVariable { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the path variable to bind to. + * @since 4.3.3 + */ + @AliasFor("value") + String name() default ""; + + /** + * Whether the path variable is required. + *

Defaults to {@code true}, leading to an exception being thrown if the path + * variable is missing in the incoming request. Switch this to {@code false} if + * you prefer a {@code null} or Java 8 {@code java.util.Optional} in this case. + * e.g. on a {@code ModelAttribute} method which serves for different requests. + * @since 4.3.3 + */ + boolean required() default true; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/PostMapping.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/PostMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..f5a304038b33ea986a5e6de8877bccc7823bcd0b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/PostMapping.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation for mapping HTTP {@code POST} requests onto specific handler + * methods. + * + *

Specifically, {@code @PostMapping} is a composed annotation that + * acts as a shortcut for {@code @RequestMapping(method = RequestMethod.POST)}. + * + * @author Sam Brannen + * @since 4.3 + * @see GetMapping + * @see PutMapping + * @see DeleteMapping + * @see PatchMapping + * @see RequestMapping + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@RequestMapping(method = RequestMethod.POST) +public @interface PostMapping { + + /** + * Alias for {@link RequestMapping#name}. + */ + @AliasFor(annotation = RequestMapping.class) + String name() default ""; + + /** + * Alias for {@link RequestMapping#value}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] value() default {}; + + /** + * Alias for {@link RequestMapping#path}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] path() default {}; + + /** + * Alias for {@link RequestMapping#params}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] params() default {}; + + /** + * Alias for {@link RequestMapping#headers}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] headers() default {}; + + /** + * Alias for {@link RequestMapping#consumes}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] consumes() default {}; + + /** + * Alias for {@link RequestMapping#produces}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] produces() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/PutMapping.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/PutMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..0040291dbefbb4e9dff9caba45200f269b5094ae --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/PutMapping.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation for mapping HTTP {@code PUT} requests onto specific handler + * methods. + * + *

Specifically, {@code @PutMapping} is a composed annotation that + * acts as a shortcut for {@code @RequestMapping(method = RequestMethod.PUT)}. + * + * @author Sam Brannen + * @since 4.3 + * @see GetMapping + * @see PostMapping + * @see DeleteMapping + * @see PatchMapping + * @see RequestMapping + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@RequestMapping(method = RequestMethod.PUT) +public @interface PutMapping { + + /** + * Alias for {@link RequestMapping#name}. + */ + @AliasFor(annotation = RequestMapping.class) + String name() default ""; + + /** + * Alias for {@link RequestMapping#value}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] value() default {}; + + /** + * Alias for {@link RequestMapping#path}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] path() default {}; + + /** + * Alias for {@link RequestMapping#params}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] params() default {}; + + /** + * Alias for {@link RequestMapping#headers}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] headers() default {}; + + /** + * Alias for {@link RequestMapping#consumes}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] consumes() default {}; + + /** + * Alias for {@link RequestMapping#produces}. + */ + @AliasFor(annotation = RequestMapping.class) + String[] produces() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestAttribute.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..8a760d84c0c5b10105c80b3c597011f251027350 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestAttribute.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation to bind a method parameter to a request attribute. + * + *

The main motivation is to provide convenient access to request attributes + * from a controller method with an optional/required check and a cast to the + * target method parameter type. + * + * @author Rossen Stoyanchev + * @since 4.3 + * @see RequestMapping + * @see SessionAttribute + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface RequestAttribute { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the request attribute to bind to. + *

The default name is inferred from the method parameter name. + */ + @AliasFor("value") + String name() default ""; + + /** + * Whether the request attribute is required. + *

Defaults to {@code true}, leading to an exception being thrown if + * the attribute is missing. Switch this to {@code false} if you prefer + * a {@code null} or Java 8 {@code java.util.Optional} if the attribute + * doesn't exist. + */ + boolean required() default true; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestBody.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestBody.java new file mode 100644 index 0000000000000000000000000000000000000000..5147a4e8077a7a22c11423a83d366d100d634085 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestBody.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.http.converter.HttpMessageConverter; + +/** + * Annotation indicating a method parameter should be bound to the body of the web request. + * The body of the request is passed through an {@link HttpMessageConverter} to resolve the + * method argument depending on the content type of the request. Optionally, automatic + * validation can be applied by annotating the argument with {@code @Valid}. + * + *

Supported for annotated handler methods. + * + * @author Arjen Poutsma + * @since 3.0 + * @see RequestHeader + * @see ResponseBody + * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface RequestBody { + + /** + * Whether body content is required. + *

Default is {@code true}, leading to an exception thrown in case + * there is no body content. Switch this to {@code false} if you prefer + * {@code null} to be passed when the body content is {@code null}. + * @since 3.2 + */ + boolean required() default true; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestHeader.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestHeader.java new file mode 100644 index 0000000000000000000000000000000000000000..b02e635693c145e2a3c4dff6e0fb7a31c00bd202 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestHeader.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation which indicates that a method parameter should be bound to a web request header. + * + *

Supported for annotated handler methods in Spring MVC and Spring WebFlux. + * + *

If the method parameter is {@link java.util.Map Map<String, String>}, + * {@link org.springframework.util.MultiValueMap MultiValueMap<String, String>}, + * or {@link org.springframework.http.HttpHeaders HttpHeaders} then the map is + * populated with all header names and values. + * + * @author Juergen Hoeller + * @author Sam Brannen + * @since 3.0 + * @see RequestMapping + * @see RequestParam + * @see CookieValue + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface RequestHeader { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the request header to bind to. + * @since 4.2 + */ + @AliasFor("value") + String name() default ""; + + /** + * Whether the header is required. + *

Defaults to {@code true}, leading to an exception being thrown + * if the header is missing in the request. Switch this to + * {@code false} if you prefer a {@code null} value if the header is + * not present in the request. + *

Alternatively, provide a {@link #defaultValue}, which implicitly + * sets this flag to {@code false}. + */ + boolean required() default true; + + /** + * The default value to use as a fallback. + *

Supplying a default value implicitly sets {@link #required} to + * {@code false}. + */ + String defaultValue() default ValueConstants.DEFAULT_NONE; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestMapping.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..1a510da07dab3f2d543b4a605ff839549372d3e4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestMapping.java @@ -0,0 +1,209 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation for mapping web requests onto methods in request-handling classes + * with flexible method signatures. + * + *

Both Spring MVC and Spring WebFlux support this annotation through a + * {@code RequestMappingHandlerMapping} and {@code RequestMappingHandlerAdapter} + * in their respective modules and package structure. For the exact list of + * supported handler method arguments and return types in each, please use the + * reference documentation links below: + *

+ * + *

Note: This annotation can be used both at the class and + * at the method level. In most cases, at the method level applications will + * prefer to use one of the HTTP method specific variants + * {@link GetMapping @GetMapping}, {@link PostMapping @PostMapping}, + * {@link PutMapping @PutMapping}, {@link DeleteMapping @DeleteMapping}, or + * {@link PatchMapping @PatchMapping}.

+ * + *

NOTE: When using controller interfaces (e.g. for AOP proxying), + * make sure to consistently put all your mapping annotations - such as + * {@code @RequestMapping} and {@code @SessionAttributes} - on + * the controller interface rather than on the implementation class. + * + * @author Juergen Hoeller + * @author Arjen Poutsma + * @author Sam Brannen + * @since 2.5 + * @see GetMapping + * @see PostMapping + * @see PutMapping + * @see DeleteMapping + * @see PatchMapping + * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter + * @see org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Mapping +public @interface RequestMapping { + + /** + * Assign a name to this mapping. + *

Supported at the type level as well as at the method level! + * When used on both levels, a combined name is derived by concatenation + * with "#" as separator. + * @see org.springframework.web.servlet.mvc.method.annotation.MvcUriComponentsBuilder + * @see org.springframework.web.servlet.handler.HandlerMethodMappingNamingStrategy + */ + String name() default ""; + + /** + * The primary mapping expressed by this annotation. + *

This is an alias for {@link #path}. For example + * {@code @RequestMapping("/foo")} is equivalent to + * {@code @RequestMapping(path="/foo")}. + *

Supported at the type level as well as at the method level! + * When used at the type level, all method-level mappings inherit + * this primary mapping, narrowing it for a specific handler method. + */ + @AliasFor("path") + String[] value() default {}; + + /** + * The path mapping URIs (e.g. "/myPath.do"). + * Ant-style path patterns are also supported (e.g. "/myPath/*.do"). + * At the method level, relative paths (e.g. "edit.do") are supported + * within the primary mapping expressed at the type level. + * Path mapping URIs may contain placeholders (e.g. "/${connect}"). + *

Supported at the type level as well as at the method level! + * When used at the type level, all method-level mappings inherit + * this primary mapping, narrowing it for a specific handler method. + * @see org.springframework.web.bind.annotation.ValueConstants#DEFAULT_NONE + * @since 4.2 + */ + @AliasFor("value") + String[] path() default {}; + + /** + * The HTTP request methods to map to, narrowing the primary mapping: + * GET, POST, HEAD, OPTIONS, PUT, PATCH, DELETE, TRACE. + *

Supported at the type level as well as at the method level! + * When used at the type level, all method-level mappings inherit + * this HTTP method restriction (i.e. the type-level restriction + * gets checked before the handler method is even resolved). + */ + RequestMethod[] method() default {}; + + /** + * The parameters of the mapped request, narrowing the primary mapping. + *

Same format for any environment: a sequence of "myParam=myValue" style + * expressions, with a request only mapped if each such parameter is found + * to have the given value. Expressions can be negated by using the "!=" operator, + * as in "myParam!=myValue". "myParam" style expressions are also supported, + * with such parameters having to be present in the request (allowed to have + * any value). Finally, "!myParam" style expressions indicate that the + * specified parameter is not supposed to be present in the request. + *

Supported at the type level as well as at the method level! + * When used at the type level, all method-level mappings inherit + * this parameter restriction (i.e. the type-level restriction + * gets checked before the handler method is even resolved). + *

Parameter mappings are considered as restrictions that are enforced at + * the type level. The primary path mapping (i.e. the specified URI value) + * still has to uniquely identify the target handler, with parameter mappings + * simply expressing preconditions for invoking the handler. + */ + String[] params() default {}; + + /** + * The headers of the mapped request, narrowing the primary mapping. + *

Same format for any environment: a sequence of "My-Header=myValue" style + * expressions, with a request only mapped if each such header is found + * to have the given value. Expressions can be negated by using the "!=" operator, + * as in "My-Header!=myValue". "My-Header" style expressions are also supported, + * with such headers having to be present in the request (allowed to have + * any value). Finally, "!My-Header" style expressions indicate that the + * specified header is not supposed to be present in the request. + *

Also supports media type wildcards (*), for headers such as Accept + * and Content-Type. For instance, + *

+	 * @RequestMapping(value = "/something", headers = "content-type=text/*")
+	 * 
+ * will match requests with a Content-Type of "text/html", "text/plain", etc. + *

Supported at the type level as well as at the method level! + * When used at the type level, all method-level mappings inherit + * this header restriction (i.e. the type-level restriction + * gets checked before the handler method is even resolved). + * @see org.springframework.http.MediaType + */ + String[] headers() default {}; + + /** + * The consumable media types of the mapped request, narrowing the primary mapping. + *

The format is a single media type or a sequence of media types, + * with a request only mapped if the {@code Content-Type} matches one of these media types. + * Examples: + *

+	 * consumes = "text/plain"
+	 * consumes = {"text/plain", "application/*"}
+	 * 
+ * Expressions can be negated by using the "!" operator, as in "!text/plain", which matches + * all requests with a {@code Content-Type} other than "text/plain". + *

Supported at the type level as well as at the method level! + * When used at the type level, all method-level mappings override + * this consumes restriction. + * @see org.springframework.http.MediaType + * @see javax.servlet.http.HttpServletRequest#getContentType() + */ + String[] consumes() default {}; + + /** + * The producible media types of the mapped request, narrowing the primary mapping. + *

The format is a single media type or a sequence of media types, + * with a request only mapped if the {@code Accept} matches one of these media types. + * Examples: + *

+	 * produces = "text/plain"
+	 * produces = {"text/plain", "application/*"}
+	 * produces = MediaType.APPLICATION_JSON_UTF8_VALUE
+	 * 
+ *

It affects the actual content type written, for example to produce a JSON response + * with UTF-8 encoding, {@link org.springframework.http.MediaType#APPLICATION_JSON_UTF8_VALUE} should be used. + *

Expressions can be negated by using the "!" operator, as in "!text/plain", which matches + * all requests with a {@code Accept} other than "text/plain". + *

Supported at the type level as well as at the method level! + * When used at the type level, all method-level mappings override + * this produces restriction. + * @see org.springframework.http.MediaType + */ + String[] produces() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestMethod.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestMethod.java new file mode 100644 index 0000000000000000000000000000000000000000..3c7d700f5d1d670b70656c62bf2ae889339c203b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestMethod.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +/** + * Java 5 enumeration of HTTP request methods. Intended for use with the + * {@link RequestMapping#method()} attribute of the {@link RequestMapping} annotation. + * + *

Note that, by default, {@link org.springframework.web.servlet.DispatcherServlet} + * supports GET, HEAD, POST, PUT, PATCH and DELETE only. DispatcherServlet will + * process TRACE and OPTIONS with the default HttpServlet behavior unless explicitly + * told to dispatch those request types as well: Check out the "dispatchOptionsRequest" + * and "dispatchTraceRequest" properties, switching them to "true" if necessary. + * + * @author Juergen Hoeller + * @since 2.5 + * @see RequestMapping + * @see org.springframework.web.servlet.DispatcherServlet#setDispatchOptionsRequest + * @see org.springframework.web.servlet.DispatcherServlet#setDispatchTraceRequest + */ +public enum RequestMethod { + + GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS, TRACE + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestParam.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestParam.java new file mode 100644 index 0000000000000000000000000000000000000000..28e706a3c4a0af29d5ead2fb74d9e8caed238cc6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestParam.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.Map; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation which indicates that a method parameter should be bound to a web + * request parameter. + * + *

Supported for annotated handler methods in Spring MVC and Spring WebFlux + * as follows: + *

    + *
  • In Spring MVC, "request parameters" map to query parameters, form data, + * and parts in multipart requests. This is because the Servlet API combines + * query parameters and form data into a single map called "parameters", and + * that includes automatic parsing of the request body. + *
  • In Spring WebFlux, "request parameters" map to query parameters only. + * To work with all 3, query, form data, and multipart data, you can use data + * binding to a command object annotated with {@link ModelAttribute}. + *
+ * + *

If the method parameter type is {@link Map} and a request parameter name + * is specified, then the request parameter value is converted to a {@link Map} + * assuming an appropriate conversion strategy is available. + * + *

If the method parameter is {@link java.util.Map Map<String, String>} or + * {@link org.springframework.util.MultiValueMap MultiValueMap<String, String>} + * and a parameter name is not specified, then the map parameter is populated + * with all request parameter names and values. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Sam Brannen + * @since 2.5 + * @see RequestMapping + * @see RequestHeader + * @see CookieValue + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface RequestParam { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the request parameter to bind to. + * @since 4.2 + */ + @AliasFor("value") + String name() default ""; + + /** + * Whether the parameter is required. + *

Defaults to {@code true}, leading to an exception being thrown + * if the parameter is missing in the request. Switch this to + * {@code false} if you prefer a {@code null} value if the parameter is + * not present in the request. + *

Alternatively, provide a {@link #defaultValue}, which implicitly + * sets this flag to {@code false}. + */ + boolean required() default true; + + /** + * The default value to use as a fallback when the request parameter is + * not provided or has an empty value. + *

Supplying a default value implicitly sets {@link #required} to + * {@code false}. + */ + String defaultValue() default ValueConstants.DEFAULT_NONE; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestPart.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestPart.java new file mode 100644 index 0000000000000000000000000000000000000000..0f8b430e54a57771a52173951b87aaf507a518c5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RequestPart.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.beans.PropertyEditor; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartResolver; + +/** + * Annotation that can be used to associate the part of a "multipart/form-data" request + * with a method argument. + * + *

Supported method argument types include {@link MultipartFile} in conjunction with + * Spring's {@link MultipartResolver} abstraction, {@code javax.servlet.http.Part} in + * conjunction with Servlet 3.0 multipart requests, or otherwise for any other method + * argument, the content of the part is passed through an {@link HttpMessageConverter} + * taking into consideration the 'Content-Type' header of the request part. This is + * analogous to what @{@link RequestBody} does to resolve an argument based on the + * content of a non-multipart regular request. + * + *

Note that @{@link RequestParam} annotation can also be used to associate the part + * of a "multipart/form-data" request with a method argument supporting the same method + * argument types. The main difference is that when the method argument is not a String + * or raw {@code MultipartFile} / {@code Part}, {@code @RequestParam} relies on type + * conversion via a registered {@link Converter} or {@link PropertyEditor} while + * {@link RequestPart} relies on {@link HttpMessageConverter HttpMessageConverters} + * taking into consideration the 'Content-Type' header of the request part. + * {@link RequestParam} is likely to be used with name-value form fields while + * {@link RequestPart} is likely to be used with parts containing more complex content + * e.g. JSON, XML). + * + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @author Sam Brannen + * @since 3.1 + * @see RequestParam + * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface RequestPart { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the part in the {@code "multipart/form-data"} request to bind to. + * @since 4.2 + */ + @AliasFor("value") + String name() default ""; + + /** + * Whether the part is required. + *

Defaults to {@code true}, leading to an exception being thrown + * if the part is missing in the request. Switch this to + * {@code false} if you prefer a {@code null} value if the part is + * not present in the request. + */ + boolean required() default true; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/ResponseBody.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/ResponseBody.java new file mode 100644 index 0000000000000000000000000000000000000000..69f9c82f3988d8cb655a43d2b385b2a8f29db2eb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/ResponseBody.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation that indicates a method return value should be bound to the web + * response body. Supported for annotated handler methods. + * + *

As of version 4.0 this annotation can also be added on the type level in + * which case it is inherited and does not need to be added on the method level. + * + * @author Arjen Poutsma + * @since 3.0 + * @see RequestBody + * @see RestController + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface ResponseBody { + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/ResponseStatus.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/ResponseStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..c283fd85a5f43325735a4312b24dda4517accc24 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/ResponseStatus.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.http.HttpStatus; + +/** + * Marks a method or exception class with the status {@link #code} and + * {@link #reason} that should be returned. + * + *

The status code is applied to the HTTP response when the handler + * method is invoked and overrides status information set by other means, + * like {@code ResponseEntity} or {@code "redirect:"}. + * + *

Warning: when using this annotation on an exception + * class, or when setting the {@code reason} attribute of this annotation, + * the {@code HttpServletResponse.sendError} method will be used. + * + *

With {@code HttpServletResponse.sendError}, the response is considered + * complete and should not be written to any further. Furthermore, the Servlet + * container will typically write an HTML error page therefore making the + * use of a {@code reason} unsuitable for REST APIs. For such cases it is + * preferable to use a {@link org.springframework.http.ResponseEntity} as + * a return type and avoid the use of {@code @ResponseStatus} altogether. + * + *

Note that a controller class may also be annotated with + * {@code @ResponseStatus} and is then inherited by all {@code @RequestMapping} + * methods. + * + * @author Arjen Poutsma + * @author Sam Brannen + * @since 3.0 + * @see org.springframework.web.servlet.mvc.annotation.ResponseStatusExceptionResolver + * @see javax.servlet.http.HttpServletResponse#sendError(int, String) + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface ResponseStatus { + + /** + * Alias for {@link #code}. + */ + @AliasFor("code") + HttpStatus value() default HttpStatus.INTERNAL_SERVER_ERROR; + + /** + * The status code to use for the response. + *

Default is {@link HttpStatus#INTERNAL_SERVER_ERROR}, which should + * typically be changed to something more appropriate. + * @since 4.2 + * @see javax.servlet.http.HttpServletResponse#setStatus(int) + * @see javax.servlet.http.HttpServletResponse#sendError(int) + */ + @AliasFor("value") + HttpStatus code() default HttpStatus.INTERNAL_SERVER_ERROR; + + /** + * The reason to be used for the response. + * @see javax.servlet.http.HttpServletResponse#sendError(int, String) + */ + String reason() default ""; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RestController.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RestController.java new file mode 100644 index 0000000000000000000000000000000000000000..68abaf310f5a9eb2b59013080fdb4a64945f9bb0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RestController.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.stereotype.Controller; + +/** + * A convenience annotation that is itself annotated with + * {@link Controller @Controller} and {@link ResponseBody @ResponseBody}. + *

+ * Types that carry this annotation are treated as controllers where + * {@link RequestMapping @RequestMapping} methods assume + * {@link ResponseBody @ResponseBody} semantics by default. + * + *

NOTE: {@code @RestController} is processed if an appropriate + * {@code HandlerMapping}-{@code HandlerAdapter} pair is configured such as the + * {@code RequestMappingHandlerMapping}-{@code RequestMappingHandlerAdapter} + * pair which are the default in the MVC Java config and the MVC namespace. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + * @since 4.0 + */ +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Controller +@ResponseBody +public @interface RestController { + + /** + * The value may indicate a suggestion for a logical component name, + * to be turned into a Spring bean in case of an autodetected component. + * @return the suggested component name, if any (or empty String otherwise) + * @since 4.0.1 + */ + @AliasFor(annotation = Controller.class) + String value() default ""; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/RestControllerAdvice.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/RestControllerAdvice.java new file mode 100644 index 0000000000000000000000000000000000000000..9e3ab4afa37683bb1a68246a7810cfa760e5b7f7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/RestControllerAdvice.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * A convenience annotation that is itself annotated with + * {@link ControllerAdvice @ControllerAdvice} + * and {@link ResponseBody @ResponseBody}. + * + *

Types that carry this annotation are treated as controller advice where + * {@link ExceptionHandler @ExceptionHandler} methods assume + * {@link ResponseBody @ResponseBody} semantics by default. + * + *

NOTE: {@code @RestControllerAdvice} is processed if an appropriate + * {@code HandlerMapping}-{@code HandlerAdapter} pair is configured such as the + * {@code RequestMappingHandlerMapping}-{@code RequestMappingHandlerAdapter} pair + * which are the default in the MVC Java config and the MVC namespace. + * + * @author Rossen Stoyanchev + * @since 4.3 + * @see RestController + * @see ControllerAdvice + */ +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@ControllerAdvice +@ResponseBody +public @interface RestControllerAdvice { + + /** + * Alias for the {@link #basePackages} attribute. + *

Allows for more concise annotation declarations e.g.: + * {@code @ControllerAdvice("org.my.pkg")} is equivalent to + * {@code @ControllerAdvice(basePackages="org.my.pkg")}. + * @see #basePackages() + */ + @AliasFor("basePackages") + String[] value() default {}; + + /** + * Array of base packages. + *

Controllers that belong to those base packages or sub-packages thereof + * will be included, e.g.: {@code @ControllerAdvice(basePackages="org.my.pkg")} + * or {@code @ControllerAdvice(basePackages={"org.my.pkg", "org.my.other.pkg"})}. + *

{@link #value} is an alias for this attribute, simply allowing for + * more concise use of the annotation. + *

Also consider using {@link #basePackageClasses()} as a type-safe + * alternative to String-based package names. + */ + @AliasFor("value") + String[] basePackages() default {}; + + /** + * Type-safe alternative to {@link #value()} for specifying the packages + * to select Controllers to be assisted by the {@code @ControllerAdvice} + * annotated class. + *

Consider creating a special no-op marker class or interface in each package + * that serves no purpose other than being referenced by this attribute. + */ + Class[] basePackageClasses() default {}; + + /** + * Array of classes. + *

Controllers that are assignable to at least one of the given types + * will be assisted by the {@code @ControllerAdvice} annotated class. + */ + Class[] assignableTypes() default {}; + + /** + * Array of annotations. + *

Controllers that are annotated with this/one of those annotation(s) + * will be assisted by the {@code @ControllerAdvice} annotated class. + *

Consider creating a special annotation or use a predefined one, + * like {@link RestController @RestController}. + */ + Class[] annotations() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/SessionAttribute.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/SessionAttribute.java new file mode 100644 index 0000000000000000000000000000000000000000..ca06dd1bdaeadec9e4d559104a83960daa05214c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/SessionAttribute.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation to bind a method parameter to a session attribute. + * + *

The main motivation is to provide convenient access to existing, permanent + * session attributes (e.g. user authentication object) with an optional/required + * check and a cast to the target method parameter type. + * + *

For use cases that require adding or removing session attributes consider + * injecting {@code org.springframework.web.context.request.WebRequest} or + * {@code javax.servlet.http.HttpSession} into the controller method. + * + *

For temporary storage of model attributes in the session as part of the + * workflow for a controller, consider using {@link SessionAttributes} instead. + * + * @author Rossen Stoyanchev + * @since 4.3 + * @see RequestMapping + * @see SessionAttributes + * @see RequestAttribute + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface SessionAttribute { + + /** + * Alias for {@link #name}. + */ + @AliasFor("name") + String value() default ""; + + /** + * The name of the session attribute to bind to. + *

The default name is inferred from the method parameter name. + */ + @AliasFor("value") + String name() default ""; + + /** + * Whether the session attribute is required. + *

Defaults to {@code true}, leading to an exception being thrown + * if the attribute is missing in the session or there is no session. + * Switch this to {@code false} if you prefer a {@code null} or Java 8 + * {@code java.util.Optional} if the attribute doesn't exist. + */ + boolean required() default true; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/SessionAttributes.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/SessionAttributes.java new file mode 100644 index 0000000000000000000000000000000000000000..218af83d3c2c32c9ea933cd242425c6e750e7a0b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/SessionAttributes.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Annotation that indicates the session attributes that a specific handler uses. + * + *

This will typically list the names of model attributes which should be + * transparently stored in the session or some conversational storage, + * serving as form-backing beans. Declared at the type level, applying + * to the model attributes that the annotated handler class operates on. + * + *

NOTE: Session attributes as indicated using this annotation + * correspond to a specific handler's model attributes, getting transparently + * stored in a conversational session. Those attributes will be removed once + * the handler indicates completion of its conversational session. Therefore, + * use this facility for such conversational attributes which are supposed + * to be stored in the session temporarily during the course of a + * specific handler's conversation. + * + *

For permanent session attributes, e.g. a user authentication object, + * use the traditional {@code session.setAttribute} method instead. + * Alternatively, consider using the attribute management capabilities of the + * generic {@link org.springframework.web.context.request.WebRequest} interface. + * + *

NOTE: When using controller interfaces (e.g. for AOP proxying), + * make sure to consistently put all your mapping annotations — + * such as {@code @RequestMapping} and {@code @SessionAttributes} — on + * the controller interface rather than on the implementation class. + * + * @author Juergen Hoeller + * @author Sam Brannen + * @since 2.5 + */ +@Target({ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +@Inherited +@Documented +public @interface SessionAttributes { + + /** + * Alias for {@link #names}. + */ + @AliasFor("names") + String[] value() default {}; + + /** + * The names of session attributes in the model that should be stored in the + * session or some conversational storage. + *

Note: This indicates the model attribute names. + * The session attribute names may or may not match the model attribute + * names. Applications should therefore not rely on the session attribute + * names but rather operate on the model only. + * @since 4.2 + */ + @AliasFor("value") + String[] names() default {}; + + /** + * The types of session attributes in the model that should be stored in the + * session or some conversational storage. + *

All model attributes of these types will be stored in the session, + * regardless of attribute name. + */ + Class[] types() default {}; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/ValueConstants.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/ValueConstants.java new file mode 100644 index 0000000000000000000000000000000000000000..0993fa584c2c680f9545f0eb740685cdffde36a8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/ValueConstants.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +/** + * Common value constants shared between bind annotations. + * + * @author Juergen Hoeller + * @since 3.0.1 + */ +public interface ValueConstants { + + /** + * Constant defining a value for no default - as a replacement for + * {@code null} which we cannot use in annotation attributes. + *

This is an artificial arrangement of 16 unicode characters, + * with its sole purpose being to never match user-declared values. + * @see RequestParam#defaultValue() + * @see RequestHeader#defaultValue() + * @see CookieValue#defaultValue() + */ + String DEFAULT_NONE = "\n\t\t\n\t\t\n\uE000\uE001\uE002\n\t\t\t\t\n"; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/package-info.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..fb7e6cc76f6123ab7b062acc1809f9f99c0b0c1e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/package-info.java @@ -0,0 +1,10 @@ +/** + * Annotations for binding requests to controllers and handler methods + * as well as for binding request parameters to method arguments. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.bind.annotation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/bind/package-info.java b/spring-web/src/main/java/org/springframework/web/bind/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..7b87239b23caba919fceacca48f582a114fd62dc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides web-specific data binding functionality. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.bind; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/ConfigurableWebBindingInitializer.java b/spring-web/src/main/java/org/springframework/web/bind/support/ConfigurableWebBindingInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..e2f600e005cf8e5e5210969b2f1920821111f802 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/ConfigurableWebBindingInitializer.java @@ -0,0 +1,220 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.springframework.beans.PropertyEditorRegistrar; +import org.springframework.core.convert.ConversionService; +import org.springframework.lang.Nullable; +import org.springframework.validation.BindingErrorProcessor; +import org.springframework.validation.MessageCodesResolver; +import org.springframework.validation.Validator; +import org.springframework.web.bind.WebDataBinder; + +/** + * Convenient {@link WebBindingInitializer} for declarative configuration + * in a Spring application context. Allows for reusing pre-configured + * initializers with multiple controller/handlers. + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setDirectFieldAccess + * @see #setMessageCodesResolver + * @see #setBindingErrorProcessor + * @see #setValidator(Validator) + * @see #setConversionService(ConversionService) + * @see #setPropertyEditorRegistrar + */ +public class ConfigurableWebBindingInitializer implements WebBindingInitializer { + + private boolean autoGrowNestedPaths = true; + + private boolean directFieldAccess = false; + + @Nullable + private MessageCodesResolver messageCodesResolver; + + @Nullable + private BindingErrorProcessor bindingErrorProcessor; + + @Nullable + private Validator validator; + + @Nullable + private ConversionService conversionService; + + @Nullable + private PropertyEditorRegistrar[] propertyEditorRegistrars; + + + /** + * Set whether a binder should attempt to "auto-grow" a nested path that contains a null value. + *

If "true", a null path location will be populated with a default object value and traversed + * instead of resulting in an exception. This flag also enables auto-growth of collection elements + * when accessing an out-of-bounds index. + *

Default is "true" on a standard DataBinder. Note that this feature is only supported + * for bean property access (DataBinder's default mode), not for field access. + * @see org.springframework.validation.DataBinder#initBeanPropertyAccess() + * @see org.springframework.validation.DataBinder#setAutoGrowNestedPaths + */ + public void setAutoGrowNestedPaths(boolean autoGrowNestedPaths) { + this.autoGrowNestedPaths = autoGrowNestedPaths; + } + + /** + * Return whether a binder should attempt to "auto-grow" a nested path that contains a null value. + */ + public boolean isAutoGrowNestedPaths() { + return this.autoGrowNestedPaths; + } + + /** + * Set whether to use direct field access instead of bean property access. + *

Default is {@code false}, using bean property access. + * Switch this to {@code true} in order to enforce direct field access. + * @see org.springframework.validation.DataBinder#initDirectFieldAccess() + * @see org.springframework.validation.DataBinder#initBeanPropertyAccess() + */ + public final void setDirectFieldAccess(boolean directFieldAccess) { + this.directFieldAccess = directFieldAccess; + } + + /** + * Return whether to use direct field access instead of bean property access. + */ + public boolean isDirectFieldAccess() { + return this.directFieldAccess; + } + + /** + * Set the strategy to use for resolving errors into message codes. + * Applies the given strategy to all data binders used by this controller. + *

Default is {@code null}, i.e. using the default strategy of + * the data binder. + * @see org.springframework.validation.DataBinder#setMessageCodesResolver + */ + public final void setMessageCodesResolver(@Nullable MessageCodesResolver messageCodesResolver) { + this.messageCodesResolver = messageCodesResolver; + } + + /** + * Return the strategy to use for resolving errors into message codes. + */ + @Nullable + public final MessageCodesResolver getMessageCodesResolver() { + return this.messageCodesResolver; + } + + /** + * Set the strategy to use for processing binding errors, that is, + * required field errors and {@code PropertyAccessException}s. + *

Default is {@code null}, that is, using the default strategy + * of the data binder. + * @see org.springframework.validation.DataBinder#setBindingErrorProcessor + */ + public final void setBindingErrorProcessor(@Nullable BindingErrorProcessor bindingErrorProcessor) { + this.bindingErrorProcessor = bindingErrorProcessor; + } + + /** + * Return the strategy to use for processing binding errors. + */ + @Nullable + public final BindingErrorProcessor getBindingErrorProcessor() { + return this.bindingErrorProcessor; + } + + /** + * Set the Validator to apply after each binding step. + */ + public final void setValidator(@Nullable Validator validator) { + this.validator = validator; + } + + /** + * Return the Validator to apply after each binding step, if any. + */ + @Nullable + public final Validator getValidator() { + return this.validator; + } + + /** + * Specify a ConversionService which will apply to every DataBinder. + * @since 3.0 + */ + public final void setConversionService(@Nullable ConversionService conversionService) { + this.conversionService = conversionService; + } + + /** + * Return the ConversionService which will apply to every DataBinder. + */ + @Nullable + public final ConversionService getConversionService() { + return this.conversionService; + } + + /** + * Specify a single PropertyEditorRegistrar to be applied to every DataBinder. + */ + public final void setPropertyEditorRegistrar(PropertyEditorRegistrar propertyEditorRegistrar) { + this.propertyEditorRegistrars = new PropertyEditorRegistrar[] {propertyEditorRegistrar}; + } + + /** + * Specify multiple PropertyEditorRegistrars to be applied to every DataBinder. + */ + public final void setPropertyEditorRegistrars(@Nullable PropertyEditorRegistrar[] propertyEditorRegistrars) { + this.propertyEditorRegistrars = propertyEditorRegistrars; + } + + /** + * Return the PropertyEditorRegistrars to be applied to every DataBinder. + */ + @Nullable + public final PropertyEditorRegistrar[] getPropertyEditorRegistrars() { + return this.propertyEditorRegistrars; + } + + + @Override + public void initBinder(WebDataBinder binder) { + binder.setAutoGrowNestedPaths(this.autoGrowNestedPaths); + if (this.directFieldAccess) { + binder.initDirectFieldAccess(); + } + if (this.messageCodesResolver != null) { + binder.setMessageCodesResolver(this.messageCodesResolver); + } + if (this.bindingErrorProcessor != null) { + binder.setBindingErrorProcessor(this.bindingErrorProcessor); + } + if (this.validator != null && binder.getTarget() != null && + this.validator.supports(binder.getTarget().getClass())) { + binder.setValidator(this.validator); + } + if (this.conversionService != null) { + binder.setConversionService(this.conversionService); + } + if (this.propertyEditorRegistrars != null) { + for (PropertyEditorRegistrar propertyEditorRegistrar : this.propertyEditorRegistrars) { + propertyEditorRegistrar.registerCustomEditors(binder); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/DefaultDataBinderFactory.java b/spring-web/src/main/java/org/springframework/web/bind/support/DefaultDataBinderFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..c6482d8b4683d4ea6f8ac74bba8f3828e601adba --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/DefaultDataBinderFactory.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.springframework.lang.Nullable; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Create a {@link WebRequestDataBinder} instance and initialize it with a + * {@link WebBindingInitializer}. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class DefaultDataBinderFactory implements WebDataBinderFactory { + + @Nullable + private final WebBindingInitializer initializer; + + + /** + * Create a new {@code DefaultDataBinderFactory} instance. + * @param initializer for global data binder initialization + * (or {@code null} if none) + */ + public DefaultDataBinderFactory(@Nullable WebBindingInitializer initializer) { + this.initializer = initializer; + } + + + /** + * Create a new {@link WebDataBinder} for the given target object and + * initialize it through a {@link WebBindingInitializer}. + * @throws Exception in case of invalid state or arguments + */ + @Override + @SuppressWarnings("deprecation") + public final WebDataBinder createBinder( + NativeWebRequest webRequest, @Nullable Object target, String objectName) throws Exception { + + WebDataBinder dataBinder = createBinderInstance(target, objectName, webRequest); + if (this.initializer != null) { + this.initializer.initBinder(dataBinder, webRequest); + } + initBinder(dataBinder, webRequest); + return dataBinder; + } + + /** + * Extension point to create the WebDataBinder instance. + * By default this is {@code WebRequestDataBinder}. + * @param target the binding target or {@code null} for type conversion only + * @param objectName the binding target object name + * @param webRequest the current request + * @throws Exception in case of invalid state or arguments + */ + protected WebDataBinder createBinderInstance( + @Nullable Object target, String objectName, NativeWebRequest webRequest) throws Exception { + + return new WebRequestDataBinder(target, objectName); + } + + /** + * Extension point to further initialize the created data binder instance + * (e.g. with {@code @InitBinder} methods) after "global" initialization + * via {@link WebBindingInitializer}. + * @param dataBinder the data binder instance to customize + * @param webRequest the current request + * @throws Exception if initialization fails + */ + protected void initBinder(WebDataBinder dataBinder, NativeWebRequest webRequest) + throws Exception { + + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/DefaultSessionAttributeStore.java b/spring-web/src/main/java/org/springframework/web/bind/support/DefaultSessionAttributeStore.java new file mode 100644 index 0000000000000000000000000000000000000000..30952bf81faebffd30165e05ef2e5d57a1d013c2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/DefaultSessionAttributeStore.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.request.WebRequest; + +/** + * Default implementation of the {@link SessionAttributeStore} interface, + * storing the attributes in the WebRequest session (i.e. HttpSession). + * + * @author Juergen Hoeller + * @since 2.5 + * @see #setAttributeNamePrefix + * @see org.springframework.web.context.request.WebRequest#setAttribute + * @see org.springframework.web.context.request.WebRequest#getAttribute + * @see org.springframework.web.context.request.WebRequest#removeAttribute + */ +public class DefaultSessionAttributeStore implements SessionAttributeStore { + + private String attributeNamePrefix = ""; + + + /** + * Specify a prefix to use for the attribute names in the backend session. + *

Default is to use no prefix, storing the session attributes with the + * same name as in the model. + */ + public void setAttributeNamePrefix(@Nullable String attributeNamePrefix) { + this.attributeNamePrefix = (attributeNamePrefix != null ? attributeNamePrefix : ""); + } + + + @Override + public void storeAttribute(WebRequest request, String attributeName, Object attributeValue) { + Assert.notNull(request, "WebRequest must not be null"); + Assert.notNull(attributeName, "Attribute name must not be null"); + Assert.notNull(attributeValue, "Attribute value must not be null"); + String storeAttributeName = getAttributeNameInSession(request, attributeName); + request.setAttribute(storeAttributeName, attributeValue, WebRequest.SCOPE_SESSION); + } + + @Override + @Nullable + public Object retrieveAttribute(WebRequest request, String attributeName) { + Assert.notNull(request, "WebRequest must not be null"); + Assert.notNull(attributeName, "Attribute name must not be null"); + String storeAttributeName = getAttributeNameInSession(request, attributeName); + return request.getAttribute(storeAttributeName, WebRequest.SCOPE_SESSION); + } + + @Override + public void cleanupAttribute(WebRequest request, String attributeName) { + Assert.notNull(request, "WebRequest must not be null"); + Assert.notNull(attributeName, "Attribute name must not be null"); + String storeAttributeName = getAttributeNameInSession(request, attributeName); + request.removeAttribute(storeAttributeName, WebRequest.SCOPE_SESSION); + } + + + /** + * Calculate the attribute name in the backend session. + *

The default implementation simply prepends the configured + * {@link #setAttributeNamePrefix "attributeNamePrefix"}, if any. + * @param request the current request + * @param attributeName the name of the attribute + * @return the attribute name in the backend session + */ + protected String getAttributeNameInSession(WebRequest request, String attributeName) { + return this.attributeNamePrefix + attributeName; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/SessionAttributeStore.java b/spring-web/src/main/java/org/springframework/web/bind/support/SessionAttributeStore.java new file mode 100644 index 0000000000000000000000000000000000000000..42acf3bbf1462b365026714c8afe0340538ec398 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/SessionAttributeStore.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.springframework.lang.Nullable; +import org.springframework.web.context.request.WebRequest; + +/** + * Strategy interface for storing model attributes in a backend session. + * + * @author Juergen Hoeller + * @since 2.5 + * @see org.springframework.web.bind.annotation.SessionAttributes + */ +public interface SessionAttributeStore { + + /** + * Store the supplied attribute in the backend session. + *

Can be called for new attributes as well as for existing attributes. + * In the latter case, this signals that the attribute value may have been modified. + * @param request the current request + * @param attributeName the name of the attribute + * @param attributeValue the attribute value to store + */ + void storeAttribute(WebRequest request, String attributeName, Object attributeValue); + + /** + * Retrieve the specified attribute from the backend session. + *

This will typically be called with the expectation that the + * attribute is already present, with an exception to be thrown + * if this method returns {@code null}. + * @param request the current request + * @param attributeName the name of the attribute + * @return the current attribute value, or {@code null} if none + */ + @Nullable + Object retrieveAttribute(WebRequest request, String attributeName); + + /** + * Clean up the specified attribute in the backend session. + *

Indicates that the attribute name will not be used anymore. + * @param request the current request + * @param attributeName the name of the attribute + */ + void cleanupAttribute(WebRequest request, String attributeName); + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/SessionStatus.java b/spring-web/src/main/java/org/springframework/web/bind/support/SessionStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..395be0e5535e5ecb02b54a58f9db9919bb65c531 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/SessionStatus.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2007 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +/** + * Simple interface that can be injected into handler methods, allowing them to + * signal that their session processing is complete. The handler invoker may + * then follow up with appropriate cleanup, e.g. of session attributes which + * have been implicitly created during this handler's processing (according to + * the + * {@link org.springframework.web.bind.annotation.SessionAttributes @SessionAttributes} + * annotation). + * + * @author Juergen Hoeller + * @since 2.5 + * @see org.springframework.web.bind.annotation.RequestMapping + * @see org.springframework.web.bind.annotation.SessionAttributes + */ +public interface SessionStatus { + + /** + * Mark the current handler's session processing as complete, allowing for + * cleanup of session attributes. + */ + void setComplete(); + + /** + * Return whether the current handler's session processing has been marked + * as complete. + */ + boolean isComplete(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/SimpleSessionStatus.java b/spring-web/src/main/java/org/springframework/web/bind/support/SimpleSessionStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..fbdb8b085dd7d5b49f3b9e56a5bd0dba7f9caaa3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/SimpleSessionStatus.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +/** + * Simple implementation of the {@link SessionStatus} interface, + * keeping the {@code complete} flag as an instance variable. + * + * @author Juergen Hoeller + * @since 2.5 + */ +public class SimpleSessionStatus implements SessionStatus { + + private boolean complete = false; + + + @Override + public void setComplete() { + this.complete = true; + } + + @Override + public boolean isComplete() { + return this.complete; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/SpringWebConstraintValidatorFactory.java b/spring-web/src/main/java/org/springframework/web/bind/support/SpringWebConstraintValidatorFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..d92cad063d8c536452c01dd710fe9f1d4cb52c1e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/SpringWebConstraintValidatorFactory.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import javax.validation.ConstraintValidator; +import javax.validation.ConstraintValidatorFactory; + +import org.springframework.web.context.ContextLoader; +import org.springframework.web.context.WebApplicationContext; + +/** + * JSR-303 {@link ConstraintValidatorFactory} implementation that delegates to + * the current Spring {@link WebApplicationContext} for creating autowired + * {@link ConstraintValidator} instances. + * + *

In contrast to + * {@link org.springframework.validation.beanvalidation.SpringConstraintValidatorFactory}, + * this variant is meant for declarative use in a standard {@code validation.xml} file, + * e.g. in combination with JAX-RS or JAX-WS. + * + * @author Juergen Hoeller + * @since 4.2.1 + * @see ContextLoader#getCurrentWebApplicationContext() + * @see org.springframework.validation.beanvalidation.SpringConstraintValidatorFactory + */ +public class SpringWebConstraintValidatorFactory implements ConstraintValidatorFactory { + + @Override + public > T getInstance(Class key) { + return getWebApplicationContext().getAutowireCapableBeanFactory().createBean(key); + } + + // Bean Validation 1.1 releaseInstance method + public void releaseInstance(ConstraintValidator instance) { + getWebApplicationContext().getAutowireCapableBeanFactory().destroyBean(instance); + } + + + /** + * Retrieve the Spring {@link WebApplicationContext} to use. + * The default implementation returns the current {@link WebApplicationContext} + * as registered for the thread context class loader. + * @return the current WebApplicationContext (never {@code null}) + * @see ContextLoader#getCurrentWebApplicationContext() + */ + protected WebApplicationContext getWebApplicationContext() { + WebApplicationContext wac = ContextLoader.getCurrentWebApplicationContext(); + if (wac == null) { + throw new IllegalStateException("No WebApplicationContext registered for current thread - " + + "consider overriding SpringWebConstraintValidatorFactory.getWebApplicationContext()"); + } + return wac; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..58574d4864e38510d8057be0879ab3e95b2e3bf4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebArgumentResolver.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * SPI for resolving custom arguments for a specific handler method parameter. + * Typically implemented to detect special parameter types, resolving + * well-known argument values for them. + * + *

A typical implementation could look like as follows: + * + *

+ * public class MySpecialArgumentResolver implements WebArgumentResolver {
+ *
+ *   public Object resolveArgument(MethodParameter methodParameter, NativeWebRequest webRequest) {
+ *     if (methodParameter.getParameterType().equals(MySpecialArg.class)) {
+ *       return new MySpecialArg("myValue");
+ *     }
+ *     return UNRESOLVED;
+ *   }
+ * }
+ * + * @author Juergen Hoeller + * @since 2.5.2 + * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter#setCustomArgumentResolvers + */ +@FunctionalInterface +public interface WebArgumentResolver { + + /** + * Marker to be returned when the resolver does not know how to + * handle the given method parameter. + */ + Object UNRESOLVED = new Object(); + + + /** + * Resolve an argument for the given handler method parameter within the given web request. + * @param methodParameter the handler method parameter to resolve + * @param webRequest the current web request, allowing access to the native request as well + * @return the argument value, or {@code UNRESOLVED} if not resolvable + * @throws Exception in case of resolution failure + */ + @Nullable + Object resolveArgument(MethodParameter methodParameter, NativeWebRequest webRequest) throws Exception; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebBindingInitializer.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebBindingInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..58d5df5f73d5f1b5a953e4dfca32f64a5c05f4a8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebBindingInitializer.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.context.request.WebRequest; + +/** + * Callback interface for initializing a {@link WebDataBinder} for performing + * data binding in the context of a specific web request. + * + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 2.5 + */ +public interface WebBindingInitializer { + + /** + * Initialize the given DataBinder. + * @param binder the DataBinder to initialize + * @since 5.0 + */ + void initBinder(WebDataBinder binder); + + /** + * Initialize the given DataBinder for the given (Servlet) request. + * @param binder the DataBinder to initialize + * @param request the web request that the data binding happens within + * @deprecated as of 5.0 in favor of {@link #initBinder(WebDataBinder)} + */ + @Deprecated + default void initBinder(WebDataBinder binder, WebRequest request) { + initBinder(binder); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebDataBinderFactory.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebDataBinderFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..e513d1094aabebe64b77da8abf349835983d98a8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebDataBinderFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.springframework.lang.Nullable; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * A factory for creating a {@link WebDataBinder} instance for a named target object. + * + * @author Arjen Poutsma + * @since 3.1 + */ +public interface WebDataBinderFactory { + + /** + * Create a {@link WebDataBinder} for the given object. + * @param webRequest the current request + * @param target the object to create a data binder for, + * or {@code null} if creating a binder for a simple type + * @param objectName the name of the target object + * @return the created {@link WebDataBinder} instance, never null + * @throws Exception raised if the creation and initialization of the data binder fails + */ + WebDataBinder createBinder(NativeWebRequest webRequest, @Nullable Object target, String objectName) + throws Exception; + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeBindException.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeBindException.java new file mode 100644 index 0000000000000000000000000000000000000000..ccb49da9f6b6ee8330b6a6d9bf960fb42e5a3d0f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeBindException.java @@ -0,0 +1,307 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import java.beans.PropertyEditor; +import java.util.List; +import java.util.Map; + +import org.springframework.beans.PropertyEditorRegistry; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.validation.BeanPropertyBindingResult; +import org.springframework.validation.BindingResult; +import org.springframework.validation.Errors; +import org.springframework.validation.FieldError; +import org.springframework.validation.ObjectError; +import org.springframework.web.server.ServerWebInputException; + +/** + * A specialization of {@link ServerWebInputException} thrown when after data + * binding and validation failure. Implements {@link BindingResult} (and its + * super-interface {@link Errors}) to allow for direct analysis of binding and + * validation errors. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +@SuppressWarnings("serial") +public class WebExchangeBindException extends ServerWebInputException implements BindingResult { + + private final BindingResult bindingResult; + + + public WebExchangeBindException(MethodParameter parameter, BindingResult bindingResult) { + super("Validation failure", parameter); + this.bindingResult = bindingResult; + } + + + /** + * Return the BindingResult that this BindException wraps. + * Will typically be a BeanPropertyBindingResult. + * @see BeanPropertyBindingResult + */ + public final BindingResult getBindingResult() { + return this.bindingResult; + } + + + @Override + public String getObjectName() { + return this.bindingResult.getObjectName(); + } + + @Override + public void setNestedPath(String nestedPath) { + this.bindingResult.setNestedPath(nestedPath); + } + + @Override + public String getNestedPath() { + return this.bindingResult.getNestedPath(); + } + + @Override + public void pushNestedPath(String subPath) { + this.bindingResult.pushNestedPath(subPath); + } + + @Override + public void popNestedPath() throws IllegalStateException { + this.bindingResult.popNestedPath(); + } + + + @Override + public void reject(String errorCode) { + this.bindingResult.reject(errorCode); + } + + @Override + public void reject(String errorCode, String defaultMessage) { + this.bindingResult.reject(errorCode, defaultMessage); + } + + @Override + public void reject(String errorCode, @Nullable Object[] errorArgs, @Nullable String defaultMessage) { + this.bindingResult.reject(errorCode, errorArgs, defaultMessage); + } + + @Override + public void rejectValue(@Nullable String field, String errorCode) { + this.bindingResult.rejectValue(field, errorCode); + } + + @Override + public void rejectValue(@Nullable String field, String errorCode, String defaultMessage) { + this.bindingResult.rejectValue(field, errorCode, defaultMessage); + } + + @Override + public void rejectValue( + @Nullable String field, String errorCode, @Nullable Object[] errorArgs, @Nullable String defaultMessage) { + + this.bindingResult.rejectValue(field, errorCode, errorArgs, defaultMessage); + } + + @Override + public void addAllErrors(Errors errors) { + this.bindingResult.addAllErrors(errors); + } + + + @Override + public boolean hasErrors() { + return this.bindingResult.hasErrors(); + } + + @Override + public int getErrorCount() { + return this.bindingResult.getErrorCount(); + } + + @Override + public List getAllErrors() { + return this.bindingResult.getAllErrors(); + } + + @Override + public boolean hasGlobalErrors() { + return this.bindingResult.hasGlobalErrors(); + } + + @Override + public int getGlobalErrorCount() { + return this.bindingResult.getGlobalErrorCount(); + } + + @Override + public List getGlobalErrors() { + return this.bindingResult.getGlobalErrors(); + } + + @Override + @Nullable + public ObjectError getGlobalError() { + return this.bindingResult.getGlobalError(); + } + + @Override + public boolean hasFieldErrors() { + return this.bindingResult.hasFieldErrors(); + } + + @Override + public int getFieldErrorCount() { + return this.bindingResult.getFieldErrorCount(); + } + + @Override + public List getFieldErrors() { + return this.bindingResult.getFieldErrors(); + } + + @Override + @Nullable + public FieldError getFieldError() { + return this.bindingResult.getFieldError(); + } + + @Override + public boolean hasFieldErrors(String field) { + return this.bindingResult.hasFieldErrors(field); + } + + @Override + public int getFieldErrorCount(String field) { + return this.bindingResult.getFieldErrorCount(field); + } + + @Override + public List getFieldErrors(String field) { + return this.bindingResult.getFieldErrors(field); + } + + @Override + @Nullable + public FieldError getFieldError(String field) { + return this.bindingResult.getFieldError(field); + } + + @Override + @Nullable + public Object getFieldValue(String field) { + return this.bindingResult.getFieldValue(field); + } + + @Override + @Nullable + public Class getFieldType(String field) { + return this.bindingResult.getFieldType(field); + } + + @Override + @Nullable + public Object getTarget() { + return this.bindingResult.getTarget(); + } + + @Override + public Map getModel() { + return this.bindingResult.getModel(); + } + + @Override + @Nullable + public Object getRawFieldValue(String field) { + return this.bindingResult.getRawFieldValue(field); + } + + @Override + @SuppressWarnings("rawtypes") + @Nullable + public PropertyEditor findEditor(@Nullable String field, @Nullable Class valueType) { + return this.bindingResult.findEditor(field, valueType); + } + + @Override + @Nullable + public PropertyEditorRegistry getPropertyEditorRegistry() { + return this.bindingResult.getPropertyEditorRegistry(); + } + + @Override + public String[] resolveMessageCodes(String errorCode) { + return this.bindingResult.resolveMessageCodes(errorCode); + } + + @Override + public String[] resolveMessageCodes(String errorCode, String field) { + return this.bindingResult.resolveMessageCodes(errorCode, field); + } + + @Override + public void addError(ObjectError error) { + this.bindingResult.addError(error); + } + + @Override + public void recordFieldValue(String field, Class type, @Nullable Object value) { + this.bindingResult.recordFieldValue(field, type, value); + } + + @Override + public void recordSuppressedField(String field) { + this.bindingResult.recordSuppressedField(field); + } + + @Override + public String[] getSuppressedFields() { + return this.bindingResult.getSuppressedFields(); + } + + + /** + * Returns diagnostic information about the errors held in this object. + */ + @Override + public String getMessage() { + MethodParameter parameter = getMethodParameter(); + Assert.state(parameter != null, "No MethodParameter"); + StringBuilder sb = new StringBuilder("Validation failed for argument at index ") + .append(parameter.getParameterIndex()).append(" in method: ") + .append(parameter.getExecutable().toGenericString()) + .append(", with ").append(this.bindingResult.getErrorCount()).append(" error(s): "); + for (ObjectError error : this.bindingResult.getAllErrors()) { + sb.append("[").append(error).append("] "); + } + return sb.toString(); + } + + @Override + public boolean equals(Object other) { + return (this == other || this.bindingResult.equals(other)); + } + + @Override + public int hashCode() { + return this.bindingResult.hashCode(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java new file mode 100644 index 0000000000000000000000000000000000000000..1114de7b61299070ae2a12a4baa15c25df310b71 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Collectors; + +import reactor.core.publisher.Mono; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.http.codec.multipart.FormFieldPart; +import org.springframework.http.codec.multipart.Part; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.server.ServerWebExchange; + +/** + * Specialized {@link org.springframework.validation.DataBinder} to perform data + * binding from URL query params or form data in the request data to Java objects. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class WebExchangeDataBinder extends WebDataBinder { + + /** + * Create a new instance, with default object name. + * @param target the target object to bind onto (or {@code null} if the + * binder is just used to convert a plain parameter value) + * @see #DEFAULT_OBJECT_NAME + */ + public WebExchangeDataBinder(@Nullable Object target) { + super(target); + } + + /** + * Create a new instance. + * @param target the target object to bind onto (or {@code null} if the + * binder is just used to convert a plain parameter value) + * @param objectName the name of the target object + */ + public WebExchangeDataBinder(@Nullable Object target, String objectName) { + super(target, objectName); + } + + + /** + * Bind query params, form data, and or multipart form data to the binder target. + * @param exchange the current exchange + * @return a {@code Mono} when binding is complete + */ + public Mono bind(ServerWebExchange exchange) { + return getValuesToBind(exchange) + .doOnNext(values -> doBind(new MutablePropertyValues(values))) + .then(); + } + + /** + * Protected method to obtain the values for data binding. By default this + * method delegates to {@link #extractValuesToBind(ServerWebExchange)}. + */ + protected Mono> getValuesToBind(ServerWebExchange exchange) { + return extractValuesToBind(exchange); + } + + + /** + * Combine query params and form data for multipart form data from the body + * of the request into a {@code Map} of values to use for + * data binding purposes. + * @param exchange the current exchange + * @return a {@code Mono} with the values to bind + * @see org.springframework.http.server.reactive.ServerHttpRequest#getQueryParams() + * @see ServerWebExchange#getFormData() + * @see ServerWebExchange#getMultipartData() + */ + public static Mono> extractValuesToBind(ServerWebExchange exchange) { + MultiValueMap queryParams = exchange.getRequest().getQueryParams(); + Mono> formData = exchange.getFormData(); + Mono> multipartData = exchange.getMultipartData(); + + return Mono.zip(Mono.just(queryParams), formData, multipartData) + .map(tuple -> { + Map result = new TreeMap<>(); + tuple.getT1().forEach((key, values) -> addBindValue(result, key, values)); + tuple.getT2().forEach((key, values) -> addBindValue(result, key, values)); + tuple.getT3().forEach((key, values) -> addBindValue(result, key, values)); + return result; + }); + } + + private static void addBindValue(Map params, String key, List values) { + if (!CollectionUtils.isEmpty(values)) { + values = values.stream() + .map(value -> value instanceof FormFieldPart ? ((FormFieldPart) value).value() : value) + .collect(Collectors.toList()); + params.put(key, values.size() == 1 ? values.get(0) : values); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java new file mode 100644 index 0000000000000000000000000000000000000000..c9673eace5770db9ce99fd3664d6ca3e055e02c4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java @@ -0,0 +1,171 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.validation.BindException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.WebRequest; +import org.springframework.web.multipart.MultipartException; +import org.springframework.web.multipart.MultipartRequest; + +/** + * Special {@link org.springframework.validation.DataBinder} to perform data binding + * from web request parameters to JavaBeans, including support for multipart files. + * + *

See the DataBinder/WebDataBinder superclasses for customization options, + * which include specifying allowed/required fields, and registering custom + * property editors. + * + *

Can also used for manual data binding in custom web controllers or interceptors + * that build on Spring's {@link org.springframework.web.context.request.WebRequest} + * abstraction: e.g. in a {@link org.springframework.web.context.request.WebRequestInterceptor} + * implementation. Simply instantiate a WebRequestDataBinder for each binding + * process, and invoke {@code bind} with the current WebRequest as argument: + * + *

+ * MyBean myBean = new MyBean();
+ * // apply binder to custom target object
+ * WebRequestDataBinder binder = new WebRequestDataBinder(myBean);
+ * // register custom editors, if desired
+ * binder.registerCustomEditor(...);
+ * // trigger actual binding of request parameters
+ * binder.bind(request);
+ * // optionally evaluate binding errors
+ * Errors errors = binder.getErrors();
+ * ...
+ * + * @author Juergen Hoeller + * @author Brian Clozel + * @since 2.5.2 + * @see #bind(org.springframework.web.context.request.WebRequest) + * @see #registerCustomEditor + * @see #setAllowedFields + * @see #setRequiredFields + * @see #setFieldMarkerPrefix + */ +public class WebRequestDataBinder extends WebDataBinder { + + /** + * Create a new WebRequestDataBinder instance, with default object name. + * @param target the target object to bind onto (or {@code null} + * if the binder is just used to convert a plain parameter value) + * @see #DEFAULT_OBJECT_NAME + */ + public WebRequestDataBinder(@Nullable Object target) { + super(target); + } + + /** + * Create a new WebRequestDataBinder instance. + * @param target the target object to bind onto (or {@code null} + * if the binder is just used to convert a plain parameter value) + * @param objectName the name of the target object + */ + public WebRequestDataBinder(@Nullable Object target, String objectName) { + super(target, objectName); + } + + + /** + * Bind the parameters of the given request to this binder's target, + * also binding multipart files in case of a multipart request. + *

This call can create field errors, representing basic binding + * errors like a required field (code "required"), or type mismatch + * between value and bean property (code "typeMismatch"). + *

Multipart files are bound via their parameter name, just like normal + * HTTP parameters: i.e. "uploadedFile" to an "uploadedFile" bean property, + * invoking a "setUploadedFile" setter method. + *

The type of the target property for a multipart file can be Part, MultipartFile, + * byte[], or String. The latter two receive the contents of the uploaded file; + * all metadata like original file name, content type, etc are lost in those cases. + * @param request the request with parameters to bind (can be multipart) + * @see org.springframework.web.multipart.MultipartRequest + * @see org.springframework.web.multipart.MultipartFile + * @see javax.servlet.http.Part + * @see #bind(org.springframework.beans.PropertyValues) + */ + public void bind(WebRequest request) { + MutablePropertyValues mpvs = new MutablePropertyValues(request.getParameterMap()); + if (request instanceof NativeWebRequest) { + MultipartRequest multipartRequest = ((NativeWebRequest) request).getNativeRequest(MultipartRequest.class); + if (multipartRequest != null) { + bindMultipart(multipartRequest.getMultiFileMap(), mpvs); + } + else if (isMultipartRequest(request)) { + HttpServletRequest servletRequest = ((NativeWebRequest) request).getNativeRequest(HttpServletRequest.class); + if (servletRequest != null) { + bindParts(servletRequest, mpvs); + } + } + } + doBind(mpvs); + } + + /** + * Check if the request is a multipart request (by checking its Content-Type header). + * @param request the request with parameters to bind + */ + private boolean isMultipartRequest(WebRequest request) { + String contentType = request.getHeader("Content-Type"); + return StringUtils.startsWithIgnoreCase(contentType, "multipart/"); + } + + private void bindParts(HttpServletRequest request, MutablePropertyValues mpvs) { + try { + MultiValueMap map = new LinkedMultiValueMap<>(); + for (Part part : request.getParts()) { + map.add(part.getName(), part); + } + map.forEach((key, values) -> { + if (values.size() == 1) { + Part part = values.get(0); + if (isBindEmptyMultipartFiles() || part.getSize() > 0) { + mpvs.add(key, part); + } + } + else { + mpvs.add(key, values); + } + }); + } + catch (Exception ex) { + throw new MultipartException("Failed to get request parts", ex); + } + } + + /** + * Treats errors as fatal. + *

Use this method only if it's an error if the input isn't valid. + * This might be appropriate if all input is from dropdowns, for example. + * @throws BindException if binding errors have been encountered + */ + public void closeNoCatch() throws BindException { + if (getBindingResult().hasErrors()) { + throw new BindException(getBindingResult()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/package-info.java b/spring-web/src/main/java/org/springframework/web/bind/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..c55741fc04f281e26a25a63e70d788d88a778a56 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/support/package-info.java @@ -0,0 +1,9 @@ +/** + * Support classes for web data binding. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.bind.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/client/AsyncRequestCallback.java b/spring-web/src/main/java/org/springframework/web/client/AsyncRequestCallback.java new file mode 100644 index 0000000000000000000000000000000000000000..4efc74eef602e89d8ccbe93d203a4de2bfaa9fca --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/AsyncRequestCallback.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; + +/** + * Callback interface for code that operates on an + * {@link org.springframework.http.client.AsyncClientHttpRequest}. Allows to + * manipulate the request headers, and write to the request body. + * + *

Used internally by the {@link AsyncRestTemplate}, but also useful for + * application code. + * + * @author Arjen Poutsma + * @see org.springframework.web.client.AsyncRestTemplate#execute + * @since 4.0 + * @deprecated as of Spring 5.0, in favor of + * {@link org.springframework.web.reactive.function.client.ExchangeFilterFunction} + */ +@FunctionalInterface +@Deprecated +public interface AsyncRequestCallback { + + /** + * Gets called by {@link AsyncRestTemplate#execute} with an opened {@code ClientHttpRequest}. + * Does not need to care about closing the request or about handling errors: + * this will all be handled by the {@code RestTemplate}. + * @param request the active HTTP request + * @throws java.io.IOException in case of I/O errors + */ + void doWithRequest(org.springframework.http.client.AsyncClientHttpRequest request) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/AsyncRestOperations.java b/spring-web/src/main/java/org/springframework/web/client/AsyncRestOperations.java new file mode 100644 index 0000000000000000000000000000000000000000..c9d92a2aff4b8bd3df149e0d4a6ddeba5f106502 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/AsyncRestOperations.java @@ -0,0 +1,464 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.net.URI; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Future; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.ResponseEntity; +import org.springframework.lang.Nullable; +import org.springframework.util.concurrent.ListenableFuture; + +/** + * Interface specifying a basic set of asynchronous RESTful operations. + * Implemented by {@link AsyncRestTemplate}. Not often used directly, but a useful + * option to enhance testability, as it can easily be mocked or stubbed. + * + * @author Arjen Poutsma + * @since 4.0 + * @see AsyncRestTemplate + * @see RestOperations + * @deprecated as of Spring 5.0, in favor of {@link org.springframework.web.reactive.function.client.WebClient} + */ +@Deprecated +public interface AsyncRestOperations { + + /** + * Expose the synchronous Spring RestTemplate to allow synchronous invocation. + */ + RestOperations getRestOperations(); + + + // GET + + /** + * Asynchronously retrieve an entity by doing a GET on the specified URL. + * The response is converted and stored in an {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param responseType the type of the return value + * @param uriVariables the variables to expand the template + * @return the entity wrapped in a {@link Future} + */ + ListenableFuture> getForEntity(String url, Class responseType, + Object... uriVariables) throws RestClientException; + + /** + * Asynchronously retrieve a representation by doing a GET on the URI template. + * The response is converted and stored in an {@link ResponseEntity}. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param responseType the type of the return value + * @param uriVariables the map containing variables for the URI template + * @return the entity wrapped in a {@link Future} + */ + ListenableFuture> getForEntity(String url, Class responseType, + Map uriVariables) throws RestClientException; + + /** + * Asynchronously retrieve a representation by doing a GET on the URL. + * The response is converted and stored in an {@link ResponseEntity}. + * @param url the URL + * @param responseType the type of the return value + * @return the entity wrapped in a {@link Future} + */ + ListenableFuture> getForEntity(URI url, Class responseType) + throws RestClientException; + + + // HEAD + + /** + * Asynchronously retrieve all headers of the resource specified by the URI template. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param uriVariables the variables to expand the template + * @return all HTTP headers of that resource wrapped in a {@link Future} + */ + ListenableFuture headForHeaders(String url, Object... uriVariables) + throws RestClientException; + + /** + * Asynchronously retrieve all headers of the resource specified by the URI template. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param uriVariables the map containing variables for the URI template + * @return all HTTP headers of that resource wrapped in a {@link Future} + */ + ListenableFuture headForHeaders(String url, Map uriVariables) + throws RestClientException; + + /** + * Asynchronously retrieve all headers of the resource specified by the URL. + * @param url the URL + * @return all HTTP headers of that resource wrapped in a {@link Future} + */ + ListenableFuture headForHeaders(URI url) throws RestClientException; + + + // POST + + /** + * Create a new resource by POSTing the given object to the URI template, and + * asynchronously returns the value of the {@code Location} header. This header + * typically indicates where the new resource is stored. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the value for the {@code Location} header wrapped in a {@link Future} + * @see org.springframework.http.HttpEntity + */ + ListenableFuture postForLocation(String url, @Nullable HttpEntity request, Object... uriVariables) + throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, and + * asynchronously returns the value of the {@code Location} header. This header + * typically indicates where the new resource is stored. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the value for the {@code Location} header wrapped in a {@link Future} + * @see org.springframework.http.HttpEntity + */ + ListenableFuture postForLocation(String url, @Nullable HttpEntity request, Map uriVariables) + throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URL, and asynchronously + * returns the value of the {@code Location} header. This header typically indicates + * where the new resource is stored. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @return the value for the {@code Location} header wrapped in a {@link Future} + * @see org.springframework.http.HttpEntity + */ + ListenableFuture postForLocation(URI url, @Nullable HttpEntity request) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, + * and asynchronously returns the response as {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the entity wrapped in a {@link Future} + * @see org.springframework.http.HttpEntity + */ + ListenableFuture> postForEntity(String url, @Nullable HttpEntity request, + Class responseType, Object... uriVariables) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, + * and asynchronously returns the response as {@link ResponseEntity}. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the entity wrapped in a {@link Future} + * @see org.springframework.http.HttpEntity + */ + ListenableFuture> postForEntity(String url, @Nullable HttpEntity request, + Class responseType, Map uriVariables) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URL, + * and asynchronously returns the response as {@link ResponseEntity}. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @return the entity wrapped in a {@link Future} + * @see org.springframework.http.HttpEntity + */ + ListenableFuture> postForEntity(URI url, @Nullable HttpEntity request, + Class responseType) throws RestClientException; + + + // PUT + + /** + * Create or update a resource by PUTting the given object to the URI. + *

URI Template variables are expanded using the given URI variables, if any. + *

The Future will return a {@code null} result upon completion. + * @param url the URL + * @param request the Object to be PUT (may be {@code null}) + * @param uriVariables the variables to expand the template + * @see HttpEntity + */ + ListenableFuture put(String url, @Nullable HttpEntity request, Object... uriVariables) + throws RestClientException; + + /** + * Creates a new resource by PUTting the given object to URI template. + *

URI Template variables are expanded using the given map. + *

The Future will return a {@code null} result upon completion. + * @param url the URL + * @param request the Object to be PUT (may be {@code null}) + * @param uriVariables the variables to expand the template + * @see HttpEntity + */ + ListenableFuture put(String url, @Nullable HttpEntity request, Map uriVariables) + throws RestClientException; + + /** + * Creates a new resource by PUTting the given object to URL. + *

The Future will return a {@code null} result upon completion. + * @param url the URL + * @param request the Object to be PUT (may be {@code null}) + * @see HttpEntity + */ + ListenableFuture put(URI url, @Nullable HttpEntity request) throws RestClientException; + + + // DELETE + + /** + * Asynchronously delete the resources at the specified URI. + *

URI Template variables are expanded using the given URI variables, if any. + *

The Future will return a {@code null} result upon completion. + * @param url the URL + * @param uriVariables the variables to expand in the template + */ + ListenableFuture delete(String url, Object... uriVariables) throws RestClientException; + + /** + * Asynchronously delete the resources at the specified URI. + *

URI Template variables are expanded using the given URI variables, if any. + *

The Future will return a {@code null} result upon completion. + * @param url the URL + * @param uriVariables the variables to expand in the template + */ + ListenableFuture delete(String url, Map uriVariables) throws RestClientException; + + /** + * Asynchronously delete the resources at the specified URI. + *

URI Template variables are expanded using the given URI variables, if any. + *

The Future will return a {@code null} result upon completion. + * @param url the URL + */ + ListenableFuture delete(URI url) throws RestClientException; + + + // OPTIONS + + /** + * Asynchronously return the value of the Allow header for the given URI. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param uriVariables the variables to expand in the template + * @return the value of the allow header wrapped in a {@link Future} + */ + ListenableFuture> optionsForAllow(String url, Object... uriVariables) + throws RestClientException; + + /** + * Asynchronously return the value of the Allow header for the given URI. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param uriVariables the variables to expand in the template + * @return the value of the allow header wrapped in a {@link Future} + */ + ListenableFuture> optionsForAllow(String url, Map uriVariables) + throws RestClientException; + + /** + * Asynchronously return the value of the Allow header for the given URL. + * @param url the URL + * @return the value of the allow header wrapped in a {@link Future} + */ + ListenableFuture> optionsForAllow(URI url) throws RestClientException; + + + // exchange + + /** + * Asynchronously execute the HTTP method to the given URI template, writing the + * given request entity to the request, and returns the response as + * {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity wrapped in a {@link Future} + */ + ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType, Object... uriVariables) + throws RestClientException; + + /** + * Asynchronously execute the HTTP method to the given URI template, writing the + * given request entity to the request, and returns the response as + * {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity wrapped in a {@link Future} + */ + ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType, + Map uriVariables) throws RestClientException; + + /** + * Asynchronously execute the HTTP method to the given URI template, writing the + * given request entity to the request, and returns the response as + * {@link ResponseEntity}. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @return the response as entity wrapped in a {@link Future} + */ + ListenableFuture> exchange(URI url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType) + throws RestClientException; + + /** + * Asynchronously execute the HTTP method to the given URI template, writing the given + * request entity to the request, and returns the response as {@link ResponseEntity}. + * The given {@link ParameterizedTypeReference} is used to pass generic type + * information: + *

+	 * ParameterizedTypeReference<List<MyBean>> myBean =
+	 *     new ParameterizedTypeReference<List<MyBean>>() {};
+	 *
+	 * ResponseEntity<List<MyBean>> response =
+	 *     template.exchange("https://example.com",HttpMethod.GET, null, myBean);
+	 * 
+ * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the + * request (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity wrapped in a {@link Future} + */ + ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, ParameterizedTypeReference responseType, + Object... uriVariables) throws RestClientException; + + /** + * Asynchronously execute the HTTP method to the given URI template, writing the given + * request entity to the request, and returns the response as {@link ResponseEntity}. + * The given {@link ParameterizedTypeReference} is used to pass generic type + * information: + *
+	 * ParameterizedTypeReference<List<MyBean>> myBean =
+	 *     new ParameterizedTypeReference<List<MyBean>>() {};
+	 *
+	 * ResponseEntity<List<MyBean>> response =
+	 *     template.exchange("https://example.com",HttpMethod.GET, null, myBean);
+	 * 
+ * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity wrapped in a {@link Future} + */ + ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, ParameterizedTypeReference responseType, + Map uriVariables) throws RestClientException; + + /** + * Asynchronously execute the HTTP method to the given URI template, writing the given + * request entity to the request, and returns the response as {@link ResponseEntity}. + * The given {@link ParameterizedTypeReference} is used to pass generic type + * information: + *
+	 * ParameterizedTypeReference<List<MyBean>> myBean =
+	 *     new ParameterizedTypeReference<List<MyBean>>() {};
+	 *
+	 * ResponseEntity<List<MyBean>> response =
+	 *     template.exchange("https://example.com",HttpMethod.GET, null, myBean);
+	 * 
+ * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @return the response as entity wrapped in a {@link Future} + */ + ListenableFuture> exchange(URI url, HttpMethod method, + @Nullable HttpEntity requestEntity, ParameterizedTypeReference responseType) + throws RestClientException; + + + // general execution + + /** + * Asynchronously execute the HTTP method to the given URI template, preparing the + * request with the {@link AsyncRequestCallback}, and reading the response with a + * {@link ResponseExtractor}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestCallback object that prepares the request + * @param responseExtractor object that extracts the return value from the response + * @param uriVariables the variables to expand in the template + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + ListenableFuture execute(String url, HttpMethod method, + @Nullable AsyncRequestCallback requestCallback, @Nullable ResponseExtractor responseExtractor, + Object... uriVariables) throws RestClientException; + + /** + * Asynchronously execute the HTTP method to the given URI template, preparing the + * request with the {@link AsyncRequestCallback}, and reading the response with a + * {@link ResponseExtractor}. + *

URI Template variables are expanded using the given URI variables map. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestCallback object that prepares the request + * @param responseExtractor object that extracts the return value from the response + * @param uriVariables the variables to expand in the template + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + ListenableFuture execute(String url, HttpMethod method, + @Nullable AsyncRequestCallback requestCallback, @Nullable ResponseExtractor responseExtractor, + Map uriVariables) throws RestClientException; + + /** + * Asynchronously execute the HTTP method to the given URL, preparing the request + * with the {@link AsyncRequestCallback}, and reading the response with a + * {@link ResponseExtractor}. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestCallback object that prepares the request + * @param responseExtractor object that extracts the return value from the response + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + ListenableFuture execute(URI url, HttpMethod method, + @Nullable AsyncRequestCallback requestCallback, @Nullable ResponseExtractor responseExtractor) + throws RestClientException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/AsyncRestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/AsyncRestTemplate.java new file mode 100644 index 0000000000000000000000000000000000000000..f94359de4aeb77a523ec1bea90dacc0bb52709f1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/AsyncRestTemplate.java @@ -0,0 +1,714 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.io.OutputStream; +import java.lang.reflect.Type; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureAdapter; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.UriTemplateHandler; + +/** + * Spring's central class for asynchronous client-side HTTP access. + * Exposes similar methods as {@link RestTemplate}, but returns {@link ListenableFuture} + * wrappers as opposed to concrete results. + * + *

The {@code AsyncRestTemplate} exposes a synchronous {@link RestTemplate} via the + * {@link #getRestOperations()} method and shares its {@linkplain #setErrorHandler error handler} + * and {@linkplain #setMessageConverters message converters} with that {@code RestTemplate}. + * + *

Note: by default {@code AsyncRestTemplate} relies on + * standard JDK facilities to establish HTTP connections. You can switch to use + * a different HTTP library such as Apache HttpComponents, Netty, and OkHttp by + * using a constructor accepting an {@link org.springframework.http.client.AsyncClientHttpRequestFactory}. + * + *

For more information, please refer to the {@link RestTemplate} API documentation. + * + * @author Arjen Poutsma + * @since 4.0 + * @see RestTemplate + * @deprecated as of Spring 5.0, in favor of {@link org.springframework.web.reactive.function.client.WebClient} + */ +@Deprecated +public class AsyncRestTemplate extends org.springframework.http.client.support.InterceptingAsyncHttpAccessor + implements AsyncRestOperations { + + private final RestTemplate syncTemplate; + + + /** + * Create a new instance of the {@code AsyncRestTemplate} using default settings. + *

This constructor uses a {@link SimpleClientHttpRequestFactory} in combination + * with a {@link SimpleAsyncTaskExecutor} for asynchronous execution. + */ + public AsyncRestTemplate() { + this(new SimpleAsyncTaskExecutor()); + } + + /** + * Create a new instance of the {@code AsyncRestTemplate} using the given + * {@link AsyncTaskExecutor}. + *

This constructor uses a {@link SimpleClientHttpRequestFactory} in combination + * with the given {@code AsyncTaskExecutor} for asynchronous execution. + */ + public AsyncRestTemplate(AsyncListenableTaskExecutor taskExecutor) { + Assert.notNull(taskExecutor, "AsyncTaskExecutor must not be null"); + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + requestFactory.setTaskExecutor(taskExecutor); + this.syncTemplate = new RestTemplate(requestFactory); + setAsyncRequestFactory(requestFactory); + } + + /** + * Create a new instance of the {@code AsyncRestTemplate} using the given + * {@link org.springframework.http.client.AsyncClientHttpRequestFactory}. + *

This constructor will cast the given asynchronous + * {@code AsyncClientHttpRequestFactory} to a {@link ClientHttpRequestFactory}. Since + * all implementations of {@code ClientHttpRequestFactory} provided in Spring also + * implement {@code AsyncClientHttpRequestFactory}, this should not result in a + * {@code ClassCastException}. + */ + public AsyncRestTemplate(org.springframework.http.client.AsyncClientHttpRequestFactory asyncRequestFactory) { + this(asyncRequestFactory, (ClientHttpRequestFactory) asyncRequestFactory); + } + + /** + * Creates a new instance of the {@code AsyncRestTemplate} using the given + * asynchronous and synchronous request factories. + * @param asyncRequestFactory the asynchronous request factory + * @param syncRequestFactory the synchronous request factory + */ + public AsyncRestTemplate(org.springframework.http.client.AsyncClientHttpRequestFactory asyncRequestFactory, + ClientHttpRequestFactory syncRequestFactory) { + + this(asyncRequestFactory, new RestTemplate(syncRequestFactory)); + } + + /** + * Create a new instance of the {@code AsyncRestTemplate} using the given + * {@link org.springframework.http.client.AsyncClientHttpRequestFactory} and synchronous {@link RestTemplate}. + * @param requestFactory the asynchronous request factory to use + * @param restTemplate the synchronous template to use + */ + public AsyncRestTemplate(org.springframework.http.client.AsyncClientHttpRequestFactory requestFactory, + RestTemplate restTemplate) { + + Assert.notNull(restTemplate, "RestTemplate must not be null"); + this.syncTemplate = restTemplate; + setAsyncRequestFactory(requestFactory); + } + + + /** + * Set the error handler. + *

By default, AsyncRestTemplate uses a + * {@link org.springframework.web.client.DefaultResponseErrorHandler}. + */ + public void setErrorHandler(ResponseErrorHandler errorHandler) { + this.syncTemplate.setErrorHandler(errorHandler); + } + + /** + * Return the error handler. + */ + public ResponseErrorHandler getErrorHandler() { + return this.syncTemplate.getErrorHandler(); + } + + /** + * Configure default URI variable values. This is a shortcut for: + *

+	 * DefaultUriTemplateHandler handler = new DefaultUriTemplateHandler();
+	 * handler.setDefaultUriVariables(...);
+	 *
+	 * AsyncRestTemplate restTemplate = new AsyncRestTemplate();
+	 * restTemplate.setUriTemplateHandler(handler);
+	 * 
+ * @param defaultUriVariables the default URI variable values + * @since 4.3 + */ + @SuppressWarnings("deprecation") + public void setDefaultUriVariables(Map defaultUriVariables) { + UriTemplateHandler handler = this.syncTemplate.getUriTemplateHandler(); + if (handler instanceof DefaultUriBuilderFactory) { + ((DefaultUriBuilderFactory) handler).setDefaultUriVariables(defaultUriVariables); + } + else if (handler instanceof org.springframework.web.util.AbstractUriTemplateHandler) { + ((org.springframework.web.util.AbstractUriTemplateHandler) handler) + .setDefaultUriVariables(defaultUriVariables); + } + else { + throw new IllegalArgumentException( + "This property is not supported with the configured UriTemplateHandler."); + } + } + + /** + * This property has the same purpose as the corresponding property on the + * {@code RestTemplate}. For more details see + * {@link RestTemplate#setUriTemplateHandler}. + * @param handler the URI template handler to use + */ + public void setUriTemplateHandler(UriTemplateHandler handler) { + this.syncTemplate.setUriTemplateHandler(handler); + } + + /** + * Return the configured URI template handler. + */ + public UriTemplateHandler getUriTemplateHandler() { + return this.syncTemplate.getUriTemplateHandler(); + } + + @Override + public RestOperations getRestOperations() { + return this.syncTemplate; + } + + /** + * Set the message body converters to use. + *

These converters are used to convert from and to HTTP requests and responses. + */ + public void setMessageConverters(List> messageConverters) { + this.syncTemplate.setMessageConverters(messageConverters); + } + + /** + * Return the message body converters. + */ + public List> getMessageConverters() { + return this.syncTemplate.getMessageConverters(); + } + + + // GET + + @Override + public ListenableFuture> getForEntity(String url, Class responseType, Object... uriVariables) + throws RestClientException { + + AsyncRequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> getForEntity(String url, Class responseType, + Map uriVariables) throws RestClientException { + + AsyncRequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> getForEntity(URI url, Class responseType) + throws RestClientException { + + AsyncRequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, HttpMethod.GET, requestCallback, responseExtractor); + } + + + // HEAD + + @Override + public ListenableFuture headForHeaders(String url, Object... uriVariables) + throws RestClientException { + + ResponseExtractor headersExtractor = headersExtractor(); + return execute(url, HttpMethod.HEAD, null, headersExtractor, uriVariables); + } + + @Override + public ListenableFuture headForHeaders(String url, Map uriVariables) + throws RestClientException { + + ResponseExtractor headersExtractor = headersExtractor(); + return execute(url, HttpMethod.HEAD, null, headersExtractor, uriVariables); + } + + @Override + public ListenableFuture headForHeaders(URI url) throws RestClientException { + ResponseExtractor headersExtractor = headersExtractor(); + return execute(url, HttpMethod.HEAD, null, headersExtractor); + } + + + // POST + + @Override + public ListenableFuture postForLocation(String url, @Nullable HttpEntity request, Object... uriVars) + throws RestClientException { + + AsyncRequestCallback callback = httpEntityCallback(request); + ResponseExtractor extractor = headersExtractor(); + ListenableFuture future = execute(url, HttpMethod.POST, callback, extractor, uriVars); + return adaptToLocationHeader(future); + } + + @Override + public ListenableFuture postForLocation(String url, @Nullable HttpEntity request, Map uriVars) + throws RestClientException { + + AsyncRequestCallback callback = httpEntityCallback(request); + ResponseExtractor extractor = headersExtractor(); + ListenableFuture future = execute(url, HttpMethod.POST, callback, extractor, uriVars); + return adaptToLocationHeader(future); + } + + @Override + public ListenableFuture postForLocation(URI url, @Nullable HttpEntity request) + throws RestClientException { + + AsyncRequestCallback callback = httpEntityCallback(request); + ResponseExtractor extractor = headersExtractor(); + ListenableFuture future = execute(url, HttpMethod.POST, callback, extractor); + return adaptToLocationHeader(future); + } + + private static ListenableFuture adaptToLocationHeader(ListenableFuture future) { + return new ListenableFutureAdapter(future) { + @Override + @Nullable + protected URI adapt(HttpHeaders headers) throws ExecutionException { + return headers.getLocation(); + } + }; + } + + @Override + public ListenableFuture> postForEntity(String url, @Nullable HttpEntity request, + Class responseType, Object... uriVariables) throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(request, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, HttpMethod.POST, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> postForEntity(String url, @Nullable HttpEntity request, + Class responseType, Map uriVariables) throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(request, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, HttpMethod.POST, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> postForEntity(URI url, + @Nullable HttpEntity request, Class responseType) throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(request, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, HttpMethod.POST, requestCallback, responseExtractor); + } + + + // PUT + + @Override + public ListenableFuture put(String url, @Nullable HttpEntity request, Object... uriVars) + throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(request); + return execute(url, HttpMethod.PUT, requestCallback, null, uriVars); + } + + @Override + public ListenableFuture put(String url, @Nullable HttpEntity request, Map uriVars) + throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(request); + return execute(url, HttpMethod.PUT, requestCallback, null, uriVars); + } + + @Override + public ListenableFuture put(URI url, @Nullable HttpEntity request) throws RestClientException { + AsyncRequestCallback requestCallback = httpEntityCallback(request); + return execute(url, HttpMethod.PUT, requestCallback, null); + } + + + // DELETE + + @Override + public ListenableFuture delete(String url, Object... uriVariables) throws RestClientException { + return execute(url, HttpMethod.DELETE, null, null, uriVariables); + } + + @Override + public ListenableFuture delete(String url, Map uriVariables) throws RestClientException { + return execute(url, HttpMethod.DELETE, null, null, uriVariables); + } + + @Override + public ListenableFuture delete(URI url) throws RestClientException { + return execute(url, HttpMethod.DELETE, null, null); + } + + + // OPTIONS + + @Override + public ListenableFuture> optionsForAllow(String url, Object... uriVars) + throws RestClientException { + + ResponseExtractor extractor = headersExtractor(); + ListenableFuture future = execute(url, HttpMethod.OPTIONS, null, extractor, uriVars); + return adaptToAllowHeader(future); + } + + @Override + public ListenableFuture> optionsForAllow(String url, Map uriVars) + throws RestClientException { + + ResponseExtractor extractor = headersExtractor(); + ListenableFuture future = execute(url, HttpMethod.OPTIONS, null, extractor, uriVars); + return adaptToAllowHeader(future); + } + + @Override + public ListenableFuture> optionsForAllow(URI url) throws RestClientException { + ResponseExtractor extractor = headersExtractor(); + ListenableFuture future = execute(url, HttpMethod.OPTIONS, null, extractor); + return adaptToAllowHeader(future); + } + + private static ListenableFuture> adaptToAllowHeader(ListenableFuture future) { + return new ListenableFutureAdapter, HttpHeaders>(future) { + @Override + protected Set adapt(HttpHeaders headers) throws ExecutionException { + return headers.getAllow(); + } + }; + } + + // exchange + + @Override + public ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType, Object... uriVariables) + throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(requestEntity, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, method, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType, Map uriVariables) + throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(requestEntity, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, method, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> exchange(URI url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType) throws RestClientException { + + AsyncRequestCallback requestCallback = httpEntityCallback(requestEntity, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return execute(url, method, requestCallback, responseExtractor); + } + + @Override + public ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, ParameterizedTypeReference responseType, + Object... uriVariables) throws RestClientException { + + Type type = responseType.getType(); + AsyncRequestCallback requestCallback = httpEntityCallback(requestEntity, type); + ResponseExtractor> responseExtractor = responseEntityExtractor(type); + return execute(url, method, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, ParameterizedTypeReference responseType, + Map uriVariables) throws RestClientException { + + Type type = responseType.getType(); + AsyncRequestCallback requestCallback = httpEntityCallback(requestEntity, type); + ResponseExtractor> responseExtractor = responseEntityExtractor(type); + return execute(url, method, requestCallback, responseExtractor, uriVariables); + } + + @Override + public ListenableFuture> exchange(URI url, HttpMethod method, + @Nullable HttpEntity requestEntity, ParameterizedTypeReference responseType) + throws RestClientException { + + Type type = responseType.getType(); + AsyncRequestCallback requestCallback = httpEntityCallback(requestEntity, type); + ResponseExtractor> responseExtractor = responseEntityExtractor(type); + return execute(url, method, requestCallback, responseExtractor); + } + + + // general execution + + @Override + public ListenableFuture execute(String url, HttpMethod method, @Nullable AsyncRequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor, Object... uriVariables) throws RestClientException { + + URI expanded = getUriTemplateHandler().expand(url, uriVariables); + return doExecute(expanded, method, requestCallback, responseExtractor); + } + + @Override + public ListenableFuture execute(String url, HttpMethod method, + @Nullable AsyncRequestCallback requestCallback, @Nullable ResponseExtractor responseExtractor, + Map uriVariables) throws RestClientException { + + URI expanded = getUriTemplateHandler().expand(url, uriVariables); + return doExecute(expanded, method, requestCallback, responseExtractor); + } + + @Override + public ListenableFuture execute(URI url, HttpMethod method, + @Nullable AsyncRequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor) throws RestClientException { + + return doExecute(url, method, requestCallback, responseExtractor); + } + + /** + * Execute the given method on the provided URI. The + * {@link org.springframework.http.client.ClientHttpRequest} + * is processed using the {@link RequestCallback}; the response with + * the {@link ResponseExtractor}. + * @param url the fully-expanded URL to connect to + * @param method the HTTP method to execute (GET, POST, etc.) + * @param requestCallback object that prepares the request (can be {@code null}) + * @param responseExtractor object that extracts the return value from the response (can + * be {@code null}) + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + protected ListenableFuture doExecute(URI url, HttpMethod method, + @Nullable AsyncRequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor) throws RestClientException { + + Assert.notNull(url, "'url' must not be null"); + Assert.notNull(method, "'method' must not be null"); + try { + org.springframework.http.client.AsyncClientHttpRequest request = createAsyncRequest(url, method); + if (requestCallback != null) { + requestCallback.doWithRequest(request); + } + ListenableFuture responseFuture = request.executeAsync(); + return new ResponseExtractorFuture<>(method, url, responseFuture, responseExtractor); + } + catch (IOException ex) { + throw new ResourceAccessException("I/O error on " + method.name() + + " request for \"" + url + "\":" + ex.getMessage(), ex); + } + } + + private void logResponseStatus(HttpMethod method, URI url, ClientHttpResponse response) { + if (logger.isDebugEnabled()) { + try { + logger.debug("Async " + method.name() + " request for \"" + url + "\" resulted in " + + response.getRawStatusCode() + " (" + response.getStatusText() + ")"); + } + catch (IOException ex) { + // ignore + } + } + } + + private void handleResponseError(HttpMethod method, URI url, ClientHttpResponse response) throws IOException { + if (logger.isWarnEnabled()) { + try { + logger.warn("Async " + method.name() + " request for \"" + url + "\" resulted in " + + response.getRawStatusCode() + " (" + response.getStatusText() + "); invoking error handler"); + } + catch (IOException ex) { + // ignore + } + } + getErrorHandler().handleError(url, method, response); + } + + /** + * Returns a request callback implementation that prepares the request {@code Accept} + * headers based on the given response type and configured {@linkplain + * #getMessageConverters() message converters}. + */ + protected AsyncRequestCallback acceptHeaderRequestCallback(Class responseType) { + return new AsyncRequestCallbackAdapter(this.syncTemplate.acceptHeaderRequestCallback(responseType)); + } + + /** + * Returns a request callback implementation that writes the given object to the + * request stream. + */ + protected AsyncRequestCallback httpEntityCallback(@Nullable HttpEntity requestBody) { + return new AsyncRequestCallbackAdapter(this.syncTemplate.httpEntityCallback(requestBody)); + } + + /** + * Returns a request callback implementation that writes the given object to the + * request stream. + */ + protected AsyncRequestCallback httpEntityCallback(@Nullable HttpEntity request, Type responseType) { + return new AsyncRequestCallbackAdapter(this.syncTemplate.httpEntityCallback(request, responseType)); + } + + /** + * Returns a response extractor for {@link ResponseEntity}. + */ + protected ResponseExtractor> responseEntityExtractor(Type responseType) { + return this.syncTemplate.responseEntityExtractor(responseType); + } + + /** + * Returns a response extractor for {@link HttpHeaders}. + */ + protected ResponseExtractor headersExtractor() { + return this.syncTemplate.headersExtractor(); + } + + + /** + * Future returned from + * {@link #doExecute(URI, HttpMethod, AsyncRequestCallback, ResponseExtractor)}. + */ + private class ResponseExtractorFuture extends ListenableFutureAdapter { + + private final HttpMethod method; + + private final URI url; + + @Nullable + private final ResponseExtractor responseExtractor; + + public ResponseExtractorFuture(HttpMethod method, URI url, + ListenableFuture clientHttpResponseFuture, + @Nullable ResponseExtractor responseExtractor) { + + super(clientHttpResponseFuture); + this.method = method; + this.url = url; + this.responseExtractor = responseExtractor; + } + + @Override + @Nullable + protected final T adapt(ClientHttpResponse response) throws ExecutionException { + try { + if (!getErrorHandler().hasError(response)) { + logResponseStatus(this.method, this.url, response); + } + else { + handleResponseError(this.method, this.url, response); + } + return convertResponse(response); + } + catch (Throwable ex) { + throw new ExecutionException(ex); + } + finally { + response.close(); + } + } + + @Nullable + protected T convertResponse(ClientHttpResponse response) throws IOException { + return (this.responseExtractor != null ? this.responseExtractor.extractData(response) : null); + } + } + + + /** + * Adapts a {@link RequestCallback} to the {@link AsyncRequestCallback} interface. + */ + private static class AsyncRequestCallbackAdapter implements AsyncRequestCallback { + + private final RequestCallback adaptee; + + /** + * Create a new {@code AsyncRequestCallbackAdapter} from the given + * {@link RequestCallback}. + * @param requestCallback the callback to base this adapter on + */ + public AsyncRequestCallbackAdapter(RequestCallback requestCallback) { + this.adaptee = requestCallback; + } + + @Override + public void doWithRequest(final org.springframework.http.client.AsyncClientHttpRequest request) + throws IOException { + + this.adaptee.doWithRequest(new ClientHttpRequest() { + @Override + public ClientHttpResponse execute() throws IOException { + throw new UnsupportedOperationException("execute not supported"); + } + @Override + public OutputStream getBody() throws IOException { + return request.getBody(); + } + @Override + @Nullable + public HttpMethod getMethod() { + return request.getMethod(); + } + @Override + public String getMethodValue() { + return request.getMethodValue(); + } + @Override + public URI getURI() { + return request.getURI(); + } + @Override + public HttpHeaders getHeaders() { + return request.getHeaders(); + } + }); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..97e3561a3526dc37931b651c351889c38eba4a63 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java @@ -0,0 +1,180 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.FileCopyUtils; + +/** + * Spring's default implementation of the {@link ResponseErrorHandler} interface. + * + *

This error handler checks for the status code on the {@link ClientHttpResponse}: + * Any code with series {@link org.springframework.http.HttpStatus.Series#CLIENT_ERROR} + * or {@link org.springframework.http.HttpStatus.Series#SERVER_ERROR} is considered to be + * an error; this behavior can be changed by overriding the {@link #hasError(HttpStatus)} + * method. Unknown status codes will be ignored by {@link #hasError(ClientHttpResponse)}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.0 + * @see RestTemplate#setErrorHandler + */ +public class DefaultResponseErrorHandler implements ResponseErrorHandler { + + /** + * Delegates to {@link #hasError(HttpStatus)} (for a standard status enum value) or + * {@link #hasError(int)} (for an unknown status code) with the response status code. + * @see ClientHttpResponse#getRawStatusCode() + * @see #hasError(HttpStatus) + * @see #hasError(int) + */ + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + int rawStatusCode = response.getRawStatusCode(); + HttpStatus statusCode = HttpStatus.resolve(rawStatusCode); + return (statusCode != null ? hasError(statusCode) : hasError(rawStatusCode)); + } + + /** + * Template method called from {@link #hasError(ClientHttpResponse)}. + *

The default implementation checks {@link HttpStatus#isError()}. + * Can be overridden in subclasses. + * @param statusCode the HTTP status code as enum value + * @return {@code true} if the response indicates an error; {@code false} otherwise + * @see HttpStatus#isError() + */ + protected boolean hasError(HttpStatus statusCode) { + return statusCode.isError(); + } + + /** + * Template method called from {@link #hasError(ClientHttpResponse)}. + *

The default implementation checks if the given status code is + * {@code HttpStatus.Series#CLIENT_ERROR CLIENT_ERROR} or + * {@code HttpStatus.Series#SERVER_ERROR SERVER_ERROR}. + * Can be overridden in subclasses. + * @param unknownStatusCode the HTTP status code as raw value + * @return {@code true} if the response indicates an error; {@code false} otherwise + * @since 4.3.21 + * @see HttpStatus.Series#CLIENT_ERROR + * @see HttpStatus.Series#SERVER_ERROR + */ + protected boolean hasError(int unknownStatusCode) { + HttpStatus.Series series = HttpStatus.Series.resolve(unknownStatusCode); + return (series == HttpStatus.Series.CLIENT_ERROR || series == HttpStatus.Series.SERVER_ERROR); + } + + /** + * Delegates to {@link #handleError(ClientHttpResponse, HttpStatus)} with the + * response status code. + * @throws UnknownHttpStatusCodeException in case of an unresolvable status code + * @see #handleError(ClientHttpResponse, HttpStatus) + */ + @Override + public void handleError(ClientHttpResponse response) throws IOException { + HttpStatus statusCode = HttpStatus.resolve(response.getRawStatusCode()); + if (statusCode == null) { + throw new UnknownHttpStatusCodeException(response.getRawStatusCode(), response.getStatusText(), + response.getHeaders(), getResponseBody(response), getCharset(response)); + } + handleError(response, statusCode); + } + + /** + * Handle the error in the given response with the given resolved status code. + *

The default implementation throws an {@link HttpClientErrorException} + * if the status code is {@link HttpStatus.Series#CLIENT_ERROR}, an + * {@link HttpServerErrorException} if it is {@link HttpStatus.Series#SERVER_ERROR}, + * and an {@link UnknownHttpStatusCodeException} in other cases. + * @since 5.0 + * @see HttpClientErrorException#create + * @see HttpServerErrorException#create + */ + protected void handleError(ClientHttpResponse response, HttpStatus statusCode) throws IOException { + String statusText = response.getStatusText(); + HttpHeaders headers = response.getHeaders(); + byte[] body = getResponseBody(response); + Charset charset = getCharset(response); + switch (statusCode.series()) { + case CLIENT_ERROR: + throw HttpClientErrorException.create(statusCode, statusText, headers, body, charset); + case SERVER_ERROR: + throw HttpServerErrorException.create(statusCode, statusText, headers, body, charset); + default: + throw new UnknownHttpStatusCodeException(statusCode.value(), statusText, headers, body, charset); + } + } + + /** + * Determine the HTTP status of the given response. + * @param response the response to inspect + * @return the associated HTTP status + * @throws IOException in case of I/O errors + * @throws UnknownHttpStatusCodeException in case of an unknown status code + * that cannot be represented with the {@link HttpStatus} enum + * @since 4.3.8 + * @deprecated as of 5.0, in favor of {@link #handleError(ClientHttpResponse, HttpStatus)} + */ + @Deprecated + protected HttpStatus getHttpStatusCode(ClientHttpResponse response) throws IOException { + HttpStatus statusCode = HttpStatus.resolve(response.getRawStatusCode()); + if (statusCode == null) { + throw new UnknownHttpStatusCodeException(response.getRawStatusCode(), response.getStatusText(), + response.getHeaders(), getResponseBody(response), getCharset(response)); + } + return statusCode; + } + + /** + * Read the body of the given response (for inclusion in a status exception). + * @param response the response to inspect + * @return the response body as a byte array, + * or an empty byte array if the body could not be read + * @since 4.3.8 + */ + protected byte[] getResponseBody(ClientHttpResponse response) { + try { + return FileCopyUtils.copyToByteArray(response.getBody()); + } + catch (IOException ex) { + // ignore + } + return new byte[0]; + } + + /** + * Determine the charset of the response (for inclusion in a status exception). + * @param response the response to inspect + * @return the associated charset, or {@code null} if none + * @since 4.3.8 + */ + @Nullable + protected Charset getCharset(ClientHttpResponse response) { + HttpHeaders headers = response.getHeaders(); + MediaType contentType = headers.getContentType(); + return (contentType != null ? contentType.getCharset() : null); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..bceee0da2ed3223b59d02a455bbc61c780b1a772 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.http.HttpStatus; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; + +/** + * Implementation of {@link ResponseErrorHandler} that uses {@link HttpMessageConverter + * HttpMessageConverters} to convert HTTP error responses to {@link RestClientException + * RestClientExceptions}. + * + *

To use this error handler, you must specify a + * {@linkplain #setStatusMapping(Map) status mapping} and/or a + * {@linkplain #setSeriesMapping(Map) series mapping}. If either of these mappings has a match + * for the {@linkplain ClientHttpResponse#getStatusCode() status code} of a given + * {@code ClientHttpResponse}, {@link #hasError(ClientHttpResponse)} will return + * {@code true}, and {@link #handleError(ClientHttpResponse)} will attempt to use the + * {@linkplain #setMessageConverters(List) configured message converters} to convert the response + * into the mapped subclass of {@link RestClientException}. Note that the + * {@linkplain #setStatusMapping(Map) status mapping} takes precedence over + * {@linkplain #setSeriesMapping(Map) series mapping}. + * + *

If there is no match, this error handler will default to the behavior of + * {@link DefaultResponseErrorHandler}. Note that you can override this default behavior + * by specifying a {@linkplain #setSeriesMapping(Map) series mapping} from + * {@code HttpStatus.Series#CLIENT_ERROR} and/or {@code HttpStatus.Series#SERVER_ERROR} + * to {@code null}. + * + * @author Simon Galperin + * @author Arjen Poutsma + * @since 5.0 + * @see RestTemplate#setErrorHandler(ResponseErrorHandler) + */ +public class ExtractingResponseErrorHandler extends DefaultResponseErrorHandler { + + private List> messageConverters = Collections.emptyList(); + + private final Map> statusMapping = new LinkedHashMap<>(); + + private final Map> seriesMapping = new LinkedHashMap<>(); + + + /** + * Create a new, empty {@code ExtractingResponseErrorHandler}. + *

Note that {@link #setMessageConverters(List)} must be called when using this constructor. + */ + public ExtractingResponseErrorHandler() { + } + + /** + * Create a new {@code ExtractingResponseErrorHandler} with the given + * {@link HttpMessageConverter} instances. + * @param messageConverters the message converters to use + */ + public ExtractingResponseErrorHandler(List> messageConverters) { + this.messageConverters = messageConverters; + } + + + /** + * Set the message converters to use by this extractor. + */ + public void setMessageConverters(List> messageConverters) { + this.messageConverters = messageConverters; + } + + /** + * Set the mapping from HTTP status code to {@code RestClientException} subclass. + * If this mapping has a match + * for the {@linkplain ClientHttpResponse#getStatusCode() status code} of a given + * {@code ClientHttpResponse}, {@link #hasError(ClientHttpResponse)} will return + * {@code true} and {@link #handleError(ClientHttpResponse)} will attempt to use the + * {@linkplain #setMessageConverters(List) configured message converters} to convert the + * response into the mapped subclass of {@link RestClientException}. + */ + public void setStatusMapping(Map> statusMapping) { + if (!CollectionUtils.isEmpty(statusMapping)) { + this.statusMapping.putAll(statusMapping); + } + } + + /** + * Set the mapping from HTTP status series to {@code RestClientException} subclass. + * If this mapping has a match + * for the {@linkplain ClientHttpResponse#getStatusCode() status code} of a given + * {@code ClientHttpResponse}, {@link #hasError(ClientHttpResponse)} will return + * {@code true} and {@link #handleError(ClientHttpResponse)} will attempt to use the + * {@linkplain #setMessageConverters(List) configured message converters} to convert the + * response into the mapped subclass of {@link RestClientException}. + */ + public void setSeriesMapping(Map> seriesMapping) { + if (!CollectionUtils.isEmpty(seriesMapping)) { + this.seriesMapping.putAll(seriesMapping); + } + } + + + @Override + protected boolean hasError(HttpStatus statusCode) { + if (this.statusMapping.containsKey(statusCode)) { + return this.statusMapping.get(statusCode) != null; + } + else if (this.seriesMapping.containsKey(statusCode.series())) { + return this.seriesMapping.get(statusCode.series()) != null; + } + else { + return super.hasError(statusCode); + } + } + + @Override + public void handleError(ClientHttpResponse response, HttpStatus statusCode) throws IOException { + if (this.statusMapping.containsKey(statusCode)) { + extract(this.statusMapping.get(statusCode), response); + } + else if (this.seriesMapping.containsKey(statusCode.series())) { + extract(this.seriesMapping.get(statusCode.series()), response); + } + else { + super.handleError(response, statusCode); + } + } + + private void extract(@Nullable Class exceptionClass, + ClientHttpResponse response) throws IOException { + + if (exceptionClass == null) { + return; + } + + HttpMessageConverterExtractor extractor = + new HttpMessageConverterExtractor<>(exceptionClass, this.messageConverters); + RestClientException exception = extractor.extractData(response); + if (exception != null) { + throw exception; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/HttpClientErrorException.java b/spring-web/src/main/java/org/springframework/web/client/HttpClientErrorException.java new file mode 100644 index 0000000000000000000000000000000000000000..a8f3328af2a8b4859508f559a537f94300c7307d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/HttpClientErrorException.java @@ -0,0 +1,240 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; + +/** + * Exception thrown when an HTTP 4xx is received. + * + * @author Arjen Poutsma + * @since 3.0 + * @see DefaultResponseErrorHandler + */ +public class HttpClientErrorException extends HttpStatusCodeException { + + private static final long serialVersionUID = 5177019431887513952L; + + + /** + * Constructor with a status code only. + */ + public HttpClientErrorException(HttpStatus statusCode) { + super(statusCode); + } + + /** + * Constructor with a status code and status text. + */ + public HttpClientErrorException(HttpStatus statusCode, String statusText) { + super(statusCode, statusText); + } + + /** + * Constructor with a status code and status text, and content. + */ + public HttpClientErrorException( + HttpStatus statusCode, String statusText, @Nullable byte[] body, @Nullable Charset responseCharset) { + + super(statusCode, statusText, body, responseCharset); + } + + /** + * Constructor with a status code and status text, headers, and content. + */ + public HttpClientErrorException(HttpStatus statusCode, String statusText, + @Nullable HttpHeaders headers, @Nullable byte[] body, @Nullable Charset responseCharset) { + + super(statusCode, statusText, headers, body, responseCharset); + } + + + /** + * Create {@code HttpClientErrorException} or an HTTP status specific sub-class. + * @since 5.1 + */ + public static HttpClientErrorException create( + HttpStatus statusCode, String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + + switch (statusCode) { + case BAD_REQUEST: + return new HttpClientErrorException.BadRequest(statusText, headers, body, charset); + case UNAUTHORIZED: + return new HttpClientErrorException.Unauthorized(statusText, headers, body, charset); + case FORBIDDEN: + return new HttpClientErrorException.Forbidden(statusText, headers, body, charset); + case NOT_FOUND: + return new HttpClientErrorException.NotFound(statusText, headers, body, charset); + case METHOD_NOT_ALLOWED: + return new HttpClientErrorException.MethodNotAllowed(statusText, headers, body, charset); + case NOT_ACCEPTABLE: + return new HttpClientErrorException.NotAcceptable(statusText, headers, body, charset); + case CONFLICT: + return new HttpClientErrorException.Conflict(statusText, headers, body, charset); + case GONE: + return new HttpClientErrorException.Gone(statusText, headers, body, charset); + case UNSUPPORTED_MEDIA_TYPE: + return new HttpClientErrorException.UnsupportedMediaType(statusText, headers, body, charset); + case TOO_MANY_REQUESTS: + return new HttpClientErrorException.TooManyRequests(statusText, headers, body, charset); + case UNPROCESSABLE_ENTITY: + return new HttpClientErrorException.UnprocessableEntity(statusText, headers, body, charset); + default: + return new HttpClientErrorException(statusCode, statusText, headers, body, charset); + } + } + + + // Subclasses for specific HTTP status codes + + /** + * {@link HttpClientErrorException} for status HTTP 400 Bad Request. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class BadRequest extends HttpClientErrorException { + + BadRequest(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.BAD_REQUEST, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 401 Unauthorized. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Unauthorized extends HttpClientErrorException { + + Unauthorized(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.UNAUTHORIZED, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 403 Forbidden. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Forbidden extends HttpClientErrorException { + + Forbidden(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.FORBIDDEN, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 404 Not Found. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class NotFound extends HttpClientErrorException { + + NotFound(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.NOT_FOUND, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 405 Method Not Allowed. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class MethodNotAllowed extends HttpClientErrorException { + + MethodNotAllowed(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.METHOD_NOT_ALLOWED, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 406 Not Acceptable. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class NotAcceptable extends HttpClientErrorException { + + NotAcceptable(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.NOT_ACCEPTABLE, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 409 Conflict. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Conflict extends HttpClientErrorException { + + Conflict(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.CONFLICT, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 410 Gone. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Gone extends HttpClientErrorException { + + Gone(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.GONE, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 415 Unsupported Media Type. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class UnsupportedMediaType extends HttpClientErrorException { + + UnsupportedMediaType(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.UNSUPPORTED_MEDIA_TYPE, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 422 Unprocessable Entity. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class UnprocessableEntity extends HttpClientErrorException { + + UnprocessableEntity(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.UNPROCESSABLE_ENTITY, statusText, headers, body, charset); + } + } + + /** + * {@link HttpClientErrorException} for status HTTP 429 Too Many Requests. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class TooManyRequests extends HttpClientErrorException { + + TooManyRequests(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.TOO_MANY_REQUESTS, statusText, headers, body, charset); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java new file mode 100644 index 0000000000000000000000000000000000000000..f13766e510f6fcb435dd0944116b2295a1ef1f8e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.ResolvableType; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Response extractor that uses the given {@linkplain HttpMessageConverter entity converters} + * to convert the response into a type {@code T}. + * + * @author Arjen Poutsma + * @since 3.0 + * @param the data type + * @see RestTemplate + */ +public class HttpMessageConverterExtractor implements ResponseExtractor { + + private final Type responseType; + + @Nullable + private final Class responseClass; + + private final List> messageConverters; + + private final Log logger; + + + /** + * Create a new instance of the {@code HttpMessageConverterExtractor} with the given response + * type and message converters. The given converters must support the response type. + */ + public HttpMessageConverterExtractor(Class responseType, List> messageConverters) { + this((Type) responseType, messageConverters); + } + + /** + * Creates a new instance of the {@code HttpMessageConverterExtractor} with the given response + * type and message converters. The given converters must support the response type. + */ + public HttpMessageConverterExtractor(Type responseType, List> messageConverters) { + this(responseType, messageConverters, LogFactory.getLog(HttpMessageConverterExtractor.class)); + } + + @SuppressWarnings("unchecked") + HttpMessageConverterExtractor(Type responseType, List> messageConverters, Log logger) { + Assert.notNull(responseType, "'responseType' must not be null"); + Assert.notEmpty(messageConverters, "'messageConverters' must not be empty"); + this.responseType = responseType; + this.responseClass = (responseType instanceof Class ? (Class) responseType : null); + this.messageConverters = messageConverters; + this.logger = logger; + } + + + @Override + @SuppressWarnings({"unchecked", "rawtypes", "resource"}) + public T extractData(ClientHttpResponse response) throws IOException { + MessageBodyClientHttpResponseWrapper responseWrapper = new MessageBodyClientHttpResponseWrapper(response); + if (!responseWrapper.hasMessageBody() || responseWrapper.hasEmptyMessageBody()) { + return null; + } + MediaType contentType = getContentType(responseWrapper); + + try { + for (HttpMessageConverter messageConverter : this.messageConverters) { + if (messageConverter instanceof GenericHttpMessageConverter) { + GenericHttpMessageConverter genericMessageConverter = + (GenericHttpMessageConverter) messageConverter; + if (genericMessageConverter.canRead(this.responseType, null, contentType)) { + if (logger.isDebugEnabled()) { + ResolvableType resolvableType = ResolvableType.forType(this.responseType); + logger.debug("Reading to [" + resolvableType + "]"); + } + return (T) genericMessageConverter.read(this.responseType, null, responseWrapper); + } + } + if (this.responseClass != null) { + if (messageConverter.canRead(this.responseClass, contentType)) { + if (logger.isDebugEnabled()) { + String className = this.responseClass.getName(); + logger.debug("Reading to [" + className + "] as \"" + contentType + "\""); + } + return (T) messageConverter.read((Class) this.responseClass, responseWrapper); + } + } + } + } + catch (IOException | HttpMessageNotReadableException ex) { + throw new RestClientException("Error while extracting response for type [" + + this.responseType + "] and content type [" + contentType + "]", ex); + } + + throw new RestClientException("Could not extract response: no suitable HttpMessageConverter found " + + "for response type [" + this.responseType + "] and content type [" + contentType + "]"); + } + + /** + * Determine the Content-Type of the response based on the "Content-Type" + * header or otherwise default to {@link MediaType#APPLICATION_OCTET_STREAM}. + * @param response the response + * @return the MediaType, possibly {@code null}. + */ + @Nullable + protected MediaType getContentType(ClientHttpResponse response) { + MediaType contentType = response.getHeaders().getContentType(); + if (contentType == null) { + if (logger.isTraceEnabled()) { + logger.trace("No content-type, using 'application/octet-stream'"); + } + contentType = MediaType.APPLICATION_OCTET_STREAM; + } + return contentType; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/HttpServerErrorException.java b/spring-web/src/main/java/org/springframework/web/client/HttpServerErrorException.java new file mode 100644 index 0000000000000000000000000000000000000000..b0f7cc85f3f50c21eb9e275287f63c89c93a1a7d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/HttpServerErrorException.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; + +/** + * Exception thrown when an HTTP 5xx is received. + * + * @author Arjen Poutsma + * @since 3.0 + * @see DefaultResponseErrorHandler + */ +public class HttpServerErrorException extends HttpStatusCodeException { + + private static final long serialVersionUID = -2915754006618138282L; + + + /** + * Constructor with a status code only. + */ + public HttpServerErrorException(HttpStatus statusCode) { + super(statusCode); + } + + /** + * Constructor with a status code and status text. + */ + public HttpServerErrorException(HttpStatus statusCode, String statusText) { + super(statusCode, statusText); + } + + /** + * Constructor with a status code and status text, and content. + */ + public HttpServerErrorException( + HttpStatus statusCode, String statusText, @Nullable byte[] body, @Nullable Charset charset) { + + super(statusCode, statusText, body, charset); + } + + /** + * Constructor with a status code and status text, headers, and content. + */ + public HttpServerErrorException(HttpStatus statusCode, String statusText, + @Nullable HttpHeaders headers, @Nullable byte[] body, @Nullable Charset charset) { + + super(statusCode, statusText, headers, body, charset); + } + + + /** + * Create an {@code HttpServerErrorException} or an HTTP status specific sub-class. + * @since 5.1 + */ + public static HttpServerErrorException create( + HttpStatus statusCode, String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + + switch (statusCode) { + case INTERNAL_SERVER_ERROR: + return new HttpServerErrorException.InternalServerError(statusText, headers, body, charset); + case NOT_IMPLEMENTED: + return new HttpServerErrorException.NotImplemented(statusText, headers, body, charset); + case BAD_GATEWAY: + return new HttpServerErrorException.BadGateway(statusText, headers, body, charset); + case SERVICE_UNAVAILABLE: + return new HttpServerErrorException.ServiceUnavailable(statusText, headers, body, charset); + case GATEWAY_TIMEOUT: + return new HttpServerErrorException.GatewayTimeout(statusText, headers, body, charset); + default: + return new HttpServerErrorException(statusCode, statusText, headers, body, charset); + } + } + + + // Subclasses for specific HTTP status codes + + /** + * {@link HttpServerErrorException} for status HTTP 500 Internal Server Error. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class InternalServerError extends HttpServerErrorException { + + InternalServerError(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.INTERNAL_SERVER_ERROR, statusText, headers, body, charset); + } + } + + /** + * {@link HttpServerErrorException} for status HTTP 501 Not Implemented. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class NotImplemented extends HttpServerErrorException { + + NotImplemented(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.NOT_IMPLEMENTED, statusText, headers, body, charset); + } + } + + /** + * {@link HttpServerErrorException} for status HTTP HTTP 502 Bad Gateway. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class BadGateway extends HttpServerErrorException { + + BadGateway(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.BAD_GATEWAY, statusText, headers, body, charset); + } + } + + /** + * {@link HttpServerErrorException} for status HTTP 503 Service Unavailable. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class ServiceUnavailable extends HttpServerErrorException { + + ServiceUnavailable(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.SERVICE_UNAVAILABLE, statusText, headers, body, charset); + } + } + + /** + * {@link HttpServerErrorException} for status HTTP 504 Gateway Timeout. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class GatewayTimeout extends HttpServerErrorException { + + GatewayTimeout(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + super(HttpStatus.GATEWAY_TIMEOUT, statusText, headers, body, charset); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/HttpStatusCodeException.java b/spring-web/src/main/java/org/springframework/web/client/HttpStatusCodeException.java new file mode 100644 index 0000000000000000000000000000000000000000..b224cbd44047fb70f7a9c6e5ccc8b0aee33492cc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/HttpStatusCodeException.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; + +/** + * Abstract base class for exceptions based on an {@link HttpStatus}. + * + * @author Arjen Poutsma + * @author Chris Beams + * @author Rossen Stoyanchev + * @since 3.0 + */ +public abstract class HttpStatusCodeException extends RestClientResponseException { + + private static final long serialVersionUID = 5696801857651587810L; + + + private final HttpStatus statusCode; + + + /** + * Construct a new instance with an {@link HttpStatus}. + * @param statusCode the status code + */ + protected HttpStatusCodeException(HttpStatus statusCode) { + this(statusCode, statusCode.name(), null, null, null); + } + + /** + * Construct a new instance with an {@link HttpStatus} and status text. + * @param statusCode the status code + * @param statusText the status text + */ + protected HttpStatusCodeException(HttpStatus statusCode, String statusText) { + this(statusCode, statusText, null, null, null); + } + + /** + * Construct instance with an {@link HttpStatus}, status text, and content. + * @param statusCode the status code + * @param statusText the status text + * @param responseBody the response body content, may be {@code null} + * @param responseCharset the response body charset, may be {@code null} + * @since 3.0.5 + */ + protected HttpStatusCodeException(HttpStatus statusCode, String statusText, + @Nullable byte[] responseBody, @Nullable Charset responseCharset) { + + this(statusCode, statusText, null, responseBody, responseCharset); + } + + /** + * Construct instance with an {@link HttpStatus}, status text, content, and + * a response charset. + * @param statusCode the status code + * @param statusText the status text + * @param responseHeaders the response headers, may be {@code null} + * @param responseBody the response body content, may be {@code null} + * @param responseCharset the response body charset, may be {@code null} + * @since 3.1.2 + */ + protected HttpStatusCodeException(HttpStatus statusCode, String statusText, + @Nullable HttpHeaders responseHeaders, @Nullable byte[] responseBody, @Nullable Charset responseCharset) { + + super(statusCode.value() + " " + statusText, statusCode.value(), statusText, + responseHeaders, responseBody, responseCharset); + this.statusCode = statusCode; + } + + + /** + * Return the HTTP status code. + */ + public HttpStatus getStatusCode() { + return this.statusCode; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/MessageBodyClientHttpResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/client/MessageBodyClientHttpResponseWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..85e1c8f1b7e8dd4c217b79b8210f7bd0d0c74959 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/MessageBodyClientHttpResponseWrapper.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PushbackInputStream; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; + +/** + * Implementation of {@link ClientHttpResponse} that can not only check if + * the response has a message body, but also if its length is 0 (i.e. empty) + * by actually reading the input stream. + * + * @author Brian Clozel + * @since 4.1.5 + * @see RFC 7230 Section 3.3.3 + */ +class MessageBodyClientHttpResponseWrapper implements ClientHttpResponse { + + private final ClientHttpResponse response; + + @Nullable + private PushbackInputStream pushbackInputStream; + + + public MessageBodyClientHttpResponseWrapper(ClientHttpResponse response) throws IOException { + this.response = response; + } + + + /** + * Indicates whether the response has a message body. + *

Implementation returns {@code false} for: + *

    + *
  • a response status of {@code 1XX}, {@code 204} or {@code 304}
  • + *
  • a {@code Content-Length} header of {@code 0}
  • + *
+ * @return {@code true} if the response has a message body, {@code false} otherwise + * @throws IOException in case of I/O errors + */ + public boolean hasMessageBody() throws IOException { + HttpStatus status = HttpStatus.resolve(getRawStatusCode()); + if (status != null && (status.is1xxInformational() || status == HttpStatus.NO_CONTENT || + status == HttpStatus.NOT_MODIFIED)) { + return false; + } + if (getHeaders().getContentLength() == 0) { + return false; + } + return true; + } + + /** + * Indicates whether the response has an empty message body. + *

Implementation tries to read the first bytes of the response stream: + *

    + *
  • if no bytes are available, the message body is empty
  • + *
  • otherwise it is not empty and the stream is reset to its start for further reading
  • + *
+ * @return {@code true} if the response has a zero-length message body, {@code false} otherwise + * @throws IOException in case of I/O errors + */ + @SuppressWarnings("ConstantConditions") + public boolean hasEmptyMessageBody() throws IOException { + InputStream body = this.response.getBody(); + // Per contract body shouldn't be null, but check anyway.. + if (body == null) { + return true; + } + if (body.markSupported()) { + body.mark(1); + if (body.read() == -1) { + return true; + } + else { + body.reset(); + return false; + } + } + else { + this.pushbackInputStream = new PushbackInputStream(body); + int b = this.pushbackInputStream.read(); + if (b == -1) { + return true; + } + else { + this.pushbackInputStream.unread(b); + return false; + } + } + } + + + @Override + public HttpHeaders getHeaders() { + return this.response.getHeaders(); + } + + @Override + public InputStream getBody() throws IOException { + return (this.pushbackInputStream != null ? this.pushbackInputStream : this.response.getBody()); + } + + @Override + public HttpStatus getStatusCode() throws IOException { + return this.response.getStatusCode(); + } + + @Override + public int getRawStatusCode() throws IOException { + return this.response.getRawStatusCode(); + } + + @Override + public String getStatusText() throws IOException { + return this.response.getStatusText(); + } + + @Override + public void close() { + this.response.close(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/RequestCallback.java b/spring-web/src/main/java/org/springframework/web/client/RequestCallback.java new file mode 100644 index 0000000000000000000000000000000000000000..e3d32f526f1494afb5de47950e8a74898076c495 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/RequestCallback.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.lang.reflect.Type; + +import org.springframework.http.client.ClientHttpRequest; + +/** + * Callback interface for code that operates on a {@link ClientHttpRequest}. + * Allows manipulating the request headers, and write to the request body. + * + *

Used internally by the {@link RestTemplate}, but also useful for + * application code. There several available factory methods: + *

    + *
  • {@link RestTemplate#acceptHeaderRequestCallback(Class)} + *
  • {@link RestTemplate#httpEntityCallback(Object)} + *
  • {@link RestTemplate#httpEntityCallback(Object, Type)} + *
+ * + * @author Arjen Poutsma + * @see RestTemplate#execute + * @since 3.0 + */ +@FunctionalInterface +public interface RequestCallback { + + /** + * Gets called by {@link RestTemplate#execute} with an opened {@code ClientHttpRequest}. + * Does not need to care about closing the request or about handling errors: + * this will all be handled by the {@code RestTemplate}. + * @param request the active HTTP request + * @throws IOException in case of I/O errors + */ + void doWithRequest(ClientHttpRequest request) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/ResourceAccessException.java b/spring-web/src/main/java/org/springframework/web/client/ResourceAccessException.java new file mode 100644 index 0000000000000000000000000000000000000000..a62207fd7b2c9b8396ed0464e113be1d3ac31477 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/ResourceAccessException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; + +/** + * Exception thrown when an I/O error occurs. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public class ResourceAccessException extends RestClientException { + + private static final long serialVersionUID = -8513182514355844870L; + + + /** + * Construct a new {@code ResourceAccessException} with the given message. + * @param msg the message + */ + public ResourceAccessException(String msg) { + super(msg); + } + + /** + * Construct a new {@code ResourceAccessException} with the given message and {@link IOException}. + * @param msg the message + * @param ex the {@code IOException} + */ + public ResourceAccessException(String msg, IOException ex) { + super(msg, ex); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/ResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/ResponseErrorHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..db2329f8fef09502a670579231a6598c87d6ab22 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/ResponseErrorHandler.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpResponse; + +/** + * Strategy interface used by the {@link RestTemplate} to determine + * whether a particular response has an error or not. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public interface ResponseErrorHandler { + + /** + * Indicate whether the given response has any errors. + *

Implementations will typically inspect the + * {@link ClientHttpResponse#getStatusCode() HttpStatus} of the response. + * @param response the response to inspect + * @return {@code true} if the response indicates an error; {@code false} otherwise + * @throws IOException in case of I/O errors + */ + boolean hasError(ClientHttpResponse response) throws IOException; + + /** + * Handle the error in the given response. + *

This method is only called when {@link #hasError(ClientHttpResponse)} + * has returned {@code true}. + * @param response the response with the error + * @throws IOException in case of I/O errors + */ + void handleError(ClientHttpResponse response) throws IOException; + + /** + * Alternative to {@link #handleError(ClientHttpResponse)} with extra + * information providing access to the request URL and HTTP method. + * @param url the request URL + * @param method the HTTP method + * @param response the response with the error + * @throws IOException in case of I/O errors + * @since 5.0 + */ + default void handleError(URI url, HttpMethod method, ClientHttpResponse response) throws IOException { + handleError(response); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/ResponseExtractor.java b/spring-web/src/main/java/org/springframework/web/client/ResponseExtractor.java new file mode 100644 index 0000000000000000000000000000000000000000..e36bb071563414762858f94c1d79a25d0d098de1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/ResponseExtractor.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.lang.reflect.Type; + +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; + +/** + * Generic callback interface used by {@link RestTemplate}'s retrieval methods + * Implementations of this interface perform the actual work of extracting data + * from a {@link ClientHttpResponse}, but don't need to worry about exception + * handling or closing resources. + * + *

Used internally by the {@link RestTemplate}, but also useful for + * application code. There is one available factory method, see + * {@link RestTemplate#responseEntityExtractor(Type)}. + * + * @author Arjen Poutsma + * @since 3.0 + * @param the data type + * @see RestTemplate#execute + */ +@FunctionalInterface +public interface ResponseExtractor { + + /** + * Extract data from the given {@code ClientHttpResponse} and return it. + * @param response the HTTP response + * @return the extracted data + * @throws IOException in case of I/O errors + */ + @Nullable + T extractData(ClientHttpResponse response) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClientException.java b/spring-web/src/main/java/org/springframework/web/client/RestClientException.java new file mode 100644 index 0000000000000000000000000000000000000000..4e25c7cf39089daf28b53300039eafa597dab50c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/RestClientException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import org.springframework.core.NestedRuntimeException; + +/** + * Base class for exceptions thrown by {@link RestTemplate} whenever it encounters + * client-side HTTP errors. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public class RestClientException extends NestedRuntimeException { + + private static final long serialVersionUID = -4084444984163796577L; + + + /** + * Construct a new instance of {@code RestClientException} with the given message. + * @param msg the message + */ + public RestClientException(String msg) { + super(msg); + } + + /** + * Construct a new instance of {@code RestClientException} with the given message and + * exception. + * @param msg the message + * @param ex the exception + */ + public RestClientException(String msg, Throwable ex) { + super(msg, ex); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClientResponseException.java b/spring-web/src/main/java/org/springframework/web/client/RestClientResponseException.java new file mode 100644 index 0000000000000000000000000000000000000000..f93f2334fa55f2caa9ee8af14c5d54ff2c03ee62 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/RestClientResponseException.java @@ -0,0 +1,128 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; + +/** + * Common base class for exceptions that contain actual HTTP response data. + * + * @author Rossen Stoyanchev + * @since 4.3 + */ +public class RestClientResponseException extends RestClientException { + + private static final long serialVersionUID = -8803556342728481792L; + + private static final Charset DEFAULT_CHARSET = StandardCharsets.ISO_8859_1; + + + private final int rawStatusCode; + + private final String statusText; + + private final byte[] responseBody; + + @Nullable + private final HttpHeaders responseHeaders; + + @Nullable + private final String responseCharset; + + + /** + * Construct a new instance of with the given response data. + * @param statusCode the raw status code value + * @param statusText the status text + * @param responseHeaders the response headers (may be {@code null}) + * @param responseBody the response body content (may be {@code null}) + * @param responseCharset the response body charset (may be {@code null}) + */ + public RestClientResponseException(String message, int statusCode, String statusText, + @Nullable HttpHeaders responseHeaders, @Nullable byte[] responseBody, @Nullable Charset responseCharset) { + + super(message); + this.rawStatusCode = statusCode; + this.statusText = statusText; + this.responseHeaders = responseHeaders; + this.responseBody = (responseBody != null ? responseBody : new byte[0]); + this.responseCharset = (responseCharset != null ? responseCharset.name() : null); + } + + + /** + * Return the raw HTTP status code value. + */ + public int getRawStatusCode() { + return this.rawStatusCode; + } + + /** + * Return the HTTP status text. + */ + public String getStatusText() { + return this.statusText; + } + + /** + * Return the HTTP response headers. + */ + @Nullable + public HttpHeaders getResponseHeaders() { + return this.responseHeaders; + } + + /** + * Return the response body as a byte array. + */ + public byte[] getResponseBodyAsByteArray() { + return this.responseBody; + } + + /** + * Return the response body converted to String. The charset used is that + * of the response "Content-Type" or otherwise {@code "UTF-8"}. + */ + public String getResponseBodyAsString() { + return getResponseBodyAsString(DEFAULT_CHARSET); + } + + /** + * Return the response body converted to String. The charset used is that + * of the response "Content-Type" or otherwise the one given. + * @param fallbackCharset the charset to use on if the response doesn't specify. + * @since 5.1.11 + */ + public String getResponseBodyAsString(Charset fallbackCharset) { + if (this.responseCharset == null) { + return new String(this.responseBody, fallbackCharset); + } + try { + return new String(this.responseBody, this.responseCharset); + } + catch (UnsupportedEncodingException ex) { + // should not occur + throw new IllegalStateException(ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/RestOperations.java b/spring-web/src/main/java/org/springframework/web/client/RestOperations.java new file mode 100644 index 0000000000000000000000000000000000000000..03d3cfa3505dfb77d3ec06492267f6f40f53e989 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/RestOperations.java @@ -0,0 +1,698 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.net.URI; +import java.util.Map; +import java.util.Set; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.lang.Nullable; + +/** + * Interface specifying a basic set of RESTful operations. + * Implemented by {@link RestTemplate}. Not often used directly, but a useful + * option to enhance testability, as it can easily be mocked or stubbed. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + * @see RestTemplate + */ +public interface RestOperations { + + // GET + + /** + * Retrieve a representation by doing a GET on the specified URL. + * The response (if any) is converted and returned. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param responseType the type of the return value + * @param uriVariables the variables to expand the template + * @return the converted object + */ + @Nullable + T getForObject(String url, Class responseType, Object... uriVariables) throws RestClientException; + + /** + * Retrieve a representation by doing a GET on the URI template. + * The response (if any) is converted and returned. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param responseType the type of the return value + * @param uriVariables the map containing variables for the URI template + * @return the converted object + */ + @Nullable + T getForObject(String url, Class responseType, Map uriVariables) throws RestClientException; + + /** + * Retrieve a representation by doing a GET on the URL . + * The response (if any) is converted and returned. + * @param url the URL + * @param responseType the type of the return value + * @return the converted object + */ + @Nullable + T getForObject(URI url, Class responseType) throws RestClientException; + + /** + * Retrieve an entity by doing a GET on the specified URL. + * The response is converted and stored in an {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param responseType the type of the return value + * @param uriVariables the variables to expand the template + * @return the entity + * @since 3.0.2 + */ + ResponseEntity getForEntity(String url, Class responseType, Object... uriVariables) + throws RestClientException; + + /** + * Retrieve a representation by doing a GET on the URI template. + * The response is converted and stored in an {@link ResponseEntity}. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param responseType the type of the return value + * @param uriVariables the map containing variables for the URI template + * @return the converted object + * @since 3.0.2 + */ + ResponseEntity getForEntity(String url, Class responseType, Map uriVariables) + throws RestClientException; + + /** + * Retrieve a representation by doing a GET on the URL . + * The response is converted and stored in an {@link ResponseEntity}. + * @param url the URL + * @param responseType the type of the return value + * @return the converted object + * @since 3.0.2 + */ + ResponseEntity getForEntity(URI url, Class responseType) throws RestClientException; + + + // HEAD + + /** + * Retrieve all headers of the resource specified by the URI template. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param uriVariables the variables to expand the template + * @return all HTTP headers of that resource + */ + HttpHeaders headForHeaders(String url, Object... uriVariables) throws RestClientException; + + /** + * Retrieve all headers of the resource specified by the URI template. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param uriVariables the map containing variables for the URI template + * @return all HTTP headers of that resource + */ + HttpHeaders headForHeaders(String url, Map uriVariables) throws RestClientException; + + /** + * Retrieve all headers of the resource specified by the URL. + * @param url the URL + * @return all HTTP headers of that resource + */ + HttpHeaders headForHeaders(URI url) throws RestClientException; + + + // POST + + /** + * Create a new resource by POSTing the given object to the URI template, and returns the value of + * the {@code Location} header. This header typically indicates where the new resource is stored. + *

URI Template variables are expanded using the given URI variables, if any. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the value for the {@code Location} header + * @see HttpEntity + */ + @Nullable + URI postForLocation(String url, @Nullable Object request, Object... uriVariables) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, and returns the value of + * the {@code Location} header. This header typically indicates where the new resource is stored. + *

URI Template variables are expanded using the given map. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the value for the {@code Location} header + * @see HttpEntity + */ + @Nullable + URI postForLocation(String url, @Nullable Object request, Map uriVariables) + throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URL, and returns the value of the + * {@code Location} header. This header typically indicates where the new resource is stored. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @return the value for the {@code Location} header + * @see HttpEntity + */ + @Nullable + URI postForLocation(URI url, @Nullable Object request) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, + * and returns the representation found in the response. + *

URI Template variables are expanded using the given URI variables, if any. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand the template + * @return the converted object + * @see HttpEntity + */ + @Nullable + T postForObject(String url, @Nullable Object request, Class responseType, + Object... uriVariables) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, + * and returns the representation found in the response. + *

URI Template variables are expanded using the given map. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand the template + * @return the converted object + * @see HttpEntity + */ + @Nullable + T postForObject(String url, @Nullable Object request, Class responseType, + Map uriVariables) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URL, + * and returns the representation found in the response. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param responseType the type of the return value + * @return the converted object + * @see HttpEntity + */ + @Nullable + T postForObject(URI url, @Nullable Object request, Class responseType) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, + * and returns the response as {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the converted object + * @since 3.0.2 + * @see HttpEntity + */ + ResponseEntity postForEntity(String url, @Nullable Object request, Class responseType, + Object... uriVariables) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URI template, + * and returns the response as {@link HttpEntity}. + *

URI Template variables are expanded using the given map. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @param uriVariables the variables to expand the template + * @return the converted object + * @since 3.0.2 + * @see HttpEntity + */ + ResponseEntity postForEntity(String url, @Nullable Object request, Class responseType, + Map uriVariables) throws RestClientException; + + /** + * Create a new resource by POSTing the given object to the URL, + * and returns the response as {@link ResponseEntity}. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

The body of the entity, or {@code request} itself, can be a + * {@link org.springframework.util.MultiValueMap MultiValueMap} to create a multipart request. + * The values in the {@code MultiValueMap} can be any Object representing the body of the part, + * or an {@link org.springframework.http.HttpEntity HttpEntity} representing a part with body + * and headers. + * @param url the URL + * @param request the Object to be POSTed (may be {@code null}) + * @return the converted object + * @since 3.0.2 + * @see HttpEntity + */ + ResponseEntity postForEntity(URI url, @Nullable Object request, Class responseType) + throws RestClientException; + + + // PUT + + /** + * Create or update a resource by PUTting the given object to the URI. + *

URI Template variables are expanded using the given URI variables, if any. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + * @param url the URL + * @param request the Object to be PUT (may be {@code null}) + * @param uriVariables the variables to expand the template + * @see HttpEntity + */ + void put(String url, @Nullable Object request, Object... uriVariables) throws RestClientException; + + /** + * Creates a new resource by PUTting the given object to URI template. + *

URI Template variables are expanded using the given map. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + * @param url the URL + * @param request the Object to be PUT (may be {@code null}) + * @param uriVariables the variables to expand the template + * @see HttpEntity + */ + void put(String url, @Nullable Object request, Map uriVariables) throws RestClientException; + + /** + * Creates a new resource by PUTting the given object to URL. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + * @param url the URL + * @param request the Object to be PUT (may be {@code null}) + * @see HttpEntity + */ + void put(URI url, @Nullable Object request) throws RestClientException; + + + // PATCH + + /** + * Update a resource by PATCHing the given object to the URI template, + * and return the representation found in the response. + *

URI Template variables are expanded using the given URI variables, if any. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

NOTE: The standard JDK HTTP library does not support HTTP PATCH. + * You need to use the Apache HttpComponents or OkHttp request factory. + * @param url the URL + * @param request the object to be PATCHed (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand the template + * @return the converted object + * @since 4.3.5 + * @see HttpEntity + * @see RestTemplate#setRequestFactory + * @see org.springframework.http.client.HttpComponentsAsyncClientHttpRequestFactory + * @see org.springframework.http.client.OkHttp3ClientHttpRequestFactory + */ + @Nullable + T patchForObject(String url, @Nullable Object request, Class responseType, Object... uriVariables) + throws RestClientException; + + /** + * Update a resource by PATCHing the given object to the URI template, + * and return the representation found in the response. + *

URI Template variables are expanded using the given map. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

NOTE: The standard JDK HTTP library does not support HTTP PATCH. + * You need to use the Apache HttpComponents or OkHttp request factory. + * @param url the URL + * @param request the object to be PATCHed (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand the template + * @return the converted object + * @since 4.3.5 + * @see HttpEntity + * @see RestTemplate#setRequestFactory + * @see org.springframework.http.client.HttpComponentsAsyncClientHttpRequestFactory + * @see org.springframework.http.client.OkHttp3ClientHttpRequestFactory + */ + @Nullable + T patchForObject(String url, @Nullable Object request, Class responseType, + Map uriVariables) throws RestClientException; + + /** + * Update a resource by PATCHing the given object to the URL, + * and return the representation found in the response. + *

The {@code request} parameter can be a {@link HttpEntity} in order to + * add additional HTTP headers to the request. + *

NOTE: The standard JDK HTTP library does not support HTTP PATCH. + * You need to use the Apache HttpComponents or OkHttp request factory. + * @param url the URL + * @param request the object to be PATCHed (may be {@code null}) + * @param responseType the type of the return value + * @return the converted object + * @since 4.3.5 + * @see HttpEntity + * @see RestTemplate#setRequestFactory + * @see org.springframework.http.client.HttpComponentsAsyncClientHttpRequestFactory + * @see org.springframework.http.client.OkHttp3ClientHttpRequestFactory + */ + @Nullable + T patchForObject(URI url, @Nullable Object request, Class responseType) + throws RestClientException; + + + + // DELETE + + /** + * Delete the resources at the specified URI. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param uriVariables the variables to expand in the template + */ + void delete(String url, Object... uriVariables) throws RestClientException; + + /** + * Delete the resources at the specified URI. + *

URI Template variables are expanded using the given map. + * + * @param url the URL + * @param uriVariables the variables to expand the template + */ + void delete(String url, Map uriVariables) throws RestClientException; + + /** + * Delete the resources at the specified URL. + * @param url the URL + */ + void delete(URI url) throws RestClientException; + + + // OPTIONS + + /** + * Return the value of the Allow header for the given URI. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param uriVariables the variables to expand in the template + * @return the value of the allow header + */ + Set optionsForAllow(String url, Object... uriVariables) throws RestClientException; + + /** + * Return the value of the Allow header for the given URI. + *

URI Template variables are expanded using the given map. + * @param url the URL + * @param uriVariables the variables to expand in the template + * @return the value of the allow header + */ + Set optionsForAllow(String url, Map uriVariables) throws RestClientException; + + /** + * Return the value of the Allow header for the given URL. + * @param url the URL + * @return the value of the allow header + */ + Set optionsForAllow(URI url) throws RestClientException; + + + // exchange + + /** + * Execute the HTTP method to the given URI template, writing the given request entity to the request, and + * returns the response as {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity + * @since 3.0.2 + */ + ResponseEntity exchange(String url, HttpMethod method, @Nullable HttpEntity requestEntity, + Class responseType, Object... uriVariables) throws RestClientException; + + /** + * Execute the HTTP method to the given URI template, writing the given request entity to the request, and + * returns the response as {@link ResponseEntity}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity + * @since 3.0.2 + */ + ResponseEntity exchange(String url, HttpMethod method, @Nullable HttpEntity requestEntity, + Class responseType, Map uriVariables) throws RestClientException; + + /** + * Execute the HTTP method to the given URI template, writing the given request entity to the request, and + * returns the response as {@link ResponseEntity}. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @return the response as entity + * @since 3.0.2 + */ + ResponseEntity exchange(URI url, HttpMethod method, @Nullable HttpEntity requestEntity, + Class responseType) throws RestClientException; + + /** + * Execute the HTTP method to the given URI template, writing the given + * request entity to the request, and returns the response as {@link ResponseEntity}. + * The given {@link ParameterizedTypeReference} is used to pass generic type information: + *

+	 * ParameterizedTypeReference<List<MyBean>> myBean =
+	 *     new ParameterizedTypeReference<List<MyBean>>() {};
+	 *
+	 * ResponseEntity<List<MyBean>> response =
+	 *     template.exchange("https://example.com",HttpMethod.GET, null, myBean);
+	 * 
+ * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the + * request (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity + * @since 3.2 + */ + ResponseEntity exchange(String url,HttpMethod method, @Nullable HttpEntity requestEntity, + ParameterizedTypeReference responseType, Object... uriVariables) throws RestClientException; + + /** + * Execute the HTTP method to the given URI template, writing the given + * request entity to the request, and returns the response as {@link ResponseEntity}. + * The given {@link ParameterizedTypeReference} is used to pass generic type information: + *
+	 * ParameterizedTypeReference<List<MyBean>> myBean =
+	 *     new ParameterizedTypeReference<List<MyBean>>() {};
+	 *
+	 * ResponseEntity<List<MyBean>> response =
+	 *     template.exchange("https://example.com",HttpMethod.GET, null, myBean);
+	 * 
+ * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @param uriVariables the variables to expand in the template + * @return the response as entity + * @since 3.2 + */ + ResponseEntity exchange(String url, HttpMethod method, @Nullable HttpEntity requestEntity, + ParameterizedTypeReference responseType, Map uriVariables) throws RestClientException; + + /** + * Execute the HTTP method to the given URI template, writing the given + * request entity to the request, and returns the response as {@link ResponseEntity}. + * The given {@link ParameterizedTypeReference} is used to pass generic type information: + *
+	 * ParameterizedTypeReference<List<MyBean>> myBean =
+	 *     new ParameterizedTypeReference<List<MyBean>>() {};
+	 *
+	 * ResponseEntity<List<MyBean>> response =
+	 *     template.exchange("https://example.com",HttpMethod.GET, null, myBean);
+	 * 
+ * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestEntity the entity (headers and/or body) to write to the request + * (may be {@code null}) + * @param responseType the type of the return value + * @return the response as entity + * @since 3.2 + */ + ResponseEntity exchange(URI url, HttpMethod method, @Nullable HttpEntity requestEntity, + ParameterizedTypeReference responseType) throws RestClientException; + + /** + * Execute the request specified in the given {@link RequestEntity} and return + * the response as {@link ResponseEntity}. Typically used in combination + * with the static builder methods on {@code RequestEntity}, for instance: + *
+	 * MyRequest body = ...
+	 * RequestEntity request = RequestEntity
+	 *     .post(new URI("https://example.com/foo"))
+	 *     .accept(MediaType.APPLICATION_JSON)
+	 *     .body(body);
+	 * ResponseEntity<MyResponse> response = template.exchange(request, MyResponse.class);
+	 * 
+ * @param requestEntity the entity to write to the request + * @param responseType the type of the return value + * @return the response as entity + * @since 4.1 + */ + ResponseEntity exchange(RequestEntity requestEntity, Class responseType) + throws RestClientException; + + /** + * Execute the request specified in the given {@link RequestEntity} and return + * the response as {@link ResponseEntity}. The given + * {@link ParameterizedTypeReference} is used to pass generic type information: + *
+	 * MyRequest body = ...
+	 * RequestEntity request = RequestEntity
+	 *     .post(new URI("https://example.com/foo"))
+	 *     .accept(MediaType.APPLICATION_JSON)
+	 *     .body(body);
+	 * ParameterizedTypeReference<List<MyResponse>> myBean =
+	 *     new ParameterizedTypeReference<List<MyResponse>>() {};
+	 * ResponseEntity<List<MyResponse>> response = template.exchange(request, myBean);
+	 * 
+ * @param requestEntity the entity to write to the request + * @param responseType the type of the return value + * @return the response as entity + * @since 4.1 + */ + ResponseEntity exchange(RequestEntity requestEntity, ParameterizedTypeReference responseType) + throws RestClientException; + + + // General execution + + /** + * Execute the HTTP method to the given URI template, preparing the request with the + * {@link RequestCallback}, and reading the response with a {@link ResponseExtractor}. + *

URI Template variables are expanded using the given URI variables, if any. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestCallback object that prepares the request + * @param responseExtractor object that extracts the return value from the response + * @param uriVariables the variables to expand in the template + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + @Nullable + T execute(String url, HttpMethod method, @Nullable RequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor, Object... uriVariables) + throws RestClientException; + + /** + * Execute the HTTP method to the given URI template, preparing the request with the + * {@link RequestCallback}, and reading the response with a {@link ResponseExtractor}. + *

URI Template variables are expanded using the given URI variables map. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestCallback object that prepares the request + * @param responseExtractor object that extracts the return value from the response + * @param uriVariables the variables to expand in the template + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + @Nullable + T execute(String url, HttpMethod method, @Nullable RequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor, Map uriVariables) + throws RestClientException; + + /** + * Execute the HTTP method to the given URL, preparing the request with the + * {@link RequestCallback}, and reading the response with a {@link ResponseExtractor}. + * @param url the URL + * @param method the HTTP method (GET, POST, etc) + * @param requestCallback object that prepares the request + * @param responseExtractor object that extracts the return value from the response + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + @Nullable + T execute(URI url, HttpMethod method, @Nullable RequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor) throws RestClientException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java new file mode 100644 index 0000000000000000000000000000000000000000..6cb0194ca4399280b9366ac5376c8b0732ddbbc5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -0,0 +1,1011 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.lang.reflect.Type; +import java.net.URI; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.support.InterceptingHttpAccessor; +import org.springframework.http.converter.ByteArrayHttpMessageConverter; +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.ResourceHttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.http.converter.cbor.MappingJackson2CborHttpMessageConverter; +import org.springframework.http.converter.feed.AtomFeedHttpMessageConverter; +import org.springframework.http.converter.feed.RssChannelHttpMessageConverter; +import org.springframework.http.converter.json.GsonHttpMessageConverter; +import org.springframework.http.converter.json.JsonbHttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.http.converter.smile.MappingJackson2SmileHttpMessageConverter; +import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; +import org.springframework.http.converter.xml.MappingJackson2XmlHttpMessageConverter; +import org.springframework.http.converter.xml.SourceHttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.DefaultUriBuilderFactory.EncodingMode; +import org.springframework.web.util.UriTemplateHandler; + +/** + * Synchronous client to perform HTTP requests, exposing a simple, template + * method API over underlying HTTP client libraries such as the JDK + * {@code HttpURLConnection}, Apache HttpComponents, and others. + * + *

The RestTemplate offers templates for common scenarios by HTTP method, in + * addition to the generalized {@code exchange} and {@code execute} methods that + * support of less frequent cases. + * + *

NOTE: As of 5.0 this class is in maintenance mode, with + * only minor requests for changes and bugs to be accepted going forward. Please, + * consider using the {@code org.springframework.web.reactive.client.WebClient} + * which has a more modern API and supports sync, async, and streaming scenarios. + * + * @author Arjen Poutsma + * @author Brian Clozel + * @author Roy Clarkson + * @author Juergen Hoeller + * @since 3.0 + * @see HttpMessageConverter + * @see RequestCallback + * @see ResponseExtractor + * @see ResponseErrorHandler + */ +public class RestTemplate extends InterceptingHttpAccessor implements RestOperations { + + private static final boolean romePresent; + + private static final boolean jaxb2Present; + + private static final boolean jackson2Present; + + private static final boolean jackson2XmlPresent; + + private static final boolean jackson2SmilePresent; + + private static final boolean jackson2CborPresent; + + private static final boolean gsonPresent; + + private static final boolean jsonbPresent; + + static { + ClassLoader classLoader = RestTemplate.class.getClassLoader(); + romePresent = ClassUtils.isPresent("com.rometools.rome.feed.WireFeed", classLoader); + jaxb2Present = ClassUtils.isPresent("javax.xml.bind.Binder", classLoader); + jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader) && + ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader); + jackson2XmlPresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.xml.XmlMapper", classLoader); + jackson2SmilePresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.smile.SmileFactory", classLoader); + jackson2CborPresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.cbor.CBORFactory", classLoader); + gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", classLoader); + jsonbPresent = ClassUtils.isPresent("javax.json.bind.Jsonb", classLoader); + } + + + private final List> messageConverters = new ArrayList<>(); + + private ResponseErrorHandler errorHandler = new DefaultResponseErrorHandler(); + + private UriTemplateHandler uriTemplateHandler; + + private final ResponseExtractor headersExtractor = new HeadersExtractor(); + + + /** + * Create a new instance of the {@link RestTemplate} using default settings. + * Default {@link HttpMessageConverter HttpMessageConverters} are initialized. + */ + public RestTemplate() { + this.messageConverters.add(new ByteArrayHttpMessageConverter()); + this.messageConverters.add(new StringHttpMessageConverter()); + this.messageConverters.add(new ResourceHttpMessageConverter(false)); + try { + this.messageConverters.add(new SourceHttpMessageConverter<>()); + } + catch (Error err) { + // Ignore when no TransformerFactory implementation is available + } + this.messageConverters.add(new AllEncompassingFormHttpMessageConverter()); + + if (romePresent) { + this.messageConverters.add(new AtomFeedHttpMessageConverter()); + this.messageConverters.add(new RssChannelHttpMessageConverter()); + } + + if (jackson2XmlPresent) { + this.messageConverters.add(new MappingJackson2XmlHttpMessageConverter()); + } + else if (jaxb2Present) { + this.messageConverters.add(new Jaxb2RootElementHttpMessageConverter()); + } + + if (jackson2Present) { + this.messageConverters.add(new MappingJackson2HttpMessageConverter()); + } + else if (gsonPresent) { + this.messageConverters.add(new GsonHttpMessageConverter()); + } + else if (jsonbPresent) { + this.messageConverters.add(new JsonbHttpMessageConverter()); + } + + if (jackson2SmilePresent) { + this.messageConverters.add(new MappingJackson2SmileHttpMessageConverter()); + } + if (jackson2CborPresent) { + this.messageConverters.add(new MappingJackson2CborHttpMessageConverter()); + } + + this.uriTemplateHandler = initUriTemplateHandler(); + } + + /** + * Create a new instance of the {@link RestTemplate} based on the given {@link ClientHttpRequestFactory}. + * @param requestFactory the HTTP request factory to use + * @see org.springframework.http.client.SimpleClientHttpRequestFactory + * @see org.springframework.http.client.HttpComponentsClientHttpRequestFactory + */ + public RestTemplate(ClientHttpRequestFactory requestFactory) { + this(); + setRequestFactory(requestFactory); + } + + /** + * Create a new instance of the {@link RestTemplate} using the given list of + * {@link HttpMessageConverter} to use. + * @param messageConverters the list of {@link HttpMessageConverter} to use + * @since 3.2.7 + */ + public RestTemplate(List> messageConverters) { + Assert.notEmpty(messageConverters, "At least one HttpMessageConverter required"); + this.messageConverters.addAll(messageConverters); + this.uriTemplateHandler = initUriTemplateHandler(); + } + + + private static DefaultUriBuilderFactory initUriTemplateHandler() { + DefaultUriBuilderFactory uriFactory = new DefaultUriBuilderFactory(); + uriFactory.setEncodingMode(EncodingMode.URI_COMPONENT); // for backwards compatibility.. + return uriFactory; + } + + + /** + * Set the message body converters to use. + *

These converters are used to convert from and to HTTP requests and responses. + */ + public void setMessageConverters(List> messageConverters) { + Assert.notEmpty(messageConverters, "At least one HttpMessageConverter required"); + // Take getMessageConverters() List as-is when passed in here + if (this.messageConverters != messageConverters) { + this.messageConverters.clear(); + this.messageConverters.addAll(messageConverters); + } + } + + /** + * Return the list of message body converters. + *

The returned {@link List} is active and may get appended to. + */ + public List> getMessageConverters() { + return this.messageConverters; + } + + /** + * Set the error handler. + *

By default, RestTemplate uses a {@link DefaultResponseErrorHandler}. + */ + public void setErrorHandler(ResponseErrorHandler errorHandler) { + Assert.notNull(errorHandler, "ResponseErrorHandler must not be null"); + this.errorHandler = errorHandler; + } + + /** + * Return the error handler. + */ + public ResponseErrorHandler getErrorHandler() { + return this.errorHandler; + } + + /** + * Configure default URI variable values. This is a shortcut for: + *

+	 * DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory();
+	 * handler.setDefaultUriVariables(...);
+	 *
+	 * RestTemplate restTemplate = new RestTemplate();
+	 * restTemplate.setUriTemplateHandler(handler);
+	 * 
+ * @param uriVars the default URI variable values + * @since 4.3 + */ + @SuppressWarnings("deprecation") + public void setDefaultUriVariables(Map uriVars) { + if (this.uriTemplateHandler instanceof DefaultUriBuilderFactory) { + ((DefaultUriBuilderFactory) this.uriTemplateHandler).setDefaultUriVariables(uriVars); + } + else if (this.uriTemplateHandler instanceof org.springframework.web.util.AbstractUriTemplateHandler) { + ((org.springframework.web.util.AbstractUriTemplateHandler) this.uriTemplateHandler) + .setDefaultUriVariables(uriVars); + } + else { + throw new IllegalArgumentException( + "This property is not supported with the configured UriTemplateHandler."); + } + } + + /** + * Configure a strategy for expanding URI templates. + *

By default, {@link DefaultUriBuilderFactory} is used and for + * backwards compatibility, the encoding mode is set to + * {@link EncodingMode#URI_COMPONENT URI_COMPONENT}. As of 5.0.8, prefer + * using {@link EncodingMode#TEMPLATE_AND_VALUES TEMPLATE_AND_VALUES}. + *

Note: in 5.0 the switch from + * {@link org.springframework.web.util.DefaultUriTemplateHandler + * DefaultUriTemplateHandler} (deprecated in 4.3), as the default to use, to + * {@link DefaultUriBuilderFactory} brings in a different default for the + * {@code parsePath} property (switching from false to true). + * @param handler the URI template handler to use + */ + public void setUriTemplateHandler(UriTemplateHandler handler) { + Assert.notNull(handler, "UriTemplateHandler must not be null"); + this.uriTemplateHandler = handler; + } + + /** + * Return the configured URI template handler. + */ + public UriTemplateHandler getUriTemplateHandler() { + return this.uriTemplateHandler; + } + + + // GET + + @Override + @Nullable + public T getForObject(String url, Class responseType, Object... uriVariables) throws RestClientException { + RequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + return execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables); + } + + @Override + @Nullable + public T getForObject(String url, Class responseType, Map uriVariables) throws RestClientException { + RequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + return execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables); + } + + @Override + @Nullable + public T getForObject(URI url, Class responseType) throws RestClientException { + RequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + return execute(url, HttpMethod.GET, requestCallback, responseExtractor); + } + + @Override + public ResponseEntity getForEntity(String url, Class responseType, Object... uriVariables) + throws RestClientException { + + RequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity getForEntity(String url, Class responseType, Map uriVariables) + throws RestClientException { + + RequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity getForEntity(URI url, Class responseType) throws RestClientException { + RequestCallback requestCallback = acceptHeaderRequestCallback(responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, HttpMethod.GET, requestCallback, responseExtractor)); + } + + + // HEAD + + @Override + public HttpHeaders headForHeaders(String url, Object... uriVariables) throws RestClientException { + return nonNull(execute(url, HttpMethod.HEAD, null, headersExtractor(), uriVariables)); + } + + @Override + public HttpHeaders headForHeaders(String url, Map uriVariables) throws RestClientException { + return nonNull(execute(url, HttpMethod.HEAD, null, headersExtractor(), uriVariables)); + } + + @Override + public HttpHeaders headForHeaders(URI url) throws RestClientException { + return nonNull(execute(url, HttpMethod.HEAD, null, headersExtractor())); + } + + + // POST + + @Override + @Nullable + public URI postForLocation(String url, @Nullable Object request, Object... uriVariables) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request); + HttpHeaders headers = execute(url, HttpMethod.POST, requestCallback, headersExtractor(), uriVariables); + return (headers != null ? headers.getLocation() : null); + } + + @Override + @Nullable + public URI postForLocation(String url, @Nullable Object request, Map uriVariables) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request); + HttpHeaders headers = execute(url, HttpMethod.POST, requestCallback, headersExtractor(), uriVariables); + return (headers != null ? headers.getLocation() : null); + } + + @Override + @Nullable + public URI postForLocation(URI url, @Nullable Object request) throws RestClientException { + RequestCallback requestCallback = httpEntityCallback(request); + HttpHeaders headers = execute(url, HttpMethod.POST, requestCallback, headersExtractor()); + return (headers != null ? headers.getLocation() : null); + } + + @Override + @Nullable + public T postForObject(String url, @Nullable Object request, Class responseType, + Object... uriVariables) throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + return execute(url, HttpMethod.POST, requestCallback, responseExtractor, uriVariables); + } + + @Override + @Nullable + public T postForObject(String url, @Nullable Object request, Class responseType, + Map uriVariables) throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + return execute(url, HttpMethod.POST, requestCallback, responseExtractor, uriVariables); + } + + @Override + @Nullable + public T postForObject(URI url, @Nullable Object request, Class responseType) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters()); + return execute(url, HttpMethod.POST, requestCallback, responseExtractor); + } + + @Override + public ResponseEntity postForEntity(String url, @Nullable Object request, + Class responseType, Object... uriVariables) throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, HttpMethod.POST, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity postForEntity(String url, @Nullable Object request, + Class responseType, Map uriVariables) throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, HttpMethod.POST, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity postForEntity(URI url, @Nullable Object request, Class responseType) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, HttpMethod.POST, requestCallback, responseExtractor)); + } + + + // PUT + + @Override + public void put(String url, @Nullable Object request, Object... uriVariables) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request); + execute(url, HttpMethod.PUT, requestCallback, null, uriVariables); + } + + @Override + public void put(String url, @Nullable Object request, Map uriVariables) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request); + execute(url, HttpMethod.PUT, requestCallback, null, uriVariables); + } + + @Override + public void put(URI url, @Nullable Object request) throws RestClientException { + RequestCallback requestCallback = httpEntityCallback(request); + execute(url, HttpMethod.PUT, requestCallback, null); + } + + + // PATCH + + @Override + @Nullable + public T patchForObject(String url, @Nullable Object request, Class responseType, + Object... uriVariables) throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + return execute(url, HttpMethod.PATCH, requestCallback, responseExtractor, uriVariables); + } + + @Override + @Nullable + public T patchForObject(String url, @Nullable Object request, Class responseType, + Map uriVariables) throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + return execute(url, HttpMethod.PATCH, requestCallback, responseExtractor, uriVariables); + } + + @Override + @Nullable + public T patchForObject(URI url, @Nullable Object request, Class responseType) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(request, responseType); + HttpMessageConverterExtractor responseExtractor = + new HttpMessageConverterExtractor<>(responseType, getMessageConverters()); + return execute(url, HttpMethod.PATCH, requestCallback, responseExtractor); + } + + + // DELETE + + @Override + public void delete(String url, Object... uriVariables) throws RestClientException { + execute(url, HttpMethod.DELETE, null, null, uriVariables); + } + + @Override + public void delete(String url, Map uriVariables) throws RestClientException { + execute(url, HttpMethod.DELETE, null, null, uriVariables); + } + + @Override + public void delete(URI url) throws RestClientException { + execute(url, HttpMethod.DELETE, null, null); + } + + + // OPTIONS + + @Override + public Set optionsForAllow(String url, Object... uriVariables) throws RestClientException { + ResponseExtractor headersExtractor = headersExtractor(); + HttpHeaders headers = execute(url, HttpMethod.OPTIONS, null, headersExtractor, uriVariables); + return (headers != null ? headers.getAllow() : Collections.emptySet()); + } + + @Override + public Set optionsForAllow(String url, Map uriVariables) throws RestClientException { + ResponseExtractor headersExtractor = headersExtractor(); + HttpHeaders headers = execute(url, HttpMethod.OPTIONS, null, headersExtractor, uriVariables); + return (headers != null ? headers.getAllow() : Collections.emptySet()); + } + + @Override + public Set optionsForAllow(URI url) throws RestClientException { + ResponseExtractor headersExtractor = headersExtractor(); + HttpHeaders headers = execute(url, HttpMethod.OPTIONS, null, headersExtractor); + return (headers != null ? headers.getAllow() : Collections.emptySet()); + } + + + // exchange + + @Override + public ResponseEntity exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType, Object... uriVariables) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(requestEntity, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, method, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity exchange(String url, HttpMethod method, + @Nullable HttpEntity requestEntity, Class responseType, Map uriVariables) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(requestEntity, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, method, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity exchange(URI url, HttpMethod method, @Nullable HttpEntity requestEntity, + Class responseType) throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(requestEntity, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(execute(url, method, requestCallback, responseExtractor)); + } + + @Override + public ResponseEntity exchange(String url, HttpMethod method, @Nullable HttpEntity requestEntity, + ParameterizedTypeReference responseType, Object... uriVariables) throws RestClientException { + + Type type = responseType.getType(); + RequestCallback requestCallback = httpEntityCallback(requestEntity, type); + ResponseExtractor> responseExtractor = responseEntityExtractor(type); + return nonNull(execute(url, method, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity exchange(String url, HttpMethod method, @Nullable HttpEntity requestEntity, + ParameterizedTypeReference responseType, Map uriVariables) throws RestClientException { + + Type type = responseType.getType(); + RequestCallback requestCallback = httpEntityCallback(requestEntity, type); + ResponseExtractor> responseExtractor = responseEntityExtractor(type); + return nonNull(execute(url, method, requestCallback, responseExtractor, uriVariables)); + } + + @Override + public ResponseEntity exchange(URI url, HttpMethod method, @Nullable HttpEntity requestEntity, + ParameterizedTypeReference responseType) throws RestClientException { + + Type type = responseType.getType(); + RequestCallback requestCallback = httpEntityCallback(requestEntity, type); + ResponseExtractor> responseExtractor = responseEntityExtractor(type); + return nonNull(execute(url, method, requestCallback, responseExtractor)); + } + + @Override + public ResponseEntity exchange(RequestEntity requestEntity, Class responseType) + throws RestClientException { + + RequestCallback requestCallback = httpEntityCallback(requestEntity, responseType); + ResponseExtractor> responseExtractor = responseEntityExtractor(responseType); + return nonNull(doExecute(requestEntity.getUrl(), requestEntity.getMethod(), requestCallback, responseExtractor)); + } + + @Override + public ResponseEntity exchange(RequestEntity requestEntity, ParameterizedTypeReference responseType) + throws RestClientException { + + Type type = responseType.getType(); + RequestCallback requestCallback = httpEntityCallback(requestEntity, type); + ResponseExtractor> responseExtractor = responseEntityExtractor(type); + return nonNull(doExecute(requestEntity.getUrl(), requestEntity.getMethod(), requestCallback, responseExtractor)); + } + + + // General execution + + /** + * {@inheritDoc} + *

To provide a {@code RequestCallback} or {@code ResponseExtractor} only, + * but not both, consider using: + *

    + *
  • {@link #acceptHeaderRequestCallback(Class)} + *
  • {@link #httpEntityCallback(Object)} + *
  • {@link #httpEntityCallback(Object, Type)} + *
  • {@link #responseEntityExtractor(Type)} + *
+ */ + @Override + @Nullable + public T execute(String url, HttpMethod method, @Nullable RequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor, Object... uriVariables) throws RestClientException { + + URI expanded = getUriTemplateHandler().expand(url, uriVariables); + return doExecute(expanded, method, requestCallback, responseExtractor); + } + + /** + * {@inheritDoc} + *

To provide a {@code RequestCallback} or {@code ResponseExtractor} only, + * but not both, consider using: + *

    + *
  • {@link #acceptHeaderRequestCallback(Class)} + *
  • {@link #httpEntityCallback(Object)} + *
  • {@link #httpEntityCallback(Object, Type)} + *
  • {@link #responseEntityExtractor(Type)} + *
+ */ + @Override + @Nullable + public T execute(String url, HttpMethod method, @Nullable RequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor, Map uriVariables) + throws RestClientException { + + URI expanded = getUriTemplateHandler().expand(url, uriVariables); + return doExecute(expanded, method, requestCallback, responseExtractor); + } + + /** + * {@inheritDoc} + *

To provide a {@code RequestCallback} or {@code ResponseExtractor} only, + * but not both, consider using: + *

    + *
  • {@link #acceptHeaderRequestCallback(Class)} + *
  • {@link #httpEntityCallback(Object)} + *
  • {@link #httpEntityCallback(Object, Type)} + *
  • {@link #responseEntityExtractor(Type)} + *
+ */ + @Override + @Nullable + public T execute(URI url, HttpMethod method, @Nullable RequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor) throws RestClientException { + + return doExecute(url, method, requestCallback, responseExtractor); + } + + /** + * Execute the given method on the provided URI. + *

The {@link ClientHttpRequest} is processed using the {@link RequestCallback}; + * the response with the {@link ResponseExtractor}. + * @param url the fully-expanded URL to connect to + * @param method the HTTP method to execute (GET, POST, etc.) + * @param requestCallback object that prepares the request (can be {@code null}) + * @param responseExtractor object that extracts the return value from the response (can be {@code null}) + * @return an arbitrary object, as returned by the {@link ResponseExtractor} + */ + @Nullable + protected T doExecute(URI url, @Nullable HttpMethod method, @Nullable RequestCallback requestCallback, + @Nullable ResponseExtractor responseExtractor) throws RestClientException { + + Assert.notNull(url, "URI is required"); + Assert.notNull(method, "HttpMethod is required"); + ClientHttpResponse response = null; + try { + ClientHttpRequest request = createRequest(url, method); + if (requestCallback != null) { + requestCallback.doWithRequest(request); + } + response = request.execute(); + handleResponse(url, method, response); + return (responseExtractor != null ? responseExtractor.extractData(response) : null); + } + catch (IOException ex) { + String resource = url.toString(); + String query = url.getRawQuery(); + resource = (query != null ? resource.substring(0, resource.indexOf('?')) : resource); + throw new ResourceAccessException("I/O error on " + method.name() + + " request for \"" + resource + "\": " + ex.getMessage(), ex); + } + finally { + if (response != null) { + response.close(); + } + } + } + + /** + * Handle the given response, performing appropriate logging and + * invoking the {@link ResponseErrorHandler} if necessary. + *

Can be overridden in subclasses. + * @param url the fully-expanded URL to connect to + * @param method the HTTP method to execute (GET, POST, etc.) + * @param response the resulting {@link ClientHttpResponse} + * @throws IOException if propagated from {@link ResponseErrorHandler} + * @since 4.1.6 + * @see #setErrorHandler + */ + protected void handleResponse(URI url, HttpMethod method, ClientHttpResponse response) throws IOException { + ResponseErrorHandler errorHandler = getErrorHandler(); + boolean hasError = errorHandler.hasError(response); + if (logger.isDebugEnabled()) { + try { + int code = response.getRawStatusCode(); + HttpStatus status = HttpStatus.resolve(code); + logger.debug("Response " + (status != null ? status : code)); + } + catch (IOException ex) { + // ignore + } + } + if (hasError) { + errorHandler.handleError(url, method, response); + } + } + + /** + * Return a {@code RequestCallback} that sets the request {@code Accept} + * header based on the given response type, cross-checked against the + * configured message converters. + */ + public RequestCallback acceptHeaderRequestCallback(Class responseType) { + return new AcceptHeaderRequestCallback(responseType); + } + + /** + * Return a {@code RequestCallback} implementation that writes the given + * object to the request stream. + */ + public RequestCallback httpEntityCallback(@Nullable Object requestBody) { + return new HttpEntityRequestCallback(requestBody); + } + + /** + * Return a {@code RequestCallback} implementation that: + *

    + *
  1. Sets the request {@code Accept} header based on the given response + * type, cross-checked against the configured message converters. + *
  2. Writes the given object to the request stream. + *
+ */ + public RequestCallback httpEntityCallback(@Nullable Object requestBody, Type responseType) { + return new HttpEntityRequestCallback(requestBody, responseType); + } + + /** + * Return a {@code ResponseExtractor} that prepares a {@link ResponseEntity}. + */ + public ResponseExtractor> responseEntityExtractor(Type responseType) { + return new ResponseEntityResponseExtractor<>(responseType); + } + + /** + * Return a response extractor for {@link HttpHeaders}. + */ + protected ResponseExtractor headersExtractor() { + return this.headersExtractor; + } + + private static T nonNull(@Nullable T result) { + Assert.state(result != null, "No result"); + return result; + } + + + /** + * Request callback implementation that prepares the request's accept headers. + */ + private class AcceptHeaderRequestCallback implements RequestCallback { + + @Nullable + private final Type responseType; + + public AcceptHeaderRequestCallback(@Nullable Type responseType) { + this.responseType = responseType; + } + + @Override + public void doWithRequest(ClientHttpRequest request) throws IOException { + if (this.responseType != null) { + List allSupportedMediaTypes = getMessageConverters().stream() + .filter(converter -> canReadResponse(this.responseType, converter)) + .flatMap(this::getSupportedMediaTypes) + .distinct() + .sorted(MediaType.SPECIFICITY_COMPARATOR) + .collect(Collectors.toList()); + if (logger.isDebugEnabled()) { + logger.debug("Accept=" + allSupportedMediaTypes); + } + request.getHeaders().setAccept(allSupportedMediaTypes); + } + } + + private boolean canReadResponse(Type responseType, HttpMessageConverter converter) { + Class responseClass = (responseType instanceof Class ? (Class) responseType : null); + if (responseClass != null) { + return converter.canRead(responseClass, null); + } + else if (converter instanceof GenericHttpMessageConverter) { + GenericHttpMessageConverter genericConverter = (GenericHttpMessageConverter) converter; + return genericConverter.canRead(responseType, null, null); + } + return false; + } + + private Stream getSupportedMediaTypes(HttpMessageConverter messageConverter) { + return messageConverter.getSupportedMediaTypes() + .stream() + .map(mediaType -> { + if (mediaType.getCharset() != null) { + return new MediaType(mediaType.getType(), mediaType.getSubtype()); + } + return mediaType; + }); + } + } + + + /** + * Request callback implementation that writes the given object to the request stream. + */ + private class HttpEntityRequestCallback extends AcceptHeaderRequestCallback { + + private final HttpEntity requestEntity; + + public HttpEntityRequestCallback(@Nullable Object requestBody) { + this(requestBody, null); + } + + public HttpEntityRequestCallback(@Nullable Object requestBody, @Nullable Type responseType) { + super(responseType); + if (requestBody instanceof HttpEntity) { + this.requestEntity = (HttpEntity) requestBody; + } + else if (requestBody != null) { + this.requestEntity = new HttpEntity<>(requestBody); + } + else { + this.requestEntity = HttpEntity.EMPTY; + } + } + + @Override + @SuppressWarnings("unchecked") + public void doWithRequest(ClientHttpRequest httpRequest) throws IOException { + super.doWithRequest(httpRequest); + Object requestBody = this.requestEntity.getBody(); + if (requestBody == null) { + HttpHeaders httpHeaders = httpRequest.getHeaders(); + HttpHeaders requestHeaders = this.requestEntity.getHeaders(); + if (!requestHeaders.isEmpty()) { + requestHeaders.forEach((key, values) -> httpHeaders.put(key, new ArrayList<>(values))); + } + if (httpHeaders.getContentLength() < 0) { + httpHeaders.setContentLength(0L); + } + } + else { + Class requestBodyClass = requestBody.getClass(); + Type requestBodyType = (this.requestEntity instanceof RequestEntity ? + ((RequestEntity)this.requestEntity).getType() : requestBodyClass); + HttpHeaders httpHeaders = httpRequest.getHeaders(); + HttpHeaders requestHeaders = this.requestEntity.getHeaders(); + MediaType requestContentType = requestHeaders.getContentType(); + for (HttpMessageConverter messageConverter : getMessageConverters()) { + if (messageConverter instanceof GenericHttpMessageConverter) { + GenericHttpMessageConverter genericConverter = + (GenericHttpMessageConverter) messageConverter; + if (genericConverter.canWrite(requestBodyType, requestBodyClass, requestContentType)) { + if (!requestHeaders.isEmpty()) { + requestHeaders.forEach((key, values) -> httpHeaders.put(key, new ArrayList<>(values))); + } + logBody(requestBody, requestContentType, genericConverter); + genericConverter.write(requestBody, requestBodyType, requestContentType, httpRequest); + return; + } + } + else if (messageConverter.canWrite(requestBodyClass, requestContentType)) { + if (!requestHeaders.isEmpty()) { + requestHeaders.forEach((key, values) -> httpHeaders.put(key, new ArrayList<>(values))); + } + logBody(requestBody, requestContentType, messageConverter); + ((HttpMessageConverter) messageConverter).write( + requestBody, requestContentType, httpRequest); + return; + } + } + String message = "No HttpMessageConverter for " + requestBodyClass.getName(); + if (requestContentType != null) { + message += " and content type \"" + requestContentType + "\""; + } + throw new RestClientException(message); + } + } + + private void logBody(Object body, @Nullable MediaType mediaType, HttpMessageConverter converter) { + if (logger.isDebugEnabled()) { + if (mediaType != null) { + logger.debug("Writing [" + body + "] as \"" + mediaType + "\""); + } + else { + logger.debug("Writing [" + body + "] with " + converter.getClass().getName()); + } + } + } + } + + + /** + * Response extractor for {@link HttpEntity}. + */ + private class ResponseEntityResponseExtractor implements ResponseExtractor> { + + @Nullable + private final HttpMessageConverterExtractor delegate; + + public ResponseEntityResponseExtractor(@Nullable Type responseType) { + if (responseType != null && Void.class != responseType) { + this.delegate = new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger); + } + else { + this.delegate = null; + } + } + + @Override + public ResponseEntity extractData(ClientHttpResponse response) throws IOException { + if (this.delegate != null) { + T body = this.delegate.extractData(response); + return ResponseEntity.status(response.getRawStatusCode()).headers(response.getHeaders()).body(body); + } + else { + return ResponseEntity.status(response.getRawStatusCode()).headers(response.getHeaders()).build(); + } + } + } + + + /** + * Response extractor that extracts the response {@link HttpHeaders}. + */ + private static class HeadersExtractor implements ResponseExtractor { + + @Override + public HttpHeaders extractData(ClientHttpResponse response) { + return response.getHeaders(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/UnknownHttpStatusCodeException.java b/spring-web/src/main/java/org/springframework/web/client/UnknownHttpStatusCodeException.java new file mode 100644 index 0000000000000000000000000000000000000000..ac239513ae903d321f070a2f7056c431107cc2f8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/UnknownHttpStatusCodeException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; + +/** + * Exception thrown when an unknown (or custom) HTTP status code is received. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class UnknownHttpStatusCodeException extends RestClientResponseException { + + private static final long serialVersionUID = 7103980251635005491L; + + + /** + * Construct a new instance of {@code HttpStatusCodeException} based on an + * {@link HttpStatus}, status text, and response body content. + * @param rawStatusCode the raw status code value + * @param statusText the status text + * @param responseHeaders the response headers (may be {@code null}) + * @param responseBody the response body content (may be {@code null}) + * @param responseCharset the response body charset (may be {@code null}) + */ + public UnknownHttpStatusCodeException(int rawStatusCode, String statusText, @Nullable HttpHeaders responseHeaders, + @Nullable byte[] responseBody, @Nullable Charset responseCharset) { + + super("Unknown status code [" + rawStatusCode + "]" + " " + statusText, + rawStatusCode, statusText, responseHeaders, responseBody, responseCharset); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/package-info.java b/spring-web/src/main/java/org/springframework/web/client/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..0b20179c445f56b76e0108b817dbfa041dc49fd5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/package-info.java @@ -0,0 +1,10 @@ +/** + * Core package of the client-side web support. + * Provides a RestTemplate class and various callback interfaces. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.client; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/client/support/RestGatewaySupport.java b/spring-web/src/main/java/org/springframework/web/client/support/RestGatewaySupport.java new file mode 100644 index 0000000000000000000000000000000000000000..80854c2a1bdc7beed8f2ff58c268ec342e7e54f9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/support/RestGatewaySupport.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client.support; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.util.Assert; +import org.springframework.web.client.RestTemplate; + +/** + * Convenient super class for application classes that need REST access. + * + *

Requires a {@link ClientHttpRequestFactory} or a {@link RestTemplate} instance to be set. + * + * @author Arjen Poutsma + * @since 3.0 + * @see #setRestTemplate + * @see org.springframework.web.client.RestTemplate + */ +public class RestGatewaySupport { + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + private RestTemplate restTemplate; + + + /** + * Construct a new instance of the {@link RestGatewaySupport}, with default parameters. + */ + public RestGatewaySupport() { + this.restTemplate = new RestTemplate(); + } + + /** + * Construct a new instance of the {@link RestGatewaySupport}, with the given {@link ClientHttpRequestFactory}. + * @see RestTemplate#RestTemplate(ClientHttpRequestFactory) + */ + public RestGatewaySupport(ClientHttpRequestFactory requestFactory) { + Assert.notNull(requestFactory, "'requestFactory' must not be null"); + this.restTemplate = new RestTemplate(requestFactory); + } + + + /** + * Sets the {@link RestTemplate} for the gateway. + */ + public void setRestTemplate(RestTemplate restTemplate) { + Assert.notNull(restTemplate, "'restTemplate' must not be null"); + this.restTemplate = restTemplate; + } + + /** + * Returns the {@link RestTemplate} for the gateway. + */ + public RestTemplate getRestTemplate() { + return this.restTemplate; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/support/package-info.java b/spring-web/src/main/java/org/springframework/web/client/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..ba9c45a63fe6994a4d87f29a70906c6b6ee21322 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/support/package-info.java @@ -0,0 +1,10 @@ +/** + * Classes supporting the {@code org.springframework.web.client} package. + * Contains a base class for RestTemplate usage. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.client.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/context/AbstractContextLoaderInitializer.java b/spring-web/src/main/java/org/springframework/web/context/AbstractContextLoaderInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..6cb40688e34e1165e529807404f21df83e1ea59f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/AbstractContextLoaderInitializer.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import javax.servlet.ServletContext; +import javax.servlet.ServletException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.lang.Nullable; +import org.springframework.web.WebApplicationInitializer; + +/** + * Convenient base class for {@link WebApplicationInitializer} implementations + * that register a {@link ContextLoaderListener} in the servlet context. + * + *

The only method required to be implemented by subclasses is + * {@link #createRootApplicationContext()}, which gets invoked from + * {@link #registerContextLoaderListener(ServletContext)}. + * + * @author Arjen Poutsma + * @author Chris Beams + * @author Juergen Hoeller + * @since 3.2 + */ +public abstract class AbstractContextLoaderInitializer implements WebApplicationInitializer { + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + + @Override + public void onStartup(ServletContext servletContext) throws ServletException { + registerContextLoaderListener(servletContext); + } + + /** + * Register a {@link ContextLoaderListener} against the given servlet context. The + * {@code ContextLoaderListener} is initialized with the application context returned + * from the {@link #createRootApplicationContext()} template method. + * @param servletContext the servlet context to register the listener against + */ + protected void registerContextLoaderListener(ServletContext servletContext) { + WebApplicationContext rootAppContext = createRootApplicationContext(); + if (rootAppContext != null) { + ContextLoaderListener listener = new ContextLoaderListener(rootAppContext); + listener.setContextInitializers(getRootApplicationContextInitializers()); + servletContext.addListener(listener); + } + else { + logger.debug("No ContextLoaderListener registered, as " + + "createRootApplicationContext() did not return an application context"); + } + } + + /** + * Create the "root" application context to be provided to the + * {@code ContextLoaderListener}. + *

The returned context is delegated to + * {@link ContextLoaderListener#ContextLoaderListener(WebApplicationContext)} and will + * be established as the parent context for any {@code DispatcherServlet} application + * contexts. As such, it typically contains middle-tier services, data sources, etc. + * @return the root application context, or {@code null} if a root context is not + * desired + * @see org.springframework.web.servlet.support.AbstractDispatcherServletInitializer + */ + @Nullable + protected abstract WebApplicationContext createRootApplicationContext(); + + /** + * Specify application context initializers to be applied to the root application + * context that the {@code ContextLoaderListener} is being created with. + * @since 4.2 + * @see #createRootApplicationContext() + * @see ContextLoaderListener#setContextInitializers + */ + @Nullable + protected ApplicationContextInitializer[] getRootApplicationContextInitializers() { + return null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/ConfigurableWebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/ConfigurableWebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..c08b195fcdb3f0acea6a4b33013f3a206b34e933 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/ConfigurableWebApplicationContext.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.lang.Nullable; + +/** + * Interface to be implemented by configurable web application contexts. + * Supported by {@link ContextLoader} and + * {@link org.springframework.web.servlet.FrameworkServlet}. + * + *

Note: The setters of this interface need to be called before an + * invocation of the {@link #refresh} method inherited from + * {@link org.springframework.context.ConfigurableApplicationContext}. + * They do not cause an initialization of the context on their own. + * + * @author Juergen Hoeller + * @since 05.12.2003 + * @see #refresh + * @see ContextLoader#createWebApplicationContext + * @see org.springframework.web.servlet.FrameworkServlet#createWebApplicationContext + */ +public interface ConfigurableWebApplicationContext extends WebApplicationContext, ConfigurableApplicationContext { + + /** + * Prefix for ApplicationContext ids that refer to context path and/or servlet name. + */ + String APPLICATION_CONTEXT_ID_PREFIX = WebApplicationContext.class.getName() + ":"; + + /** + * Name of the ServletConfig environment bean in the factory. + * @see javax.servlet.ServletConfig + */ + String SERVLET_CONFIG_BEAN_NAME = "servletConfig"; + + + /** + * Set the ServletContext for this web application context. + *

Does not cause an initialization of the context: refresh needs to be + * called after the setting of all configuration properties. + * @see #refresh() + */ + void setServletContext(@Nullable ServletContext servletContext); + + /** + * Set the ServletConfig for this web application context. + * Only called for a WebApplicationContext that belongs to a specific Servlet. + * @see #refresh() + */ + void setServletConfig(@Nullable ServletConfig servletConfig); + + /** + * Return the ServletConfig for this web application context, if any. + */ + @Nullable + ServletConfig getServletConfig(); + + /** + * Set the namespace for this web application context, + * to be used for building a default context config location. + * The root web application context does not have a namespace. + */ + void setNamespace(@Nullable String namespace); + + /** + * Return the namespace for this web application context, if any. + */ + @Nullable + String getNamespace(); + + /** + * Set the config locations for this web application context in init-param style, + * i.e. with distinct locations separated by commas, semicolons or whitespace. + *

If not set, the implementation is supposed to use a default for the + * given namespace or the root web application context, as appropriate. + */ + void setConfigLocation(String configLocation); + + /** + * Set the config locations for this web application context. + *

If not set, the implementation is supposed to use a default for the + * given namespace or the root web application context, as appropriate. + */ + void setConfigLocations(String... configLocations); + + /** + * Return the config locations for this web application context, + * or {@code null} if none specified. + */ + @Nullable + String[] getConfigLocations(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/ConfigurableWebEnvironment.java b/spring-web/src/main/java/org/springframework/web/context/ConfigurableWebEnvironment.java new file mode 100644 index 0000000000000000000000000000000000000000..6d92f54530a2431b17299c61a385305d5cf812e3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/ConfigurableWebEnvironment.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.lang.Nullable; + +/** + * Specialization of {@link ConfigurableEnvironment} allowing initialization of + * servlet-related {@link org.springframework.core.env.PropertySource} objects at the + * earliest moment that the {@link ServletContext} and (optionally) {@link ServletConfig} + * become available. + * + * @author Chris Beams + * @since 3.1.2 + * @see ConfigurableWebApplicationContext#getEnvironment() + */ +public interface ConfigurableWebEnvironment extends ConfigurableEnvironment { + + /** + * Replace any {@linkplain + * org.springframework.core.env.PropertySource.StubPropertySource stub property source} + * instances acting as placeholders with real servlet context/config property sources + * using the given parameters. + * @param servletContext the {@link ServletContext} (may not be {@code null}) + * @param servletConfig the {@link ServletConfig} ({@code null} if not available) + * @see org.springframework.web.context.support.WebApplicationContextUtils#initServletPropertySources( + * org.springframework.core.env.MutablePropertySources, ServletContext, ServletConfig) + */ + void initPropertySources(@Nullable ServletContext servletContext, @Nullable ServletConfig servletConfig); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/ContextCleanupListener.java b/spring-web/src/main/java/org/springframework/web/context/ContextCleanupListener.java new file mode 100644 index 0000000000000000000000000000000000000000..4225e0aba1877c1cb58da9da909a3b1cc241eb69 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/ContextCleanupListener.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import java.util.Enumeration; + +import javax.servlet.ServletContext; +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.DisposableBean; + +/** + * Web application listener that cleans up remaining disposable attributes + * in the ServletContext, i.e. attributes which implement {@link DisposableBean} + * and haven't been removed before. This is typically used for destroying objects + * in "application" scope, for which the lifecycle implies destruction at the + * very end of the web application's shutdown phase. + * + * @author Juergen Hoeller + * @since 3.0 + * @see org.springframework.web.context.support.ServletContextScope + * @see ContextLoaderListener + */ +public class ContextCleanupListener implements ServletContextListener { + + private static final Log logger = LogFactory.getLog(ContextCleanupListener.class); + + + @Override + public void contextInitialized(ServletContextEvent event) { + } + + @Override + public void contextDestroyed(ServletContextEvent event) { + cleanupAttributes(event.getServletContext()); + } + + + /** + * Find all Spring-internal ServletContext attributes which implement + * {@link DisposableBean} and invoke the destroy method on them. + * @param servletContext the ServletContext to check + * @see DisposableBean#destroy() + */ + static void cleanupAttributes(ServletContext servletContext) { + Enumeration attrNames = servletContext.getAttributeNames(); + while (attrNames.hasMoreElements()) { + String attrName = attrNames.nextElement(); + if (attrName.startsWith("org.springframework.")) { + Object attrValue = servletContext.getAttribute(attrName); + if (attrValue instanceof DisposableBean) { + try { + ((DisposableBean) attrValue).destroy(); + } + catch (Throwable ex) { + if (logger.isWarnEnabled()) { + logger.warn("Invocation of destroy method failed on ServletContext " + + "attribute with name '" + attrName + "'", ex); + } + } + } + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/ContextLoader.java b/spring-web/src/main/java/org/springframework/web/context/ContextLoader.java new file mode 100644 index 0000000000000000000000000000000000000000..dfd4b0196caa5dba34d4609a8dbfc9115ee616ce --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/ContextLoader.java @@ -0,0 +1,553 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + +import javax.servlet.ServletContext; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.BeanUtils; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextException; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.core.GenericTypeResolver; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.support.PropertiesLoaderUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * Performs the actual initialization work for the root application context. + * Called by {@link ContextLoaderListener}. + * + *

Looks for a {@link #CONTEXT_CLASS_PARAM "contextClass"} parameter at the + * {@code web.xml} context-param level to specify the context class type, falling + * back to {@link org.springframework.web.context.support.XmlWebApplicationContext} + * if not found. With the default ContextLoader implementation, any context class + * specified needs to implement the {@link ConfigurableWebApplicationContext} interface. + * + *

Processes a {@link #CONFIG_LOCATION_PARAM "contextConfigLocation"} context-param + * and passes its value to the context instance, parsing it into potentially multiple + * file paths which can be separated by any number of commas and spaces, e.g. + * "WEB-INF/applicationContext1.xml, WEB-INF/applicationContext2.xml". + * Ant-style path patterns are supported as well, e.g. + * "WEB-INF/*Context.xml,WEB-INF/spring*.xml" or "WEB-INF/**/*Context.xml". + * If not explicitly specified, the context implementation is supposed to use a + * default location (with XmlWebApplicationContext: "/WEB-INF/applicationContext.xml"). + * + *

Note: In case of multiple config locations, later bean definitions will + * override ones defined in previously loaded files, at least when using one of + * Spring's default ApplicationContext implementations. This can be leveraged + * to deliberately override certain bean definitions via an extra XML file. + * + *

Above and beyond loading the root application context, this class can optionally + * load or obtain and hook up a shared parent context to the root application context. + * See the {@link #loadParentContext(ServletContext)} method for more information. + * + *

As of Spring 3.1, {@code ContextLoader} supports injecting the root web + * application context via the {@link #ContextLoader(WebApplicationContext)} + * constructor, allowing for programmatic configuration in Servlet 3.0+ environments. + * See {@link org.springframework.web.WebApplicationInitializer} for usage examples. + * + * @author Juergen Hoeller + * @author Colin Sampaleanu + * @author Sam Brannen + * @since 17.02.2003 + * @see ContextLoaderListener + * @see ConfigurableWebApplicationContext + * @see org.springframework.web.context.support.XmlWebApplicationContext + */ +public class ContextLoader { + + /** + * Config param for the root WebApplicationContext id, + * to be used as serialization id for the underlying BeanFactory: {@value}. + */ + public static final String CONTEXT_ID_PARAM = "contextId"; + + /** + * Name of servlet context parameter (i.e., {@value}) that can specify the + * config location for the root context, falling back to the implementation's + * default otherwise. + * @see org.springframework.web.context.support.XmlWebApplicationContext#DEFAULT_CONFIG_LOCATION + */ + public static final String CONFIG_LOCATION_PARAM = "contextConfigLocation"; + + /** + * Config param for the root WebApplicationContext implementation class to use: {@value}. + * @see #determineContextClass(ServletContext) + */ + public static final String CONTEXT_CLASS_PARAM = "contextClass"; + + /** + * Config param for {@link ApplicationContextInitializer} classes to use + * for initializing the root web application context: {@value}. + * @see #customizeContext(ServletContext, ConfigurableWebApplicationContext) + */ + public static final String CONTEXT_INITIALIZER_CLASSES_PARAM = "contextInitializerClasses"; + + /** + * Config param for global {@link ApplicationContextInitializer} classes to use + * for initializing all web application contexts in the current application: {@value}. + * @see #customizeContext(ServletContext, ConfigurableWebApplicationContext) + */ + public static final String GLOBAL_INITIALIZER_CLASSES_PARAM = "globalInitializerClasses"; + + /** + * Any number of these characters are considered delimiters between + * multiple values in a single init-param String value. + */ + private static final String INIT_PARAM_DELIMITERS = ",; \t\n"; + + /** + * Name of the class path resource (relative to the ContextLoader class) + * that defines ContextLoader's default strategy names. + */ + private static final String DEFAULT_STRATEGIES_PATH = "ContextLoader.properties"; + + + private static final Properties defaultStrategies; + + static { + // Load default strategy implementations from properties file. + // This is currently strictly internal and not meant to be customized + // by application developers. + try { + ClassPathResource resource = new ClassPathResource(DEFAULT_STRATEGIES_PATH, ContextLoader.class); + defaultStrategies = PropertiesLoaderUtils.loadProperties(resource); + } + catch (IOException ex) { + throw new IllegalStateException("Could not load 'ContextLoader.properties': " + ex.getMessage()); + } + } + + + /** + * Map from (thread context) ClassLoader to corresponding 'current' WebApplicationContext. + */ + private static final Map currentContextPerThread = + new ConcurrentHashMap<>(1); + + /** + * The 'current' WebApplicationContext, if the ContextLoader class is + * deployed in the web app ClassLoader itself. + */ + @Nullable + private static volatile WebApplicationContext currentContext; + + + /** + * The root WebApplicationContext instance that this loader manages. + */ + @Nullable + private WebApplicationContext context; + + /** Actual ApplicationContextInitializer instances to apply to the context. */ + private final List> contextInitializers = + new ArrayList<>(); + + + /** + * Create a new {@code ContextLoader} that will create a web application context + * based on the "contextClass" and "contextConfigLocation" servlet context-params. + * See class-level documentation for details on default values for each. + *

This constructor is typically used when declaring the {@code + * ContextLoaderListener} subclass as a {@code } within {@code web.xml}, as + * a no-arg constructor is required. + *

The created application context will be registered into the ServletContext under + * the attribute name {@link WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE} + * and subclasses are free to call the {@link #closeWebApplicationContext} method on + * container shutdown to close the application context. + * @see #ContextLoader(WebApplicationContext) + * @see #initWebApplicationContext(ServletContext) + * @see #closeWebApplicationContext(ServletContext) + */ + public ContextLoader() { + } + + /** + * Create a new {@code ContextLoader} with the given application context. This + * constructor is useful in Servlet 3.0+ environments where instance-based + * registration of listeners is possible through the {@link ServletContext#addListener} + * API. + *

The context may or may not yet be {@linkplain + * ConfigurableApplicationContext#refresh() refreshed}. If it (a) is an implementation + * of {@link ConfigurableWebApplicationContext} and (b) has not + * already been refreshed (the recommended approach), then the following will occur: + *

    + *
  • If the given context has not already been assigned an {@linkplain + * ConfigurableApplicationContext#setId id}, one will be assigned to it
  • + *
  • {@code ServletContext} and {@code ServletConfig} objects will be delegated to + * the application context
  • + *
  • {@link #customizeContext} will be called
  • + *
  • Any {@link ApplicationContextInitializer ApplicationContextInitializers} specified through the + * "contextInitializerClasses" init-param will be applied.
  • + *
  • {@link ConfigurableApplicationContext#refresh refresh()} will be called
  • + *
+ * If the context has already been refreshed or does not implement + * {@code ConfigurableWebApplicationContext}, none of the above will occur under the + * assumption that the user has performed these actions (or not) per his or her + * specific needs. + *

See {@link org.springframework.web.WebApplicationInitializer} for usage examples. + *

In any case, the given application context will be registered into the + * ServletContext under the attribute name {@link + * WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE} and subclasses are + * free to call the {@link #closeWebApplicationContext} method on container shutdown + * to close the application context. + * @param context the application context to manage + * @see #initWebApplicationContext(ServletContext) + * @see #closeWebApplicationContext(ServletContext) + */ + public ContextLoader(WebApplicationContext context) { + this.context = context; + } + + + /** + * Specify which {@link ApplicationContextInitializer} instances should be used + * to initialize the application context used by this {@code ContextLoader}. + * @since 4.2 + * @see #configureAndRefreshWebApplicationContext + * @see #customizeContext + */ + @SuppressWarnings("unchecked") + public void setContextInitializers(@Nullable ApplicationContextInitializer... initializers) { + if (initializers != null) { + for (ApplicationContextInitializer initializer : initializers) { + this.contextInitializers.add((ApplicationContextInitializer) initializer); + } + } + } + + + /** + * Initialize Spring's web application context for the given servlet context, + * using the application context provided at construction time, or creating a new one + * according to the "{@link #CONTEXT_CLASS_PARAM contextClass}" and + * "{@link #CONFIG_LOCATION_PARAM contextConfigLocation}" context-params. + * @param servletContext current servlet context + * @return the new WebApplicationContext + * @see #ContextLoader(WebApplicationContext) + * @see #CONTEXT_CLASS_PARAM + * @see #CONFIG_LOCATION_PARAM + */ + public WebApplicationContext initWebApplicationContext(ServletContext servletContext) { + if (servletContext.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE) != null) { + throw new IllegalStateException( + "Cannot initialize context because there is already a root application context present - " + + "check whether you have multiple ContextLoader* definitions in your web.xml!"); + } + + servletContext.log("Initializing Spring root WebApplicationContext"); + Log logger = LogFactory.getLog(ContextLoader.class); + if (logger.isInfoEnabled()) { + logger.info("Root WebApplicationContext: initialization started"); + } + long startTime = System.currentTimeMillis(); + + try { + // Store context in local instance variable, to guarantee that + // it is available on ServletContext shutdown. + if (this.context == null) { + this.context = createWebApplicationContext(servletContext); + } + if (this.context instanceof ConfigurableWebApplicationContext) { + ConfigurableWebApplicationContext cwac = (ConfigurableWebApplicationContext) this.context; + if (!cwac.isActive()) { + // The context has not yet been refreshed -> provide services such as + // setting the parent context, setting the application context id, etc + if (cwac.getParent() == null) { + // The context instance was injected without an explicit parent -> + // determine parent for root web application context, if any. + ApplicationContext parent = loadParentContext(servletContext); + cwac.setParent(parent); + } + configureAndRefreshWebApplicationContext(cwac, servletContext); + } + } + servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, this.context); + + ClassLoader ccl = Thread.currentThread().getContextClassLoader(); + if (ccl == ContextLoader.class.getClassLoader()) { + currentContext = this.context; + } + else if (ccl != null) { + currentContextPerThread.put(ccl, this.context); + } + + if (logger.isInfoEnabled()) { + long elapsedTime = System.currentTimeMillis() - startTime; + logger.info("Root WebApplicationContext initialized in " + elapsedTime + " ms"); + } + + return this.context; + } + catch (RuntimeException | Error ex) { + logger.error("Context initialization failed", ex); + servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ex); + throw ex; + } + } + + /** + * Instantiate the root WebApplicationContext for this loader, either the + * default context class or a custom context class if specified. + *

This implementation expects custom contexts to implement the + * {@link ConfigurableWebApplicationContext} interface. + * Can be overridden in subclasses. + *

In addition, {@link #customizeContext} gets called prior to refreshing the + * context, allowing subclasses to perform custom modifications to the context. + * @param sc current servlet context + * @return the root WebApplicationContext + * @see ConfigurableWebApplicationContext + */ + protected WebApplicationContext createWebApplicationContext(ServletContext sc) { + Class contextClass = determineContextClass(sc); + if (!ConfigurableWebApplicationContext.class.isAssignableFrom(contextClass)) { + throw new ApplicationContextException("Custom context class [" + contextClass.getName() + + "] is not of type [" + ConfigurableWebApplicationContext.class.getName() + "]"); + } + return (ConfigurableWebApplicationContext) BeanUtils.instantiateClass(contextClass); + } + + /** + * Return the WebApplicationContext implementation class to use, either the + * default XmlWebApplicationContext or a custom context class if specified. + * @param servletContext current servlet context + * @return the WebApplicationContext implementation class to use + * @see #CONTEXT_CLASS_PARAM + * @see org.springframework.web.context.support.XmlWebApplicationContext + */ + protected Class determineContextClass(ServletContext servletContext) { + String contextClassName = servletContext.getInitParameter(CONTEXT_CLASS_PARAM); + if (contextClassName != null) { + try { + return ClassUtils.forName(contextClassName, ClassUtils.getDefaultClassLoader()); + } + catch (ClassNotFoundException ex) { + throw new ApplicationContextException( + "Failed to load custom context class [" + contextClassName + "]", ex); + } + } + else { + contextClassName = defaultStrategies.getProperty(WebApplicationContext.class.getName()); + try { + return ClassUtils.forName(contextClassName, ContextLoader.class.getClassLoader()); + } + catch (ClassNotFoundException ex) { + throw new ApplicationContextException( + "Failed to load default context class [" + contextClassName + "]", ex); + } + } + } + + protected void configureAndRefreshWebApplicationContext(ConfigurableWebApplicationContext wac, ServletContext sc) { + if (ObjectUtils.identityToString(wac).equals(wac.getId())) { + // The application context id is still set to its original default value + // -> assign a more useful id based on available information + String idParam = sc.getInitParameter(CONTEXT_ID_PARAM); + if (idParam != null) { + wac.setId(idParam); + } + else { + // Generate default id... + wac.setId(ConfigurableWebApplicationContext.APPLICATION_CONTEXT_ID_PREFIX + + ObjectUtils.getDisplayString(sc.getContextPath())); + } + } + + wac.setServletContext(sc); + String configLocationParam = sc.getInitParameter(CONFIG_LOCATION_PARAM); + if (configLocationParam != null) { + wac.setConfigLocation(configLocationParam); + } + + // The wac environment's #initPropertySources will be called in any case when the context + // is refreshed; do it eagerly here to ensure servlet property sources are in place for + // use in any post-processing or initialization that occurs below prior to #refresh + ConfigurableEnvironment env = wac.getEnvironment(); + if (env instanceof ConfigurableWebEnvironment) { + ((ConfigurableWebEnvironment) env).initPropertySources(sc, null); + } + + customizeContext(sc, wac); + wac.refresh(); + } + + /** + * Customize the {@link ConfigurableWebApplicationContext} created by this + * ContextLoader after config locations have been supplied to the context + * but before the context is refreshed. + *

The default implementation {@linkplain #determineContextInitializerClasses(ServletContext) + * determines} what (if any) context initializer classes have been specified through + * {@linkplain #CONTEXT_INITIALIZER_CLASSES_PARAM context init parameters} and + * {@linkplain ApplicationContextInitializer#initialize invokes each} with the + * given web application context. + *

Any {@code ApplicationContextInitializers} implementing + * {@link org.springframework.core.Ordered Ordered} or marked with @{@link + * org.springframework.core.annotation.Order Order} will be sorted appropriately. + * @param sc the current servlet context + * @param wac the newly created application context + * @see #CONTEXT_INITIALIZER_CLASSES_PARAM + * @see ApplicationContextInitializer#initialize(ConfigurableApplicationContext) + */ + protected void customizeContext(ServletContext sc, ConfigurableWebApplicationContext wac) { + List>> initializerClasses = + determineContextInitializerClasses(sc); + + for (Class> initializerClass : initializerClasses) { + Class initializerContextClass = + GenericTypeResolver.resolveTypeArgument(initializerClass, ApplicationContextInitializer.class); + if (initializerContextClass != null && !initializerContextClass.isInstance(wac)) { + throw new ApplicationContextException(String.format( + "Could not apply context initializer [%s] since its generic parameter [%s] " + + "is not assignable from the type of application context used by this " + + "context loader: [%s]", initializerClass.getName(), initializerContextClass.getName(), + wac.getClass().getName())); + } + this.contextInitializers.add(BeanUtils.instantiateClass(initializerClass)); + } + + AnnotationAwareOrderComparator.sort(this.contextInitializers); + for (ApplicationContextInitializer initializer : this.contextInitializers) { + initializer.initialize(wac); + } + } + + /** + * Return the {@link ApplicationContextInitializer} implementation classes to use + * if any have been specified by {@link #CONTEXT_INITIALIZER_CLASSES_PARAM}. + * @param servletContext current servlet context + * @see #CONTEXT_INITIALIZER_CLASSES_PARAM + */ + protected List>> + determineContextInitializerClasses(ServletContext servletContext) { + + List>> classes = + new ArrayList<>(); + + String globalClassNames = servletContext.getInitParameter(GLOBAL_INITIALIZER_CLASSES_PARAM); + if (globalClassNames != null) { + for (String className : StringUtils.tokenizeToStringArray(globalClassNames, INIT_PARAM_DELIMITERS)) { + classes.add(loadInitializerClass(className)); + } + } + + String localClassNames = servletContext.getInitParameter(CONTEXT_INITIALIZER_CLASSES_PARAM); + if (localClassNames != null) { + for (String className : StringUtils.tokenizeToStringArray(localClassNames, INIT_PARAM_DELIMITERS)) { + classes.add(loadInitializerClass(className)); + } + } + + return classes; + } + + @SuppressWarnings("unchecked") + private Class> loadInitializerClass(String className) { + try { + Class clazz = ClassUtils.forName(className, ClassUtils.getDefaultClassLoader()); + if (!ApplicationContextInitializer.class.isAssignableFrom(clazz)) { + throw new ApplicationContextException( + "Initializer class does not implement ApplicationContextInitializer interface: " + clazz); + } + return (Class>) clazz; + } + catch (ClassNotFoundException ex) { + throw new ApplicationContextException("Failed to load context initializer class [" + className + "]", ex); + } + } + + /** + * Template method with default implementation (which may be overridden by a + * subclass), to load or obtain an ApplicationContext instance which will be + * used as the parent context of the root WebApplicationContext. If the + * return value from the method is null, no parent context is set. + *

The main reason to load a parent context here is to allow multiple root + * web application contexts to all be children of a shared EAR context, or + * alternately to also share the same parent context that is visible to + * EJBs. For pure web applications, there is usually no need to worry about + * having a parent context to the root web application context. + *

The default implementation simply returns {@code null}, as of 5.0. + * @param servletContext current servlet context + * @return the parent application context, or {@code null} if none + */ + @Nullable + protected ApplicationContext loadParentContext(ServletContext servletContext) { + return null; + } + + /** + * Close Spring's web application context for the given servlet context. + *

If overriding {@link #loadParentContext(ServletContext)}, you may have + * to override this method as well. + * @param servletContext the ServletContext that the WebApplicationContext runs in + */ + public void closeWebApplicationContext(ServletContext servletContext) { + servletContext.log("Closing Spring root WebApplicationContext"); + try { + if (this.context instanceof ConfigurableWebApplicationContext) { + ((ConfigurableWebApplicationContext) this.context).close(); + } + } + finally { + ClassLoader ccl = Thread.currentThread().getContextClassLoader(); + if (ccl == ContextLoader.class.getClassLoader()) { + currentContext = null; + } + else if (ccl != null) { + currentContextPerThread.remove(ccl); + } + servletContext.removeAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + } + } + + + /** + * Obtain the Spring root web application context for the current thread + * (i.e. for the current thread's context ClassLoader, which needs to be + * the web application's ClassLoader). + * @return the current root web application context, or {@code null} + * if none found + * @see org.springframework.web.context.support.SpringBeanAutowiringSupport + */ + @Nullable + public static WebApplicationContext getCurrentWebApplicationContext() { + ClassLoader ccl = Thread.currentThread().getContextClassLoader(); + if (ccl != null) { + WebApplicationContext ccpt = currentContextPerThread.get(ccl); + if (ccpt != null) { + return ccpt; + } + } + return currentContext; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/ContextLoaderListener.java b/spring-web/src/main/java/org/springframework/web/context/ContextLoaderListener.java new file mode 100644 index 0000000000000000000000000000000000000000..3c4d40722258b5715b6f20110b3c17e367e3e444 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/ContextLoaderListener.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; + +/** + * Bootstrap listener to start up and shut down Spring's root {@link WebApplicationContext}. + * Simply delegates to {@link ContextLoader} as well as to {@link ContextCleanupListener}. + * + *

As of Spring 3.1, {@code ContextLoaderListener} supports injecting the root web + * application context via the {@link #ContextLoaderListener(WebApplicationContext)} + * constructor, allowing for programmatic configuration in Servlet 3.0+ environments. + * See {@link org.springframework.web.WebApplicationInitializer} for usage examples. + * + * @author Juergen Hoeller + * @author Chris Beams + * @since 17.02.2003 + * @see #setContextInitializers + * @see org.springframework.web.WebApplicationInitializer + */ +public class ContextLoaderListener extends ContextLoader implements ServletContextListener { + + /** + * Create a new {@code ContextLoaderListener} that will create a web application + * context based on the "contextClass" and "contextConfigLocation" servlet + * context-params. See {@link ContextLoader} superclass documentation for details on + * default values for each. + *

This constructor is typically used when declaring {@code ContextLoaderListener} + * as a {@code } within {@code web.xml}, where a no-arg constructor is + * required. + *

The created application context will be registered into the ServletContext under + * the attribute name {@link WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE} + * and the Spring application context will be closed when the {@link #contextDestroyed} + * lifecycle method is invoked on this listener. + * @see ContextLoader + * @see #ContextLoaderListener(WebApplicationContext) + * @see #contextInitialized(ServletContextEvent) + * @see #contextDestroyed(ServletContextEvent) + */ + public ContextLoaderListener() { + } + + /** + * Create a new {@code ContextLoaderListener} with the given application context. This + * constructor is useful in Servlet 3.0+ environments where instance-based + * registration of listeners is possible through the {@link javax.servlet.ServletContext#addListener} + * API. + *

The context may or may not yet be {@linkplain + * org.springframework.context.ConfigurableApplicationContext#refresh() refreshed}. If it + * (a) is an implementation of {@link ConfigurableWebApplicationContext} and + * (b) has not already been refreshed (the recommended approach), + * then the following will occur: + *

    + *
  • If the given context has not already been assigned an {@linkplain + * org.springframework.context.ConfigurableApplicationContext#setId id}, one will be assigned to it
  • + *
  • {@code ServletContext} and {@code ServletConfig} objects will be delegated to + * the application context
  • + *
  • {@link #customizeContext} will be called
  • + *
  • Any {@link org.springframework.context.ApplicationContextInitializer ApplicationContextInitializer org.springframework.context.ApplicationContextInitializer ApplicationContextInitializers} + * specified through the "contextInitializerClasses" init-param will be applied.
  • + *
  • {@link org.springframework.context.ConfigurableApplicationContext#refresh refresh()} will be called
  • + *
+ * If the context has already been refreshed or does not implement + * {@code ConfigurableWebApplicationContext}, none of the above will occur under the + * assumption that the user has performed these actions (or not) per his or her + * specific needs. + *

See {@link org.springframework.web.WebApplicationInitializer} for usage examples. + *

In any case, the given application context will be registered into the + * ServletContext under the attribute name {@link + * WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE} and the Spring + * application context will be closed when the {@link #contextDestroyed} lifecycle + * method is invoked on this listener. + * @param context the application context to manage + * @see #contextInitialized(ServletContextEvent) + * @see #contextDestroyed(ServletContextEvent) + */ + public ContextLoaderListener(WebApplicationContext context) { + super(context); + } + + + /** + * Initialize the root web application context. + */ + @Override + public void contextInitialized(ServletContextEvent event) { + initWebApplicationContext(event.getServletContext()); + } + + + /** + * Close the root web application context. + */ + @Override + public void contextDestroyed(ServletContextEvent event) { + closeWebApplicationContext(event.getServletContext()); + ContextCleanupListener.cleanupAttributes(event.getServletContext()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/ServletConfigAware.java b/spring-web/src/main/java/org/springframework/web/context/ServletConfigAware.java new file mode 100644 index 0000000000000000000000000000000000000000..aaa5a903909696e93d2321894001e81b14b588f9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/ServletConfigAware.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import javax.servlet.ServletConfig; + +import org.springframework.beans.factory.Aware; + +/** + * Interface to be implemented by any object that wishes to be notified of the + * {@link ServletConfig} (typically determined by the {@link WebApplicationContext}) + * that it runs in. + * + *

Note: Only satisfied if actually running within a Servlet-specific + * WebApplicationContext. Otherwise, no ServletConfig will be set. + * + * @author Juergen Hoeller + * @author Chris Beams + * @since 2.0 + * @see ServletContextAware + */ +public interface ServletConfigAware extends Aware { + + /** + * Set the {@link ServletConfig} that this object runs in. + *

Invoked after population of normal bean properties but before an init + * callback like InitializingBean's {@code afterPropertiesSet} or a + * custom init-method. Invoked after ApplicationContextAware's + * {@code setApplicationContext}. + * @param servletConfig the {@link ServletConfig} to be used by this object + * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet + * @see org.springframework.context.ApplicationContextAware#setApplicationContext + */ + void setServletConfig(ServletConfig servletConfig); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/ServletContextAware.java b/spring-web/src/main/java/org/springframework/web/context/ServletContextAware.java new file mode 100644 index 0000000000000000000000000000000000000000..709cce1c6ace0c8727536efd40dd4020b1910edc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/ServletContextAware.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.Aware; + +/** + * Interface to be implemented by any object that wishes to be notified of the + * {@link ServletContext} (typically determined by the {@link WebApplicationContext}) + * that it runs in. + * + * @author Juergen Hoeller + * @author Chris Beams + * @since 12.03.2004 + * @see ServletConfigAware + */ +public interface ServletContextAware extends Aware { + + /** + * Set the {@link ServletContext} that this object runs in. + *

Invoked after population of normal bean properties but before an init + * callback like InitializingBean's {@code afterPropertiesSet} or a + * custom init-method. Invoked after ApplicationContextAware's + * {@code setApplicationContext}. + * @param servletContext the ServletContext object to be used by this object + * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet + * @see org.springframework.context.ApplicationContextAware#setApplicationContext + */ + void setServletContext(ServletContext servletContext); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/WebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/WebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..b2c15d545c9adea765301208ce9c3a34a703d4b2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/WebApplicationContext.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import javax.servlet.ServletContext; + +import org.springframework.context.ApplicationContext; +import org.springframework.lang.Nullable; + +/** + * Interface to provide configuration for a web application. This is read-only while + * the application is running, but may be reloaded if the implementation supports this. + * + *

This interface adds a {@code getServletContext()} method to the generic + * ApplicationContext interface, and defines a well-known application attribute name + * that the root context must be bound to in the bootstrap process. + * + *

Like generic application contexts, web application contexts are hierarchical. + * There is a single root context per application, while each servlet in the application + * (including a dispatcher servlet in the MVC framework) has its own child context. + * + *

In addition to standard application context lifecycle capabilities, + * WebApplicationContext implementations need to detect {@link ServletContextAware} + * beans and invoke the {@code setServletContext} method accordingly. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since January 19, 2001 + * @see ServletContextAware#setServletContext + */ +public interface WebApplicationContext extends ApplicationContext { + + /** + * Context attribute to bind root WebApplicationContext to on successful startup. + *

Note: If the startup of the root context fails, this attribute can contain + * an exception or error as value. Use WebApplicationContextUtils for convenient + * lookup of the root WebApplicationContext. + * @see org.springframework.web.context.support.WebApplicationContextUtils#getWebApplicationContext + * @see org.springframework.web.context.support.WebApplicationContextUtils#getRequiredWebApplicationContext + */ + String ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE = WebApplicationContext.class.getName() + ".ROOT"; + + /** + * Scope identifier for request scope: "request". + * Supported in addition to the standard scopes "singleton" and "prototype". + */ + String SCOPE_REQUEST = "request"; + + /** + * Scope identifier for session scope: "session". + * Supported in addition to the standard scopes "singleton" and "prototype". + */ + String SCOPE_SESSION = "session"; + + /** + * Scope identifier for the global web application scope: "application". + * Supported in addition to the standard scopes "singleton" and "prototype". + */ + String SCOPE_APPLICATION = "application"; + + /** + * Name of the ServletContext environment bean in the factory. + * @see javax.servlet.ServletContext + */ + String SERVLET_CONTEXT_BEAN_NAME = "servletContext"; + + /** + * Name of the ServletContext init-params environment bean in the factory. + *

Note: Possibly merged with ServletConfig parameters. + * ServletConfig parameters override ServletContext parameters of the same name. + * @see javax.servlet.ServletContext#getInitParameterNames() + * @see javax.servlet.ServletContext#getInitParameter(String) + * @see javax.servlet.ServletConfig#getInitParameterNames() + * @see javax.servlet.ServletConfig#getInitParameter(String) + */ + String CONTEXT_PARAMETERS_BEAN_NAME = "contextParameters"; + + /** + * Name of the ServletContext attributes environment bean in the factory. + * @see javax.servlet.ServletContext#getAttributeNames() + * @see javax.servlet.ServletContext#getAttribute(String) + */ + String CONTEXT_ATTRIBUTES_BEAN_NAME = "contextAttributes"; + + + /** + * Return the standard Servlet API ServletContext for this application. + */ + @Nullable + ServletContext getServletContext(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/annotation/ApplicationScope.java b/spring-web/src/main/java/org/springframework/web/context/annotation/ApplicationScope.java new file mode 100644 index 0000000000000000000000000000000000000000..2abe61cbb00997fa8efc698e77969747ef1932ca --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/annotation/ApplicationScope.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.annotation.Scope; +import org.springframework.context.annotation.ScopedProxyMode; +import org.springframework.core.annotation.AliasFor; +import org.springframework.web.context.WebApplicationContext; + +/** + * {@code @ApplicationScope} is a specialization of {@link Scope @Scope} for a + * component whose lifecycle is bound to the current web application. + * + *

Specifically, {@code @ApplicationScope} is a composed annotation that + * acts as a shortcut for {@code @Scope("application")} with the default + * {@link #proxyMode} set to {@link ScopedProxyMode#TARGET_CLASS TARGET_CLASS}. + * + *

{@code @ApplicationScope} may be used as a meta-annotation to create custom + * composed annotations. + * + * @author Sam Brannen + * @since 4.3 + * @see RequestScope + * @see SessionScope + * @see org.springframework.context.annotation.Scope + * @see org.springframework.web.context.WebApplicationContext#SCOPE_APPLICATION + * @see org.springframework.web.context.support.ServletContextScope + * @see org.springframework.stereotype.Component + * @see org.springframework.context.annotation.Bean + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Scope(WebApplicationContext.SCOPE_APPLICATION) +public @interface ApplicationScope { + + /** + * Alias for {@link Scope#proxyMode}. + *

Defaults to {@link ScopedProxyMode#TARGET_CLASS}. + */ + @AliasFor(annotation = Scope.class) + ScopedProxyMode proxyMode() default ScopedProxyMode.TARGET_CLASS; + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/annotation/RequestScope.java b/spring-web/src/main/java/org/springframework/web/context/annotation/RequestScope.java new file mode 100644 index 0000000000000000000000000000000000000000..fe00544094b769a3156496567fb5c348e19fcc93 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/annotation/RequestScope.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.annotation.Scope; +import org.springframework.context.annotation.ScopedProxyMode; +import org.springframework.core.annotation.AliasFor; +import org.springframework.web.context.WebApplicationContext; + +/** + * {@code @RequestScope} is a specialization of {@link Scope @Scope} for a + * component whose lifecycle is bound to the current web request. + * + *

Specifically, {@code @RequestScope} is a composed annotation that + * acts as a shortcut for {@code @Scope("request")} with the default + * {@link #proxyMode} set to {@link ScopedProxyMode#TARGET_CLASS TARGET_CLASS}. + * + *

{@code @RequestScope} may be used as a meta-annotation to create custom + * composed annotations. + * + * @author Sam Brannen + * @since 4.3 + * @see SessionScope + * @see ApplicationScope + * @see org.springframework.context.annotation.Scope + * @see org.springframework.web.context.WebApplicationContext#SCOPE_REQUEST + * @see org.springframework.web.context.request.RequestScope + * @see org.springframework.stereotype.Component + * @see org.springframework.context.annotation.Bean + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Scope(WebApplicationContext.SCOPE_REQUEST) +public @interface RequestScope { + + /** + * Alias for {@link Scope#proxyMode}. + *

Defaults to {@link ScopedProxyMode#TARGET_CLASS}. + */ + @AliasFor(annotation = Scope.class) + ScopedProxyMode proxyMode() default ScopedProxyMode.TARGET_CLASS; + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/annotation/SessionScope.java b/spring-web/src/main/java/org/springframework/web/context/annotation/SessionScope.java new file mode 100644 index 0000000000000000000000000000000000000000..2b718835ee2b4847c7df2bab11a8c6b8e1320c26 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/annotation/SessionScope.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.annotation.Scope; +import org.springframework.context.annotation.ScopedProxyMode; +import org.springframework.core.annotation.AliasFor; +import org.springframework.web.context.WebApplicationContext; + +/** + * {@code @SessionScope} is a specialization of {@link Scope @Scope} for a + * component whose lifecycle is bound to the current web session. + * + *

Specifically, {@code @SessionScope} is a composed annotation that + * acts as a shortcut for {@code @Scope("session")} with the default + * {@link #proxyMode} set to {@link ScopedProxyMode#TARGET_CLASS TARGET_CLASS}. + * + *

{@code @SessionScope} may be used as a meta-annotation to create custom + * composed annotations. + * + * @author Sam Brannen + * @since 4.3 + * @see RequestScope + * @see ApplicationScope + * @see org.springframework.context.annotation.Scope + * @see org.springframework.web.context.WebApplicationContext#SCOPE_SESSION + * @see org.springframework.web.context.request.SessionScope + * @see org.springframework.stereotype.Component + * @see org.springframework.context.annotation.Bean + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Scope(WebApplicationContext.SCOPE_SESSION) +public @interface SessionScope { + + /** + * Alias for {@link Scope#proxyMode}. + *

Defaults to {@link ScopedProxyMode#TARGET_CLASS}. + */ + @AliasFor(annotation = Scope.class) + ScopedProxyMode proxyMode() default ScopedProxyMode.TARGET_CLASS; + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/annotation/package-info.java b/spring-web/src/main/java/org/springframework/web/context/annotation/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..a339a8b0e5fffec35b0ba68928c4890295c7e606 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/annotation/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides convenience annotations for web scopes. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.context.annotation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/context/package-info.java b/spring-web/src/main/java/org/springframework/web/context/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..01aa3900014ad3b8a0eac96e67502b76bd9936e1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/package-info.java @@ -0,0 +1,10 @@ +/** + * Contains a variant of the application context interface for web applications, + * and the ContextLoaderListener that bootstraps a root web application context. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.context; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/context/request/AbstractRequestAttributes.java b/spring-web/src/main/java/org/springframework/web/context/request/AbstractRequestAttributes.java new file mode 100644 index 0000000000000000000000000000000000000000..f939c3129b0100204dce494860eaa25133309e9f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/AbstractRequestAttributes.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.util.Assert; + +/** + * Abstract support class for RequestAttributes implementations, + * offering a request completion mechanism for request-specific destruction + * callbacks and for updating accessed session attributes. + * + * @author Juergen Hoeller + * @since 2.0 + * @see #requestCompleted() + */ +public abstract class AbstractRequestAttributes implements RequestAttributes { + + /** Map from attribute name String to destruction callback Runnable. */ + protected final Map requestDestructionCallbacks = new LinkedHashMap<>(8); + + private volatile boolean requestActive = true; + + + /** + * Signal that the request has been completed. + *

Executes all request destruction callbacks and updates the + * session attributes that have been accessed during request processing. + */ + public void requestCompleted() { + executeRequestDestructionCallbacks(); + updateAccessedSessionAttributes(); + this.requestActive = false; + } + + /** + * Determine whether the original request is still active. + * @see #requestCompleted() + */ + protected final boolean isRequestActive() { + return this.requestActive; + } + + /** + * Register the given callback as to be executed after request completion. + * @param name the name of the attribute to register the callback for + * @param callback the callback to be executed for destruction + */ + protected final void registerRequestDestructionCallback(String name, Runnable callback) { + Assert.notNull(name, "Name must not be null"); + Assert.notNull(callback, "Callback must not be null"); + synchronized (this.requestDestructionCallbacks) { + this.requestDestructionCallbacks.put(name, callback); + } + } + + /** + * Remove the request destruction callback for the specified attribute, if any. + * @param name the name of the attribute to remove the callback for + */ + protected final void removeRequestDestructionCallback(String name) { + Assert.notNull(name, "Name must not be null"); + synchronized (this.requestDestructionCallbacks) { + this.requestDestructionCallbacks.remove(name); + } + } + + /** + * Execute all callbacks that have been registered for execution + * after request completion. + */ + private void executeRequestDestructionCallbacks() { + synchronized (this.requestDestructionCallbacks) { + for (Runnable runnable : this.requestDestructionCallbacks.values()) { + runnable.run(); + } + this.requestDestructionCallbacks.clear(); + } + } + + /** + * Update all session attributes that have been accessed during request processing, + * to expose their potentially updated state to the underlying session manager. + */ + protected abstract void updateAccessedSessionAttributes(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/AbstractRequestAttributesScope.java b/spring-web/src/main/java/org/springframework/web/context/request/AbstractRequestAttributesScope.java new file mode 100644 index 0000000000000000000000000000000000000000..26813bd10353fa50fe2a2cfd18f53760b279ca7c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/AbstractRequestAttributesScope.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.beans.factory.config.Scope; +import org.springframework.lang.Nullable; + +/** + * Abstract {@link Scope} implementation that reads from a particular scope + * in the current thread-bound {@link RequestAttributes} object. + * + *

Subclasses simply need to implement {@link #getScope()} to instruct + * this class which {@link RequestAttributes} scope to read attributes from. + * + *

Subclasses may wish to override the {@link #get} and {@link #remove} + * methods to add synchronization around the call back into this super class. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Rob Harrop + * @since 2.0 + */ +public abstract class AbstractRequestAttributesScope implements Scope { + + @Override + public Object get(String name, ObjectFactory objectFactory) { + RequestAttributes attributes = RequestContextHolder.currentRequestAttributes(); + Object scopedObject = attributes.getAttribute(name, getScope()); + if (scopedObject == null) { + scopedObject = objectFactory.getObject(); + attributes.setAttribute(name, scopedObject, getScope()); + // Retrieve object again, registering it for implicit session attribute updates. + // As a bonus, we also allow for potential decoration at the getAttribute level. + Object retrievedObject = attributes.getAttribute(name, getScope()); + if (retrievedObject != null) { + // Only proceed with retrieved object if still present (the expected case). + // If it disappeared concurrently, we return our locally created instance. + scopedObject = retrievedObject; + } + } + return scopedObject; + } + + @Override + @Nullable + public Object remove(String name) { + RequestAttributes attributes = RequestContextHolder.currentRequestAttributes(); + Object scopedObject = attributes.getAttribute(name, getScope()); + if (scopedObject != null) { + attributes.removeAttribute(name, getScope()); + return scopedObject; + } + else { + return null; + } + } + + @Override + public void registerDestructionCallback(String name, Runnable callback) { + RequestAttributes attributes = RequestContextHolder.currentRequestAttributes(); + attributes.registerDestructionCallback(name, callback, getScope()); + } + + @Override + @Nullable + public Object resolveContextualObject(String key) { + RequestAttributes attributes = RequestContextHolder.currentRequestAttributes(); + return attributes.resolveReference(key); + } + + + /** + * Template method that determines the actual target scope. + * @return the target scope, in the form of an appropriate + * {@link RequestAttributes} constant + * @see RequestAttributes#SCOPE_REQUEST + * @see RequestAttributes#SCOPE_SESSION + */ + protected abstract int getScope(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/AsyncWebRequestInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/AsyncWebRequestInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..8e02e8650134171e9177bbd20c55dce256470626 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/AsyncWebRequestInterceptor.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.context.request; + +/** + * Extends {@code WebRequestInterceptor} with a callback method invoked during + * asynchronous request handling. + * + *

When a handler starts asynchronous request handling, the DispatcherServlet + * exits without invoking {@code postHandle} and {@code afterCompletion}, as it + * normally does, since the results of request handling (e.g. ModelAndView) are + * not available in the current thread and handling is not yet complete. + * In such scenarios, the {@link #afterConcurrentHandlingStarted(WebRequest)} + * method is invoked instead allowing implementations to perform tasks such as + * cleaning up thread bound attributes. + * + *

When asynchronous handling completes, the request is dispatched to the + * container for further processing. At this stage the DispatcherServlet invokes + * {@code preHandle}, {@code postHandle} and {@code afterCompletion} as usual. + * + * @author Rossen Stoyanchev + * @since 3.2 + * + * @see org.springframework.web.context.request.async.WebAsyncManager + */ +public interface AsyncWebRequestInterceptor extends WebRequestInterceptor{ + + /** + * Called instead of {@code postHandle} and {@code afterCompletion}, when the + * handler started handling the request concurrently. + * + * @param request the current request + */ + void afterConcurrentHandlingStarted(WebRequest request); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/DestructionCallbackBindingListener.java b/spring-web/src/main/java/org/springframework/web/context/request/DestructionCallbackBindingListener.java new file mode 100644 index 0000000000000000000000000000000000000000..69e280bd7d9a378affdd8d11885a996911a0ee23 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/DestructionCallbackBindingListener.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.io.Serializable; + +import javax.servlet.http.HttpSessionBindingEvent; +import javax.servlet.http.HttpSessionBindingListener; + +/** + * Adapter that implements the Servlet HttpSessionBindingListener interface, + * wrapping a session destruction callback. + * + * @author Juergen Hoeller + * @since 3.0 + * @see RequestAttributes#registerDestructionCallback + * @see ServletRequestAttributes#registerSessionDestructionCallback + */ +@SuppressWarnings("serial") +public class DestructionCallbackBindingListener implements HttpSessionBindingListener, Serializable { + + private final Runnable destructionCallback; + + + /** + * Create a new DestructionCallbackBindingListener for the given callback. + * @param destructionCallback the Runnable to execute when this listener + * object gets unbound from the session + */ + public DestructionCallbackBindingListener(Runnable destructionCallback) { + this.destructionCallback = destructionCallback; + } + + + @Override + public void valueBound(HttpSessionBindingEvent event) { + } + + @Override + public void valueUnbound(HttpSessionBindingEvent event) { + this.destructionCallback.run(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/FacesRequestAttributes.java b/spring-web/src/main/java/org/springframework/web/context/request/FacesRequestAttributes.java new file mode 100644 index 0000000000000000000000000000000000000000..9937a0c076fd7857ab96ef846ef225881c513215 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/FacesRequestAttributes.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.lang.reflect.Method; +import java.util.Map; + +import javax.faces.context.ExternalContext; +import javax.faces.context.FacesContext; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * {@link RequestAttributes} adapter for a JSF {@link javax.faces.context.FacesContext}. + * Used as default in a JSF environment, wrapping the current FacesContext. + * + *

NOTE: In contrast to {@link ServletRequestAttributes}, this variant does + * not support destruction callbacks for scoped attributes, neither for the + * request scope nor for the session scope. If you rely on such implicit destruction + * callbacks, consider defining a Spring {@link RequestContextListener} in your + * {@code web.xml}. + * + *

Requires JSF 2.0 or higher, as of Spring 4.0. + * + * @author Juergen Hoeller + * @since 2.5.2 + * @see javax.faces.context.FacesContext#getExternalContext() + * @see javax.faces.context.ExternalContext#getRequestMap() + * @see javax.faces.context.ExternalContext#getSessionMap() + * @see RequestContextHolder#currentRequestAttributes() + */ +public class FacesRequestAttributes implements RequestAttributes { + + /** + * We'll create a lot of these objects, so we don't want a new logger every time. + */ + private static final Log logger = LogFactory.getLog(FacesRequestAttributes.class); + + private final FacesContext facesContext; + + + /** + * Create a new FacesRequestAttributes adapter for the given FacesContext. + * @param facesContext the current FacesContext + * @see javax.faces.context.FacesContext#getCurrentInstance() + */ + public FacesRequestAttributes(FacesContext facesContext) { + Assert.notNull(facesContext, "FacesContext must not be null"); + this.facesContext = facesContext; + } + + + /** + * Return the JSF FacesContext that this adapter operates on. + */ + protected final FacesContext getFacesContext() { + return this.facesContext; + } + + /** + * Return the JSF ExternalContext that this adapter operates on. + * @see javax.faces.context.FacesContext#getExternalContext() + */ + protected final ExternalContext getExternalContext() { + return getFacesContext().getExternalContext(); + } + + /** + * Return the JSF attribute Map for the specified scope. + * @param scope constant indicating request or session scope + * @return the Map representation of the attributes in the specified scope + * @see #SCOPE_REQUEST + * @see #SCOPE_SESSION + */ + protected Map getAttributeMap(int scope) { + if (scope == SCOPE_REQUEST) { + return getExternalContext().getRequestMap(); + } + else { + return getExternalContext().getSessionMap(); + } + } + + + @Override + public Object getAttribute(String name, int scope) { + return getAttributeMap(scope).get(name); + } + + @Override + public void setAttribute(String name, Object value, int scope) { + getAttributeMap(scope).put(name, value); + } + + @Override + public void removeAttribute(String name, int scope) { + getAttributeMap(scope).remove(name); + } + + @Override + public String[] getAttributeNames(int scope) { + return StringUtils.toStringArray(getAttributeMap(scope).keySet()); + } + + @Override + public void registerDestructionCallback(String name, Runnable callback, int scope) { + if (logger.isWarnEnabled()) { + logger.warn("Could not register destruction callback [" + callback + "] for attribute '" + name + + "' because FacesRequestAttributes does not support such callbacks"); + } + } + + @Override + public Object resolveReference(String key) { + if (REFERENCE_REQUEST.equals(key)) { + return getExternalContext().getRequest(); + } + else if (REFERENCE_SESSION.equals(key)) { + return getExternalContext().getSession(true); + } + else if ("application".equals(key)) { + return getExternalContext().getContext(); + } + else if ("requestScope".equals(key)) { + return getExternalContext().getRequestMap(); + } + else if ("sessionScope".equals(key)) { + return getExternalContext().getSessionMap(); + } + else if ("applicationScope".equals(key)) { + return getExternalContext().getApplicationMap(); + } + else if ("facesContext".equals(key)) { + return getFacesContext(); + } + else if ("cookie".equals(key)) { + return getExternalContext().getRequestCookieMap(); + } + else if ("header".equals(key)) { + return getExternalContext().getRequestHeaderMap(); + } + else if ("headerValues".equals(key)) { + return getExternalContext().getRequestHeaderValuesMap(); + } + else if ("param".equals(key)) { + return getExternalContext().getRequestParameterMap(); + } + else if ("paramValues".equals(key)) { + return getExternalContext().getRequestParameterValuesMap(); + } + else if ("initParam".equals(key)) { + return getExternalContext().getInitParameterMap(); + } + else if ("view".equals(key)) { + return getFacesContext().getViewRoot(); + } + else if ("viewScope".equals(key)) { + return getFacesContext().getViewRoot().getViewMap(); + } + else if ("flash".equals(key)) { + return getExternalContext().getFlash(); + } + else if ("resource".equals(key)) { + return getFacesContext().getApplication().getResourceHandler(); + } + else { + return null; + } + } + + @Override + public String getSessionId() { + Object session = getExternalContext().getSession(true); + try { + // HttpSession has a getId() method. + Method getIdMethod = session.getClass().getMethod("getId"); + return String.valueOf(ReflectionUtils.invokeMethod(getIdMethod, session)); + } + catch (NoSuchMethodException ex) { + throw new IllegalStateException("Session object [" + session + "] does not have a getId() method"); + } + } + + @Override + public Object getSessionMutex() { + // Enforce presence of a session first to allow listeners to create the mutex attribute + ExternalContext externalContext = getExternalContext(); + Object session = externalContext.getSession(true); + Object mutex = externalContext.getSessionMap().get(WebUtils.SESSION_MUTEX_ATTRIBUTE); + if (mutex == null) { + mutex = (session != null ? session : externalContext); + } + return mutex; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/FacesWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/FacesWebRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..d3d56213fa8d49d379d604dd29dcbec2130c462d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/FacesWebRequest.java @@ -0,0 +1,195 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.security.Principal; +import java.util.Iterator; +import java.util.Locale; +import java.util.Map; + +import javax.faces.context.ExternalContext; +import javax.faces.context.FacesContext; + +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * {@link WebRequest} adapter for a JSF {@link javax.faces.context.FacesContext}. + * + *

Requires JSF 2.0 or higher, as of Spring 4.0. + * + * @author Juergen Hoeller + * @since 2.5.2 + */ +public class FacesWebRequest extends FacesRequestAttributes implements NativeWebRequest { + + /** + * Create a new FacesWebRequest adapter for the given FacesContext. + * @param facesContext the current FacesContext + * @see javax.faces.context.FacesContext#getCurrentInstance() + */ + public FacesWebRequest(FacesContext facesContext) { + super(facesContext); + } + + + @Override + public Object getNativeRequest() { + return getExternalContext().getRequest(); + } + + @Override + public Object getNativeResponse() { + return getExternalContext().getResponse(); + } + + @Override + @SuppressWarnings("unchecked") + public T getNativeRequest(@Nullable Class requiredType) { + if (requiredType != null) { + Object request = getExternalContext().getRequest(); + if (requiredType.isInstance(request)) { + return (T) request; + } + } + return null; + } + + @Override + @SuppressWarnings("unchecked") + public T getNativeResponse(@Nullable Class requiredType) { + if (requiredType != null) { + Object response = getExternalContext().getResponse(); + if (requiredType.isInstance(response)) { + return (T) response; + } + } + return null; + } + + + @Override + @Nullable + public String getHeader(String headerName) { + return getExternalContext().getRequestHeaderMap().get(headerName); + } + + @Override + @Nullable + public String[] getHeaderValues(String headerName) { + return getExternalContext().getRequestHeaderValuesMap().get(headerName); + } + + @Override + public Iterator getHeaderNames() { + return getExternalContext().getRequestHeaderMap().keySet().iterator(); + } + + @Override + @Nullable + public String getParameter(String paramName) { + return getExternalContext().getRequestParameterMap().get(paramName); + } + + @Override + public Iterator getParameterNames() { + return getExternalContext().getRequestParameterNames(); + } + + @Override + @Nullable + public String[] getParameterValues(String paramName) { + return getExternalContext().getRequestParameterValuesMap().get(paramName); + } + + @Override + public Map getParameterMap() { + return getExternalContext().getRequestParameterValuesMap(); + } + + @Override + public Locale getLocale() { + return getFacesContext().getExternalContext().getRequestLocale(); + } + + @Override + public String getContextPath() { + return getFacesContext().getExternalContext().getRequestContextPath(); + } + + @Override + @Nullable + public String getRemoteUser() { + return getFacesContext().getExternalContext().getRemoteUser(); + } + + @Override + @Nullable + public Principal getUserPrincipal() { + return getFacesContext().getExternalContext().getUserPrincipal(); + } + + @Override + public boolean isUserInRole(String role) { + return getFacesContext().getExternalContext().isUserInRole(role); + } + + @Override + public boolean isSecure() { + return false; + } + + @Override + public boolean checkNotModified(long lastModifiedTimestamp) { + return false; + } + + @Override + public boolean checkNotModified(@Nullable String eTag) { + return false; + } + + @Override + public boolean checkNotModified(@Nullable String etag, long lastModifiedTimestamp) { + return false; + } + + @Override + public String getDescription(boolean includeClientInfo) { + ExternalContext externalContext = getExternalContext(); + StringBuilder sb = new StringBuilder(); + sb.append("context=").append(externalContext.getRequestContextPath()); + if (includeClientInfo) { + Object session = externalContext.getSession(false); + if (session != null) { + sb.append(";session=").append(getSessionId()); + } + String user = externalContext.getRemoteUser(); + if (StringUtils.hasLength(user)) { + sb.append(";user=").append(user); + } + } + return sb.toString(); + } + + + @Override + public String toString() { + return "FacesWebRequest: " + getDescription(true); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/NativeWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/NativeWebRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..cf2e3247fb7409397f288a79d9d60a18a97d3074 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/NativeWebRequest.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.springframework.lang.Nullable; + +/** + * Extension of the {@link WebRequest} interface, exposing the + * native request and response objects in a generic fashion. + * + *

Mainly intended for framework-internal usage, + * in particular for generic argument resolution code. + * + * @author Juergen Hoeller + * @since 2.5.2 + */ +public interface NativeWebRequest extends WebRequest { + + /** + * Return the underlying native request object. + * @see javax.servlet.http.HttpServletRequest + */ + Object getNativeRequest(); + + /** + * Return the underlying native response object, if any. + * @see javax.servlet.http.HttpServletResponse + */ + @Nullable + Object getNativeResponse(); + + /** + * Return the underlying native request object, if available. + * @param requiredType the desired type of request object + * @return the matching request object, or {@code null} if none + * of that type is available + * @see javax.servlet.http.HttpServletRequest + */ + @Nullable + T getNativeRequest(@Nullable Class requiredType); + + /** + * Return the underlying native response object, if available. + * @param requiredType the desired type of response object + * @return the matching response object, or {@code null} if none + * of that type is available + * @see javax.servlet.http.HttpServletResponse + */ + @Nullable + T getNativeResponse(@Nullable Class requiredType); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/RequestAttributes.java b/spring-web/src/main/java/org/springframework/web/context/request/RequestAttributes.java new file mode 100644 index 0000000000000000000000000000000000000000..00becf5e3d975163d083a534d3e6105e5fe434b1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/RequestAttributes.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.springframework.lang.Nullable; + +/** + * Abstraction for accessing attribute objects associated with a request. + * Supports access to request-scoped attributes as well as to session-scoped + * attributes, with the optional notion of a "global session". + * + *

Can be implemented for any kind of request/session mechanism, + * in particular for servlet requests. + * + * @author Juergen Hoeller + * @since 2.0 + * @see ServletRequestAttributes + */ +public interface RequestAttributes { + + /** + * Constant that indicates request scope. + */ + int SCOPE_REQUEST = 0; + + /** + * Constant that indicates session scope. + *

This preferably refers to a locally isolated session, if such + * a distinction is available. + * Else, it simply refers to the common session. + */ + int SCOPE_SESSION = 1; + + + /** + * Name of the standard reference to the request object: "request". + * @see #resolveReference + */ + String REFERENCE_REQUEST = "request"; + + /** + * Name of the standard reference to the session object: "session". + * @see #resolveReference + */ + String REFERENCE_SESSION = "session"; + + + /** + * Return the value for the scoped attribute of the given name, if any. + * @param name the name of the attribute + * @param scope the scope identifier + * @return the current attribute value, or {@code null} if not found + */ + @Nullable + Object getAttribute(String name, int scope); + + /** + * Set the value for the scoped attribute of the given name, + * replacing an existing value (if any). + * @param name the name of the attribute + * @param scope the scope identifier + * @param value the value for the attribute + */ + void setAttribute(String name, Object value, int scope); + + /** + * Remove the scoped attribute of the given name, if it exists. + *

Note that an implementation should also remove a registered destruction + * callback for the specified attribute, if any. It does, however, not + * need to execute a registered destruction callback in this case, + * since the object will be destroyed by the caller (if appropriate). + * @param name the name of the attribute + * @param scope the scope identifier + */ + void removeAttribute(String name, int scope); + + /** + * Retrieve the names of all attributes in the scope. + * @param scope the scope identifier + * @return the attribute names as String array + */ + String[] getAttributeNames(int scope); + + /** + * Register a callback to be executed on destruction of the + * specified attribute in the given scope. + *

Implementations should do their best to execute the callback + * at the appropriate time: that is, at request completion or session + * termination, respectively. If such a callback is not supported by the + * underlying runtime environment, the callback must be ignored + * and a corresponding warning should be logged. + *

Note that 'destruction' usually corresponds to destruction of the + * entire scope, not to the individual attribute having been explicitly + * removed by the application. If an attribute gets removed via this + * facade's {@link #removeAttribute(String, int)} method, any registered + * destruction callback should be disabled as well, assuming that the + * removed object will be reused or manually destroyed. + *

NOTE: Callback objects should generally be serializable if + * they are being registered for a session scope. Otherwise the callback + * (or even the entire session) might not survive web app restarts. + * @param name the name of the attribute to register the callback for + * @param callback the destruction callback to be executed + * @param scope the scope identifier + */ + void registerDestructionCallback(String name, Runnable callback, int scope); + + /** + * Resolve the contextual reference for the given key, if any. + *

At a minimum: the HttpServletRequest reference for key "request", and + * the HttpSession reference for key "session". + * @param key the contextual key + * @return the corresponding object, or {@code null} if none found + */ + @Nullable + Object resolveReference(String key); + + /** + * Return an id for the current underlying session. + * @return the session id as String (never {@code null}) + */ + String getSessionId(); + + /** + * Expose the best available mutex for the underlying session: + * that is, an object to synchronize on for the underlying session. + * @return the session mutex to use (never {@code null}) + */ + Object getSessionMutex(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/RequestContextHolder.java b/spring-web/src/main/java/org/springframework/web/context/request/RequestContextHolder.java new file mode 100644 index 0000000000000000000000000000000000000000..a52c9915afcc2d2429ad2fd20855bfe98a7b163d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/RequestContextHolder.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import javax.faces.context.FacesContext; + +import org.springframework.core.NamedInheritableThreadLocal; +import org.springframework.core.NamedThreadLocal; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; + +/** + * Holder class to expose the web request in the form of a thread-bound + * {@link RequestAttributes} object. The request will be inherited + * by any child threads spawned by the current thread if the + * {@code inheritable} flag is set to {@code true}. + * + *

Use {@link RequestContextListener} or + * {@link org.springframework.web.filter.RequestContextFilter} to expose + * the current web request. Note that + * {@link org.springframework.web.servlet.DispatcherServlet} + * already exposes the current request by default. + * + * @author Juergen Hoeller + * @author Rod Johnson + * @since 2.0 + * @see RequestContextListener + * @see org.springframework.web.filter.RequestContextFilter + * @see org.springframework.web.servlet.DispatcherServlet + */ +public abstract class RequestContextHolder { + + private static final boolean jsfPresent = + ClassUtils.isPresent("javax.faces.context.FacesContext", RequestContextHolder.class.getClassLoader()); + + private static final ThreadLocal requestAttributesHolder = + new NamedThreadLocal<>("Request attributes"); + + private static final ThreadLocal inheritableRequestAttributesHolder = + new NamedInheritableThreadLocal<>("Request context"); + + + /** + * Reset the RequestAttributes for the current thread. + */ + public static void resetRequestAttributes() { + requestAttributesHolder.remove(); + inheritableRequestAttributesHolder.remove(); + } + + /** + * Bind the given RequestAttributes to the current thread, + * not exposing it as inheritable for child threads. + * @param attributes the RequestAttributes to expose + * @see #setRequestAttributes(RequestAttributes, boolean) + */ + public static void setRequestAttributes(@Nullable RequestAttributes attributes) { + setRequestAttributes(attributes, false); + } + + /** + * Bind the given RequestAttributes to the current thread. + * @param attributes the RequestAttributes to expose, + * or {@code null} to reset the thread-bound context + * @param inheritable whether to expose the RequestAttributes as inheritable + * for child threads (using an {@link InheritableThreadLocal}) + */ + public static void setRequestAttributes(@Nullable RequestAttributes attributes, boolean inheritable) { + if (attributes == null) { + resetRequestAttributes(); + } + else { + if (inheritable) { + inheritableRequestAttributesHolder.set(attributes); + requestAttributesHolder.remove(); + } + else { + requestAttributesHolder.set(attributes); + inheritableRequestAttributesHolder.remove(); + } + } + } + + /** + * Return the RequestAttributes currently bound to the thread. + * @return the RequestAttributes currently bound to the thread, + * or {@code null} if none bound + */ + @Nullable + public static RequestAttributes getRequestAttributes() { + RequestAttributes attributes = requestAttributesHolder.get(); + if (attributes == null) { + attributes = inheritableRequestAttributesHolder.get(); + } + return attributes; + } + + /** + * Return the RequestAttributes currently bound to the thread. + *

Exposes the previously bound RequestAttributes instance, if any. + * Falls back to the current JSF FacesContext, if any. + * @return the RequestAttributes currently bound to the thread + * @throws IllegalStateException if no RequestAttributes object + * is bound to the current thread + * @see #setRequestAttributes + * @see ServletRequestAttributes + * @see FacesRequestAttributes + * @see javax.faces.context.FacesContext#getCurrentInstance() + */ + public static RequestAttributes currentRequestAttributes() throws IllegalStateException { + RequestAttributes attributes = getRequestAttributes(); + if (attributes == null) { + if (jsfPresent) { + attributes = FacesRequestAttributesFactory.getFacesRequestAttributes(); + } + if (attributes == null) { + throw new IllegalStateException("No thread-bound request found: " + + "Are you referring to request attributes outside of an actual web request, " + + "or processing a request outside of the originally receiving thread? " + + "If you are actually operating within a web request and still receive this message, " + + "your code is probably running outside of DispatcherServlet: " + + "In this case, use RequestContextListener or RequestContextFilter to expose the current request."); + } + } + return attributes; + } + + + /** + * Inner class to avoid hard-coded JSF dependency. + */ + private static class FacesRequestAttributesFactory { + + @Nullable + public static RequestAttributes getFacesRequestAttributes() { + FacesContext facesContext = FacesContext.getCurrentInstance(); + return (facesContext != null ? new FacesRequestAttributes(facesContext) : null); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/RequestContextListener.java b/spring-web/src/main/java/org/springframework/web/context/request/RequestContextListener.java new file mode 100644 index 0000000000000000000000000000000000000000..ef42dba98eee7b69af615dd7e5ed4c3aedf94858 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/RequestContextListener.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import javax.servlet.ServletRequestEvent; +import javax.servlet.ServletRequestListener; +import javax.servlet.http.HttpServletRequest; + +import org.springframework.context.i18n.LocaleContextHolder; + +/** + * Servlet listener that exposes the request to the current thread, + * through both {@link org.springframework.context.i18n.LocaleContextHolder} and + * {@link RequestContextHolder}. To be registered as listener in {@code web.xml}. + * + *

Alternatively, Spring's {@link org.springframework.web.filter.RequestContextFilter} + * and Spring's {@link org.springframework.web.servlet.DispatcherServlet} also expose + * the same request context to the current thread. In contrast to this listener, + * advanced options are available there (e.g. "threadContextInheritable"). + * + *

This listener is mainly for use with third-party servlets, e.g. the JSF FacesServlet. + * Within Spring's own web support, DispatcherServlet's processing is perfectly sufficient. + * + * @author Juergen Hoeller + * @since 2.0 + * @see javax.servlet.ServletRequestListener + * @see org.springframework.context.i18n.LocaleContextHolder + * @see RequestContextHolder + * @see org.springframework.web.filter.RequestContextFilter + * @see org.springframework.web.servlet.DispatcherServlet + */ +public class RequestContextListener implements ServletRequestListener { + + private static final String REQUEST_ATTRIBUTES_ATTRIBUTE = + RequestContextListener.class.getName() + ".REQUEST_ATTRIBUTES"; + + + @Override + public void requestInitialized(ServletRequestEvent requestEvent) { + if (!(requestEvent.getServletRequest() instanceof HttpServletRequest)) { + throw new IllegalArgumentException( + "Request is not an HttpServletRequest: " + requestEvent.getServletRequest()); + } + HttpServletRequest request = (HttpServletRequest) requestEvent.getServletRequest(); + ServletRequestAttributes attributes = new ServletRequestAttributes(request); + request.setAttribute(REQUEST_ATTRIBUTES_ATTRIBUTE, attributes); + LocaleContextHolder.setLocale(request.getLocale()); + RequestContextHolder.setRequestAttributes(attributes); + } + + @Override + public void requestDestroyed(ServletRequestEvent requestEvent) { + ServletRequestAttributes attributes = null; + Object reqAttr = requestEvent.getServletRequest().getAttribute(REQUEST_ATTRIBUTES_ATTRIBUTE); + if (reqAttr instanceof ServletRequestAttributes) { + attributes = (ServletRequestAttributes) reqAttr; + } + RequestAttributes threadAttributes = RequestContextHolder.getRequestAttributes(); + if (threadAttributes != null) { + // We're assumably within the original request thread... + LocaleContextHolder.resetLocaleContext(); + RequestContextHolder.resetRequestAttributes(); + if (attributes == null && threadAttributes instanceof ServletRequestAttributes) { + attributes = (ServletRequestAttributes) threadAttributes; + } + } + if (attributes != null) { + attributes.requestCompleted(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/RequestScope.java b/spring-web/src/main/java/org/springframework/web/context/request/RequestScope.java new file mode 100644 index 0000000000000000000000000000000000000000..4919ba598d12285bb9f52162587e71e39de6234b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/RequestScope.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.springframework.lang.Nullable; + +/** + * Request-backed {@link org.springframework.beans.factory.config.Scope} + * implementation. + * + *

Relies on a thread-bound {@link RequestAttributes} instance, which + * can be exported through {@link RequestContextListener}, + * {@link org.springframework.web.filter.RequestContextFilter} or + * {@link org.springframework.web.servlet.DispatcherServlet}. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Rob Harrop + * @since 2.0 + * @see RequestContextHolder#currentRequestAttributes() + * @see RequestAttributes#SCOPE_REQUEST + * @see RequestContextListener + * @see org.springframework.web.filter.RequestContextFilter + * @see org.springframework.web.servlet.DispatcherServlet + */ +public class RequestScope extends AbstractRequestAttributesScope { + + @Override + protected int getScope() { + return RequestAttributes.SCOPE_REQUEST; + } + + /** + * There is no conversation id concept for a request, so this method + * returns {@code null}. + */ + @Override + @Nullable + public String getConversationId() { + return null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/ServletRequestAttributes.java b/spring-web/src/main/java/org/springframework/web/context/request/ServletRequestAttributes.java new file mode 100644 index 0000000000000000000000000000000000000000..88a46f8d84fbc74b881a401c92b5d142c5732c05 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/ServletRequestAttributes.java @@ -0,0 +1,331 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.NumberUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * Servlet-based implementation of the {@link RequestAttributes} interface. + * + *

Accesses objects from servlet request and HTTP session scope, + * with no distinction between "session" and "global session". + * + * @author Juergen Hoeller + * @since 2.0 + * @see javax.servlet.ServletRequest#getAttribute + * @see javax.servlet.http.HttpSession#getAttribute + */ +public class ServletRequestAttributes extends AbstractRequestAttributes { + + /** + * Constant identifying the {@link String} prefixed to the name of a + * destruction callback when it is stored in a {@link HttpSession}. + */ + public static final String DESTRUCTION_CALLBACK_NAME_PREFIX = + ServletRequestAttributes.class.getName() + ".DESTRUCTION_CALLBACK."; + + protected static final Set> immutableValueTypes = new HashSet<>(16); + + static { + immutableValueTypes.addAll(NumberUtils.STANDARD_NUMBER_TYPES); + immutableValueTypes.add(Boolean.class); + immutableValueTypes.add(Character.class); + immutableValueTypes.add(String.class); + } + + + private final HttpServletRequest request; + + @Nullable + private HttpServletResponse response; + + @Nullable + private volatile HttpSession session; + + private final Map sessionAttributesToUpdate = new ConcurrentHashMap<>(1); + + + /** + * Create a new ServletRequestAttributes instance for the given request. + * @param request current HTTP request + */ + public ServletRequestAttributes(HttpServletRequest request) { + Assert.notNull(request, "Request must not be null"); + this.request = request; + } + + /** + * Create a new ServletRequestAttributes instance for the given request. + * @param request current HTTP request + * @param response current HTTP response (for optional exposure) + */ + public ServletRequestAttributes(HttpServletRequest request, @Nullable HttpServletResponse response) { + this(request); + this.response = response; + } + + + /** + * Exposes the native {@link HttpServletRequest} that we're wrapping. + */ + public final HttpServletRequest getRequest() { + return this.request; + } + + /** + * Exposes the native {@link HttpServletResponse} that we're wrapping (if any). + */ + @Nullable + public final HttpServletResponse getResponse() { + return this.response; + } + + /** + * Exposes the {@link HttpSession} that we're wrapping. + * @param allowCreate whether to allow creation of a new session if none exists yet + */ + @Nullable + protected final HttpSession getSession(boolean allowCreate) { + if (isRequestActive()) { + HttpSession session = this.request.getSession(allowCreate); + this.session = session; + return session; + } + else { + // Access through stored session reference, if any... + HttpSession session = this.session; + if (session == null) { + if (allowCreate) { + throw new IllegalStateException( + "No session found and request already completed - cannot create new session!"); + } + else { + session = this.request.getSession(false); + this.session = session; + } + } + return session; + } + } + + private HttpSession obtainSession() { + HttpSession session = getSession(true); + Assert.state(session != null, "No HttpSession"); + return session; + } + + + @Override + public Object getAttribute(String name, int scope) { + if (scope == SCOPE_REQUEST) { + if (!isRequestActive()) { + throw new IllegalStateException( + "Cannot ask for request attribute - request is not active anymore!"); + } + return this.request.getAttribute(name); + } + else { + HttpSession session = getSession(false); + if (session != null) { + try { + Object value = session.getAttribute(name); + if (value != null) { + this.sessionAttributesToUpdate.put(name, value); + } + return value; + } + catch (IllegalStateException ex) { + // Session invalidated - shouldn't usually happen. + } + } + return null; + } + } + + @Override + public void setAttribute(String name, Object value, int scope) { + if (scope == SCOPE_REQUEST) { + if (!isRequestActive()) { + throw new IllegalStateException( + "Cannot set request attribute - request is not active anymore!"); + } + this.request.setAttribute(name, value); + } + else { + HttpSession session = obtainSession(); + this.sessionAttributesToUpdate.remove(name); + session.setAttribute(name, value); + } + } + + @Override + public void removeAttribute(String name, int scope) { + if (scope == SCOPE_REQUEST) { + if (isRequestActive()) { + removeRequestDestructionCallback(name); + this.request.removeAttribute(name); + } + } + else { + HttpSession session = getSession(false); + if (session != null) { + this.sessionAttributesToUpdate.remove(name); + try { + session.removeAttribute(DESTRUCTION_CALLBACK_NAME_PREFIX + name); + session.removeAttribute(name); + } + catch (IllegalStateException ex) { + // Session invalidated - shouldn't usually happen. + } + } + } + } + + @Override + public String[] getAttributeNames(int scope) { + if (scope == SCOPE_REQUEST) { + if (!isRequestActive()) { + throw new IllegalStateException( + "Cannot ask for request attributes - request is not active anymore!"); + } + return StringUtils.toStringArray(this.request.getAttributeNames()); + } + else { + HttpSession session = getSession(false); + if (session != null) { + try { + return StringUtils.toStringArray(session.getAttributeNames()); + } + catch (IllegalStateException ex) { + // Session invalidated - shouldn't usually happen. + } + } + return new String[0]; + } + } + + @Override + public void registerDestructionCallback(String name, Runnable callback, int scope) { + if (scope == SCOPE_REQUEST) { + registerRequestDestructionCallback(name, callback); + } + else { + registerSessionDestructionCallback(name, callback); + } + } + + @Override + public Object resolveReference(String key) { + if (REFERENCE_REQUEST.equals(key)) { + return this.request; + } + else if (REFERENCE_SESSION.equals(key)) { + return getSession(true); + } + else { + return null; + } + } + + @Override + public String getSessionId() { + return obtainSession().getId(); + } + + @Override + public Object getSessionMutex() { + return WebUtils.getSessionMutex(obtainSession()); + } + + + /** + * Update all accessed session attributes through {@code session.setAttribute} + * calls, explicitly indicating to the container that they might have been modified. + */ + @Override + protected void updateAccessedSessionAttributes() { + if (!this.sessionAttributesToUpdate.isEmpty()) { + // Update all affected session attributes. + HttpSession session = getSession(false); + if (session != null) { + try { + for (Map.Entry entry : this.sessionAttributesToUpdate.entrySet()) { + String name = entry.getKey(); + Object newValue = entry.getValue(); + Object oldValue = session.getAttribute(name); + if (oldValue == newValue && !isImmutableSessionAttribute(name, newValue)) { + session.setAttribute(name, newValue); + } + } + } + catch (IllegalStateException ex) { + // Session invalidated - shouldn't usually happen. + } + } + this.sessionAttributesToUpdate.clear(); + } + } + + /** + * Determine whether the given value is to be considered as an immutable session + * attribute, that is, doesn't have to be re-set via {@code session.setAttribute} + * since its value cannot meaningfully change internally. + *

The default implementation returns {@code true} for {@code String}, + * {@code Character}, {@code Boolean} and standard {@code Number} values. + * @param name the name of the attribute + * @param value the corresponding value to check + * @return {@code true} if the value is to be considered as immutable for the + * purposes of session attribute management; {@code false} otherwise + * @see #updateAccessedSessionAttributes() + */ + protected boolean isImmutableSessionAttribute(String name, @Nullable Object value) { + return (value == null || immutableValueTypes.contains(value.getClass())); + } + + /** + * Register the given callback as to be executed after session termination. + *

Note: The callback object should be serializable in order to survive + * web app restarts. + * @param name the name of the attribute to register the callback for + * @param callback the callback to be executed for destruction + */ + protected void registerSessionDestructionCallback(String name, Runnable callback) { + HttpSession session = obtainSession(); + session.setAttribute(DESTRUCTION_CALLBACK_NAME_PREFIX + name, + new DestructionCallbackBindingListener(callback)); + } + + + @Override + public String toString() { + return this.request.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/ServletWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/ServletWebRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..dc85a71d32534b85280f4096db55a444c409ee0d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/ServletWebRequest.java @@ -0,0 +1,405 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.security.Principal; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * {@link WebRequest} adapter for an {@link javax.servlet.http.HttpServletRequest}. + * + * @author Juergen Hoeller + * @author Brian Clozel + * @author Markus Malkusch + * @since 2.0 + */ +public class ServletWebRequest extends ServletRequestAttributes implements NativeWebRequest { + + private static final String ETAG = "ETag"; + + private static final String IF_MODIFIED_SINCE = "If-Modified-Since"; + + private static final String IF_UNMODIFIED_SINCE = "If-Unmodified-Since"; + + private static final String IF_NONE_MATCH = "If-None-Match"; + + private static final String LAST_MODIFIED = "Last-Modified"; + + private static final List SAFE_METHODS = Arrays.asList("GET", "HEAD"); + + /** + * Pattern matching ETag multiple field values in headers such as "If-Match", "If-None-Match". + * @see Section 2.3 of RFC 7232 + */ + private static final Pattern ETAG_HEADER_VALUE_PATTERN = Pattern.compile("\\*|\\s*((W\\/)?(\"[^\"]*\"))\\s*,?"); + + /** + * Date formats as specified in the HTTP RFC. + * @see Section 7.1.1.1 of RFC 7231 + */ + private static final String[] DATE_FORMATS = new String[] { + "EEE, dd MMM yyyy HH:mm:ss zzz", + "EEE, dd-MMM-yy HH:mm:ss zzz", + "EEE MMM dd HH:mm:ss yyyy" + }; + + private static final TimeZone GMT = TimeZone.getTimeZone("GMT"); + + private boolean notModified = false; + + + /** + * Create a new ServletWebRequest instance for the given request. + * @param request current HTTP request + */ + public ServletWebRequest(HttpServletRequest request) { + super(request); + } + + /** + * Create a new ServletWebRequest instance for the given request/response pair. + * @param request current HTTP request + * @param response current HTTP response (for automatic last-modified handling) + */ + public ServletWebRequest(HttpServletRequest request, @Nullable HttpServletResponse response) { + super(request, response); + } + + + @Override + public Object getNativeRequest() { + return getRequest(); + } + + @Override + public Object getNativeResponse() { + return getResponse(); + } + + @Override + public T getNativeRequest(@Nullable Class requiredType) { + return WebUtils.getNativeRequest(getRequest(), requiredType); + } + + @Override + public T getNativeResponse(@Nullable Class requiredType) { + HttpServletResponse response = getResponse(); + return (response != null ? WebUtils.getNativeResponse(response, requiredType) : null); + } + + /** + * Return the HTTP method of the request. + * @since 4.0.2 + */ + @Nullable + public HttpMethod getHttpMethod() { + return HttpMethod.resolve(getRequest().getMethod()); + } + + @Override + @Nullable + public String getHeader(String headerName) { + return getRequest().getHeader(headerName); + } + + @Override + @Nullable + public String[] getHeaderValues(String headerName) { + String[] headerValues = StringUtils.toStringArray(getRequest().getHeaders(headerName)); + return (!ObjectUtils.isEmpty(headerValues) ? headerValues : null); + } + + @Override + public Iterator getHeaderNames() { + return CollectionUtils.toIterator(getRequest().getHeaderNames()); + } + + @Override + @Nullable + public String getParameter(String paramName) { + return getRequest().getParameter(paramName); + } + + @Override + @Nullable + public String[] getParameterValues(String paramName) { + return getRequest().getParameterValues(paramName); + } + + @Override + public Iterator getParameterNames() { + return CollectionUtils.toIterator(getRequest().getParameterNames()); + } + + @Override + public Map getParameterMap() { + return getRequest().getParameterMap(); + } + + @Override + public Locale getLocale() { + return getRequest().getLocale(); + } + + @Override + public String getContextPath() { + return getRequest().getContextPath(); + } + + @Override + @Nullable + public String getRemoteUser() { + return getRequest().getRemoteUser(); + } + + @Override + @Nullable + public Principal getUserPrincipal() { + return getRequest().getUserPrincipal(); + } + + @Override + public boolean isUserInRole(String role) { + return getRequest().isUserInRole(role); + } + + @Override + public boolean isSecure() { + return getRequest().isSecure(); + } + + + @Override + public boolean checkNotModified(long lastModifiedTimestamp) { + return checkNotModified(null, lastModifiedTimestamp); + } + + @Override + public boolean checkNotModified(String etag) { + return checkNotModified(etag, -1); + } + + @Override + public boolean checkNotModified(@Nullable String etag, long lastModifiedTimestamp) { + HttpServletResponse response = getResponse(); + if (this.notModified || (response != null && HttpStatus.OK.value() != response.getStatus())) { + return this.notModified; + } + + // Evaluate conditions in order of precedence. + // See https://tools.ietf.org/html/rfc7232#section-6 + + if (validateIfUnmodifiedSince(lastModifiedTimestamp)) { + if (this.notModified && response != null) { + response.setStatus(HttpStatus.PRECONDITION_FAILED.value()); + } + return this.notModified; + } + + boolean validated = validateIfNoneMatch(etag); + if (!validated) { + validateIfModifiedSince(lastModifiedTimestamp); + } + + // Update response + if (response != null) { + boolean isHttpGetOrHead = SAFE_METHODS.contains(getRequest().getMethod()); + if (this.notModified) { + response.setStatus(isHttpGetOrHead ? + HttpStatus.NOT_MODIFIED.value() : HttpStatus.PRECONDITION_FAILED.value()); + } + if (isHttpGetOrHead) { + if (lastModifiedTimestamp > 0 && parseDateValue(response.getHeader(LAST_MODIFIED)) == -1) { + response.setDateHeader(LAST_MODIFIED, lastModifiedTimestamp); + } + if (StringUtils.hasLength(etag) && response.getHeader(ETAG) == null) { + response.setHeader(ETAG, padEtagIfNecessary(etag)); + } + } + } + + return this.notModified; + } + + private boolean validateIfUnmodifiedSince(long lastModifiedTimestamp) { + if (lastModifiedTimestamp < 0) { + return false; + } + long ifUnmodifiedSince = parseDateHeader(IF_UNMODIFIED_SINCE); + if (ifUnmodifiedSince == -1) { + return false; + } + // We will perform this validation... + this.notModified = (ifUnmodifiedSince < (lastModifiedTimestamp / 1000 * 1000)); + return true; + } + + private boolean validateIfNoneMatch(@Nullable String etag) { + if (!StringUtils.hasLength(etag)) { + return false; + } + + Enumeration ifNoneMatch; + try { + ifNoneMatch = getRequest().getHeaders(IF_NONE_MATCH); + } + catch (IllegalArgumentException ex) { + return false; + } + if (!ifNoneMatch.hasMoreElements()) { + return false; + } + + // We will perform this validation... + etag = padEtagIfNecessary(etag); + if (etag.startsWith("W/")) { + etag = etag.substring(2); + } + while (ifNoneMatch.hasMoreElements()) { + String clientETags = ifNoneMatch.nextElement(); + Matcher etagMatcher = ETAG_HEADER_VALUE_PATTERN.matcher(clientETags); + // Compare weak/strong ETags as per https://tools.ietf.org/html/rfc7232#section-2.3 + while (etagMatcher.find()) { + if (StringUtils.hasLength(etagMatcher.group()) && etag.equals(etagMatcher.group(3))) { + this.notModified = true; + break; + } + } + } + + return true; + } + + private String padEtagIfNecessary(String etag) { + if (!StringUtils.hasLength(etag)) { + return etag; + } + if ((etag.startsWith("\"") || etag.startsWith("W/\"")) && etag.endsWith("\"")) { + return etag; + } + return "\"" + etag + "\""; + } + + private boolean validateIfModifiedSince(long lastModifiedTimestamp) { + if (lastModifiedTimestamp < 0) { + return false; + } + long ifModifiedSince = parseDateHeader(IF_MODIFIED_SINCE); + if (ifModifiedSince == -1) { + return false; + } + // We will perform this validation... + this.notModified = ifModifiedSince >= (lastModifiedTimestamp / 1000 * 1000); + return true; + } + + public boolean isNotModified() { + return this.notModified; + } + + private long parseDateHeader(String headerName) { + long dateValue = -1; + try { + dateValue = getRequest().getDateHeader(headerName); + } + catch (IllegalArgumentException ex) { + String headerValue = getHeader(headerName); + // Possibly an IE 10 style value: "Wed, 09 Apr 2014 09:57:42 GMT; length=13774" + if (headerValue != null) { + int separatorIndex = headerValue.indexOf(';'); + if (separatorIndex != -1) { + String datePart = headerValue.substring(0, separatorIndex); + dateValue = parseDateValue(datePart); + } + } + } + return dateValue; + } + + private long parseDateValue(@Nullable String headerValue) { + if (headerValue == null) { + // No header value sent at all + return -1; + } + if (headerValue.length() >= 3) { + // Short "0" or "-1" like values are never valid HTTP date headers... + // Let's only bother with SimpleDateFormat parsing for long enough values. + for (String dateFormat : DATE_FORMATS) { + SimpleDateFormat simpleDateFormat = new SimpleDateFormat(dateFormat, Locale.US); + simpleDateFormat.setTimeZone(GMT); + try { + return simpleDateFormat.parse(headerValue).getTime(); + } + catch (ParseException ex) { + // ignore + } + } + } + return -1; + } + + @Override + public String getDescription(boolean includeClientInfo) { + HttpServletRequest request = getRequest(); + StringBuilder sb = new StringBuilder(); + sb.append("uri=").append(request.getRequestURI()); + if (includeClientInfo) { + String client = request.getRemoteAddr(); + if (StringUtils.hasLength(client)) { + sb.append(";client=").append(client); + } + HttpSession session = request.getSession(false); + if (session != null) { + sb.append(";session=").append(session.getId()); + } + String user = request.getRemoteUser(); + if (StringUtils.hasLength(user)) { + sb.append(";user=").append(user); + } + } + return sb.toString(); + } + + + @Override + public String toString() { + return "ServletWebRequest: " + getDescription(true); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/SessionScope.java b/spring-web/src/main/java/org/springframework/web/context/request/SessionScope.java new file mode 100644 index 0000000000000000000000000000000000000000..3c2c9d2ad05012afa3534d2e02cf6301321cc77f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/SessionScope.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.lang.Nullable; + +/** + * Session-backed {@link org.springframework.beans.factory.config.Scope} + * implementation. + * + *

Relies on a thread-bound {@link RequestAttributes} instance, which + * can be exported through {@link RequestContextListener}, + * {@link org.springframework.web.filter.RequestContextFilter} or + * {@link org.springframework.web.servlet.DispatcherServlet}. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Rob Harrop + * @since 2.0 + * @see RequestContextHolder#currentRequestAttributes() + * @see RequestAttributes#SCOPE_SESSION + * @see RequestContextListener + * @see org.springframework.web.filter.RequestContextFilter + * @see org.springframework.web.servlet.DispatcherServlet + */ +public class SessionScope extends AbstractRequestAttributesScope { + + @Override + protected int getScope() { + return RequestAttributes.SCOPE_SESSION; + } + + @Override + public String getConversationId() { + return RequestContextHolder.currentRequestAttributes().getSessionId(); + } + + @Override + public Object get(String name, ObjectFactory objectFactory) { + Object mutex = RequestContextHolder.currentRequestAttributes().getSessionMutex(); + synchronized (mutex) { + return super.get(name, objectFactory); + } + } + + @Override + @Nullable + public Object remove(String name) { + Object mutex = RequestContextHolder.currentRequestAttributes().getSessionMutex(); + synchronized (mutex) { + return super.remove(name); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/WebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/WebRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..6985a3e564bf508580698c94803ac9d05a78a36b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/WebRequest.java @@ -0,0 +1,248 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.security.Principal; +import java.util.Iterator; +import java.util.Locale; +import java.util.Map; + +import org.springframework.lang.Nullable; + +/** + * Generic interface for a web request. Mainly intended for generic web + * request interceptors, giving them access to general request metadata, + * not for actual handling of the request. + * + * @author Juergen Hoeller + * @author Brian Clozel + * @since 2.0 + * @see WebRequestInterceptor + */ +public interface WebRequest extends RequestAttributes { + + /** + * Return the request header of the given name, or {@code null} if none. + *

Retrieves the first header value in case of a multi-value header. + * @since 3.0 + * @see javax.servlet.http.HttpServletRequest#getHeader(String) + */ + @Nullable + String getHeader(String headerName); + + /** + * Return the request header values for the given header name, + * or {@code null} if none. + *

A single-value header will be exposed as an array with a single element. + * @since 3.0 + * @see javax.servlet.http.HttpServletRequest#getHeaders(String) + */ + @Nullable + String[] getHeaderValues(String headerName); + + /** + * Return a Iterator over request header names. + * @since 3.0 + * @see javax.servlet.http.HttpServletRequest#getHeaderNames() + */ + Iterator getHeaderNames(); + + /** + * Return the request parameter of the given name, or {@code null} if none. + *

Retrieves the first parameter value in case of a multi-value parameter. + * @see javax.servlet.http.HttpServletRequest#getParameter(String) + */ + @Nullable + String getParameter(String paramName); + + /** + * Return the request parameter values for the given parameter name, + * or {@code null} if none. + *

A single-value parameter will be exposed as an array with a single element. + * @see javax.servlet.http.HttpServletRequest#getParameterValues(String) + */ + @Nullable + String[] getParameterValues(String paramName); + + /** + * Return a Iterator over request parameter names. + * @since 3.0 + * @see javax.servlet.http.HttpServletRequest#getParameterNames() + */ + Iterator getParameterNames(); + + /** + * Return a immutable Map of the request parameters, with parameter names as map keys + * and parameter values as map values. The map values will be of type String array. + *

A single-value parameter will be exposed as an array with a single element. + * @see javax.servlet.http.HttpServletRequest#getParameterMap() + */ + Map getParameterMap(); + + /** + * Return the primary Locale for this request. + * @see javax.servlet.http.HttpServletRequest#getLocale() + */ + Locale getLocale(); + + /** + * Return the context path for this request + * (usually the base path that the current web application is mapped to). + * @see javax.servlet.http.HttpServletRequest#getContextPath() + */ + String getContextPath(); + + /** + * Return the remote user for this request, if any. + * @see javax.servlet.http.HttpServletRequest#getRemoteUser() + */ + @Nullable + String getRemoteUser(); + + /** + * Return the user principal for this request, if any. + * @see javax.servlet.http.HttpServletRequest#getUserPrincipal() + */ + @Nullable + Principal getUserPrincipal(); + + /** + * Determine whether the user is in the given role for this request. + * @see javax.servlet.http.HttpServletRequest#isUserInRole(String) + */ + boolean isUserInRole(String role); + + /** + * Return whether this request has been sent over a secure transport + * mechanism (such as SSL). + * @see javax.servlet.http.HttpServletRequest#isSecure() + */ + boolean isSecure(); + + /** + * Check whether the requested resource has been modified given the + * supplied last-modified timestamp (as determined by the application). + *

This will also transparently set the "Last-Modified" response header + * and HTTP status when applicable. + *

Typical usage: + *

+	 * public String myHandleMethod(WebRequest request, Model model) {
+	 *   long lastModified = // application-specific calculation
+	 *   if (request.checkNotModified(lastModified)) {
+	 *     // shortcut exit - no further processing necessary
+	 *     return null;
+	 *   }
+	 *   // further request processing, actually building content
+	 *   model.addAttribute(...);
+	 *   return "myViewName";
+	 * }
+ *

This method works with conditional GET/HEAD requests, but + * also with conditional POST/PUT/DELETE requests. + *

Note: you can use either + * this {@code #checkNotModified(long)} method; or + * {@link #checkNotModified(String)}. If you want enforce both + * a strong entity tag and a Last-Modified value, + * as recommended by the HTTP specification, + * then you should use {@link #checkNotModified(String, long)}. + *

If the "If-Modified-Since" header is set but cannot be parsed + * to a date value, this method will ignore the header and proceed + * with setting the last-modified timestamp on the response. + * @param lastModifiedTimestamp the last-modified timestamp in + * milliseconds that the application determined for the underlying + * resource + * @return whether the request qualifies as not modified, + * allowing to abort request processing and relying on the response + * telling the client that the content has not been modified + */ + boolean checkNotModified(long lastModifiedTimestamp); + + /** + * Check whether the requested resource has been modified given the + * supplied {@code ETag} (entity tag), as determined by the application. + *

This will also transparently set the "ETag" response header + * and HTTP status when applicable. + *

Typical usage: + *

+	 * public String myHandleMethod(WebRequest request, Model model) {
+	 *   String eTag = // application-specific calculation
+	 *   if (request.checkNotModified(eTag)) {
+	 *     // shortcut exit - no further processing necessary
+	 *     return null;
+	 *   }
+	 *   // further request processing, actually building content
+	 *   model.addAttribute(...);
+	 *   return "myViewName";
+	 * }
+ *

Note: you can use either + * this {@code #checkNotModified(String)} method; or + * {@link #checkNotModified(long)}. If you want enforce both + * a strong entity tag and a Last-Modified value, + * as recommended by the HTTP specification, + * then you should use {@link #checkNotModified(String, long)}. + * @param etag the entity tag that the application determined + * for the underlying resource. This parameter will be padded + * with quotes (") if necessary. + * @return true if the request does not require further processing. + */ + boolean checkNotModified(String etag); + + /** + * Check whether the requested resource has been modified given the + * supplied {@code ETag} (entity tag) and last-modified timestamp, + * as determined by the application. + *

This will also transparently set the "ETag" and "Last-Modified" + * response headers, and HTTP status when applicable. + *

Typical usage: + *

+	 * public String myHandleMethod(WebRequest request, Model model) {
+	 *   String eTag = // application-specific calculation
+	 *   long lastModified = // application-specific calculation
+	 *   if (request.checkNotModified(eTag, lastModified)) {
+	 *     // shortcut exit - no further processing necessary
+	 *     return null;
+	 *   }
+	 *   // further request processing, actually building content
+	 *   model.addAttribute(...);
+	 *   return "myViewName";
+	 * }
+ *

This method works with conditional GET/HEAD requests, but + * also with conditional POST/PUT/DELETE requests. + *

Note: The HTTP specification recommends + * setting both ETag and Last-Modified values, but you can also + * use {@code #checkNotModified(String)} or + * {@link #checkNotModified(long)}. + * @param etag the entity tag that the application determined + * for the underlying resource. This parameter will be padded + * with quotes (") if necessary. + * @param lastModifiedTimestamp the last-modified timestamp in + * milliseconds that the application determined for the underlying + * resource + * @return true if the request does not require further processing. + * @since 4.2 + */ + boolean checkNotModified(@Nullable String etag, long lastModifiedTimestamp); + + /** + * Get a short description of this request, + * typically containing request URI and session id. + * @param includeClientInfo whether to include client-specific + * information such as session id and user name + * @return the requested description as String + */ + String getDescription(boolean includeClientInfo); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/WebRequestInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/WebRequestInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..46a75bade575ac258bbb3537ae5a933913cdca28 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/WebRequestInterceptor.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.springframework.lang.Nullable; +import org.springframework.ui.ModelMap; + +/** + * Interface for general web request interception. Allows for being applied + * to Servlet request by building on the {@link WebRequest} abstraction. + * + *

This interface assumes MVC-style request processing: A handler gets executed, + * exposes a set of model objects, then a view gets rendered based on that model. + * Alternatively, a handler may also process the request completely, with no + * view to be rendered. + * + *

In an async processing scenario, the handler may be executed in a separate + * thread while the main thread exits without rendering or invoking the + * {@code postHandle} and {@code afterCompletion} callbacks. When concurrent + * handler execution completes, the request is dispatched back in order to + * proceed with rendering the model and all methods of this contract are invoked + * again. For further options and comments see + * {@code org.springframework.web.context.request.async.AsyncWebRequestInterceptor} + * + *

This interface is deliberately minimalistic to keep the dependencies of + * generic request interceptors as minimal as feasible. + * + * @author Juergen Hoeller + * @since 2.0 + * @see ServletWebRequest + * @see org.springframework.web.servlet.DispatcherServlet + * @see org.springframework.web.servlet.handler.AbstractHandlerMapping#setInterceptors + * @see org.springframework.web.servlet.HandlerInterceptor + */ +public interface WebRequestInterceptor { + + /** + * Intercept the execution of a request handler before its invocation. + *

Allows for preparing context resources (such as a Hibernate Session) + * and expose them as request attributes or as thread-local objects. + * @param request the current web request + * @throws Exception in case of errors + */ + void preHandle(WebRequest request) throws Exception; + + /** + * Intercept the execution of a request handler after its successful + * invocation, right before view rendering (if any). + *

Allows for modifying context resources after successful handler + * execution (for example, flushing a Hibernate Session). + * @param request the current web request + * @param model the map of model objects that will be exposed to the view + * (may be {@code null}). Can be used to analyze the exposed model + * and/or to add further model attributes, if desired. + * @throws Exception in case of errors + */ + void postHandle(WebRequest request, @Nullable ModelMap model) throws Exception; + + /** + * Callback after completion of request processing, that is, after rendering + * the view. Will be called on any outcome of handler execution, thus allows + * for proper resource cleanup. + *

Note: Will only be called if this interceptor's {@code preHandle} + * method has successfully completed! + * @param request the current web request + * @param ex exception thrown on handler execution, if any + * @throws Exception in case of errors + */ + void afterCompletion(WebRequest request, @Nullable Exception ex) throws Exception; + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncRequestTimeoutException.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncRequestTimeoutException.java new file mode 100644 index 0000000000000000000000000000000000000000..084f8a516c2e21e24fe6b041db2854a8c3a9ffd6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncRequestTimeoutException.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +/** + * Exception to be thrown when an async request times out. + * Alternatively an applications can register a + * {@link DeferredResultProcessingInterceptor} or a + * {@link CallableProcessingInterceptor} to handle the timeout through + * the MVC Java config or the MVC XML namespace or directly through properties + * of the {@code RequestMappingHandlerAdapter}. + * + *

By default the exception will be handled as a 503 error. + * + * @author Rossen Stoyanchev + * @since 4.2.8 + */ +@SuppressWarnings("serial") +public class AsyncRequestTimeoutException extends RuntimeException { + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..f20c68d5c7b2f69031e83ebd37d1defac47b0bb7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.function.Consumer; + +import org.springframework.lang.Nullable; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Extends {@link NativeWebRequest} with methods for asynchronous request processing. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public interface AsyncWebRequest extends NativeWebRequest { + + /** + * Set the time required for concurrent handling to complete. + * This property should not be set when concurrent handling is in progress, + * i.e. when {@link #isAsyncStarted()} is {@code true}. + * @param timeout amount of time in milliseconds; {@code null} means no + * timeout, i.e. rely on the default timeout of the container. + */ + void setTimeout(@Nullable Long timeout); + + /** + * Add a handler to invoke when concurrent handling has timed out. + */ + void addTimeoutHandler(Runnable runnable); + + /** + * Add a handler to invoke when an error occurred while concurrent + * handling of a request. + * @since 5.0 + */ + void addErrorHandler(Consumer exceptionHandler); + + /** + * Add a handler to invoke when request processing completes. + */ + void addCompletionHandler(Runnable runnable); + + /** + * Mark the start of asynchronous request processing so that when the main + * processing thread exits, the response remains open for further processing + * in another thread. + * @throws IllegalStateException if async processing has completed or is not supported + */ + void startAsync(); + + /** + * Whether the request is in async mode following a call to {@link #startAsync()}. + * Returns "false" if asynchronous processing never started, has completed, + * or the request was dispatched for further processing. + */ + boolean isAsyncStarted(); + + /** + * Dispatch the request to the container in order to resume processing after + * concurrent execution in an application thread. + */ + void dispatch(); + + /** + * Whether asynchronous processing has completed. + */ + boolean isAsyncComplete(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/CallableInterceptorChain.java b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableInterceptorChain.java new file mode 100644 index 0000000000000000000000000000000000000000..1d3dab6e6ff796b71903b815f9b7cc02a1f1f917 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableInterceptorChain.java @@ -0,0 +1,153 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Assists with the invocation of {@link CallableProcessingInterceptor}'s. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 3.2 + */ +class CallableInterceptorChain { + + private static final Log logger = LogFactory.getLog(CallableInterceptorChain.class); + + private final List interceptors; + + private int preProcessIndex = -1; + + private volatile Future taskFuture; + + + public CallableInterceptorChain(List interceptors) { + this.interceptors = interceptors; + } + + + public void setTaskFuture(Future taskFuture) { + this.taskFuture = taskFuture; + } + + + public void applyBeforeConcurrentHandling(NativeWebRequest request, Callable task) throws Exception { + for (CallableProcessingInterceptor interceptor : this.interceptors) { + interceptor.beforeConcurrentHandling(request, task); + } + } + + public void applyPreProcess(NativeWebRequest request, Callable task) throws Exception { + for (CallableProcessingInterceptor interceptor : this.interceptors) { + interceptor.preProcess(request, task); + this.preProcessIndex++; + } + } + + public Object applyPostProcess(NativeWebRequest request, Callable task, Object concurrentResult) { + Throwable exceptionResult = null; + for (int i = this.preProcessIndex; i >= 0; i--) { + try { + this.interceptors.get(i).postProcess(request, task, concurrentResult); + } + catch (Throwable ex) { + // Save the first exception but invoke all interceptors + if (exceptionResult != null) { + if (logger.isTraceEnabled()) { + logger.trace("Ignoring failure in postProcess method", ex); + } + } + else { + exceptionResult = ex; + } + } + } + return (exceptionResult != null) ? exceptionResult : concurrentResult; + } + + public Object triggerAfterTimeout(NativeWebRequest request, Callable task) { + cancelTask(); + for (CallableProcessingInterceptor interceptor : this.interceptors) { + try { + Object result = interceptor.handleTimeout(request, task); + if (result == CallableProcessingInterceptor.RESPONSE_HANDLED) { + break; + } + else if (result != CallableProcessingInterceptor.RESULT_NONE) { + return result; + } + } + catch (Throwable ex) { + return ex; + } + } + return CallableProcessingInterceptor.RESULT_NONE; + } + + private void cancelTask() { + Future future = this.taskFuture; + if (future != null) { + try { + future.cancel(true); + } + catch (Throwable ex) { + // Ignore + } + } + } + + public Object triggerAfterError(NativeWebRequest request, Callable task, Throwable throwable) { + cancelTask(); + for (CallableProcessingInterceptor interceptor : this.interceptors) { + try { + Object result = interceptor.handleError(request, task, throwable); + if (result == CallableProcessingInterceptor.RESPONSE_HANDLED) { + break; + } + else if (result != CallableProcessingInterceptor.RESULT_NONE) { + return result; + } + } + catch (Throwable ex) { + return ex; + } + } + return CallableProcessingInterceptor.RESULT_NONE; + } + + public void triggerAfterCompletion(NativeWebRequest request, Callable task) { + for (int i = this.interceptors.size()-1; i >= 0; i--) { + try { + this.interceptors.get(i).afterCompletion(request, task); + } + catch (Throwable ex) { + if (logger.isTraceEnabled()) { + logger.trace("Ignoring failure in afterCompletion method", ex); + } + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/CallableProcessingInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableProcessingInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..22fbc9835fa6725f17a5976f9d52a931b4ad211b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableProcessingInterceptor.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.concurrent.Callable; + +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Intercepts concurrent request handling, where the concurrent result is + * obtained by executing a {@link Callable} on behalf of the application with + * an {@link AsyncTaskExecutor}. + * + *

A {@code CallableProcessingInterceptor} is invoked before and after the + * invocation of the {@code Callable} task in the asynchronous thread, as well + * as on timeout/error from a container thread, or after completing for any reason + * including a timeout or network error. + * + *

As a general rule exceptions raised by interceptor methods will cause + * async processing to resume by dispatching back to the container and using + * the Exception instance as the concurrent result. Such exceptions will then + * be processed through the {@code HandlerExceptionResolver} mechanism. + * + *

The {@link #handleTimeout(NativeWebRequest, Callable) handleTimeout} method + * can select a value to be used to resume processing. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 3.2 + */ +public interface CallableProcessingInterceptor { + + /** + * Constant indicating that no result has been determined by this + * interceptor, giving subsequent interceptors a chance. + * @see #handleTimeout + * @see #handleError + */ + Object RESULT_NONE = new Object(); + + /** + * Constant indicating that the response has been handled by this interceptor + * without a result and that no further interceptors are to be invoked. + * @see #handleTimeout + * @see #handleError + */ + Object RESPONSE_HANDLED = new Object(); + + + /** + * Invoked before the start of concurrent handling in the original + * thread in which the {@code Callable} is submitted for concurrent handling. + *

This is useful for capturing the state of the current thread just prior to + * invoking the {@link Callable}. Once the state is captured, it can then be + * transferred to the new {@link Thread} in + * {@link #preProcess(NativeWebRequest, Callable)}. Capturing the state of + * Spring Security's SecurityContextHolder and migrating it to the new Thread + * is a concrete example of where this is useful. + *

The default implementation is empty. + * @param request the current request + * @param task the task for the current async request + * @throws Exception in case of errors + */ + default void beforeConcurrentHandling(NativeWebRequest request, Callable task) throws Exception { + } + + /** + * Invoked after the start of concurrent handling in the async + * thread in which the {@code Callable} is executed and before the + * actual invocation of the {@code Callable}. + *

The default implementation is empty. + * @param request the current request + * @param task the task for the current async request + * @throws Exception in case of errors + */ + default void preProcess(NativeWebRequest request, Callable task) throws Exception { + } + + /** + * Invoked after the {@code Callable} has produced a result in the + * async thread in which the {@code Callable} is executed. This method may + * be invoked later than {@code afterTimeout} or {@code afterCompletion} + * depending on when the {@code Callable} finishes processing. + *

The default implementation is empty. + * @param request the current request + * @param task the task for the current async request + * @param concurrentResult the result of concurrent processing, which could + * be a {@link Throwable} if the {@code Callable} raised an exception + * @throws Exception in case of errors + */ + default void postProcess(NativeWebRequest request, Callable task, + Object concurrentResult) throws Exception { + } + + /** + * Invoked from a container thread when the async request times out before + * the {@code Callable} task completes. Implementations may return a value, + * including an {@link Exception}, to use instead of the value the + * {@link Callable} did not return in time. + *

The default implementation always returns {@link #RESULT_NONE}. + * @param request the current request + * @param task the task for the current async request + * @return a concurrent result value; if the value is anything other than + * {@link #RESULT_NONE} or {@link #RESPONSE_HANDLED}, concurrent processing + * is resumed and subsequent interceptors are not invoked + * @throws Exception in case of errors + */ + default Object handleTimeout(NativeWebRequest request, Callable task) throws Exception { + return RESULT_NONE; + } + + /** + * Invoked from a container thread when an error occurred while processing + * the async request before the {@code Callable} task completes. + * Implementations may return a value, including an {@link Exception}, to + * use instead of the value the {@link Callable} did not return in time. + *

The default implementation always returns {@link #RESULT_NONE}. + * @param request the current request + * @param task the task for the current async request + * @param t the error that occurred while request processing + * @return a concurrent result value; if the value is anything other than + * {@link #RESULT_NONE} or {@link #RESPONSE_HANDLED}, concurrent processing + * is resumed and subsequent interceptors are not invoked + * @throws Exception in case of errors + * @since 5.0 + */ + default Object handleError(NativeWebRequest request, Callable task, Throwable t) throws Exception { + return RESULT_NONE; + } + + /** + * Invoked from a container thread when async processing completes for any + * reason including timeout or network error. + *

The default implementation is empty. + * @param request the current request + * @param task the task for the current async request + * @throws Exception in case of errors + */ + default void afterCompletion(NativeWebRequest request, Callable task) throws Exception { + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/CallableProcessingInterceptorAdapter.java b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableProcessingInterceptorAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..4c04274e7406093be887c6ff4a6e4b73759e6360 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableProcessingInterceptorAdapter.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.concurrent.Callable; + +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Abstract adapter class for the {@link CallableProcessingInterceptor} interface, + * for simplified implementation of individual methods. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 3.2 + * @deprecated as of 5.0 where CallableProcessingInterceptor has default methods + */ +@Deprecated +public abstract class CallableProcessingInterceptorAdapter implements CallableProcessingInterceptor { + + @Override + public void beforeConcurrentHandling(NativeWebRequest request, Callable task) throws Exception { + } + + @Override + public void preProcess(NativeWebRequest request, Callable task) throws Exception { + } + + @Override + public void postProcess(NativeWebRequest request, Callable task, Object concurrentResult) throws Exception { + } + + @Override + public Object handleTimeout(NativeWebRequest request, Callable task) throws Exception { + return RESULT_NONE; + } + + @Override + public Object handleError(NativeWebRequest request, Callable task, Throwable t) throws Exception { + return RESULT_NONE; + } + + @Override + public void afterCompletion(NativeWebRequest request, Callable task) throws Exception { + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResult.java b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResult.java new file mode 100644 index 0000000000000000000000000000000000000000..3ecdf9b6ff014f95e933ee54480c3246814d7ffb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResult.java @@ -0,0 +1,347 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.PriorityQueue; +import java.util.concurrent.Callable; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * {@code DeferredResult} provides an alternative to using a {@link Callable} for + * asynchronous request processing. While a {@code Callable} is executed concurrently + * on behalf of the application, with a {@code DeferredResult} the application can + * produce the result from a thread of its choice. + * + *

Subclasses can extend this class to easily associate additional data or behavior + * with the {@link DeferredResult}. For example, one might want to associate the user + * used to create the {@link DeferredResult} by extending the class and adding an + * additional property for the user. In this way, the user could easily be accessed + * later without the need to use a data structure to do the mapping. + * + *

An example of associating additional behavior to this class might be realized + * by extending the class to implement an additional interface. For example, one + * might want to implement {@link Comparable} so that when the {@link DeferredResult} + * is added to a {@link PriorityQueue} it is handled in the correct order. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Rob Winch + * @since 3.2 + * @param the result type + */ +public class DeferredResult { + + private static final Object RESULT_NONE = new Object(); + + private static final Log logger = LogFactory.getLog(DeferredResult.class); + + + @Nullable + private final Long timeoutValue; + + private final Supplier timeoutResult; + + private Runnable timeoutCallback; + + private Consumer errorCallback; + + private Runnable completionCallback; + + private DeferredResultHandler resultHandler; + + private volatile Object result = RESULT_NONE; + + private volatile boolean expired = false; + + + /** + * Create a DeferredResult. + */ + public DeferredResult() { + this(null, () -> RESULT_NONE); + } + + /** + * Create a DeferredResult with a custom timeout value. + *

By default not set in which case the default configured in the MVC + * Java Config or the MVC namespace is used, or if that's not set, then the + * timeout depends on the default of the underlying server. + * @param timeoutValue timeout value in milliseconds + */ + public DeferredResult(Long timeoutValue) { + this(timeoutValue, () -> RESULT_NONE); + } + + /** + * Create a DeferredResult with a timeout value and a default result to use + * in case of timeout. + * @param timeoutValue timeout value in milliseconds (ignored if {@code null}) + * @param timeoutResult the result to use + */ + public DeferredResult(@Nullable Long timeoutValue, Object timeoutResult) { + this.timeoutValue = timeoutValue; + this.timeoutResult = () -> timeoutResult; + } + + /** + * Variant of {@link #DeferredResult(Long, Object)} that accepts a dynamic + * fallback value based on a {@link Supplier}. + * @param timeoutValue timeout value in milliseconds (ignored if {@code null}) + * @param timeoutResult the result supplier to use + * @since 5.1.1 + */ + public DeferredResult(@Nullable Long timeoutValue, Supplier timeoutResult) { + this.timeoutValue = timeoutValue; + this.timeoutResult = timeoutResult; + } + + + /** + * Return {@code true} if this DeferredResult is no longer usable either + * because it was previously set or because the underlying request expired. + *

The result may have been set with a call to {@link #setResult(Object)}, + * or {@link #setErrorResult(Object)}, or as a result of a timeout, if a + * timeout result was provided to the constructor. The request may also + * expire due to a timeout or network error. + */ + public final boolean isSetOrExpired() { + return (this.result != RESULT_NONE || this.expired); + } + + /** + * Return {@code true} if the DeferredResult has been set. + * @since 4.0 + */ + public boolean hasResult() { + return (this.result != RESULT_NONE); + } + + /** + * Return the result, or {@code null} if the result wasn't set. Since the result + * can also be {@code null}, it is recommended to use {@link #hasResult()} first + * to check if there is a result prior to calling this method. + * @since 4.0 + */ + @Nullable + public Object getResult() { + Object resultToCheck = this.result; + return (resultToCheck != RESULT_NONE ? resultToCheck : null); + } + + /** + * Return the configured timeout value in milliseconds. + */ + @Nullable + final Long getTimeoutValue() { + return this.timeoutValue; + } + + /** + * Register code to invoke when the async request times out. + *

This method is called from a container thread when an async request + * times out before the {@code DeferredResult} has been populated. + * It may invoke {@link DeferredResult#setResult setResult} or + * {@link DeferredResult#setErrorResult setErrorResult} to resume processing. + */ + public void onTimeout(Runnable callback) { + this.timeoutCallback = callback; + } + + /** + * Register code to invoke when an error occurred during the async request. + *

This method is called from a container thread when an error occurs + * while processing an async request before the {@code DeferredResult} has + * been populated. It may invoke {@link DeferredResult#setResult setResult} + * or {@link DeferredResult#setErrorResult setErrorResult} to resume + * processing. + * @since 5.0 + */ + public void onError(Consumer callback) { + this.errorCallback = callback; + } + + /** + * Register code to invoke when the async request completes. + *

This method is called from a container thread when an async request + * completed for any reason including timeout and network error. This is useful + * for detecting that a {@code DeferredResult} instance is no longer usable. + */ + public void onCompletion(Runnable callback) { + this.completionCallback = callback; + } + + /** + * Provide a handler to use to handle the result value. + * @param resultHandler the handler + * @see DeferredResultProcessingInterceptor + */ + public final void setResultHandler(DeferredResultHandler resultHandler) { + Assert.notNull(resultHandler, "DeferredResultHandler is required"); + // Immediate expiration check outside of the result lock + if (this.expired) { + return; + } + Object resultToHandle; + synchronized (this) { + // Got the lock in the meantime: double-check expiration status + if (this.expired) { + return; + } + resultToHandle = this.result; + if (resultToHandle == RESULT_NONE) { + // No result yet: store handler for processing once it comes in + this.resultHandler = resultHandler; + return; + } + } + // If we get here, we need to process an existing result object immediately. + // The decision is made within the result lock; just the handle call outside + // of it, avoiding any deadlock potential with Servlet container locks. + try { + resultHandler.handleResult(resultToHandle); + } + catch (Throwable ex) { + logger.debug("Failed to process async result", ex); + } + } + + /** + * Set the value for the DeferredResult and handle it. + * @param result the value to set + * @return {@code true} if the result was set and passed on for handling; + * {@code false} if the result was already set or the async request expired + * @see #isSetOrExpired() + */ + public boolean setResult(T result) { + return setResultInternal(result); + } + + private boolean setResultInternal(Object result) { + // Immediate expiration check outside of the result lock + if (isSetOrExpired()) { + return false; + } + DeferredResultHandler resultHandlerToUse; + synchronized (this) { + // Got the lock in the meantime: double-check expiration status + if (isSetOrExpired()) { + return false; + } + // At this point, we got a new result to process + this.result = result; + resultHandlerToUse = this.resultHandler; + if (resultHandlerToUse == null) { + // No result handler set yet -> let the setResultHandler implementation + // pick up the result object and invoke the result handler for it. + return true; + } + // Result handler available -> let's clear the stored reference since + // we don't need it anymore. + this.resultHandler = null; + } + // If we get here, we need to process an existing result object immediately. + // The decision is made within the result lock; just the handle call outside + // of it, avoiding any deadlock potential with Servlet container locks. + resultHandlerToUse.handleResult(result); + return true; + } + + /** + * Set an error value for the {@link DeferredResult} and handle it. + * The value may be an {@link Exception} or {@link Throwable} in which case + * it will be processed as if a handler raised the exception. + * @param result the error result value + * @return {@code true} if the result was set to the error value and passed on + * for handling; {@code false} if the result was already set or the async + * request expired + * @see #isSetOrExpired() + */ + public boolean setErrorResult(Object result) { + return setResultInternal(result); + } + + + final DeferredResultProcessingInterceptor getInterceptor() { + return new DeferredResultProcessingInterceptor() { + @Override + public boolean handleTimeout(NativeWebRequest request, DeferredResult deferredResult) { + boolean continueProcessing = true; + try { + if (timeoutCallback != null) { + timeoutCallback.run(); + } + } + finally { + Object value = timeoutResult.get(); + if (value != RESULT_NONE) { + continueProcessing = false; + try { + setResultInternal(value); + } + catch (Throwable ex) { + logger.debug("Failed to handle timeout result", ex); + } + } + } + return continueProcessing; + } + @Override + public boolean handleError(NativeWebRequest request, DeferredResult deferredResult, Throwable t) { + try { + if (errorCallback != null) { + errorCallback.accept(t); + } + } + finally { + try { + setResultInternal(t); + } + catch (Throwable ex) { + logger.debug("Failed to handle error result", ex); + } + } + return false; + } + @Override + public void afterCompletion(NativeWebRequest request, DeferredResult deferredResult) { + expired = true; + if (completionCallback != null) { + completionCallback.run(); + } + } + }; + } + + + /** + * Handles a DeferredResult value when set. + */ + @FunctionalInterface + public interface DeferredResultHandler { + + void handleResult(Object result); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultInterceptorChain.java b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultInterceptorChain.java new file mode 100644 index 0000000000000000000000000000000000000000..3b8d35660e12b5f6b2a806df1d936f36e9de08d9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultInterceptorChain.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Assists with the invocation of {@link DeferredResultProcessingInterceptor}'s. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +class DeferredResultInterceptorChain { + + private static final Log logger = LogFactory.getLog(DeferredResultInterceptorChain.class); + + private final List interceptors; + + private int preProcessingIndex = -1; + + + public DeferredResultInterceptorChain(List interceptors) { + this.interceptors = interceptors; + } + + public void applyBeforeConcurrentHandling(NativeWebRequest request, DeferredResult deferredResult) + throws Exception { + + for (DeferredResultProcessingInterceptor interceptor : this.interceptors) { + interceptor.beforeConcurrentHandling(request, deferredResult); + } + } + + public void applyPreProcess(NativeWebRequest request, DeferredResult deferredResult) throws Exception { + for (DeferredResultProcessingInterceptor interceptor : this.interceptors) { + interceptor.preProcess(request, deferredResult); + this.preProcessingIndex++; + } + } + + public Object applyPostProcess(NativeWebRequest request, DeferredResult deferredResult, + Object concurrentResult) { + + try { + for (int i = this.preProcessingIndex; i >= 0; i--) { + this.interceptors.get(i).postProcess(request, deferredResult, concurrentResult); + } + } + catch (Throwable ex) { + return ex; + } + return concurrentResult; + } + + public void triggerAfterTimeout(NativeWebRequest request, DeferredResult deferredResult) throws Exception { + for (DeferredResultProcessingInterceptor interceptor : this.interceptors) { + if (deferredResult.isSetOrExpired()) { + return; + } + if (!interceptor.handleTimeout(request, deferredResult)){ + break; + } + } + } + + /** + * Determine if further error handling should be bypassed. + * @return {@code true} to continue error handling, or false to bypass any further + * error handling + */ + public boolean triggerAfterError(NativeWebRequest request, DeferredResult deferredResult, Throwable ex) + throws Exception { + + for (DeferredResultProcessingInterceptor interceptor : this.interceptors) { + if (deferredResult.isSetOrExpired()) { + return false; + } + if (!interceptor.handleError(request, deferredResult, ex)){ + return false; + } + } + return true; + } + + public void triggerAfterCompletion(NativeWebRequest request, DeferredResult deferredResult) { + for (int i = this.preProcessingIndex; i >= 0; i--) { + try { + this.interceptors.get(i).afterCompletion(request, deferredResult); + } + catch (Throwable ex) { + logger.trace("Ignoring failure in afterCompletion method", ex); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultProcessingInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultProcessingInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..9891e491acaff6ba95d25e0ae1db15a04987d9dd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultProcessingInterceptor.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Intercepts concurrent request handling, where the concurrent result is + * obtained by waiting for a {@link DeferredResult} to be set from a thread + * chosen by the application (e.g. in response to some external event). + * + *

A {@code DeferredResultProcessingInterceptor} is invoked before the start + * of async processing, after the {@code DeferredResult} is set as well as on + * timeout/error, or after completing for any reason including a timeout or network + * error. + * + *

As a general rule exceptions raised by interceptor methods will cause + * async processing to resume by dispatching back to the container and using + * the Exception instance as the concurrent result. Such exceptions will then + * be processed through the {@code HandlerExceptionResolver} mechanism. + * + *

The {@link #handleTimeout(NativeWebRequest, DeferredResult) handleTimeout} + * method can set the {@code DeferredResult} in order to resume processing. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 3.2 + */ +public interface DeferredResultProcessingInterceptor { + + /** + * Invoked immediately before the start of concurrent handling, in the same + * thread that started it. This method may be used to capture state just prior + * to the start of concurrent processing with the given {@code DeferredResult}. + * @param request the current request + * @param deferredResult the DeferredResult for the current request + * @throws Exception in case of errors + */ + default void beforeConcurrentHandling(NativeWebRequest request, DeferredResult deferredResult) + throws Exception { + } + + /** + * Invoked immediately after the start of concurrent handling, in the same + * thread that started it. This method may be used to detect the start of + * concurrent processing with the given {@code DeferredResult}. + *

The {@code DeferredResult} may have already been set, for example at + * the time of its creation or by another thread. + * @param request the current request + * @param deferredResult the DeferredResult for the current request + * @throws Exception in case of errors + */ + default void preProcess(NativeWebRequest request, DeferredResult deferredResult) + throws Exception { + } + + /** + * Invoked after a {@code DeferredResult} has been set, via + * {@link DeferredResult#setResult(Object)} or + * {@link DeferredResult#setErrorResult(Object)}, and is also ready to + * handle the concurrent result. + *

This method may also be invoked after a timeout when the + * {@code DeferredResult} was created with a constructor accepting a default + * timeout result. + * @param request the current request + * @param deferredResult the DeferredResult for the current request + * @param concurrentResult the result to which the {@code DeferredResult} + * @throws Exception in case of errors + */ + default void postProcess(NativeWebRequest request, DeferredResult deferredResult, + Object concurrentResult) throws Exception { + } + + /** + * Invoked from a container thread when an async request times out before + * the {@code DeferredResult} has been set. Implementations may invoke + * {@link DeferredResult#setResult(Object) setResult} or + * {@link DeferredResult#setErrorResult(Object) setErrorResult} to resume processing. + * @param request the current request + * @param deferredResult the DeferredResult for the current request; if the + * {@code DeferredResult} is set, then concurrent processing is resumed and + * subsequent interceptors are not invoked + * @return {@code true} if processing should continue, or {@code false} if + * other interceptors should not be invoked + * @throws Exception in case of errors + */ + default boolean handleTimeout(NativeWebRequest request, DeferredResult deferredResult) + throws Exception { + + return true; + } + + /** + * Invoked from a container thread when an error occurred while processing an async request + * before the {@code DeferredResult} has been set. Implementations may invoke + * {@link DeferredResult#setResult(Object) setResult} or + * {@link DeferredResult#setErrorResult(Object) setErrorResult} to resume processing. + * @param request the current request + * @param deferredResult the DeferredResult for the current request; if the + * {@code DeferredResult} is set, then concurrent processing is resumed and + * subsequent interceptors are not invoked + * @param t the error that occurred while request processing + * @return {@code true} if error handling should continue, or {@code false} if + * other interceptors should by bypassed and not be invoked + * @throws Exception in case of errors + */ + default boolean handleError(NativeWebRequest request, DeferredResult deferredResult, + Throwable t) throws Exception { + + return true; + } + + /** + * Invoked from a container thread when an async request completed for any + * reason including timeout and network error. This method is useful for + * detecting that a {@code DeferredResult} instance is no longer usable. + * @param request the current request + * @param deferredResult the DeferredResult for the current request + * @throws Exception in case of errors + */ + default void afterCompletion(NativeWebRequest request, DeferredResult deferredResult) + throws Exception { + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultProcessingInterceptorAdapter.java b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultProcessingInterceptorAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..9afac024a51a2e9230c8affa04211c0f05b208de --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResultProcessingInterceptorAdapter.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Abstract adapter class for the {@link DeferredResultProcessingInterceptor} + * interface for simplified implementation of individual methods. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 3.2 + * @deprecated as of 5.0 where DeferredResultProcessingInterceptor has default methods + */ +@Deprecated +public abstract class DeferredResultProcessingInterceptorAdapter implements DeferredResultProcessingInterceptor { + + /** + * This implementation is empty. + */ + @Override + public void beforeConcurrentHandling(NativeWebRequest request, DeferredResult deferredResult) + throws Exception { + } + + /** + * This implementation is empty. + */ + @Override + public void preProcess(NativeWebRequest request, DeferredResult deferredResult) throws Exception { + } + + /** + * This implementation is empty. + */ + @Override + public void postProcess(NativeWebRequest request, DeferredResult deferredResult, + Object concurrentResult) throws Exception { + } + + /** + * This implementation returns {@code true} by default allowing other interceptors + * to be given a chance to handle the timeout. + */ + @Override + public boolean handleTimeout(NativeWebRequest request, DeferredResult deferredResult) throws Exception { + return true; + } + + /** + * This implementation returns {@code true} by default allowing other interceptors + * to be given a chance to handle the error. + */ + @Override + public boolean handleError(NativeWebRequest request, DeferredResult deferredResult, Throwable t) + throws Exception { + return true; + } + + /** + * This implementation is empty. + */ + @Override + public void afterCompletion(NativeWebRequest request, DeferredResult deferredResult) throws Exception { + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..9edd8a48d0fa59a0e2d05232dbc3ebf8d3aa1d7d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java @@ -0,0 +1,161 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.util.Assert; +import org.springframework.web.context.request.ServletWebRequest; + +/** + * A Servlet 3.0 implementation of {@link AsyncWebRequest}. + * + *

The servlet and all filters involved in an async request must have async + * support enabled using the Servlet API or by adding an + * <async-supported>true</async-supported> element to servlet and filter + * declarations in {@code web.xml}. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class StandardServletAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest, AsyncListener { + + private Long timeout; + + private AsyncContext asyncContext; + + private AtomicBoolean asyncCompleted = new AtomicBoolean(false); + + private final List timeoutHandlers = new ArrayList<>(); + + private final List> exceptionHandlers = new ArrayList<>(); + + private final List completionHandlers = new ArrayList<>(); + + + /** + * Create a new instance for the given request/response pair. + * @param request current HTTP request + * @param response current HTTP response + */ + public StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { + super(request, response); + } + + + /** + * In Servlet 3 async processing, the timeout period begins after the + * container processing thread has exited. + */ + @Override + public void setTimeout(Long timeout) { + Assert.state(!isAsyncStarted(), "Cannot change the timeout with concurrent handling in progress"); + this.timeout = timeout; + } + + @Override + public void addTimeoutHandler(Runnable timeoutHandler) { + this.timeoutHandlers.add(timeoutHandler); + } + + @Override + public void addErrorHandler(Consumer exceptionHandler) { + this.exceptionHandlers.add(exceptionHandler); + } + + @Override + public void addCompletionHandler(Runnable runnable) { + this.completionHandlers.add(runnable); + } + + @Override + public boolean isAsyncStarted() { + return (this.asyncContext != null && getRequest().isAsyncStarted()); + } + + /** + * Whether async request processing has completed. + *

It is important to avoid use of request and response objects after async + * processing has completed. Servlet containers often re-use them. + */ + @Override + public boolean isAsyncComplete() { + return this.asyncCompleted.get(); + } + + @Override + public void startAsync() { + Assert.state(getRequest().isAsyncSupported(), + "Async support must be enabled on a servlet and for all filters involved " + + "in async request processing. This is done in Java code using the Servlet API " + + "or by adding \"true\" to servlet and " + + "filter declarations in web.xml."); + Assert.state(!isAsyncComplete(), "Async processing has already completed"); + + if (isAsyncStarted()) { + return; + } + this.asyncContext = getRequest().startAsync(getRequest(), getResponse()); + this.asyncContext.addListener(this); + if (this.timeout != null) { + this.asyncContext.setTimeout(this.timeout); + } + } + + @Override + public void dispatch() { + Assert.notNull(this.asyncContext, "Cannot dispatch without an AsyncContext"); + this.asyncContext.dispatch(); + } + + + // --------------------------------------------------------------------- + // Implementation of AsyncListener methods + // --------------------------------------------------------------------- + + @Override + public void onStartAsync(AsyncEvent event) throws IOException { + } + + @Override + public void onError(AsyncEvent event) throws IOException { + this.exceptionHandlers.forEach(consumer -> consumer.accept(event.getThrowable())); + } + + @Override + public void onTimeout(AsyncEvent event) throws IOException { + this.timeoutHandlers.forEach(Runnable::run); + } + + @Override + public void onComplete(AsyncEvent event) throws IOException { + this.completionHandlers.forEach(Runnable::run); + this.asyncContext = null; + this.asyncCompleted.set(true); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/TimeoutCallableProcessingInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/async/TimeoutCallableProcessingInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..3aae3aea27eae4f3b3e83b0c812cd9791f5bb076 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/TimeoutCallableProcessingInterceptor.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.concurrent.Callable; + +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Sends a 503 (SERVICE_UNAVAILABLE) in case of a timeout if the response is not + * already committed. As of 4.2.8 this is done indirectly by setting the result + * to an {@link AsyncRequestTimeoutException} which is then handled by + * Spring MVC's default exception handling as a 503 error. + * + *

Registered at the end, after all other interceptors and + * therefore invoked only if no other interceptor handles the timeout. + * + *

Note that according to RFC 7231, a 503 without a 'Retry-After' header is + * interpreted as a 500 error and the client should not retry. Applications + * can install their own interceptor to handle a timeout and add a 'Retry-After' + * header if necessary. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class TimeoutCallableProcessingInterceptor implements CallableProcessingInterceptor { + + @Override + public Object handleTimeout(NativeWebRequest request, Callable task) throws Exception { + return new AsyncRequestTimeoutException(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/TimeoutDeferredResultProcessingInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/async/TimeoutDeferredResultProcessingInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..cf9dfa2e91b596694eeaaa89093d1842ffb3a28e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/TimeoutDeferredResultProcessingInterceptor.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Sends a 503 (SERVICE_UNAVAILABLE) in case of a timeout if the response is not + * already committed. As of 4.2.8 this is done indirectly by returning + * {@link AsyncRequestTimeoutException} as the result of processing which is + * then handled by Spring MVC's default exception handling as a 503 error. + * + *

Registered at the end, after all other interceptors and + * therefore invoked only if no other interceptor handles the timeout. + * + *

Note that according to RFC 7231, a 503 without a 'Retry-After' header is + * interpreted as a 500 error and the client should not retry. Applications + * can install their own interceptor to handle a timeout and add a 'Retry-After' + * header if necessary. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class TimeoutDeferredResultProcessingInterceptor implements DeferredResultProcessingInterceptor { + + @Override + public boolean handleTimeout(NativeWebRequest request, DeferredResult result) throws Exception { + result.setErrorResult(new AsyncRequestTimeoutException()); + return false; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java new file mode 100644 index 0000000000000000000000000000000000000000..223effdeb5f19014433876eb3ec3518dbf1dfffe --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java @@ -0,0 +1,477 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; + +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler; + +/** + * The central class for managing asynchronous request processing, mainly intended + * as an SPI and not typically used directly by application classes. + * + *

An async scenario starts with request processing as usual in a thread (T1). + * Concurrent request handling can be initiated by calling + * {@link #startCallableProcessing(Callable, Object...) startCallableProcessing} or + * {@link #startDeferredResultProcessing(DeferredResult, Object...) startDeferredResultProcessing}, + * both of which produce a result in a separate thread (T2). The result is saved + * and the request dispatched to the container, to resume processing with the saved + * result in a third thread (T3). Within the dispatched thread (T3), the saved + * result can be accessed via {@link #getConcurrentResult()} or its presence + * detected via {@link #hasConcurrentResult()}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.2 + * @see org.springframework.web.context.request.AsyncWebRequestInterceptor + * @see org.springframework.web.servlet.AsyncHandlerInterceptor + * @see org.springframework.web.filter.OncePerRequestFilter#shouldNotFilterAsyncDispatch + * @see org.springframework.web.filter.OncePerRequestFilter#isAsyncDispatch + */ +public final class WebAsyncManager { + + private static final Object RESULT_NONE = new Object(); + + private static final AsyncTaskExecutor DEFAULT_TASK_EXECUTOR = + new SimpleAsyncTaskExecutor(WebAsyncManager.class.getSimpleName()); + + private static final Log logger = LogFactory.getLog(WebAsyncManager.class); + + private static final CallableProcessingInterceptor timeoutCallableInterceptor = + new TimeoutCallableProcessingInterceptor(); + + private static final DeferredResultProcessingInterceptor timeoutDeferredResultInterceptor = + new TimeoutDeferredResultProcessingInterceptor(); + + private static Boolean taskExecutorWarning = true; + + + private AsyncWebRequest asyncWebRequest; + + private AsyncTaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR; + + private volatile Object concurrentResult = RESULT_NONE; + + private volatile Object[] concurrentResultContext; + + private final Map callableInterceptors = new LinkedHashMap<>(); + + private final Map deferredResultInterceptors = new LinkedHashMap<>(); + + + /** + * Package-private constructor. + * @see WebAsyncUtils#getAsyncManager(javax.servlet.ServletRequest) + * @see WebAsyncUtils#getAsyncManager(org.springframework.web.context.request.WebRequest) + */ + WebAsyncManager() { + } + + + /** + * Configure the {@link AsyncWebRequest} to use. This property may be set + * more than once during a single request to accurately reflect the current + * state of the request (e.g. following a forward, request/response + * wrapping, etc). However, it should not be set while concurrent handling + * is in progress, i.e. while {@link #isConcurrentHandlingStarted()} is + * {@code true}. + * @param asyncWebRequest the web request to use + */ + public void setAsyncWebRequest(AsyncWebRequest asyncWebRequest) { + Assert.notNull(asyncWebRequest, "AsyncWebRequest must not be null"); + this.asyncWebRequest = asyncWebRequest; + this.asyncWebRequest.addCompletionHandler(() -> asyncWebRequest.removeAttribute( + WebAsyncUtils.WEB_ASYNC_MANAGER_ATTRIBUTE, RequestAttributes.SCOPE_REQUEST)); + } + + /** + * Configure an AsyncTaskExecutor for use with concurrent processing via + * {@link #startCallableProcessing(Callable, Object...)}. + *

By default a {@link SimpleAsyncTaskExecutor} instance is used. + */ + public void setTaskExecutor(AsyncTaskExecutor taskExecutor) { + this.taskExecutor = taskExecutor; + } + + /** + * Whether the selected handler for the current request chose to handle the + * request asynchronously. A return value of "true" indicates concurrent + * handling is under way and the response will remain open. A return value + * of "false" means concurrent handling was either not started or possibly + * that it has completed and the request was dispatched for further + * processing of the concurrent result. + */ + public boolean isConcurrentHandlingStarted() { + return (this.asyncWebRequest != null && this.asyncWebRequest.isAsyncStarted()); + } + + /** + * Whether a result value exists as a result of concurrent handling. + */ + public boolean hasConcurrentResult() { + return (this.concurrentResult != RESULT_NONE); + } + + /** + * Provides access to the result from concurrent handling. + * @return an Object, possibly an {@code Exception} or {@code Throwable} if + * concurrent handling raised one. + * @see #clearConcurrentResult() + */ + public Object getConcurrentResult() { + return this.concurrentResult; + } + + /** + * Provides access to additional processing context saved at the start of + * concurrent handling. + * @see #clearConcurrentResult() + */ + public Object[] getConcurrentResultContext() { + return this.concurrentResultContext; + } + + /** + * Get the {@link CallableProcessingInterceptor} registered under the given key. + * @param key the key + * @return the interceptor registered under that key, or {@code null} if none + */ + @Nullable + public CallableProcessingInterceptor getCallableInterceptor(Object key) { + return this.callableInterceptors.get(key); + } + + /** + * Get the {@link DeferredResultProcessingInterceptor} registered under the given key. + * @param key the key + * @return the interceptor registered under that key, or {@code null} if none + */ + @Nullable + public DeferredResultProcessingInterceptor getDeferredResultInterceptor(Object key) { + return this.deferredResultInterceptors.get(key); + } + + /** + * Register a {@link CallableProcessingInterceptor} under the given key. + * @param key the key + * @param interceptor the interceptor to register + */ + public void registerCallableInterceptor(Object key, CallableProcessingInterceptor interceptor) { + Assert.notNull(key, "Key is required"); + Assert.notNull(interceptor, "CallableProcessingInterceptor is required"); + this.callableInterceptors.put(key, interceptor); + } + + /** + * Register a {@link CallableProcessingInterceptor} without a key. + * The key is derived from the class name and hashcode. + * @param interceptors one or more interceptors to register + */ + public void registerCallableInterceptors(CallableProcessingInterceptor... interceptors) { + Assert.notNull(interceptors, "A CallableProcessingInterceptor is required"); + for (CallableProcessingInterceptor interceptor : interceptors) { + String key = interceptor.getClass().getName() + ":" + interceptor.hashCode(); + this.callableInterceptors.put(key, interceptor); + } + } + + /** + * Register a {@link DeferredResultProcessingInterceptor} under the given key. + * @param key the key + * @param interceptor the interceptor to register + */ + public void registerDeferredResultInterceptor(Object key, DeferredResultProcessingInterceptor interceptor) { + Assert.notNull(key, "Key is required"); + Assert.notNull(interceptor, "DeferredResultProcessingInterceptor is required"); + this.deferredResultInterceptors.put(key, interceptor); + } + + /** + * Register one or more {@link DeferredResultProcessingInterceptor DeferredResultProcessingInterceptors} without a specified key. + * The default key is derived from the interceptor class name and hash code. + * @param interceptors one or more interceptors to register + */ + public void registerDeferredResultInterceptors(DeferredResultProcessingInterceptor... interceptors) { + Assert.notNull(interceptors, "A DeferredResultProcessingInterceptor is required"); + for (DeferredResultProcessingInterceptor interceptor : interceptors) { + String key = interceptor.getClass().getName() + ":" + interceptor.hashCode(); + this.deferredResultInterceptors.put(key, interceptor); + } + } + + /** + * Clear {@linkplain #getConcurrentResult() concurrentResult} and + * {@linkplain #getConcurrentResultContext() concurrentResultContext}. + */ + public void clearConcurrentResult() { + synchronized (WebAsyncManager.this) { + this.concurrentResult = RESULT_NONE; + this.concurrentResultContext = null; + } + } + + /** + * Start concurrent request processing and execute the given task with an + * {@link #setTaskExecutor(AsyncTaskExecutor) AsyncTaskExecutor}. The result + * from the task execution is saved and the request dispatched in order to + * resume processing of that result. If the task raises an Exception then + * the saved result will be the raised Exception. + * @param callable a unit of work to be executed asynchronously + * @param processingContext additional context to save that can be accessed + * via {@link #getConcurrentResultContext()} + * @throws Exception if concurrent processing failed to start + * @see #getConcurrentResult() + * @see #getConcurrentResultContext() + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + public void startCallableProcessing(Callable callable, Object... processingContext) throws Exception { + Assert.notNull(callable, "Callable must not be null"); + startCallableProcessing(new WebAsyncTask(callable), processingContext); + } + + /** + * Use the given {@link WebAsyncTask} to configure the task executor as well as + * the timeout value of the {@code AsyncWebRequest} before delegating to + * {@link #startCallableProcessing(Callable, Object...)}. + * @param webAsyncTask a WebAsyncTask containing the target {@code Callable} + * @param processingContext additional context to save that can be accessed + * via {@link #getConcurrentResultContext()} + * @throws Exception if concurrent processing failed to start + */ + public void startCallableProcessing(final WebAsyncTask webAsyncTask, Object... processingContext) + throws Exception { + + Assert.notNull(webAsyncTask, "WebAsyncTask must not be null"); + Assert.state(this.asyncWebRequest != null, "AsyncWebRequest must not be null"); + + Long timeout = webAsyncTask.getTimeout(); + if (timeout != null) { + this.asyncWebRequest.setTimeout(timeout); + } + + AsyncTaskExecutor executor = webAsyncTask.getExecutor(); + if (executor != null) { + this.taskExecutor = executor; + } + else { + logExecutorWarning(); + } + + List interceptors = new ArrayList<>(); + interceptors.add(webAsyncTask.getInterceptor()); + interceptors.addAll(this.callableInterceptors.values()); + interceptors.add(timeoutCallableInterceptor); + + final Callable callable = webAsyncTask.getCallable(); + final CallableInterceptorChain interceptorChain = new CallableInterceptorChain(interceptors); + + this.asyncWebRequest.addTimeoutHandler(() -> { + if (logger.isDebugEnabled()) { + logger.debug("Async request timeout for " + formatRequestUri()); + } + Object result = interceptorChain.triggerAfterTimeout(this.asyncWebRequest, callable); + if (result != CallableProcessingInterceptor.RESULT_NONE) { + setConcurrentResultAndDispatch(result); + } + }); + + this.asyncWebRequest.addErrorHandler(ex -> { + if (logger.isDebugEnabled()) { + logger.debug("Async request error for " + formatRequestUri() + ": " + ex); + } + Object result = interceptorChain.triggerAfterError(this.asyncWebRequest, callable, ex); + result = (result != CallableProcessingInterceptor.RESULT_NONE ? result : ex); + setConcurrentResultAndDispatch(result); + }); + + this.asyncWebRequest.addCompletionHandler(() -> + interceptorChain.triggerAfterCompletion(this.asyncWebRequest, callable)); + + interceptorChain.applyBeforeConcurrentHandling(this.asyncWebRequest, callable); + startAsyncProcessing(processingContext); + try { + Future future = this.taskExecutor.submit(() -> { + Object result = null; + try { + interceptorChain.applyPreProcess(this.asyncWebRequest, callable); + result = callable.call(); + } + catch (Throwable ex) { + result = ex; + } + finally { + result = interceptorChain.applyPostProcess(this.asyncWebRequest, callable, result); + } + setConcurrentResultAndDispatch(result); + }); + interceptorChain.setTaskFuture(future); + } + catch (RejectedExecutionException ex) { + Object result = interceptorChain.applyPostProcess(this.asyncWebRequest, callable, ex); + setConcurrentResultAndDispatch(result); + throw ex; + } + } + + private void logExecutorWarning() { + if (taskExecutorWarning && logger.isWarnEnabled()) { + synchronized (DEFAULT_TASK_EXECUTOR) { + AsyncTaskExecutor executor = this.taskExecutor; + if (taskExecutorWarning && + (executor instanceof SimpleAsyncTaskExecutor || executor instanceof SyncTaskExecutor)) { + String executorTypeName = executor.getClass().getSimpleName(); + logger.warn("\n!!!\n" + + "An Executor is required to handle java.util.concurrent.Callable return values.\n" + + "Please, configure a TaskExecutor in the MVC config under \"async support\".\n" + + "The " + executorTypeName + " currently in use is not suitable under load.\n" + + "-------------------------------\n" + + "Request URI: '" + formatRequestUri() + "'\n" + + "!!!"); + taskExecutorWarning = false; + } + } + } + } + + private String formatRequestUri() { + HttpServletRequest request = this.asyncWebRequest.getNativeRequest(HttpServletRequest.class); + return request != null ? request.getRequestURI() : "servlet container"; + } + + private void setConcurrentResultAndDispatch(Object result) { + synchronized (WebAsyncManager.this) { + if (this.concurrentResult != RESULT_NONE) { + return; + } + this.concurrentResult = result; + } + + if (this.asyncWebRequest.isAsyncComplete()) { + if (logger.isDebugEnabled()) { + logger.debug("Async result set but request already complete: " + formatRequestUri()); + } + return; + } + + if (logger.isDebugEnabled()) { + boolean isError = result instanceof Throwable; + logger.debug("Async " + (isError ? "error" : "result set") + ", dispatch to " + formatRequestUri()); + } + this.asyncWebRequest.dispatch(); + } + + /** + * Start concurrent request processing and initialize the given + * {@link DeferredResult} with a {@link DeferredResultHandler} that saves + * the result and dispatches the request to resume processing of that + * result. The {@code AsyncWebRequest} is also updated with a completion + * handler that expires the {@code DeferredResult} and a timeout handler + * assuming the {@code DeferredResult} has a default timeout result. + * @param deferredResult the DeferredResult instance to initialize + * @param processingContext additional context to save that can be accessed + * via {@link #getConcurrentResultContext()} + * @throws Exception if concurrent processing failed to start + * @see #getConcurrentResult() + * @see #getConcurrentResultContext() + */ + public void startDeferredResultProcessing( + final DeferredResult deferredResult, Object... processingContext) throws Exception { + + Assert.notNull(deferredResult, "DeferredResult must not be null"); + Assert.state(this.asyncWebRequest != null, "AsyncWebRequest must not be null"); + + Long timeout = deferredResult.getTimeoutValue(); + if (timeout != null) { + this.asyncWebRequest.setTimeout(timeout); + } + + List interceptors = new ArrayList<>(); + interceptors.add(deferredResult.getInterceptor()); + interceptors.addAll(this.deferredResultInterceptors.values()); + interceptors.add(timeoutDeferredResultInterceptor); + + final DeferredResultInterceptorChain interceptorChain = new DeferredResultInterceptorChain(interceptors); + + this.asyncWebRequest.addTimeoutHandler(() -> { + try { + interceptorChain.triggerAfterTimeout(this.asyncWebRequest, deferredResult); + } + catch (Throwable ex) { + setConcurrentResultAndDispatch(ex); + } + }); + + this.asyncWebRequest.addErrorHandler(ex -> { + try { + if (!interceptorChain.triggerAfterError(this.asyncWebRequest, deferredResult, ex)) { + return; + } + deferredResult.setErrorResult(ex); + } + catch (Throwable interceptorEx) { + setConcurrentResultAndDispatch(interceptorEx); + } + }); + + this.asyncWebRequest.addCompletionHandler(() + -> interceptorChain.triggerAfterCompletion(this.asyncWebRequest, deferredResult)); + + interceptorChain.applyBeforeConcurrentHandling(this.asyncWebRequest, deferredResult); + startAsyncProcessing(processingContext); + + try { + interceptorChain.applyPreProcess(this.asyncWebRequest, deferredResult); + deferredResult.setResultHandler(result -> { + result = interceptorChain.applyPostProcess(this.asyncWebRequest, deferredResult, result); + setConcurrentResultAndDispatch(result); + }); + } + catch (Throwable ex) { + setConcurrentResultAndDispatch(ex); + } + } + + private void startAsyncProcessing(Object[] processingContext) { + synchronized (WebAsyncManager.this) { + this.concurrentResult = RESULT_NONE; + this.concurrentResultContext = processingContext; + } + this.asyncWebRequest.startAsync(); + + if (logger.isDebugEnabled()) { + logger.debug("Started async request"); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncTask.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncTask.java new file mode 100644 index 0000000000000000000000000000000000000000..5d3a91121c3e0b4d0ddb53b97b67fa1917539796 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncTask.java @@ -0,0 +1,198 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.concurrent.Callable; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Holder for a {@link Callable}, a timeout value, and a task executor. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.2 + * @param the value type + */ +public class WebAsyncTask implements BeanFactoryAware { + + private final Callable callable; + + private Long timeout; + + private AsyncTaskExecutor executor; + + private String executorName; + + private BeanFactory beanFactory; + + private Callable timeoutCallback; + + private Callable errorCallback; + + private Runnable completionCallback; + + + /** + * Create a {@code WebAsyncTask} wrapping the given {@link Callable}. + * @param callable the callable for concurrent handling + */ + public WebAsyncTask(Callable callable) { + Assert.notNull(callable, "Callable must not be null"); + this.callable = callable; + } + + /** + * Create a {@code WebAsyncTask} with a timeout value and a {@link Callable}. + * @param timeout a timeout value in milliseconds + * @param callable the callable for concurrent handling + */ + public WebAsyncTask(long timeout, Callable callable) { + this(callable); + this.timeout = timeout; + } + + /** + * Create a {@code WebAsyncTask} with a timeout value, an executor name, and a {@link Callable}. + * @param timeout timeout value in milliseconds; ignored if {@code null} + * @param executorName the name of an executor bean to use + * @param callable the callable for concurrent handling + */ + public WebAsyncTask(@Nullable Long timeout, String executorName, Callable callable) { + this(callable); + Assert.notNull(executorName, "Executor name must not be null"); + this.executorName = executorName; + this.timeout = timeout; + } + + /** + * Create a {@code WebAsyncTask} with a timeout value, an executor instance, and a Callable. + * @param timeout timeout value in milliseconds; ignored if {@code null} + * @param executor the executor to use + * @param callable the callable for concurrent handling + */ + public WebAsyncTask(@Nullable Long timeout, AsyncTaskExecutor executor, Callable callable) { + this(callable); + Assert.notNull(executor, "Executor must not be null"); + this.executor = executor; + this.timeout = timeout; + } + + + /** + * Return the {@link Callable} to use for concurrent handling (never {@code null}). + */ + public Callable getCallable() { + return this.callable; + } + + /** + * Return the timeout value in milliseconds, or {@code null} if no timeout is set. + */ + @Nullable + public Long getTimeout() { + return this.timeout; + } + + /** + * A {@link BeanFactory} to use for resolving an executor name. + *

This factory reference will automatically be set when + * {@code WebAsyncTask} is used within a Spring MVC controller. + */ + public void setBeanFactory(BeanFactory beanFactory) { + this.beanFactory = beanFactory; + } + + /** + * Return the AsyncTaskExecutor to use for concurrent handling, + * or {@code null} if none specified. + */ + @Nullable + public AsyncTaskExecutor getExecutor() { + if (this.executor != null) { + return this.executor; + } + else if (this.executorName != null) { + Assert.state(this.beanFactory != null, "BeanFactory is required to look up an executor bean by name"); + return this.beanFactory.getBean(this.executorName, AsyncTaskExecutor.class); + } + else { + return null; + } + } + + + /** + * Register code to invoke when the async request times out. + *

This method is called from a container thread when an async request times + * out before the {@code Callable} has completed. The callback is executed in + * the same thread and therefore should return without blocking. It may return + * an alternative value to use, including an {@link Exception} or return + * {@link CallableProcessingInterceptor#RESULT_NONE RESULT_NONE}. + */ + public void onTimeout(Callable callback) { + this.timeoutCallback = callback; + } + + /** + * Register code to invoke for an error during async request processing. + *

This method is called from a container thread when an error occurred + * while processing an async request before the {@code Callable} has + * completed. The callback is executed in the same thread and therefore + * should return without blocking. It may return an alternative value to + * use, including an {@link Exception} or return + * {@link CallableProcessingInterceptor#RESULT_NONE RESULT_NONE}. + * @since 5.0 + */ + public void onError(Callable callback) { + this.errorCallback = callback; + } + + /** + * Register code to invoke when the async request completes. + *

This method is called from a container thread when an async request + * completed for any reason, including timeout and network error. + */ + public void onCompletion(Runnable callback) { + this.completionCallback = callback; + } + + CallableProcessingInterceptor getInterceptor() { + return new CallableProcessingInterceptor() { + @Override + public Object handleTimeout(NativeWebRequest request, Callable task) throws Exception { + return (timeoutCallback != null ? timeoutCallback.call() : CallableProcessingInterceptor.RESULT_NONE); + } + @Override + public Object handleError(NativeWebRequest request, Callable task, Throwable t) throws Exception { + return (errorCallback != null ? errorCallback.call() : CallableProcessingInterceptor.RESULT_NONE); + } + @Override + public void afterCompletion(NativeWebRequest request, Callable task) throws Exception { + if (completionCallback != null) { + completionCallback.run(); + } + } + }; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncUtils.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..abd2336ab1d90dadcf77d27bc834ea3267896894 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncUtils.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.WebRequest; + +/** + * Utility methods related to processing asynchronous web requests. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.2 + */ +public abstract class WebAsyncUtils { + + /** + * The name attribute containing the {@link WebAsyncManager}. + */ + public static final String WEB_ASYNC_MANAGER_ATTRIBUTE = + WebAsyncManager.class.getName() + ".WEB_ASYNC_MANAGER"; + + + /** + * Obtain the {@link WebAsyncManager} for the current request, or if not + * found, create and associate it with the request. + */ + public static WebAsyncManager getAsyncManager(ServletRequest servletRequest) { + WebAsyncManager asyncManager = null; + Object asyncManagerAttr = servletRequest.getAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE); + if (asyncManagerAttr instanceof WebAsyncManager) { + asyncManager = (WebAsyncManager) asyncManagerAttr; + } + if (asyncManager == null) { + asyncManager = new WebAsyncManager(); + servletRequest.setAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE, asyncManager); + } + return asyncManager; + } + + /** + * Obtain the {@link WebAsyncManager} for the current request, or if not + * found, create and associate it with the request. + */ + public static WebAsyncManager getAsyncManager(WebRequest webRequest) { + int scope = RequestAttributes.SCOPE_REQUEST; + WebAsyncManager asyncManager = null; + Object asyncManagerAttr = webRequest.getAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE, scope); + if (asyncManagerAttr instanceof WebAsyncManager) { + asyncManager = (WebAsyncManager) asyncManagerAttr; + } + if (asyncManager == null) { + asyncManager = new WebAsyncManager(); + webRequest.setAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE, asyncManager, scope); + } + return asyncManager; + } + + /** + * Create an AsyncWebRequest instance. By default, an instance of + * {@link StandardServletAsyncWebRequest} gets created. + * @param request the current request + * @param response the current response + * @return an AsyncWebRequest instance (never {@code null}) + */ + public static AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { + return new StandardServletAsyncWebRequest(request, response); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/package-info.java b/spring-web/src/main/java/org/springframework/web/context/request/async/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..775d53e0d5c59b18249af246e939f1ce3a069fcd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/package-info.java @@ -0,0 +1,4 @@ +/** + * Support for asynchronous request processing. + */ +package org.springframework.web.context.request.async; diff --git a/spring-web/src/main/java/org/springframework/web/context/request/package-info.java b/spring-web/src/main/java/org/springframework/web/context/request/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..ed3aae5f0b76778f67c75f6892306acb65b64f8f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/package-info.java @@ -0,0 +1,10 @@ +/** + * Support for generic request context holding, in particular for + * scoping of application objects per HTTP request or HTTP session. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.context.request; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/context/support/AbstractRefreshableWebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/support/AbstractRefreshableWebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..6c47fc938fe8f5ecd80adb1a2e4f3cebe0c9db3f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/AbstractRefreshableWebApplicationContext.java @@ -0,0 +1,223 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.context.support.AbstractRefreshableConfigApplicationContext; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourcePatternResolver; +import org.springframework.lang.Nullable; +import org.springframework.ui.context.Theme; +import org.springframework.ui.context.ThemeSource; +import org.springframework.ui.context.support.UiApplicationContextUtils; +import org.springframework.util.Assert; +import org.springframework.web.context.ConfigurableWebApplicationContext; +import org.springframework.web.context.ConfigurableWebEnvironment; +import org.springframework.web.context.ServletConfigAware; +import org.springframework.web.context.ServletContextAware; + +/** + * {@link org.springframework.context.support.AbstractRefreshableApplicationContext} + * subclass which implements the + * {@link org.springframework.web.context.ConfigurableWebApplicationContext} + * interface for web environments. Provides a "configLocations" property, + * to be populated through the ConfigurableWebApplicationContext interface + * on web application startup. + * + *

This class is as easy to subclass as AbstractRefreshableApplicationContext: + * All you need to implements is the {@link #loadBeanDefinitions} method; + * see the superclass javadoc for details. Note that implementations are supposed + * to load bean definitions from the files specified by the locations returned + * by the {@link #getConfigLocations} method. + * + *

Interprets resource paths as servlet context resources, i.e. as paths beneath + * the web application root. Absolute paths, e.g. for files outside the web app root, + * can be accessed via "file:" URLs, as implemented by + * {@link org.springframework.core.io.DefaultResourceLoader}. + * + *

In addition to the special beans detected by + * {@link org.springframework.context.support.AbstractApplicationContext}, + * this class detects a bean of type {@link org.springframework.ui.context.ThemeSource} + * in the context, under the special bean name "themeSource". + * + *

This is the web context to be subclassed for a different bean definition format. + * Such a context implementation can be specified as "contextClass" context-param + * for {@link org.springframework.web.context.ContextLoader} or as "contextClass" + * init-param for {@link org.springframework.web.servlet.FrameworkServlet}, + * replacing the default {@link XmlWebApplicationContext}. It will then automatically + * receive the "contextConfigLocation" context-param or init-param, respectively. + * + *

Note that WebApplicationContext implementations are generally supposed + * to configure themselves based on the configuration received through the + * {@link ConfigurableWebApplicationContext} interface. In contrast, a standalone + * application context might allow for configuration in custom startup code + * (for example, {@link org.springframework.context.support.GenericApplicationContext}). + * + * @author Juergen Hoeller + * @since 1.1.3 + * @see #loadBeanDefinitions + * @see org.springframework.web.context.ConfigurableWebApplicationContext#setConfigLocations + * @see org.springframework.ui.context.ThemeSource + * @see XmlWebApplicationContext + */ +public abstract class AbstractRefreshableWebApplicationContext extends AbstractRefreshableConfigApplicationContext + implements ConfigurableWebApplicationContext, ThemeSource { + + /** Servlet context that this context runs in. */ + @Nullable + private ServletContext servletContext; + + /** Servlet config that this context runs in, if any. */ + @Nullable + private ServletConfig servletConfig; + + /** Namespace of this context, or {@code null} if root. */ + @Nullable + private String namespace; + + /** the ThemeSource for this ApplicationContext. */ + @Nullable + private ThemeSource themeSource; + + + public AbstractRefreshableWebApplicationContext() { + setDisplayName("Root WebApplicationContext"); + } + + + @Override + public void setServletContext(@Nullable ServletContext servletContext) { + this.servletContext = servletContext; + } + + @Override + @Nullable + public ServletContext getServletContext() { + return this.servletContext; + } + + @Override + public void setServletConfig(@Nullable ServletConfig servletConfig) { + this.servletConfig = servletConfig; + if (servletConfig != null && this.servletContext == null) { + setServletContext(servletConfig.getServletContext()); + } + } + + @Override + @Nullable + public ServletConfig getServletConfig() { + return this.servletConfig; + } + + @Override + public void setNamespace(@Nullable String namespace) { + this.namespace = namespace; + if (namespace != null) { + setDisplayName("WebApplicationContext for namespace '" + namespace + "'"); + } + } + + @Override + @Nullable + public String getNamespace() { + return this.namespace; + } + + @Override + public String[] getConfigLocations() { + return super.getConfigLocations(); + } + + @Override + public String getApplicationName() { + return (this.servletContext != null ? this.servletContext.getContextPath() : ""); + } + + /** + * Create and return a new {@link StandardServletEnvironment}. Subclasses may override + * in order to configure the environment or specialize the environment type returned. + */ + @Override + protected ConfigurableEnvironment createEnvironment() { + return new StandardServletEnvironment(); + } + + /** + * Register request/session scopes, a {@link ServletContextAwareProcessor}, etc. + */ + @Override + protected void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { + beanFactory.addBeanPostProcessor(new ServletContextAwareProcessor(this.servletContext, this.servletConfig)); + beanFactory.ignoreDependencyInterface(ServletContextAware.class); + beanFactory.ignoreDependencyInterface(ServletConfigAware.class); + + WebApplicationContextUtils.registerWebApplicationScopes(beanFactory, this.servletContext); + WebApplicationContextUtils.registerEnvironmentBeans(beanFactory, this.servletContext, this.servletConfig); + } + + /** + * This implementation supports file paths beneath the root of the ServletContext. + * @see ServletContextResource + */ + @Override + protected Resource getResourceByPath(String path) { + Assert.state(this.servletContext != null, "No ServletContext available"); + return new ServletContextResource(this.servletContext, path); + } + + /** + * This implementation supports pattern matching in unexpanded WARs too. + * @see ServletContextResourcePatternResolver + */ + @Override + protected ResourcePatternResolver getResourcePatternResolver() { + return new ServletContextResourcePatternResolver(this); + } + + /** + * Initialize the theme capability. + */ + @Override + protected void onRefresh() { + this.themeSource = UiApplicationContextUtils.initThemeSource(this); + } + + /** + * {@inheritDoc} + *

Replace {@code Servlet}-related property sources. + */ + @Override + protected void initPropertySources() { + ConfigurableEnvironment env = getEnvironment(); + if (env instanceof ConfigurableWebEnvironment) { + ((ConfigurableWebEnvironment) env).initPropertySources(this.servletContext, this.servletConfig); + } + } + + @Override + @Nullable + public Theme getTheme(String themeName) { + Assert.state(this.themeSource != null, "No ThemeSource available"); + return this.themeSource.getTheme(themeName); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/AnnotationConfigWebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/support/AnnotationConfigWebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..5de3b4352b2e93d5d7c38f989b81170e93281377 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/AnnotationConfigWebApplicationContext.java @@ -0,0 +1,284 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +import org.springframework.beans.factory.support.BeanNameGenerator; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.context.annotation.AnnotatedBeanDefinitionReader; +import org.springframework.context.annotation.AnnotationConfigRegistry; +import org.springframework.context.annotation.AnnotationConfigUtils; +import org.springframework.context.annotation.ClassPathBeanDefinitionScanner; +import org.springframework.context.annotation.ScopeMetadataResolver; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.context.ContextLoader; + +/** + * {@link org.springframework.web.context.WebApplicationContext WebApplicationContext} + * implementation which accepts component classes as input — in particular + * {@link org.springframework.context.annotation.Configuration @Configuration}-annotated + * classes, but also plain {@link org.springframework.stereotype.Component @Component} + * classes and JSR-330 compliant classes using {@code javax.inject} annotations. + * + *

Allows for registering classes one by one (specifying class names as config + * location) as well as for classpath scanning (specifying base packages as config location). + * + *

This is essentially the equivalent of + * {@link org.springframework.context.annotation.AnnotationConfigApplicationContext + * AnnotationConfigApplicationContext} for a web environment. + * + *

To make use of this application context, the + * {@linkplain ContextLoader#CONTEXT_CLASS_PARAM "contextClass"} context-param for + * ContextLoader and/or "contextClass" init-param for FrameworkServlet must be set to + * the fully-qualified name of this class. + * + *

As of Spring 3.1, this class may also be directly instantiated and injected into + * Spring's {@code DispatcherServlet} or {@code ContextLoaderListener} when using the + * {@link org.springframework.web.WebApplicationInitializer WebApplicationInitializer} + * code-based alternative to {@code web.xml}. See its Javadoc for details and usage examples. + * + *

Unlike {@link XmlWebApplicationContext}, no default configuration class locations + * are assumed. Rather, it is a requirement to set the + * {@linkplain ContextLoader#CONFIG_LOCATION_PARAM "contextConfigLocation"} + * context-param for {@link ContextLoader} and/or "contextConfigLocation" init-param for + * FrameworkServlet. The param-value may contain both fully-qualified + * class names and base packages to scan for components. See {@link #loadBeanDefinitions} + * for exact details on how these locations are processed. + * + *

As an alternative to setting the "contextConfigLocation" parameter, users may + * implement an {@link org.springframework.context.ApplicationContextInitializer + * ApplicationContextInitializer} and set the + * {@linkplain ContextLoader#CONTEXT_INITIALIZER_CLASSES_PARAM "contextInitializerClasses"} + * context-param / init-param. In such cases, users should favor the {@link #refresh()} + * and {@link #scan(String...)} methods over the {@link #setConfigLocation(String)} + * method, which is primarily for use by {@code ContextLoader}. + * + *

Note: In case of multiple {@code @Configuration} classes, later {@code @Bean} + * definitions will override ones defined in earlier loaded files. This can be leveraged + * to deliberately override certain bean definitions via an extra {@code @Configuration} + * class. + * + * @author Chris Beams + * @author Juergen Hoeller + * @since 3.0 + * @see org.springframework.context.annotation.AnnotationConfigApplicationContext + */ +public class AnnotationConfigWebApplicationContext extends AbstractRefreshableWebApplicationContext + implements AnnotationConfigRegistry { + + @Nullable + private BeanNameGenerator beanNameGenerator; + + @Nullable + private ScopeMetadataResolver scopeMetadataResolver; + + private final Set> componentClasses = new LinkedHashSet<>(); + + private final Set basePackages = new LinkedHashSet<>(); + + + /** + * Set a custom {@link BeanNameGenerator} for use with {@link AnnotatedBeanDefinitionReader} + * and/or {@link ClassPathBeanDefinitionScanner}. + *

Default is {@link org.springframework.context.annotation.AnnotationBeanNameGenerator}. + * @see AnnotatedBeanDefinitionReader#setBeanNameGenerator + * @see ClassPathBeanDefinitionScanner#setBeanNameGenerator + */ + public void setBeanNameGenerator(@Nullable BeanNameGenerator beanNameGenerator) { + this.beanNameGenerator = beanNameGenerator; + } + + /** + * Return the custom {@link BeanNameGenerator} for use with {@link AnnotatedBeanDefinitionReader} + * and/or {@link ClassPathBeanDefinitionScanner}, if any. + */ + @Nullable + protected BeanNameGenerator getBeanNameGenerator() { + return this.beanNameGenerator; + } + + /** + * Set a custom {@link ScopeMetadataResolver} for use with {@link AnnotatedBeanDefinitionReader} + * and/or {@link ClassPathBeanDefinitionScanner}. + *

Default is an {@link org.springframework.context.annotation.AnnotationScopeMetadataResolver}. + * @see AnnotatedBeanDefinitionReader#setScopeMetadataResolver + * @see ClassPathBeanDefinitionScanner#setScopeMetadataResolver + */ + public void setScopeMetadataResolver(@Nullable ScopeMetadataResolver scopeMetadataResolver) { + this.scopeMetadataResolver = scopeMetadataResolver; + } + + /** + * Return the custom {@link ScopeMetadataResolver} for use with {@link AnnotatedBeanDefinitionReader} + * and/or {@link ClassPathBeanDefinitionScanner}, if any. + */ + @Nullable + protected ScopeMetadataResolver getScopeMetadataResolver() { + return this.scopeMetadataResolver; + } + + + /** + * Register one or more component classes to be processed. + *

Note that {@link #refresh()} must be called in order for the context + * to fully process the new classes. + * @param componentClasses one or more component classes, + * e.g. {@link org.springframework.context.annotation.Configuration @Configuration} classes + * @see #scan(String...) + * @see #loadBeanDefinitions(DefaultListableBeanFactory) + * @see #setConfigLocation(String) + * @see #refresh() + */ + @Override + public void register(Class... componentClasses) { + Assert.notEmpty(componentClasses, "At least one component class must be specified"); + Collections.addAll(this.componentClasses, componentClasses); + } + + /** + * Perform a scan within the specified base packages. + *

Note that {@link #refresh()} must be called in order for the context + * to fully process the new classes. + * @param basePackages the packages to check for component classes + * @see #loadBeanDefinitions(DefaultListableBeanFactory) + * @see #register(Class...) + * @see #setConfigLocation(String) + * @see #refresh() + */ + @Override + public void scan(String... basePackages) { + Assert.notEmpty(basePackages, "At least one base package must be specified"); + Collections.addAll(this.basePackages, basePackages); + } + + + /** + * Register a {@link org.springframework.beans.factory.config.BeanDefinition} for + * any classes specified by {@link #register(Class...)} and scan any packages + * specified by {@link #scan(String...)}. + *

For any values specified by {@link #setConfigLocation(String)} or + * {@link #setConfigLocations(String[])}, attempt first to load each location as a + * class, registering a {@code BeanDefinition} if class loading is successful, + * and if class loading fails (i.e. a {@code ClassNotFoundException} is raised), + * assume the value is a package and attempt to scan it for component classes. + *

Enables the default set of annotation configuration post processors, such that + * {@code @Autowired}, {@code @Required}, and associated annotations can be used. + *

Configuration class bean definitions are registered with generated bean + * definition names unless the {@code value} attribute is provided to the stereotype + * annotation. + * @param beanFactory the bean factory to load bean definitions into + * @see #register(Class...) + * @see #scan(String...) + * @see #setConfigLocation(String) + * @see #setConfigLocations(String[]) + * @see AnnotatedBeanDefinitionReader + * @see ClassPathBeanDefinitionScanner + */ + @Override + protected void loadBeanDefinitions(DefaultListableBeanFactory beanFactory) { + AnnotatedBeanDefinitionReader reader = getAnnotatedBeanDefinitionReader(beanFactory); + ClassPathBeanDefinitionScanner scanner = getClassPathBeanDefinitionScanner(beanFactory); + + BeanNameGenerator beanNameGenerator = getBeanNameGenerator(); + if (beanNameGenerator != null) { + reader.setBeanNameGenerator(beanNameGenerator); + scanner.setBeanNameGenerator(beanNameGenerator); + beanFactory.registerSingleton(AnnotationConfigUtils.CONFIGURATION_BEAN_NAME_GENERATOR, beanNameGenerator); + } + + ScopeMetadataResolver scopeMetadataResolver = getScopeMetadataResolver(); + if (scopeMetadataResolver != null) { + reader.setScopeMetadataResolver(scopeMetadataResolver); + scanner.setScopeMetadataResolver(scopeMetadataResolver); + } + + if (!this.componentClasses.isEmpty()) { + if (logger.isDebugEnabled()) { + logger.debug("Registering component classes: [" + + StringUtils.collectionToCommaDelimitedString(this.componentClasses) + "]"); + } + reader.register(ClassUtils.toClassArray(this.componentClasses)); + } + + if (!this.basePackages.isEmpty()) { + if (logger.isDebugEnabled()) { + logger.debug("Scanning base packages: [" + + StringUtils.collectionToCommaDelimitedString(this.basePackages) + "]"); + } + scanner.scan(StringUtils.toStringArray(this.basePackages)); + } + + String[] configLocations = getConfigLocations(); + if (configLocations != null) { + for (String configLocation : configLocations) { + try { + Class clazz = ClassUtils.forName(configLocation, getClassLoader()); + if (logger.isTraceEnabled()) { + logger.trace("Registering [" + configLocation + "]"); + } + reader.register(clazz); + } + catch (ClassNotFoundException ex) { + if (logger.isTraceEnabled()) { + logger.trace("Could not load class for config location [" + configLocation + + "] - trying package scan. " + ex); + } + int count = scanner.scan(configLocation); + if (count == 0 && logger.isDebugEnabled()) { + logger.debug("No component classes found for specified class/package [" + configLocation + "]"); + } + } + } + } + } + + + /** + * Build an {@link AnnotatedBeanDefinitionReader} for the given bean factory. + *

This should be pre-configured with the {@code Environment} (if desired) + * but not with a {@code BeanNameGenerator} or {@code ScopeMetadataResolver} yet. + * @param beanFactory the bean factory to load bean definitions into + * @since 4.1.9 + * @see #getEnvironment() + * @see #getBeanNameGenerator() + * @see #getScopeMetadataResolver() + */ + protected AnnotatedBeanDefinitionReader getAnnotatedBeanDefinitionReader(DefaultListableBeanFactory beanFactory) { + return new AnnotatedBeanDefinitionReader(beanFactory, getEnvironment()); + } + + /** + * Build a {@link ClassPathBeanDefinitionScanner} for the given bean factory. + *

This should be pre-configured with the {@code Environment} (if desired) + * but not with a {@code BeanNameGenerator} or {@code ScopeMetadataResolver} yet. + * @param beanFactory the bean factory to load bean definitions into + * @since 4.1.9 + * @see #getEnvironment() + * @see #getBeanNameGenerator() + * @see #getScopeMetadataResolver() + */ + protected ClassPathBeanDefinitionScanner getClassPathBeanDefinitionScanner(DefaultListableBeanFactory beanFactory) { + return new ClassPathBeanDefinitionScanner(beanFactory, true, getEnvironment()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ContextExposingHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/context/support/ContextExposingHttpServletRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..fc6f1af7a5f22fcd2ee00dcd9bf335bc063a40fb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ContextExposingHttpServletRequest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.util.HashSet; +import java.util.Set; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.WebApplicationContext; + +/** + * HttpServletRequest decorator that makes all Spring beans in a + * given WebApplicationContext accessible as request attributes, + * through lazy checking once an attribute gets accessed. + * + * @author Juergen Hoeller + * @since 2.5 + */ +public class ContextExposingHttpServletRequest extends HttpServletRequestWrapper { + + private final WebApplicationContext webApplicationContext; + + @Nullable + private final Set exposedContextBeanNames; + + @Nullable + private Set explicitAttributes; + + + /** + * Create a new ContextExposingHttpServletRequest for the given request. + * @param originalRequest the original HttpServletRequest + * @param context the WebApplicationContext that this request runs in + */ + public ContextExposingHttpServletRequest(HttpServletRequest originalRequest, WebApplicationContext context) { + this(originalRequest, context, null); + } + + /** + * Create a new ContextExposingHttpServletRequest for the given request. + * @param originalRequest the original HttpServletRequest + * @param context the WebApplicationContext that this request runs in + * @param exposedContextBeanNames the names of beans in the context which + * are supposed to be exposed (if this is non-null, only the beans in this + * Set are eligible for exposure as attributes) + */ + public ContextExposingHttpServletRequest(HttpServletRequest originalRequest, WebApplicationContext context, + @Nullable Set exposedContextBeanNames) { + + super(originalRequest); + Assert.notNull(context, "WebApplicationContext must not be null"); + this.webApplicationContext = context; + this.exposedContextBeanNames = exposedContextBeanNames; + } + + + /** + * Return the WebApplicationContext that this request runs in. + */ + public final WebApplicationContext getWebApplicationContext() { + return this.webApplicationContext; + } + + + @Override + @Nullable + public Object getAttribute(String name) { + if ((this.explicitAttributes == null || !this.explicitAttributes.contains(name)) && + (this.exposedContextBeanNames == null || this.exposedContextBeanNames.contains(name)) && + this.webApplicationContext.containsBean(name)) { + return this.webApplicationContext.getBean(name); + } + else { + return super.getAttribute(name); + } + } + + @Override + public void setAttribute(String name, Object value) { + super.setAttribute(name, value); + if (this.explicitAttributes == null) { + this.explicitAttributes = new HashSet<>(8); + } + this.explicitAttributes.add(name); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/GenericWebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/support/GenericWebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..c8a39cb2908b0ceaacf63723a093f57d61fc33b3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/GenericWebApplicationContext.java @@ -0,0 +1,258 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourcePatternResolver; +import org.springframework.lang.Nullable; +import org.springframework.ui.context.Theme; +import org.springframework.ui.context.ThemeSource; +import org.springframework.ui.context.support.UiApplicationContextUtils; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.context.ConfigurableWebApplicationContext; +import org.springframework.web.context.ConfigurableWebEnvironment; +import org.springframework.web.context.ServletContextAware; + +/** + * Subclass of {@link GenericApplicationContext}, suitable for web environments. + * + *

Implements {@link org.springframework.web.context.ConfigurableWebApplicationContext}, + * but is not intended for declarative setup in {@code web.xml}. Instead, it is designed + * for programmatic setup, for example for building nested contexts or for use within + * {@link org.springframework.web.WebApplicationInitializer WebApplicationInitializers}. + * + *

If you intend to implement a WebApplicationContext that reads bean definitions + * from configuration files, consider deriving from AbstractRefreshableWebApplicationContext, + * reading the bean definitions in an implementation of the {@code loadBeanDefinitions} + * method. + * + *

Interprets resource paths as servlet context resources, i.e. as paths beneath + * the web application root. Absolute paths, e.g. for files outside the web app root, + * can be accessed via "file:" URLs, as implemented by AbstractApplicationContext. + * + *

In addition to the special beans detected by + * {@link org.springframework.context.support.AbstractApplicationContext}, + * this class detects a ThemeSource bean in the context, with the name "themeSource". + * + * @author Juergen Hoeller + * @author Chris Beams + * @since 1.2 + */ +public class GenericWebApplicationContext extends GenericApplicationContext + implements ConfigurableWebApplicationContext, ThemeSource { + + @Nullable + private ServletContext servletContext; + + @Nullable + private ThemeSource themeSource; + + + /** + * Create a new GenericWebApplicationContext. + * @see #setServletContext + * @see #registerBeanDefinition + * @see #refresh + */ + public GenericWebApplicationContext() { + super(); + } + + /** + * Create a new GenericWebApplicationContext for the given ServletContext. + * @param servletContext the ServletContext to run in + * @see #registerBeanDefinition + * @see #refresh + */ + public GenericWebApplicationContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + + /** + * Create a new GenericWebApplicationContext with the given DefaultListableBeanFactory. + * @param beanFactory the DefaultListableBeanFactory instance to use for this context + * @see #setServletContext + * @see #registerBeanDefinition + * @see #refresh + */ + public GenericWebApplicationContext(DefaultListableBeanFactory beanFactory) { + super(beanFactory); + } + + /** + * Create a new GenericWebApplicationContext with the given DefaultListableBeanFactory. + * @param beanFactory the DefaultListableBeanFactory instance to use for this context + * @param servletContext the ServletContext to run in + * @see #registerBeanDefinition + * @see #refresh + */ + public GenericWebApplicationContext(DefaultListableBeanFactory beanFactory, ServletContext servletContext) { + super(beanFactory); + this.servletContext = servletContext; + } + + + /** + * Set the ServletContext that this WebApplicationContext runs in. + */ + @Override + public void setServletContext(@Nullable ServletContext servletContext) { + this.servletContext = servletContext; + } + + @Override + @Nullable + public ServletContext getServletContext() { + return this.servletContext; + } + + @Override + public String getApplicationName() { + return (this.servletContext != null ? this.servletContext.getContextPath() : ""); + } + + /** + * Create and return a new {@link StandardServletEnvironment}. + */ + @Override + protected ConfigurableEnvironment createEnvironment() { + return new StandardServletEnvironment(); + } + + /** + * Register ServletContextAwareProcessor. + * @see ServletContextAwareProcessor + */ + @Override + protected void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { + if (this.servletContext != null) { + beanFactory.addBeanPostProcessor(new ServletContextAwareProcessor(this.servletContext)); + beanFactory.ignoreDependencyInterface(ServletContextAware.class); + } + WebApplicationContextUtils.registerWebApplicationScopes(beanFactory, this.servletContext); + WebApplicationContextUtils.registerEnvironmentBeans(beanFactory, this.servletContext); + } + + /** + * This implementation supports file paths beneath the root of the ServletContext. + * @see ServletContextResource + */ + @Override + protected Resource getResourceByPath(String path) { + Assert.state(this.servletContext != null, "No ServletContext available"); + return new ServletContextResource(this.servletContext, path); + } + + /** + * This implementation supports pattern matching in unexpanded WARs too. + * @see ServletContextResourcePatternResolver + */ + @Override + protected ResourcePatternResolver getResourcePatternResolver() { + return new ServletContextResourcePatternResolver(this); + } + + /** + * Initialize the theme capability. + */ + @Override + protected void onRefresh() { + this.themeSource = UiApplicationContextUtils.initThemeSource(this); + } + + /** + * {@inheritDoc} + *

Replace {@code Servlet}-related property sources. + */ + @Override + protected void initPropertySources() { + ConfigurableEnvironment env = getEnvironment(); + if (env instanceof ConfigurableWebEnvironment) { + ((ConfigurableWebEnvironment) env).initPropertySources(this.servletContext, null); + } + } + + @Override + @Nullable + public Theme getTheme(String themeName) { + Assert.state(this.themeSource != null, "No ThemeSource available"); + return this.themeSource.getTheme(themeName); + } + + + // --------------------------------------------------------------------- + // Pseudo-implementation of ConfigurableWebApplicationContext + // --------------------------------------------------------------------- + + @Override + public void setServletConfig(@Nullable ServletConfig servletConfig) { + // no-op + } + + @Override + @Nullable + public ServletConfig getServletConfig() { + throw new UnsupportedOperationException( + "GenericWebApplicationContext does not support getServletConfig()"); + } + + @Override + public void setNamespace(@Nullable String namespace) { + // no-op + } + + @Override + @Nullable + public String getNamespace() { + throw new UnsupportedOperationException( + "GenericWebApplicationContext does not support getNamespace()"); + } + + @Override + public void setConfigLocation(String configLocation) { + if (StringUtils.hasText(configLocation)) { + throw new UnsupportedOperationException( + "GenericWebApplicationContext does not support setConfigLocation(). " + + "Do you still have an 'contextConfigLocations' init-param set?"); + } + } + + @Override + public void setConfigLocations(String... configLocations) { + if (!ObjectUtils.isEmpty(configLocations)) { + throw new UnsupportedOperationException( + "GenericWebApplicationContext does not support setConfigLocations(). " + + "Do you still have an 'contextConfigLocations' init-param set?"); + } + } + + @Override + public String[] getConfigLocations() { + throw new UnsupportedOperationException( + "GenericWebApplicationContext does not support getConfigLocations()"); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/GroovyWebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/support/GroovyWebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..25610fd6a3d1640e21931e2482b4803aa76f5ffc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/GroovyWebApplicationContext.java @@ -0,0 +1,189 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.IOException; + +import groovy.lang.GroovyObject; +import groovy.lang.GroovySystem; +import groovy.lang.MetaClass; + +import org.springframework.beans.BeanWrapper; +import org.springframework.beans.BeanWrapperImpl; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.groovy.GroovyBeanDefinitionReader; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.lang.Nullable; + +/** + * {@link org.springframework.web.context.WebApplicationContext} implementation which takes + * its configuration from Groovy bean definition scripts and/or XML files, as understood by + * a {@link org.springframework.beans.factory.groovy.GroovyBeanDefinitionReader}. + * This is essentially the equivalent of + * {@link org.springframework.context.support.GenericGroovyApplicationContext} + * for a web environment. + * + *

By default, the configuration will be taken from "/WEB-INF/applicationContext.groovy" + * for the root context, and "/WEB-INF/test-servlet.groovy" for a context with the namespace + * "test-servlet" (like for a DispatcherServlet instance with the servlet-name "test"). + * + *

The config location defaults can be overridden via the "contextConfigLocation" + * context-param of {@link org.springframework.web.context.ContextLoader} and servlet + * init-param of {@link org.springframework.web.servlet.FrameworkServlet}. Config locations + * can either denote concrete files like "/WEB-INF/context.groovy" or Ant-style patterns + * like "/WEB-INF/*-context.groovy" (see {@link org.springframework.util.PathMatcher} + * javadoc for pattern details). Note that ".xml" files will be parsed as XML content; + * all other kinds of resources will be parsed as Groovy scripts. + * + *

Note: In case of multiple config locations, later bean definitions will + * override ones defined in earlier loaded files. This can be leveraged to + * deliberately override certain bean definitions via an extra Groovy script. + * + *

For a WebApplicationContext that reads in a different bean definition format, + * create an analogous subclass of {@link AbstractRefreshableWebApplicationContext}. + * Such a context implementation can be specified as "contextClass" context-param + * for ContextLoader or "contextClass" init-param for FrameworkServlet. + * + * @author Juergen Hoeller + * @since 4.1 + * @see #setNamespace + * @see #setConfigLocations + * @see org.springframework.beans.factory.groovy.GroovyBeanDefinitionReader + * @see org.springframework.web.context.ContextLoader#initWebApplicationContext + * @see org.springframework.web.servlet.FrameworkServlet#initWebApplicationContext + */ +public class GroovyWebApplicationContext extends AbstractRefreshableWebApplicationContext implements GroovyObject { + + /** Default config location for the root context. */ + public static final String DEFAULT_CONFIG_LOCATION = "/WEB-INF/applicationContext.groovy"; + + /** Default prefix for building a config location for a namespace. */ + public static final String DEFAULT_CONFIG_LOCATION_PREFIX = "/WEB-INF/"; + + /** Default suffix for building a config location for a namespace. */ + public static final String DEFAULT_CONFIG_LOCATION_SUFFIX = ".groovy"; + + + private final BeanWrapper contextWrapper = new BeanWrapperImpl(this); + + private MetaClass metaClass = GroovySystem.getMetaClassRegistry().getMetaClass(getClass()); + + + /** + * Loads the bean definitions via an GroovyBeanDefinitionReader. + * @see org.springframework.beans.factory.groovy.GroovyBeanDefinitionReader + * @see #initBeanDefinitionReader + * @see #loadBeanDefinitions + */ + @Override + protected void loadBeanDefinitions(DefaultListableBeanFactory beanFactory) throws BeansException, IOException { + // Create a new XmlBeanDefinitionReader for the given BeanFactory. + GroovyBeanDefinitionReader beanDefinitionReader = new GroovyBeanDefinitionReader(beanFactory); + + // Configure the bean definition reader with this context's + // resource loading environment. + beanDefinitionReader.setEnvironment(getEnvironment()); + beanDefinitionReader.setResourceLoader(this); + + // Allow a subclass to provide custom initialization of the reader, + // then proceed with actually loading the bean definitions. + initBeanDefinitionReader(beanDefinitionReader); + loadBeanDefinitions(beanDefinitionReader); + } + + /** + * Initialize the bean definition reader used for loading the bean + * definitions of this context. Default implementation is empty. + *

Can be overridden in subclasses. + * @param beanDefinitionReader the bean definition reader used by this context + */ + protected void initBeanDefinitionReader(GroovyBeanDefinitionReader beanDefinitionReader) { + } + + /** + * Load the bean definitions with the given GroovyBeanDefinitionReader. + *

The lifecycle of the bean factory is handled by the refreshBeanFactory method; + * therefore this method is just supposed to load and/or register bean definitions. + *

Delegates to a ResourcePatternResolver for resolving location patterns + * into Resource instances. + * @throws IOException if the required Groovy script or XML file isn't found + * @see #refreshBeanFactory + * @see #getConfigLocations + * @see #getResources + * @see #getResourcePatternResolver + */ + protected void loadBeanDefinitions(GroovyBeanDefinitionReader reader) throws IOException { + String[] configLocations = getConfigLocations(); + if (configLocations != null) { + for (String configLocation : configLocations) { + reader.loadBeanDefinitions(configLocation); + } + } + } + + /** + * The default location for the root context is "/WEB-INF/applicationContext.groovy", + * and "/WEB-INF/test-servlet.groovy" for a context with the namespace "test-servlet" + * (like for a DispatcherServlet instance with the servlet-name "test"). + */ + @Override + protected String[] getDefaultConfigLocations() { + if (getNamespace() != null) { + return new String[] {DEFAULT_CONFIG_LOCATION_PREFIX + getNamespace() + DEFAULT_CONFIG_LOCATION_SUFFIX}; + } + else { + return new String[] {DEFAULT_CONFIG_LOCATION}; + } + } + + + // Implementation of the GroovyObject interface + + @Override + public void setMetaClass(MetaClass metaClass) { + this.metaClass = metaClass; + } + + @Override + public MetaClass getMetaClass() { + return this.metaClass; + } + + @Override + public Object invokeMethod(String name, Object args) { + return this.metaClass.invokeMethod(this, name, args); + } + + @Override + public void setProperty(String property, Object newValue) { + this.metaClass.setProperty(this, property, newValue); + } + + @Override + @Nullable + public Object getProperty(String property) { + if (containsBean(property)) { + return getBean(property); + } + else if (this.contextWrapper.isReadableProperty(property)) { + return this.contextWrapper.getPropertyValue(property); + } + throw new NoSuchBeanDefinitionException(property); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/HttpRequestHandlerServlet.java b/spring-web/src/main/java/org/springframework/web/context/support/HttpRequestHandlerServlet.java new file mode 100644 index 0000000000000000000000000000000000000000..97be4db90667b3208b940e63a3e52bf0cbf5d7aa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/HttpRequestHandlerServlet.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.context.i18n.LocaleContextHolder; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.HttpRequestHandler; +import org.springframework.web.HttpRequestMethodNotSupportedException; +import org.springframework.web.context.WebApplicationContext; + +/** + * Simple HttpServlet that delegates to an {@link HttpRequestHandler} bean defined + * in Spring's root web application context. The target bean name must match the + * HttpRequestHandlerServlet servlet-name as defined in {@code web.xml}. + * + *

This can for example be used to expose a single Spring remote exporter, + * such as {@link org.springframework.remoting.httpinvoker.HttpInvokerServiceExporter} + * or {@link org.springframework.remoting.caucho.HessianServiceExporter}, + * per HttpRequestHandlerServlet definition. This is a minimal alternative + * to defining remote exporters as beans in a DispatcherServlet context + * (with advanced mapping and interception facilities being available there). + * + * @author Juergen Hoeller + * @since 2.0 + * @see org.springframework.web.HttpRequestHandler + * @see org.springframework.web.servlet.DispatcherServlet + */ +@SuppressWarnings("serial") +public class HttpRequestHandlerServlet extends HttpServlet { + + @Nullable + private HttpRequestHandler target; + + + @Override + public void init() throws ServletException { + WebApplicationContext wac = WebApplicationContextUtils.getRequiredWebApplicationContext(getServletContext()); + this.target = wac.getBean(getServletName(), HttpRequestHandler.class); + } + + + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + Assert.state(this.target != null, "No HttpRequestHandler available"); + + LocaleContextHolder.setLocale(request.getLocale()); + try { + this.target.handleRequest(request, response); + } + catch (HttpRequestMethodNotSupportedException ex) { + String[] supportedMethods = ex.getSupportedMethods(); + if (supportedMethods != null) { + response.setHeader("Allow", StringUtils.arrayToDelimitedString(supportedMethods, ", ")); + } + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, ex.getMessage()); + } + finally { + LocaleContextHolder.resetLocaleContext(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/LiveBeansViewServlet.java b/spring-web/src/main/java/org/springframework/web/context/support/LiveBeansViewServlet.java new file mode 100644 index 0000000000000000000000000000000000000000..4302e05d3297cd0b3745bd854eb138f1514f91b9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/LiveBeansViewServlet.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.context.support.LiveBeansView; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Servlet variant of {@link LiveBeansView}'s MBean exposure. + * + *

Generates a JSON snapshot for current beans and their dependencies in + * all ApplicationContexts that live within the current web application. + * + * @author Juergen Hoeller + * @since 3.2 + * @see org.springframework.context.support.LiveBeansView#getSnapshotAsJson() + */ +@SuppressWarnings("serial") +public class LiveBeansViewServlet extends HttpServlet { + + @Nullable + private LiveBeansView liveBeansView; + + + @Override + public void init() throws ServletException { + this.liveBeansView = buildLiveBeansView(); + } + + protected LiveBeansView buildLiveBeansView() { + return new ServletContextLiveBeansView(getServletContext()); + } + + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + Assert.state(this.liveBeansView != null, "No LiveBeansView available"); + String content = this.liveBeansView.getSnapshotAsJson(); + response.setContentType("application/json"); + response.setContentLength(content.length()); + response.getWriter().write(content); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/RequestHandledEvent.java b/spring-web/src/main/java/org/springframework/web/context/support/RequestHandledEvent.java new file mode 100644 index 0000000000000000000000000000000000000000..4d4bd8c85dd16b69387f01d2df6656378ca890a6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/RequestHandledEvent.java @@ -0,0 +1,167 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import org.springframework.context.ApplicationEvent; +import org.springframework.lang.Nullable; + +/** + * Event raised when a request is handled within an ApplicationContext. + * + *

Supported by Spring's own FrameworkServlet (through a specific + * ServletRequestHandledEvent subclass), but can also be raised by any + * other web component. Used, for example, by Spring's out-of-the-box + * PerformanceMonitorListener. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since January 17, 2001 + * @see ServletRequestHandledEvent + * @see org.springframework.web.servlet.FrameworkServlet + * @see org.springframework.context.ApplicationContext#publishEvent + */ +@SuppressWarnings("serial") +public class RequestHandledEvent extends ApplicationEvent { + + /** Session id that applied to the request, if any. */ + @Nullable + private String sessionId; + + /** Usually the UserPrincipal. */ + @Nullable + private String userName; + + /** Request processing time. */ + private final long processingTimeMillis; + + /** Cause of failure, if any. */ + @Nullable + private Throwable failureCause; + + + /** + * Create a new RequestHandledEvent with session information. + * @param source the component that published the event + * @param sessionId the id of the HTTP session, if any + * @param userName the name of the user that was associated with the + * request, if any (usually the UserPrincipal) + * @param processingTimeMillis the processing time of the request in milliseconds + */ + public RequestHandledEvent(Object source, @Nullable String sessionId, @Nullable String userName, + long processingTimeMillis) { + + super(source); + this.sessionId = sessionId; + this.userName = userName; + this.processingTimeMillis = processingTimeMillis; + } + + /** + * Create a new RequestHandledEvent with session information. + * @param source the component that published the event + * @param sessionId the id of the HTTP session, if any + * @param userName the name of the user that was associated with the + * request, if any (usually the UserPrincipal) + * @param processingTimeMillis the processing time of the request in milliseconds + * @param failureCause the cause of failure, if any + */ + public RequestHandledEvent(Object source, @Nullable String sessionId, @Nullable String userName, + long processingTimeMillis, @Nullable Throwable failureCause) { + + this(source, sessionId, userName, processingTimeMillis); + this.failureCause = failureCause; + } + + + /** + * Return the processing time of the request in milliseconds. + */ + public long getProcessingTimeMillis() { + return this.processingTimeMillis; + } + + /** + * Return the id of the HTTP session, if any. + */ + @Nullable + public String getSessionId() { + return this.sessionId; + } + + /** + * Return the name of the user that was associated with the request + * (usually the UserPrincipal). + * @see javax.servlet.http.HttpServletRequest#getUserPrincipal() + */ + @Nullable + public String getUserName() { + return this.userName; + } + + /** + * Return whether the request failed. + */ + public boolean wasFailure() { + return (this.failureCause != null); + } + + /** + * Return the cause of failure, if any. + */ + @Nullable + public Throwable getFailureCause() { + return this.failureCause; + } + + + /** + * Return a short description of this event, only involving + * the most important context data. + */ + public String getShortDescription() { + StringBuilder sb = new StringBuilder(); + sb.append("session=[").append(this.sessionId).append("]; "); + sb.append("user=[").append(this.userName).append("]; "); + return sb.toString(); + } + + /** + * Return a full description of this event, involving + * all available context data. + */ + public String getDescription() { + StringBuilder sb = new StringBuilder(); + sb.append("session=[").append(this.sessionId).append("]; "); + sb.append("user=[").append(this.userName).append("]; "); + sb.append("time=[").append(this.processingTimeMillis).append("ms]; "); + sb.append("status=["); + if (!wasFailure()) { + sb.append("OK"); + } + else { + sb.append("failed: ").append(this.failureCause); + } + sb.append(']'); + return sb.toString(); + } + + @Override + public String toString() { + return ("RequestHandledEvent: " + getDescription()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletConfigPropertySource.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletConfigPropertySource.java new file mode 100644 index 0000000000000000000000000000000000000000..01fd5d064dd2785462167a2f4e4a7b9bd8aa95ec --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletConfigPropertySource.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletConfig; + +import org.springframework.core.env.EnumerablePropertySource; +import org.springframework.core.env.PropertySource; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * {@link PropertySource} that reads init parameters from a {@link ServletConfig} object. + * + * @author Chris Beams + * @since 3.1 + * @see ServletContextPropertySource + */ +public class ServletConfigPropertySource extends EnumerablePropertySource { + + public ServletConfigPropertySource(String name, ServletConfig servletConfig) { + super(name, servletConfig); + } + + @Override + public String[] getPropertyNames() { + return StringUtils.toStringArray(this.source.getInitParameterNames()); + } + + @Override + @Nullable + public String getProperty(String name) { + return this.source.getInitParameter(name); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAttributeExporter.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAttributeExporter.java new file mode 100644 index 0000000000000000000000000000000000000000..8e8ac5518216563e939e382205b3ddac381475ed --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAttributeExporter.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.util.Map; + +import javax.servlet.ServletContext; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.web.context.ServletContextAware; + +/** + * Exporter that takes Spring-defined objects and exposes them as + * ServletContext attributes. Usually, bean references will be used + * to export Spring-defined beans as ServletContext attributes. + * + *

Useful to make Spring-defined beans available to code that is + * not aware of Spring at all, but rather just of the Servlet API. + * Client code can then use plain ServletContext attribute lookups + * to access those objects, despite them being defined in a Spring + * application context. + * + *

Alternatively, consider using the WebApplicationContextUtils + * class to access Spring-defined beans via the WebApplicationContext + * interface. This makes client code aware of Spring API, of course. + * + * @author Juergen Hoeller + * @since 1.1.4 + * @see javax.servlet.ServletContext#getAttribute + * @see WebApplicationContextUtils#getWebApplicationContext + */ +public class ServletContextAttributeExporter implements ServletContextAware { + + protected final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private Map attributes; + + + /** + * Set the ServletContext attributes to expose as key-value pairs. + * Each key will be considered a ServletContext attributes key, + * and each value will be used as corresponding attribute value. + *

Usually, you will use bean references for the values, + * to export Spring-defined beans as ServletContext attributes. + * Of course, it is also possible to define plain values to export. + */ + public void setAttributes(Map attributes) { + this.attributes = attributes; + } + + @Override + public void setServletContext(ServletContext servletContext) { + if (this.attributes != null) { + for (Map.Entry entry : this.attributes.entrySet()) { + String attributeName = entry.getKey(); + if (logger.isDebugEnabled()) { + if (servletContext.getAttribute(attributeName) != null) { + logger.debug("Replacing existing ServletContext attribute with name '" + attributeName + "'"); + } + } + servletContext.setAttribute(attributeName, entry.getValue()); + if (logger.isTraceEnabled()) { + logger.trace("Exported ServletContext attribute with name '" + attributeName + "'"); + } + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAttributeFactoryBean.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAttributeFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..c835219b66aaa435709789b68c13c8ccc56742bd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAttributeFactoryBean.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.lang.Nullable; +import org.springframework.web.context.ServletContextAware; + +/** + * {@link FactoryBean} that fetches a specific, existing ServletContext attribute. + * Exposes that ServletContext attribute when used as bean reference, + * effectively making it available as named Spring bean instance. + * + *

Intended to link in ServletContext attributes that exist before + * the startup of the Spring application context. Typically, such + * attributes will have been put there by third-party web frameworks. + * In a purely Spring-based web application, no such linking in of + * ServletContext attributes will be necessary. + * + *

NOTE: As of Spring 3.0, you may also use the "contextAttributes" default + * bean which is of type Map, and dereference it using an "#{contextAttributes.myKey}" + * expression to access a specific attribute by name. + * + * @author Juergen Hoeller + * @since 1.1.4 + * @see org.springframework.web.context.WebApplicationContext#CONTEXT_ATTRIBUTES_BEAN_NAME + * @see ServletContextParameterFactoryBean + */ +public class ServletContextAttributeFactoryBean implements FactoryBean, ServletContextAware { + + @Nullable + private String attributeName; + + @Nullable + private Object attribute; + + + /** + * Set the name of the ServletContext attribute to expose. + */ + public void setAttributeName(String attributeName) { + this.attributeName = attributeName; + } + + @Override + public void setServletContext(ServletContext servletContext) { + if (this.attributeName == null) { + throw new IllegalArgumentException("Property 'attributeName' is required"); + } + this.attribute = servletContext.getAttribute(this.attributeName); + if (this.attribute == null) { + throw new IllegalStateException("No ServletContext attribute '" + this.attributeName + "' found"); + } + } + + + @Override + @Nullable + public Object getObject() throws Exception { + return this.attribute; + } + + @Override + public Class getObjectType() { + return (this.attribute != null ? this.attribute.getClass() : null); + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAwareProcessor.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAwareProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..0146649cc3df822cd3e5a6b8680e15448c58965f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextAwareProcessor.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.lang.Nullable; +import org.springframework.web.context.ServletConfigAware; +import org.springframework.web.context.ServletContextAware; + +/** + * {@link org.springframework.beans.factory.config.BeanPostProcessor} + * implementation that passes the ServletContext to beans that implement + * the {@link ServletContextAware} interface. + * + *

Web application contexts will automatically register this with their + * underlying bean factory. Applications do not use this directly. + * + * @author Juergen Hoeller + * @author Phillip Webb + * @since 12.03.2004 + * @see org.springframework.web.context.ServletContextAware + * @see org.springframework.web.context.support.XmlWebApplicationContext#postProcessBeanFactory + */ +public class ServletContextAwareProcessor implements BeanPostProcessor { + + @Nullable + private ServletContext servletContext; + + @Nullable + private ServletConfig servletConfig; + + + /** + * Create a new ServletContextAwareProcessor without an initial context or config. + * When this constructor is used the {@link #getServletContext()} and/or + * {@link #getServletConfig()} methods should be overridden. + */ + protected ServletContextAwareProcessor() { + } + + /** + * Create a new ServletContextAwareProcessor for the given context. + */ + public ServletContextAwareProcessor(ServletContext servletContext) { + this(servletContext, null); + } + + /** + * Create a new ServletContextAwareProcessor for the given config. + */ + public ServletContextAwareProcessor(ServletConfig servletConfig) { + this(null, servletConfig); + } + + /** + * Create a new ServletContextAwareProcessor for the given context and config. + */ + public ServletContextAwareProcessor(@Nullable ServletContext servletContext, @Nullable ServletConfig servletConfig) { + this.servletContext = servletContext; + this.servletConfig = servletConfig; + } + + + /** + * Returns the {@link ServletContext} to be injected or {@code null}. This method + * can be overridden by subclasses when a context is obtained after the post-processor + * has been registered. + */ + @Nullable + protected ServletContext getServletContext() { + if (this.servletContext == null && getServletConfig() != null) { + return getServletConfig().getServletContext(); + } + return this.servletContext; + } + + /** + * Returns the {@link ServletConfig} to be injected or {@code null}. This method + * can be overridden by subclasses when a context is obtained after the post-processor + * has been registered. + */ + @Nullable + protected ServletConfig getServletConfig() { + return this.servletConfig; + } + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + if (getServletContext() != null && bean instanceof ServletContextAware) { + ((ServletContextAware) bean).setServletContext(getServletContext()); + } + if (getServletConfig() != null && bean instanceof ServletConfigAware) { + ((ServletConfigAware) bean).setServletConfig(getServletConfig()); + } + return bean; + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) { + return bean; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextLiveBeansView.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextLiveBeansView.java new file mode 100644 index 0000000000000000000000000000000000000000..89cf33fb973373b06aa95d99ead3c2dc4976878e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextLiveBeansView.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.util.Enumeration; +import java.util.LinkedHashSet; +import java.util.Set; + +import javax.servlet.ServletContext; + +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.support.LiveBeansView; +import org.springframework.util.Assert; + +/** + * {@link LiveBeansView} subclass which looks for all ApplicationContexts + * in the web application, as exposed in ServletContext attributes. + * + * @author Juergen Hoeller + * @since 3.2 + */ +public class ServletContextLiveBeansView extends LiveBeansView { + + private final ServletContext servletContext; + + /** + * Create a new LiveBeansView for the given ServletContext. + * @param servletContext current ServletContext + */ + public ServletContextLiveBeansView(ServletContext servletContext) { + Assert.notNull(servletContext, "ServletContext must not be null"); + this.servletContext = servletContext; + } + + @Override + protected Set findApplicationContexts() { + Set contexts = new LinkedHashSet<>(); + Enumeration attrNames = this.servletContext.getAttributeNames(); + while (attrNames.hasMoreElements()) { + String attrName = attrNames.nextElement(); + Object attrValue = this.servletContext.getAttribute(attrName); + if (attrValue instanceof ConfigurableApplicationContext) { + contexts.add((ConfigurableApplicationContext) attrValue); + } + } + return contexts; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextParameterFactoryBean.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextParameterFactoryBean.java new file mode 100644 index 0000000000000000000000000000000000000000..74f5ef0d2bb32d95de519a61be5534ae7acf5559 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextParameterFactoryBean.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.lang.Nullable; +import org.springframework.web.context.ServletContextAware; + +/** + * {@link FactoryBean} that retrieves a specific ServletContext init parameter + * (that is, a "context-param" defined in {@code web.xml}). + * Exposes that ServletContext init parameter when used as bean reference, + * effectively making it available as named Spring bean instance. + * + *

NOTE: As of Spring 3.0, you may also use the "contextParameters" default + * bean which is of type Map, and dereference it using an "#{contextParameters.myKey}" + * expression to access a specific parameter by name. + * + * @author Juergen Hoeller + * @since 1.2.4 + * @see org.springframework.web.context.WebApplicationContext#CONTEXT_PARAMETERS_BEAN_NAME + * @see ServletContextAttributeFactoryBean + */ +public class ServletContextParameterFactoryBean implements FactoryBean, ServletContextAware { + + @Nullable + private String initParamName; + + @Nullable + private String paramValue; + + + /** + * Set the name of the ServletContext init parameter to expose. + */ + public void setInitParamName(String initParamName) { + this.initParamName = initParamName; + } + + @Override + public void setServletContext(ServletContext servletContext) { + if (this.initParamName == null) { + throw new IllegalArgumentException("initParamName is required"); + } + this.paramValue = servletContext.getInitParameter(this.initParamName); + if (this.paramValue == null) { + throw new IllegalStateException("No ServletContext init parameter '" + this.initParamName + "' found"); + } + } + + + @Override + @Nullable + public String getObject() { + return this.paramValue; + } + + @Override + public Class getObjectType() { + return String.class; + } + + @Override + public boolean isSingleton() { + return true; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextPropertySource.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextPropertySource.java new file mode 100644 index 0000000000000000000000000000000000000000..468d1c3c04f79350c416f2a45ea2aa8ab8ad9718 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextPropertySource.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletContext; + +import org.springframework.core.env.EnumerablePropertySource; +import org.springframework.core.env.PropertySource; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * {@link PropertySource} that reads init parameters from a {@link ServletContext} object. + * + * @author Chris Beams + * @since 3.1 + * @see ServletConfigPropertySource + */ +public class ServletContextPropertySource extends EnumerablePropertySource { + + public ServletContextPropertySource(String name, ServletContext servletContext) { + super(name, servletContext); + } + + @Override + public String[] getPropertyNames() { + return StringUtils.toStringArray(this.source.getInitParameterNames()); + } + + @Override + @Nullable + public String getProperty(String name) { + return this.source.getInitParameter(name); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResource.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResource.java new file mode 100644 index 0000000000000000000000000000000000000000..5b8c01f9fc2dff91f6fe3b107d83789ffa38af64 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResource.java @@ -0,0 +1,260 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; + +import javax.servlet.ServletContext; + +import org.springframework.core.io.AbstractFileResolvingResource; +import org.springframework.core.io.ContextResource; +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ResourceUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * {@link org.springframework.core.io.Resource} implementation for + * {@link javax.servlet.ServletContext} resources, interpreting + * relative paths within the web application root directory. + * + *

Always supports stream access and URL access, but only allows + * {@code java.io.File} access when the web application archive + * is expanded. + * + * @author Juergen Hoeller + * @since 28.12.2003 + * @see javax.servlet.ServletContext#getResourceAsStream + * @see javax.servlet.ServletContext#getResource + * @see javax.servlet.ServletContext#getRealPath + */ +public class ServletContextResource extends AbstractFileResolvingResource implements ContextResource { + + private final ServletContext servletContext; + + private final String path; + + + /** + * Create a new ServletContextResource. + *

The Servlet spec requires that resource paths start with a slash, + * even if many containers accept paths without leading slash too. + * Consequently, the given path will be prepended with a slash if it + * doesn't already start with one. + * @param servletContext the ServletContext to load from + * @param path the path of the resource + */ + public ServletContextResource(ServletContext servletContext, String path) { + // check ServletContext + Assert.notNull(servletContext, "Cannot resolve ServletContextResource without ServletContext"); + this.servletContext = servletContext; + + // check path + Assert.notNull(path, "Path is required"); + String pathToUse = StringUtils.cleanPath(path); + if (!pathToUse.startsWith("/")) { + pathToUse = "/" + pathToUse; + } + this.path = pathToUse; + } + + + /** + * Return the ServletContext for this resource. + */ + public final ServletContext getServletContext() { + return this.servletContext; + } + + /** + * Return the path for this resource. + */ + public final String getPath() { + return this.path; + } + + /** + * This implementation checks {@code ServletContext.getResource}. + * @see javax.servlet.ServletContext#getResource(String) + */ + @Override + public boolean exists() { + try { + URL url = this.servletContext.getResource(this.path); + return (url != null); + } + catch (MalformedURLException ex) { + return false; + } + } + + /** + * This implementation delegates to {@code ServletContext.getResourceAsStream}, + * which returns {@code null} in case of a non-readable resource (e.g. a directory). + * @see javax.servlet.ServletContext#getResourceAsStream(String) + */ + @Override + public boolean isReadable() { + InputStream is = this.servletContext.getResourceAsStream(this.path); + if (is != null) { + try { + is.close(); + } + catch (IOException ex) { + // ignore + } + return true; + } + else { + return false; + } + } + + @Override + public boolean isFile() { + try { + URL url = this.servletContext.getResource(this.path); + if (url != null && ResourceUtils.isFileURL(url)) { + return true; + } + else { + return (this.servletContext.getRealPath(this.path) != null); + } + } + catch (MalformedURLException ex) { + return false; + } + } + + /** + * This implementation delegates to {@code ServletContext.getResourceAsStream}, + * but throws a FileNotFoundException if no resource found. + * @see javax.servlet.ServletContext#getResourceAsStream(String) + */ + @Override + public InputStream getInputStream() throws IOException { + InputStream is = this.servletContext.getResourceAsStream(this.path); + if (is == null) { + throw new FileNotFoundException("Could not open " + getDescription()); + } + return is; + } + + /** + * This implementation delegates to {@code ServletContext.getResource}, + * but throws a FileNotFoundException if no resource found. + * @see javax.servlet.ServletContext#getResource(String) + */ + @Override + public URL getURL() throws IOException { + URL url = this.servletContext.getResource(this.path); + if (url == null) { + throw new FileNotFoundException( + getDescription() + " cannot be resolved to URL because it does not exist"); + } + return url; + } + + /** + * This implementation resolves "file:" URLs or alternatively delegates to + * {@code ServletContext.getRealPath}, throwing a FileNotFoundException + * if not found or not resolvable. + * @see javax.servlet.ServletContext#getResource(String) + * @see javax.servlet.ServletContext#getRealPath(String) + */ + @Override + public File getFile() throws IOException { + URL url = this.servletContext.getResource(this.path); + if (url != null && ResourceUtils.isFileURL(url)) { + // Proceed with file system resolution... + return super.getFile(); + } + else { + String realPath = WebUtils.getRealPath(this.servletContext, this.path); + return new File(realPath); + } + } + + /** + * This implementation creates a ServletContextResource, applying the given path + * relative to the path of the underlying file of this resource descriptor. + * @see org.springframework.util.StringUtils#applyRelativePath(String, String) + */ + @Override + public Resource createRelative(String relativePath) { + String pathToUse = StringUtils.applyRelativePath(this.path, relativePath); + return new ServletContextResource(this.servletContext, pathToUse); + } + + /** + * This implementation returns the name of the file that this ServletContext + * resource refers to. + * @see org.springframework.util.StringUtils#getFilename(String) + */ + @Override + @Nullable + public String getFilename() { + return StringUtils.getFilename(this.path); + } + + /** + * This implementation returns a description that includes the ServletContext + * resource location. + */ + @Override + public String getDescription() { + return "ServletContext resource [" + this.path + "]"; + } + + @Override + public String getPathWithinContext() { + return this.path; + } + + + /** + * This implementation compares the underlying ServletContext resource locations. + */ + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ServletContextResource)) { + return false; + } + ServletContextResource otherRes = (ServletContextResource) other; + return (this.servletContext.equals(otherRes.servletContext) && this.path.equals(otherRes.path)); + } + + /** + * This implementation returns the hash code of the underlying + * ServletContext resource location. + */ + @Override + public int hashCode() { + return this.path.hashCode(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResourceLoader.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResourceLoader.java new file mode 100644 index 0000000000000000000000000000000000000000..1598788776000ec70d2248eade2a3f989683bd80 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResourceLoader.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2010 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletContext; + +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.core.io.Resource; + +/** + * ResourceLoader implementation that resolves paths as ServletContext + * resources, for use outside a WebApplicationContext (for example, + * in an HttpServletBean or GenericFilterBean subclass). + * + *

Within a WebApplicationContext, resource paths are automatically + * resolved as ServletContext resources by the context implementation. + * + * @author Juergen Hoeller + * @since 1.0.2 + * @see #getResourceByPath + * @see ServletContextResource + * @see org.springframework.web.context.WebApplicationContext + * @see org.springframework.web.servlet.HttpServletBean + * @see org.springframework.web.filter.GenericFilterBean + */ +public class ServletContextResourceLoader extends DefaultResourceLoader { + + private final ServletContext servletContext; + + + /** + * Create a new ServletContextResourceLoader. + * @param servletContext the ServletContext to load resources with + */ + public ServletContextResourceLoader(ServletContext servletContext) { + this.servletContext = servletContext; + } + + /** + * This implementation supports file paths beneath the root of the web application. + * @see ServletContextResource + */ + @Override + protected Resource getResourceByPath(String path) { + return new ServletContextResource(this.servletContext, path); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResourcePatternResolver.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResourcePatternResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..3826c08fa275a7697cde04e687764ea66edcc7f7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextResourcePatternResolver.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.IOException; +import java.util.Enumeration; +import java.util.LinkedHashSet; +import java.util.Set; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; + +import javax.servlet.ServletContext; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.io.Resource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.core.io.UrlResource; +import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.util.ResourceUtils; +import org.springframework.util.StringUtils; + +/** + * ServletContext-aware subclass of {@link PathMatchingResourcePatternResolver}, + * able to find matching resources below the web application root directory + * via {@link ServletContext#getResourcePaths}. Falls back to the superclass' + * file system checking for other resources. + * + * @author Juergen Hoeller + * @since 1.1.2 + */ +public class ServletContextResourcePatternResolver extends PathMatchingResourcePatternResolver { + + private static final Log logger = LogFactory.getLog(ServletContextResourcePatternResolver.class); + + + /** + * Create a new ServletContextResourcePatternResolver. + * @param servletContext the ServletContext to load resources with + * @see ServletContextResourceLoader#ServletContextResourceLoader(javax.servlet.ServletContext) + */ + public ServletContextResourcePatternResolver(ServletContext servletContext) { + super(new ServletContextResourceLoader(servletContext)); + } + + /** + * Create a new ServletContextResourcePatternResolver. + * @param resourceLoader the ResourceLoader to load root directories and + * actual resources with + */ + public ServletContextResourcePatternResolver(ResourceLoader resourceLoader) { + super(resourceLoader); + } + + + /** + * Overridden version which checks for ServletContextResource + * and uses {@code ServletContext.getResourcePaths} to find + * matching resources below the web application root directory. + * In case of other resources, delegates to the superclass version. + * @see #doRetrieveMatchingServletContextResources + * @see ServletContextResource + * @see javax.servlet.ServletContext#getResourcePaths + */ + @Override + protected Set doFindPathMatchingFileResources(Resource rootDirResource, String subPattern) + throws IOException { + + if (rootDirResource instanceof ServletContextResource) { + ServletContextResource scResource = (ServletContextResource) rootDirResource; + ServletContext sc = scResource.getServletContext(); + String fullPattern = scResource.getPath() + subPattern; + Set result = new LinkedHashSet<>(8); + doRetrieveMatchingServletContextResources(sc, fullPattern, scResource.getPath(), result); + return result; + } + else { + return super.doFindPathMatchingFileResources(rootDirResource, subPattern); + } + } + + /** + * Recursively retrieve ServletContextResources that match the given pattern, + * adding them to the given result set. + * @param servletContext the ServletContext to work on + * @param fullPattern the pattern to match against, + * with preprended root directory path + * @param dir the current directory + * @param result the Set of matching Resources to add to + * @throws IOException if directory contents could not be retrieved + * @see ServletContextResource + * @see javax.servlet.ServletContext#getResourcePaths + */ + protected void doRetrieveMatchingServletContextResources( + ServletContext servletContext, String fullPattern, String dir, Set result) + throws IOException { + + Set candidates = servletContext.getResourcePaths(dir); + if (candidates != null) { + boolean dirDepthNotFixed = fullPattern.contains("**"); + int jarFileSep = fullPattern.indexOf(ResourceUtils.JAR_URL_SEPARATOR); + String jarFilePath = null; + String pathInJarFile = null; + if (jarFileSep > 0 && jarFileSep + ResourceUtils.JAR_URL_SEPARATOR.length() < fullPattern.length()) { + jarFilePath = fullPattern.substring(0, jarFileSep); + pathInJarFile = fullPattern.substring(jarFileSep + ResourceUtils.JAR_URL_SEPARATOR.length()); + } + for (String currPath : candidates) { + if (!currPath.startsWith(dir)) { + // Returned resource path does not start with relative directory: + // assuming absolute path returned -> strip absolute path. + int dirIndex = currPath.indexOf(dir); + if (dirIndex != -1) { + currPath = currPath.substring(dirIndex); + } + } + if (currPath.endsWith("/") && (dirDepthNotFixed || StringUtils.countOccurrencesOf(currPath, "/") <= + StringUtils.countOccurrencesOf(fullPattern, "/"))) { + // Search subdirectories recursively: ServletContext.getResourcePaths + // only returns entries for one directory level. + doRetrieveMatchingServletContextResources(servletContext, fullPattern, currPath, result); + } + if (jarFilePath != null && getPathMatcher().match(jarFilePath, currPath)) { + // Base pattern matches a jar file - search for matching entries within. + String absoluteJarPath = servletContext.getRealPath(currPath); + if (absoluteJarPath != null) { + doRetrieveMatchingJarEntries(absoluteJarPath, pathInJarFile, result); + } + } + if (getPathMatcher().match(fullPattern, currPath)) { + result.add(new ServletContextResource(servletContext, currPath)); + } + } + } + } + + /** + * Extract entries from the given jar by pattern. + * @param jarFilePath the path to the jar file + * @param entryPattern the pattern for jar entries to match + * @param result the Set of matching Resources to add to + */ + private void doRetrieveMatchingJarEntries(String jarFilePath, String entryPattern, Set result) { + if (logger.isDebugEnabled()) { + logger.debug("Searching jar file [" + jarFilePath + "] for entries matching [" + entryPattern + "]"); + } + try { + JarFile jarFile = new JarFile(jarFilePath); + try { + for (Enumeration entries = jarFile.entries(); entries.hasMoreElements();) { + JarEntry entry = entries.nextElement(); + String entryPath = entry.getName(); + if (getPathMatcher().match(entryPattern, entryPath)) { + result.add(new UrlResource( + ResourceUtils.URL_PROTOCOL_JAR, + ResourceUtils.FILE_URL_PREFIX + jarFilePath + ResourceUtils.JAR_URL_SEPARATOR + entryPath)); + } + } + } + finally { + jarFile.close(); + } + } + catch (IOException ex) { + if (logger.isWarnEnabled()) { + logger.warn("Cannot search for matching resources in jar file [" + jarFilePath + + "] because the jar cannot be opened through the file system", ex); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletContextScope.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextScope.java new file mode 100644 index 0000000000000000000000000000000000000000..5da0d5298e27b93fb1d271832c511bb64d458f0e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletContextScope.java @@ -0,0 +1,128 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.util.LinkedHashMap; +import java.util.Map; + +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.beans.factory.config.Scope; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link Scope} wrapper for a ServletContext, i.e. for global web application attributes. + * + *

This differs from traditional Spring singletons in that it exposes attributes in the + * ServletContext. Those attributes will get destroyed whenever the entire application + * shuts down, which might be earlier or later than the shutdown of the containing Spring + * ApplicationContext. + * + *

The associated destruction mechanism relies on a + * {@link org.springframework.web.context.ContextCleanupListener} being registered in + * {@code web.xml}. Note that {@link org.springframework.web.context.ContextLoaderListener} + * includes ContextCleanupListener's functionality. + * + *

This scope is registered as default scope with key + * {@link org.springframework.web.context.WebApplicationContext#SCOPE_APPLICATION "application"}. + * + * @author Juergen Hoeller + * @since 3.0 + * @see org.springframework.web.context.ContextCleanupListener + */ +public class ServletContextScope implements Scope, DisposableBean { + + private final ServletContext servletContext; + + private final Map destructionCallbacks = new LinkedHashMap<>(); + + + /** + * Create a new Scope wrapper for the given ServletContext. + * @param servletContext the ServletContext to wrap + */ + public ServletContextScope(ServletContext servletContext) { + Assert.notNull(servletContext, "ServletContext must not be null"); + this.servletContext = servletContext; + } + + + @Override + public Object get(String name, ObjectFactory objectFactory) { + Object scopedObject = this.servletContext.getAttribute(name); + if (scopedObject == null) { + scopedObject = objectFactory.getObject(); + this.servletContext.setAttribute(name, scopedObject); + } + return scopedObject; + } + + @Override + @Nullable + public Object remove(String name) { + Object scopedObject = this.servletContext.getAttribute(name); + if (scopedObject != null) { + synchronized (this.destructionCallbacks) { + this.destructionCallbacks.remove(name); + } + this.servletContext.removeAttribute(name); + return scopedObject; + } + else { + return null; + } + } + + @Override + public void registerDestructionCallback(String name, Runnable callback) { + synchronized (this.destructionCallbacks) { + this.destructionCallbacks.put(name, callback); + } + } + + @Override + @Nullable + public Object resolveContextualObject(String key) { + return null; + } + + @Override + @Nullable + public String getConversationId() { + return null; + } + + + /** + * Invoke all registered destruction callbacks. + * To be called on ServletContext shutdown. + * @see org.springframework.web.context.ContextCleanupListener + */ + @Override + public void destroy() { + synchronized (this.destructionCallbacks) { + for (Runnable runnable : this.destructionCallbacks.values()) { + runnable.run(); + } + this.destructionCallbacks.clear(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/ServletRequestHandledEvent.java b/spring-web/src/main/java/org/springframework/web/context/support/ServletRequestHandledEvent.java new file mode 100644 index 0000000000000000000000000000000000000000..e864033b511318caefc16d4d63ca8eb772509627 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/ServletRequestHandledEvent.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import org.springframework.lang.Nullable; + +/** + * Servlet-specific subclass of RequestHandledEvent, + * adding servlet-specific context information. + * + * @author Juergen Hoeller + * @since 2.0 + * @see org.springframework.web.servlet.FrameworkServlet + * @see org.springframework.context.ApplicationContext#publishEvent + */ +@SuppressWarnings("serial") +public class ServletRequestHandledEvent extends RequestHandledEvent { + + /** URL that triggered the request. */ + private final String requestUrl; + + /** IP address that the request came from. */ + private final String clientAddress; + + /** Usually GET or POST. */ + private final String method; + + /** Name of the servlet that handled the request. */ + private final String servletName; + + /** HTTP status code of the response. */ + private final int statusCode; + + + /** + * Create a new ServletRequestHandledEvent. + * @param source the component that published the event + * @param requestUrl the URL of the request + * @param clientAddress the IP address that the request came from + * @param method the HTTP method of the request (usually GET or POST) + * @param servletName the name of the servlet that handled the request + * @param sessionId the id of the HTTP session, if any + * @param userName the name of the user that was associated with the + * request, if any (usually the UserPrincipal) + * @param processingTimeMillis the processing time of the request in milliseconds + */ + public ServletRequestHandledEvent(Object source, String requestUrl, + String clientAddress, String method, String servletName, + @Nullable String sessionId, @Nullable String userName, long processingTimeMillis) { + + super(source, sessionId, userName, processingTimeMillis); + this.requestUrl = requestUrl; + this.clientAddress = clientAddress; + this.method = method; + this.servletName = servletName; + this.statusCode = -1; + } + + /** + * Create a new ServletRequestHandledEvent. + * @param source the component that published the event + * @param requestUrl the URL of the request + * @param clientAddress the IP address that the request came from + * @param method the HTTP method of the request (usually GET or POST) + * @param servletName the name of the servlet that handled the request + * @param sessionId the id of the HTTP session, if any + * @param userName the name of the user that was associated with the + * request, if any (usually the UserPrincipal) + * @param processingTimeMillis the processing time of the request in milliseconds + * @param failureCause the cause of failure, if any + */ + public ServletRequestHandledEvent(Object source, String requestUrl, + String clientAddress, String method, String servletName, @Nullable String sessionId, + @Nullable String userName, long processingTimeMillis, @Nullable Throwable failureCause) { + + super(source, sessionId, userName, processingTimeMillis, failureCause); + this.requestUrl = requestUrl; + this.clientAddress = clientAddress; + this.method = method; + this.servletName = servletName; + this.statusCode = -1; + } + + /** + * Create a new ServletRequestHandledEvent. + * @param source the component that published the event + * @param requestUrl the URL of the request + * @param clientAddress the IP address that the request came from + * @param method the HTTP method of the request (usually GET or POST) + * @param servletName the name of the servlet that handled the request + * @param sessionId the id of the HTTP session, if any + * @param userName the name of the user that was associated with the + * request, if any (usually the UserPrincipal) + * @param processingTimeMillis the processing time of the request in milliseconds + * @param failureCause the cause of failure, if any + * @param statusCode the HTTP status code of the response + */ + public ServletRequestHandledEvent(Object source, String requestUrl, + String clientAddress, String method, String servletName, @Nullable String sessionId, + @Nullable String userName, long processingTimeMillis, @Nullable Throwable failureCause, int statusCode) { + + super(source, sessionId, userName, processingTimeMillis, failureCause); + this.requestUrl = requestUrl; + this.clientAddress = clientAddress; + this.method = method; + this.servletName = servletName; + this.statusCode = statusCode; + } + + + /** + * Return the URL of the request. + */ + public String getRequestUrl() { + return this.requestUrl; + } + + /** + * Return the IP address that the request came from. + */ + public String getClientAddress() { + return this.clientAddress; + } + + /** + * Return the HTTP method of the request (usually GET or POST). + */ + public String getMethod() { + return this.method; + } + + /** + * Return the name of the servlet that handled the request. + */ + public String getServletName() { + return this.servletName; + } + + /** + * Return the HTTP status code of the response or -1 if the status + * code is not available. + * @since 4.1 + */ + public int getStatusCode() { + return this.statusCode; + } + + @Override + public String getShortDescription() { + StringBuilder sb = new StringBuilder(); + sb.append("url=[").append(getRequestUrl()).append("]; "); + sb.append("client=[").append(getClientAddress()).append("]; "); + sb.append(super.getShortDescription()); + return sb.toString(); + } + + @Override + public String getDescription() { + StringBuilder sb = new StringBuilder(); + sb.append("url=[").append(getRequestUrl()).append("]; "); + sb.append("client=[").append(getClientAddress()).append("]; "); + sb.append("method=[").append(getMethod()).append("]; "); + sb.append("servlet=[").append(getServletName()).append("]; "); + sb.append(super.getDescription()); + return sb.toString(); + } + + @Override + public String toString() { + return "ServletRequestHandledEvent: " + getDescription(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/SpringBeanAutowiringSupport.java b/spring-web/src/main/java/org/springframework/web/context/support/SpringBeanAutowiringSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..b319a3d8c6a24f71c2be919c607872233fcf4a35 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/SpringBeanAutowiringSupport.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletContext; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.web.context.ContextLoader; +import org.springframework.web.context.WebApplicationContext; + +/** + * Convenient base class for self-autowiring classes that gets constructed + * within a Spring-based web application. Resolves {@code @Autowired} + * annotations in the endpoint class against beans in the current Spring + * root web application context (as determined by the current thread's + * context ClassLoader, which needs to be the web application's ClassLoader). + * Can alternatively be used as a delegate instead of as a base class. + * + *

A typical usage of this base class is a JAX-WS endpoint class: + * Such a Spring-based JAX-WS endpoint implementation will follow the + * standard JAX-WS contract for endpoint classes but will be 'thin' + * in that it delegates the actual work to one or more Spring-managed + * service beans - typically obtained using {@code @Autowired}. + * The lifecycle of such an endpoint instance will be managed by the + * JAX-WS runtime, hence the need for this base class to provide + * {@code @Autowired} processing based on the current Spring context. + * + *

NOTE: If there is an explicit way to access the ServletContext, + * prefer such a way over using this class. The {@link WebApplicationContextUtils} + * class allows for easy access to the Spring root web application context + * based on the ServletContext. + * + * @author Juergen Hoeller + * @since 2.5.1 + * @see WebApplicationObjectSupport + */ +public abstract class SpringBeanAutowiringSupport { + + private static final Log logger = LogFactory.getLog(SpringBeanAutowiringSupport.class); + + + /** + * This constructor performs injection on this instance, + * based on the current web application context. + *

Intended for use as a base class. + * @see #processInjectionBasedOnCurrentContext + */ + public SpringBeanAutowiringSupport() { + processInjectionBasedOnCurrentContext(this); + } + + + /** + * Process {@code @Autowired} injection for the given target object, + * based on the current web application context. + *

Intended for use as a delegate. + * @param target the target object to process + * @see org.springframework.web.context.ContextLoader#getCurrentWebApplicationContext() + */ + public static void processInjectionBasedOnCurrentContext(Object target) { + Assert.notNull(target, "Target object must not be null"); + WebApplicationContext cc = ContextLoader.getCurrentWebApplicationContext(); + if (cc != null) { + AutowiredAnnotationBeanPostProcessor bpp = new AutowiredAnnotationBeanPostProcessor(); + bpp.setBeanFactory(cc.getAutowireCapableBeanFactory()); + bpp.processInjection(target); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("Current WebApplicationContext is not available for processing of " + + ClassUtils.getShortName(target.getClass()) + ": " + + "Make sure this class gets constructed in a Spring web application. Proceeding without injection."); + } + } + } + + + /** + * Process {@code @Autowired} injection for the given target object, + * based on the current root web application context as stored in the ServletContext. + *

Intended for use as a delegate. + * @param target the target object to process + * @param servletContext the ServletContext to find the Spring web application context in + * @see WebApplicationContextUtils#getWebApplicationContext(javax.servlet.ServletContext) + */ + public static void processInjectionBasedOnServletContext(Object target, ServletContext servletContext) { + Assert.notNull(target, "Target object must not be null"); + WebApplicationContext cc = WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext); + AutowiredAnnotationBeanPostProcessor bpp = new AutowiredAnnotationBeanPostProcessor(); + bpp.setBeanFactory(cc.getAutowireCapableBeanFactory()); + bpp.processInjection(target); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/StandardServletEnvironment.java b/spring-web/src/main/java/org/springframework/web/context/support/StandardServletEnvironment.java new file mode 100644 index 0000000000000000000000000000000000000000..c3177c711e82415dc3ee09b7e94799bf3b8986b9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/StandardServletEnvironment.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.core.env.Environment; +import org.springframework.core.env.MutablePropertySources; +import org.springframework.core.env.PropertySource; +import org.springframework.core.env.PropertySource.StubPropertySource; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.jndi.JndiLocatorDelegate; +import org.springframework.jndi.JndiPropertySource; +import org.springframework.lang.Nullable; +import org.springframework.web.context.ConfigurableWebEnvironment; + +/** + * {@link Environment} implementation to be used by {@code Servlet}-based web + * applications. All web-related (servlet-based) {@code ApplicationContext} classes + * initialize an instance by default. + * + *

Contributes {@code ServletConfig}, {@code ServletContext}, and JNDI-based + * {@link PropertySource} instances. See {@link #customizePropertySources} method + * documentation for details. + * + * @author Chris Beams + * @since 3.1 + * @see StandardEnvironment + */ +public class StandardServletEnvironment extends StandardEnvironment implements ConfigurableWebEnvironment { + + /** Servlet context init parameters property source name: {@value}. */ + public static final String SERVLET_CONTEXT_PROPERTY_SOURCE_NAME = "servletContextInitParams"; + + /** Servlet config init parameters property source name: {@value}. */ + public static final String SERVLET_CONFIG_PROPERTY_SOURCE_NAME = "servletConfigInitParams"; + + /** JNDI property source name: {@value}. */ + public static final String JNDI_PROPERTY_SOURCE_NAME = "jndiProperties"; + + + /** + * Customize the set of property sources with those contributed by superclasses as + * well as those appropriate for standard servlet-based environments: + *

    + *
  • {@value #SERVLET_CONFIG_PROPERTY_SOURCE_NAME} + *
  • {@value #SERVLET_CONTEXT_PROPERTY_SOURCE_NAME} + *
  • {@value #JNDI_PROPERTY_SOURCE_NAME} + *
+ *

Properties present in {@value #SERVLET_CONFIG_PROPERTY_SOURCE_NAME} will + * take precedence over those in {@value #SERVLET_CONTEXT_PROPERTY_SOURCE_NAME}, and + * properties found in either of the above take precedence over those found in + * {@value #JNDI_PROPERTY_SOURCE_NAME}. + *

Properties in any of the above will take precedence over system properties and + * environment variables contributed by the {@link StandardEnvironment} superclass. + *

The {@code Servlet}-related property sources are added as + * {@link StubPropertySource stubs} at this stage, and will be + * {@linkplain #initPropertySources(ServletContext, ServletConfig) fully initialized} + * once the actual {@link ServletContext} object becomes available. + * @see StandardEnvironment#customizePropertySources + * @see org.springframework.core.env.AbstractEnvironment#customizePropertySources + * @see ServletConfigPropertySource + * @see ServletContextPropertySource + * @see org.springframework.jndi.JndiPropertySource + * @see org.springframework.context.support.AbstractApplicationContext#initPropertySources + * @see #initPropertySources(ServletContext, ServletConfig) + */ + @Override + protected void customizePropertySources(MutablePropertySources propertySources) { + propertySources.addLast(new StubPropertySource(SERVLET_CONFIG_PROPERTY_SOURCE_NAME)); + propertySources.addLast(new StubPropertySource(SERVLET_CONTEXT_PROPERTY_SOURCE_NAME)); + if (JndiLocatorDelegate.isDefaultJndiEnvironmentAvailable()) { + propertySources.addLast(new JndiPropertySource(JNDI_PROPERTY_SOURCE_NAME)); + } + super.customizePropertySources(propertySources); + } + + @Override + public void initPropertySources(@Nullable ServletContext servletContext, @Nullable ServletConfig servletConfig) { + WebApplicationContextUtils.initServletPropertySources(getPropertySources(), servletContext, servletConfig); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/StaticWebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/support/StaticWebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..1a011a89dddbd12d1fb5a75050161345cf58f374 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/StaticWebApplicationContext.java @@ -0,0 +1,206 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.context.support.StaticApplicationContext; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourcePatternResolver; +import org.springframework.lang.Nullable; +import org.springframework.ui.context.Theme; +import org.springframework.ui.context.ThemeSource; +import org.springframework.ui.context.support.UiApplicationContextUtils; +import org.springframework.util.Assert; +import org.springframework.web.context.ConfigurableWebApplicationContext; +import org.springframework.web.context.ServletConfigAware; +import org.springframework.web.context.ServletContextAware; + +/** + * Static {@link org.springframework.web.context.WebApplicationContext} + * implementation for testing. Not intended for use in production applications. + * + *

Implements the {@link org.springframework.web.context.ConfigurableWebApplicationContext} + * interface to allow for direct replacement of an {@link XmlWebApplicationContext}, + * despite not actually supporting external configuration files. + * + *

Interprets resource paths as servlet context resources, i.e. as paths beneath + * the web application root. Absolute paths, e.g. for files outside the web app root, + * can be accessed via "file:" URLs, as implemented by + * {@link org.springframework.core.io.DefaultResourceLoader}. + * + *

In addition to the special beans detected by + * {@link org.springframework.context.support.AbstractApplicationContext}, + * this class detects a bean of type {@link org.springframework.ui.context.ThemeSource} + * in the context, under the special bean name "themeSource". + * + * @author Rod Johnson + * @author Juergen Hoeller + * @see org.springframework.ui.context.ThemeSource + */ +public class StaticWebApplicationContext extends StaticApplicationContext + implements ConfigurableWebApplicationContext, ThemeSource { + + @Nullable + private ServletContext servletContext; + + @Nullable + private ServletConfig servletConfig; + + @Nullable + private String namespace; + + @Nullable + private ThemeSource themeSource; + + + public StaticWebApplicationContext() { + setDisplayName("Root WebApplicationContext"); + } + + + /** + * Set the ServletContext that this WebApplicationContext runs in. + */ + @Override + public void setServletContext(@Nullable ServletContext servletContext) { + this.servletContext = servletContext; + } + + @Override + @Nullable + public ServletContext getServletContext() { + return this.servletContext; + } + + @Override + public void setServletConfig(@Nullable ServletConfig servletConfig) { + this.servletConfig = servletConfig; + if (servletConfig != null && this.servletContext == null) { + this.servletContext = servletConfig.getServletContext(); + } + } + + @Override + @Nullable + public ServletConfig getServletConfig() { + return this.servletConfig; + } + + @Override + public void setNamespace(@Nullable String namespace) { + this.namespace = namespace; + if (namespace != null) { + setDisplayName("WebApplicationContext for namespace '" + namespace + "'"); + } + } + + @Override + @Nullable + public String getNamespace() { + return this.namespace; + } + + /** + * The {@link StaticWebApplicationContext} class does not support this method. + * @throws UnsupportedOperationException always + */ + @Override + public void setConfigLocation(String configLocation) { + throw new UnsupportedOperationException("StaticWebApplicationContext does not support config locations"); + } + + /** + * The {@link StaticWebApplicationContext} class does not support this method. + * @throws UnsupportedOperationException always + */ + @Override + public void setConfigLocations(String... configLocations) { + throw new UnsupportedOperationException("StaticWebApplicationContext does not support config locations"); + } + + @Override + public String[] getConfigLocations() { + return null; + } + + + /** + * Register request/session scopes, a {@link ServletContextAwareProcessor}, etc. + */ + @Override + protected void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { + beanFactory.addBeanPostProcessor(new ServletContextAwareProcessor(this.servletContext, this.servletConfig)); + beanFactory.ignoreDependencyInterface(ServletContextAware.class); + beanFactory.ignoreDependencyInterface(ServletConfigAware.class); + + WebApplicationContextUtils.registerWebApplicationScopes(beanFactory, this.servletContext); + WebApplicationContextUtils.registerEnvironmentBeans(beanFactory, this.servletContext, this.servletConfig); + } + + /** + * This implementation supports file paths beneath the root of the ServletContext. + * @see ServletContextResource + */ + @Override + protected Resource getResourceByPath(String path) { + Assert.state(this.servletContext != null, "No ServletContext available"); + return new ServletContextResource(this.servletContext, path); + } + + /** + * This implementation supports pattern matching in unexpanded WARs too. + * @see ServletContextResourcePatternResolver + */ + @Override + protected ResourcePatternResolver getResourcePatternResolver() { + return new ServletContextResourcePatternResolver(this); + } + + /** + * Create and return a new {@link StandardServletEnvironment}. + */ + @Override + protected ConfigurableEnvironment createEnvironment() { + return new StandardServletEnvironment(); + } + + /** + * Initialize the theme capability. + */ + @Override + protected void onRefresh() { + this.themeSource = UiApplicationContextUtils.initThemeSource(this); + } + + @Override + protected void initPropertySources() { + WebApplicationContextUtils.initServletPropertySources(getEnvironment().getPropertySources(), + this.servletContext, this.servletConfig); + } + + @Override + @Nullable + public Theme getTheme(String themeName) { + Assert.state(this.themeSource != null, "No ThemeSource available"); + return this.themeSource.getTheme(themeName); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/WebApplicationContextUtils.java b/spring-web/src/main/java/org/springframework/web/context/support/WebApplicationContextUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..d9bfe57b15238364dafc6d5b47903aa1d125dbdd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/WebApplicationContextUtils.java @@ -0,0 +1,428 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.Serializable; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; + +import javax.faces.context.ExternalContext; +import javax.faces.context.FacesContext; +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpSession; + +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.core.env.MutablePropertySources; +import org.springframework.core.env.PropertySource.StubPropertySource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.web.context.ConfigurableWebApplicationContext; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.RequestScope; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.context.request.SessionScope; +import org.springframework.web.context.request.WebRequest; + +/** + * Convenience methods for retrieving the root {@link WebApplicationContext} for + * a given {@link ServletContext}. This is useful for programmatically accessing + * a Spring application context from within custom web views or MVC actions. + * + *

Note that there are more convenient ways of accessing the root context for + * many web frameworks, either part of Spring or available as an external library. + * This helper class is just the most generic way to access the root context. + * + * @author Juergen Hoeller + * @see org.springframework.web.context.ContextLoader + * @see org.springframework.web.servlet.FrameworkServlet + * @see org.springframework.web.servlet.DispatcherServlet + * @see org.springframework.web.jsf.FacesContextUtils + * @see org.springframework.web.jsf.el.SpringBeanFacesELResolver + */ +public abstract class WebApplicationContextUtils { + + private static final boolean jsfPresent = + ClassUtils.isPresent("javax.faces.context.FacesContext", RequestContextHolder.class.getClassLoader()); + + + /** + * Find the root {@code WebApplicationContext} for this web app, typically + * loaded via {@link org.springframework.web.context.ContextLoaderListener}. + *

Will rethrow an exception that happened on root context startup, + * to differentiate between a failed context startup and no context at all. + * @param sc the ServletContext to find the web application context for + * @return the root WebApplicationContext for this web app + * @throws IllegalStateException if the root WebApplicationContext could not be found + * @see org.springframework.web.context.WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE + */ + public static WebApplicationContext getRequiredWebApplicationContext(ServletContext sc) throws IllegalStateException { + WebApplicationContext wac = getWebApplicationContext(sc); + if (wac == null) { + throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?"); + } + return wac; + } + + /** + * Find the root {@code WebApplicationContext} for this web app, typically + * loaded via {@link org.springframework.web.context.ContextLoaderListener}. + *

Will rethrow an exception that happened on root context startup, + * to differentiate between a failed context startup and no context at all. + * @param sc the ServletContext to find the web application context for + * @return the root WebApplicationContext for this web app, or {@code null} if none + * @see org.springframework.web.context.WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE + */ + @Nullable + public static WebApplicationContext getWebApplicationContext(ServletContext sc) { + return getWebApplicationContext(sc, WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + } + + /** + * Find a custom {@code WebApplicationContext} for this web app. + * @param sc the ServletContext to find the web application context for + * @param attrName the name of the ServletContext attribute to look for + * @return the desired WebApplicationContext for this web app, or {@code null} if none + */ + @Nullable + public static WebApplicationContext getWebApplicationContext(ServletContext sc, String attrName) { + Assert.notNull(sc, "ServletContext must not be null"); + Object attr = sc.getAttribute(attrName); + if (attr == null) { + return null; + } + if (attr instanceof RuntimeException) { + throw (RuntimeException) attr; + } + if (attr instanceof Error) { + throw (Error) attr; + } + if (attr instanceof Exception) { + throw new IllegalStateException((Exception) attr); + } + if (!(attr instanceof WebApplicationContext)) { + throw new IllegalStateException("Context attribute is not of type WebApplicationContext: " + attr); + } + return (WebApplicationContext) attr; + } + + /** + * Find a unique {@code WebApplicationContext} for this web app: either the + * root web app context (preferred) or a unique {@code WebApplicationContext} + * among the registered {@code ServletContext} attributes (typically coming + * from a single {@code DispatcherServlet} in the current web application). + *

Note that {@code DispatcherServlet}'s exposure of its context can be + * controlled through its {@code publishContext} property, which is {@code true} + * by default but can be selectively switched to only publish a single context + * despite multiple {@code DispatcherServlet} registrations in the web app. + * @param sc the ServletContext to find the web application context for + * @return the desired WebApplicationContext for this web app, or {@code null} if none + * @since 4.2 + * @see #getWebApplicationContext(ServletContext) + * @see ServletContext#getAttributeNames() + */ + @Nullable + public static WebApplicationContext findWebApplicationContext(ServletContext sc) { + WebApplicationContext wac = getWebApplicationContext(sc); + if (wac == null) { + Enumeration attrNames = sc.getAttributeNames(); + while (attrNames.hasMoreElements()) { + String attrName = attrNames.nextElement(); + Object attrValue = sc.getAttribute(attrName); + if (attrValue instanceof WebApplicationContext) { + if (wac != null) { + throw new IllegalStateException("No unique WebApplicationContext found: more than one " + + "DispatcherServlet registered with publishContext=true?"); + } + wac = (WebApplicationContext) attrValue; + } + } + } + return wac; + } + + + /** + * Register web-specific scopes ("request", "session", "globalSession") + * with the given BeanFactory, as used by the WebApplicationContext. + * @param beanFactory the BeanFactory to configure + */ + public static void registerWebApplicationScopes(ConfigurableListableBeanFactory beanFactory) { + registerWebApplicationScopes(beanFactory, null); + } + + /** + * Register web-specific scopes ("request", "session", "globalSession", "application") + * with the given BeanFactory, as used by the WebApplicationContext. + * @param beanFactory the BeanFactory to configure + * @param sc the ServletContext that we're running within + */ + public static void registerWebApplicationScopes(ConfigurableListableBeanFactory beanFactory, + @Nullable ServletContext sc) { + + beanFactory.registerScope(WebApplicationContext.SCOPE_REQUEST, new RequestScope()); + beanFactory.registerScope(WebApplicationContext.SCOPE_SESSION, new SessionScope()); + if (sc != null) { + ServletContextScope appScope = new ServletContextScope(sc); + beanFactory.registerScope(WebApplicationContext.SCOPE_APPLICATION, appScope); + // Register as ServletContext attribute, for ContextCleanupListener to detect it. + sc.setAttribute(ServletContextScope.class.getName(), appScope); + } + + beanFactory.registerResolvableDependency(ServletRequest.class, new RequestObjectFactory()); + beanFactory.registerResolvableDependency(ServletResponse.class, new ResponseObjectFactory()); + beanFactory.registerResolvableDependency(HttpSession.class, new SessionObjectFactory()); + beanFactory.registerResolvableDependency(WebRequest.class, new WebRequestObjectFactory()); + if (jsfPresent) { + FacesDependencyRegistrar.registerFacesDependencies(beanFactory); + } + } + + /** + * Register web-specific environment beans ("contextParameters", "contextAttributes") + * with the given BeanFactory, as used by the WebApplicationContext. + * @param bf the BeanFactory to configure + * @param sc the ServletContext that we're running within + */ + public static void registerEnvironmentBeans(ConfigurableListableBeanFactory bf, @Nullable ServletContext sc) { + registerEnvironmentBeans(bf, sc, null); + } + + /** + * Register web-specific environment beans ("contextParameters", "contextAttributes") + * with the given BeanFactory, as used by the WebApplicationContext. + * @param bf the BeanFactory to configure + * @param servletContext the ServletContext that we're running within + * @param servletConfig the ServletConfig + */ + public static void registerEnvironmentBeans(ConfigurableListableBeanFactory bf, + @Nullable ServletContext servletContext, @Nullable ServletConfig servletConfig) { + + if (servletContext != null && !bf.containsBean(WebApplicationContext.SERVLET_CONTEXT_BEAN_NAME)) { + bf.registerSingleton(WebApplicationContext.SERVLET_CONTEXT_BEAN_NAME, servletContext); + } + + if (servletConfig != null && !bf.containsBean(ConfigurableWebApplicationContext.SERVLET_CONFIG_BEAN_NAME)) { + bf.registerSingleton(ConfigurableWebApplicationContext.SERVLET_CONFIG_BEAN_NAME, servletConfig); + } + + if (!bf.containsBean(WebApplicationContext.CONTEXT_PARAMETERS_BEAN_NAME)) { + Map parameterMap = new HashMap<>(); + if (servletContext != null) { + Enumeration paramNameEnum = servletContext.getInitParameterNames(); + while (paramNameEnum.hasMoreElements()) { + String paramName = (String) paramNameEnum.nextElement(); + parameterMap.put(paramName, servletContext.getInitParameter(paramName)); + } + } + if (servletConfig != null) { + Enumeration paramNameEnum = servletConfig.getInitParameterNames(); + while (paramNameEnum.hasMoreElements()) { + String paramName = (String) paramNameEnum.nextElement(); + parameterMap.put(paramName, servletConfig.getInitParameter(paramName)); + } + } + bf.registerSingleton(WebApplicationContext.CONTEXT_PARAMETERS_BEAN_NAME, + Collections.unmodifiableMap(parameterMap)); + } + + if (!bf.containsBean(WebApplicationContext.CONTEXT_ATTRIBUTES_BEAN_NAME)) { + Map attributeMap = new HashMap<>(); + if (servletContext != null) { + Enumeration attrNameEnum = servletContext.getAttributeNames(); + while (attrNameEnum.hasMoreElements()) { + String attrName = (String) attrNameEnum.nextElement(); + attributeMap.put(attrName, servletContext.getAttribute(attrName)); + } + } + bf.registerSingleton(WebApplicationContext.CONTEXT_ATTRIBUTES_BEAN_NAME, + Collections.unmodifiableMap(attributeMap)); + } + } + + /** + * Convenient variant of {@link #initServletPropertySources(MutablePropertySources, + * ServletContext, ServletConfig)} that always provides {@code null} for the + * {@link ServletConfig} parameter. + * @see #initServletPropertySources(MutablePropertySources, ServletContext, ServletConfig) + */ + public static void initServletPropertySources(MutablePropertySources propertySources, ServletContext servletContext) { + initServletPropertySources(propertySources, servletContext, null); + } + + /** + * Replace {@code Servlet}-based {@link StubPropertySource stub property sources} with + * actual instances populated with the given {@code servletContext} and + * {@code servletConfig} objects. + *

This method is idempotent with respect to the fact it may be called any number + * of times but will perform replacement of stub property sources with their + * corresponding actual property sources once and only once. + * @param sources the {@link MutablePropertySources} to initialize (must not + * be {@code null}) + * @param servletContext the current {@link ServletContext} (ignored if {@code null} + * or if the {@link StandardServletEnvironment#SERVLET_CONTEXT_PROPERTY_SOURCE_NAME + * servlet context property source} has already been initialized) + * @param servletConfig the current {@link ServletConfig} (ignored if {@code null} + * or if the {@link StandardServletEnvironment#SERVLET_CONFIG_PROPERTY_SOURCE_NAME + * servlet config property source} has already been initialized) + * @see org.springframework.core.env.PropertySource.StubPropertySource + * @see org.springframework.core.env.ConfigurableEnvironment#getPropertySources() + */ + public static void initServletPropertySources(MutablePropertySources sources, + @Nullable ServletContext servletContext, @Nullable ServletConfig servletConfig) { + + Assert.notNull(sources, "'propertySources' must not be null"); + String name = StandardServletEnvironment.SERVLET_CONTEXT_PROPERTY_SOURCE_NAME; + if (servletContext != null && sources.get(name) instanceof StubPropertySource) { + sources.replace(name, new ServletContextPropertySource(name, servletContext)); + } + name = StandardServletEnvironment.SERVLET_CONFIG_PROPERTY_SOURCE_NAME; + if (servletConfig != null && sources.get(name) instanceof StubPropertySource) { + sources.replace(name, new ServletConfigPropertySource(name, servletConfig)); + } + } + + /** + * Return the current RequestAttributes instance as ServletRequestAttributes. + * @see RequestContextHolder#currentRequestAttributes() + */ + private static ServletRequestAttributes currentRequestAttributes() { + RequestAttributes requestAttr = RequestContextHolder.currentRequestAttributes(); + if (!(requestAttr instanceof ServletRequestAttributes)) { + throw new IllegalStateException("Current request is not a servlet request"); + } + return (ServletRequestAttributes) requestAttr; + } + + + /** + * Factory that exposes the current request object on demand. + */ + @SuppressWarnings("serial") + private static class RequestObjectFactory implements ObjectFactory, Serializable { + + @Override + public ServletRequest getObject() { + return currentRequestAttributes().getRequest(); + } + + @Override + public String toString() { + return "Current HttpServletRequest"; + } + } + + + /** + * Factory that exposes the current response object on demand. + */ + @SuppressWarnings("serial") + private static class ResponseObjectFactory implements ObjectFactory, Serializable { + + @Override + public ServletResponse getObject() { + ServletResponse response = currentRequestAttributes().getResponse(); + if (response == null) { + throw new IllegalStateException("Current servlet response not available - " + + "consider using RequestContextFilter instead of RequestContextListener"); + } + return response; + } + + @Override + public String toString() { + return "Current HttpServletResponse"; + } + } + + + /** + * Factory that exposes the current session object on demand. + */ + @SuppressWarnings("serial") + private static class SessionObjectFactory implements ObjectFactory, Serializable { + + @Override + public HttpSession getObject() { + return currentRequestAttributes().getRequest().getSession(); + } + + @Override + public String toString() { + return "Current HttpSession"; + } + } + + + /** + * Factory that exposes the current WebRequest object on demand. + */ + @SuppressWarnings("serial") + private static class WebRequestObjectFactory implements ObjectFactory, Serializable { + + @Override + public WebRequest getObject() { + ServletRequestAttributes requestAttr = currentRequestAttributes(); + return new ServletWebRequest(requestAttr.getRequest(), requestAttr.getResponse()); + } + + @Override + public String toString() { + return "Current ServletWebRequest"; + } + } + + + /** + * Inner class to avoid hard-coded JSF dependency. + */ + private static class FacesDependencyRegistrar { + + public static void registerFacesDependencies(ConfigurableListableBeanFactory beanFactory) { + beanFactory.registerResolvableDependency(FacesContext.class, new ObjectFactory() { + @Override + public FacesContext getObject() { + return FacesContext.getCurrentInstance(); + } + @Override + public String toString() { + return "Current JSF FacesContext"; + } + }); + beanFactory.registerResolvableDependency(ExternalContext.class, new ObjectFactory() { + @Override + public ExternalContext getObject() { + return FacesContext.getCurrentInstance().getExternalContext(); + } + @Override + public String toString() { + return "Current JSF ExternalContext"; + } + }); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/WebApplicationObjectSupport.java b/spring-web/src/main/java/org/springframework/web/context/support/WebApplicationObjectSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..e0d213f09cd5de7861083711b65063d7f55f0d71 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/WebApplicationObjectSupport.java @@ -0,0 +1,158 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.File; + +import javax.servlet.ServletContext; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.support.ApplicationObjectSupport; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.ServletContextAware; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.util.WebUtils; + +/** + * Convenient superclass for application objects running in a {@link WebApplicationContext}. + * Provides {@code getWebApplicationContext()}, {@code getServletContext()}, and + * {@code getTempDir()} accessors. + * + *

Note: It is generally recommended to use individual callback interfaces for the actual + * callbacks needed. This broad base class is primarily intended for use within the framework, + * in case of {@link ServletContext} access etc typically being needed. + * + * @author Juergen Hoeller + * @since 28.08.2003 + * @see SpringBeanAutowiringSupport + */ +public abstract class WebApplicationObjectSupport extends ApplicationObjectSupport implements ServletContextAware { + + @Nullable + private ServletContext servletContext; + + + @Override + public final void setServletContext(ServletContext servletContext) { + if (servletContext != this.servletContext) { + this.servletContext = servletContext; + initServletContext(servletContext); + } + } + + /** + * Overrides the base class behavior to enforce running in an ApplicationContext. + * All accessors will throw IllegalStateException if not running in a context. + * @see #getApplicationContext() + * @see #getMessageSourceAccessor() + * @see #getWebApplicationContext() + * @see #getServletContext() + * @see #getTempDir() + */ + @Override + protected boolean isContextRequired() { + return true; + } + + /** + * Calls {@link #initServletContext(javax.servlet.ServletContext)} if the + * given ApplicationContext is a {@link WebApplicationContext}. + */ + @Override + protected void initApplicationContext(ApplicationContext context) { + super.initApplicationContext(context); + if (this.servletContext == null && context instanceof WebApplicationContext) { + this.servletContext = ((WebApplicationContext) context).getServletContext(); + if (this.servletContext != null) { + initServletContext(this.servletContext); + } + } + } + + /** + * Subclasses may override this for custom initialization based + * on the ServletContext that this application object runs in. + *

The default implementation is empty. Called by + * {@link #initApplicationContext(org.springframework.context.ApplicationContext)} + * as well as {@link #setServletContext(javax.servlet.ServletContext)}. + * @param servletContext the ServletContext that this application object runs in + * (never {@code null}) + */ + protected void initServletContext(ServletContext servletContext) { + } + + /** + * Return the current application context as WebApplicationContext. + *

NOTE: Only use this if you actually need to access + * WebApplicationContext-specific functionality. Preferably use + * {@code getApplicationContext()} or {@code getServletContext()} + * else, to be able to run in non-WebApplicationContext environments as well. + * @throws IllegalStateException if not running in a WebApplicationContext + * @see #getApplicationContext() + */ + @Nullable + protected final WebApplicationContext getWebApplicationContext() throws IllegalStateException { + ApplicationContext ctx = getApplicationContext(); + if (ctx instanceof WebApplicationContext) { + return (WebApplicationContext) getApplicationContext(); + } + else if (isContextRequired()) { + throw new IllegalStateException("WebApplicationObjectSupport instance [" + this + + "] does not run in a WebApplicationContext but in: " + ctx); + } + else { + return null; + } + } + + /** + * Return the current ServletContext. + * @throws IllegalStateException if not running within a required ServletContext + * @see #isContextRequired() + */ + @Nullable + protected final ServletContext getServletContext() throws IllegalStateException { + if (this.servletContext != null) { + return this.servletContext; + } + ServletContext servletContext = null; + WebApplicationContext wac = getWebApplicationContext(); + if (wac != null) { + servletContext = wac.getServletContext(); + } + if (servletContext == null && isContextRequired()) { + throw new IllegalStateException("WebApplicationObjectSupport instance [" + this + + "] does not run within a ServletContext. Make sure the object is fully configured!"); + } + return servletContext; + } + + /** + * Return the temporary directory for the current web application, + * as provided by the servlet container. + * @return the File representing the temporary directory + * @throws IllegalStateException if not running within a ServletContext + * @see org.springframework.web.util.WebUtils#getTempDir(javax.servlet.ServletContext) + */ + protected final File getTempDir() throws IllegalStateException { + ServletContext servletContext = getServletContext(); + Assert.state(servletContext != null, "ServletContext is required"); + return WebUtils.getTempDir(servletContext); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/XmlWebApplicationContext.java b/spring-web/src/main/java/org/springframework/web/context/support/XmlWebApplicationContext.java new file mode 100644 index 0000000000000000000000000000000000000000..2f396175b21c1d9b8a7448935943d856c61534a8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/XmlWebApplicationContext.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import java.io.IOException; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.xml.ResourceEntityResolver; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; + +/** + * {@link org.springframework.web.context.WebApplicationContext} implementation + * which takes its configuration from XML documents, understood by an + * {@link org.springframework.beans.factory.xml.XmlBeanDefinitionReader}. + * This is essentially the equivalent of + * {@link org.springframework.context.support.GenericXmlApplicationContext} + * for a web environment. + * + *

By default, the configuration will be taken from "/WEB-INF/applicationContext.xml" + * for the root context, and "/WEB-INF/test-servlet.xml" for a context with the namespace + * "test-servlet" (like for a DispatcherServlet instance with the servlet-name "test"). + * + *

The config location defaults can be overridden via the "contextConfigLocation" + * context-param of {@link org.springframework.web.context.ContextLoader} and servlet + * init-param of {@link org.springframework.web.servlet.FrameworkServlet}. Config locations + * can either denote concrete files like "/WEB-INF/context.xml" or Ant-style patterns + * like "/WEB-INF/*-context.xml" (see {@link org.springframework.util.PathMatcher} + * javadoc for pattern details). + * + *

Note: In case of multiple config locations, later bean definitions will + * override ones defined in earlier loaded files. This can be leveraged to + * deliberately override certain bean definitions via an extra XML file. + * + *

For a WebApplicationContext that reads in a different bean definition format, + * create an analogous subclass of {@link AbstractRefreshableWebApplicationContext}. + * Such a context implementation can be specified as "contextClass" context-param + * for ContextLoader or "contextClass" init-param for FrameworkServlet. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @see #setNamespace + * @see #setConfigLocations + * @see org.springframework.beans.factory.xml.XmlBeanDefinitionReader + * @see org.springframework.web.context.ContextLoader#initWebApplicationContext + * @see org.springframework.web.servlet.FrameworkServlet#initWebApplicationContext + */ +public class XmlWebApplicationContext extends AbstractRefreshableWebApplicationContext { + + /** Default config location for the root context. */ + public static final String DEFAULT_CONFIG_LOCATION = "/WEB-INF/applicationContext.xml"; + + /** Default prefix for building a config location for a namespace. */ + public static final String DEFAULT_CONFIG_LOCATION_PREFIX = "/WEB-INF/"; + + /** Default suffix for building a config location for a namespace. */ + public static final String DEFAULT_CONFIG_LOCATION_SUFFIX = ".xml"; + + + /** + * Loads the bean definitions via an XmlBeanDefinitionReader. + * @see org.springframework.beans.factory.xml.XmlBeanDefinitionReader + * @see #initBeanDefinitionReader + * @see #loadBeanDefinitions + */ + @Override + protected void loadBeanDefinitions(DefaultListableBeanFactory beanFactory) throws BeansException, IOException { + // Create a new XmlBeanDefinitionReader for the given BeanFactory. + XmlBeanDefinitionReader beanDefinitionReader = new XmlBeanDefinitionReader(beanFactory); + + // Configure the bean definition reader with this context's + // resource loading environment. + beanDefinitionReader.setEnvironment(getEnvironment()); + beanDefinitionReader.setResourceLoader(this); + beanDefinitionReader.setEntityResolver(new ResourceEntityResolver(this)); + + // Allow a subclass to provide custom initialization of the reader, + // then proceed with actually loading the bean definitions. + initBeanDefinitionReader(beanDefinitionReader); + loadBeanDefinitions(beanDefinitionReader); + } + + /** + * Initialize the bean definition reader used for loading the bean + * definitions of this context. Default implementation is empty. + *

Can be overridden in subclasses, e.g. for turning off XML validation + * or using a different XmlBeanDefinitionParser implementation. + * @param beanDefinitionReader the bean definition reader used by this context + * @see org.springframework.beans.factory.xml.XmlBeanDefinitionReader#setValidationMode + * @see org.springframework.beans.factory.xml.XmlBeanDefinitionReader#setDocumentReaderClass + */ + protected void initBeanDefinitionReader(XmlBeanDefinitionReader beanDefinitionReader) { + } + + /** + * Load the bean definitions with the given XmlBeanDefinitionReader. + *

The lifecycle of the bean factory is handled by the refreshBeanFactory method; + * therefore this method is just supposed to load and/or register bean definitions. + *

Delegates to a ResourcePatternResolver for resolving location patterns + * into Resource instances. + * @throws IOException if the required XML document isn't found + * @see #refreshBeanFactory + * @see #getConfigLocations + * @see #getResources + * @see #getResourcePatternResolver + */ + protected void loadBeanDefinitions(XmlBeanDefinitionReader reader) throws IOException { + String[] configLocations = getConfigLocations(); + if (configLocations != null) { + for (String configLocation : configLocations) { + reader.loadBeanDefinitions(configLocation); + } + } + } + + /** + * The default location for the root context is "/WEB-INF/applicationContext.xml", + * and "/WEB-INF/test-servlet.xml" for a context with the namespace "test-servlet" + * (like for a DispatcherServlet instance with the servlet-name "test"). + */ + @Override + protected String[] getDefaultConfigLocations() { + if (getNamespace() != null) { + return new String[] {DEFAULT_CONFIG_LOCATION_PREFIX + getNamespace() + DEFAULT_CONFIG_LOCATION_SUFFIX}; + } + else { + return new String[] {DEFAULT_CONFIG_LOCATION}; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/support/package-info.java b/spring-web/src/main/java/org/springframework/web/context/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..78999b0c42be4069dbde1d0ff92fac1c12035300 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/support/package-info.java @@ -0,0 +1,10 @@ +/** + * Classes supporting the {@code org.springframework.web.context} package, + * such as WebApplicationContext implementations and various utility classes. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.context.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..82ec0c02ae1374fbd81ea33bd96d1bc3cc448342 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java @@ -0,0 +1,508 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * A container for CORS configuration along with methods to check against the + * actual origin, HTTP methods, and headers of a given request. + * + *

By default a newly created {@code CorsConfiguration} does not permit any + * cross-origin requests and must be configured explicitly to indicate what + * should be allowed. Use {@link #applyPermitDefaultValues()} to flip the + * initialization model to start with open defaults that permit all cross-origin + * requests for GET, HEAD, and POST requests. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Sam Brannen + * @since 4.2 + * @see CORS spec + */ +public class CorsConfiguration { + + /** Wildcard representing all origins, methods, or headers. */ + public static final String ALL = "*"; + + private static final List DEFAULT_METHODS = Collections.unmodifiableList( + Arrays.asList(HttpMethod.GET, HttpMethod.HEAD)); + + private static final List DEFAULT_PERMIT_METHODS = Collections.unmodifiableList( + Arrays.asList(HttpMethod.GET.name(), HttpMethod.HEAD.name(), HttpMethod.POST.name())); + + private static final List DEFAULT_PERMIT_ALL = Collections.unmodifiableList( + Collections.singletonList(ALL)); + + + @Nullable + private List allowedOrigins; + + @Nullable + private List allowedMethods; + + @Nullable + private List resolvedMethods = DEFAULT_METHODS; + + @Nullable + private List allowedHeaders; + + @Nullable + private List exposedHeaders; + + @Nullable + private Boolean allowCredentials; + + @Nullable + private Long maxAge; + + + /** + * Construct a new {@code CorsConfiguration} instance with no cross-origin + * requests allowed for any origin by default. + * @see #applyPermitDefaultValues() + */ + public CorsConfiguration() { + } + + /** + * Construct a new {@code CorsConfiguration} instance by copying all + * values from the supplied {@code CorsConfiguration}. + */ + public CorsConfiguration(CorsConfiguration other) { + this.allowedOrigins = other.allowedOrigins; + this.allowedMethods = other.allowedMethods; + this.resolvedMethods = other.resolvedMethods; + this.allowedHeaders = other.allowedHeaders; + this.exposedHeaders = other.exposedHeaders; + this.allowCredentials = other.allowCredentials; + this.maxAge = other.maxAge; + } + + + /** + * Set the origins to allow, e.g. {@code "https://domain1.com"}. + *

The special value {@code "*"} allows all domains. + *

By default this is not set. + */ + public void setAllowedOrigins(@Nullable List allowedOrigins) { + this.allowedOrigins = (allowedOrigins != null ? new ArrayList<>(allowedOrigins) : null); + } + + /** + * Return the configured origins to allow, or {@code null} if none. + * @see #addAllowedOrigin(String) + * @see #setAllowedOrigins(List) + */ + @Nullable + public List getAllowedOrigins() { + return this.allowedOrigins; + } + + /** + * Add an origin to allow. + */ + public void addAllowedOrigin(String origin) { + if (this.allowedOrigins == null) { + this.allowedOrigins = new ArrayList<>(4); + } + else if (this.allowedOrigins == DEFAULT_PERMIT_ALL) { + setAllowedOrigins(DEFAULT_PERMIT_ALL); + } + this.allowedOrigins.add(origin); + } + + /** + * Set the HTTP methods to allow, e.g. {@code "GET"}, {@code "POST"}, + * {@code "PUT"}, etc. + *

The special value {@code "*"} allows all methods. + *

If not set, only {@code "GET"} and {@code "HEAD"} are allowed. + *

By default this is not set. + *

Note: CORS checks use values from "Forwarded" + * (RFC 7239), + * "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, + * if present, in order to reflect the client-originated address. + * Consider using the {@code ForwardedHeaderFilter} in order to choose from a + * central place whether to extract and use, or to discard such headers. + * See the Spring Framework reference for more on this filter. + */ + public void setAllowedMethods(@Nullable List allowedMethods) { + this.allowedMethods = (allowedMethods != null ? new ArrayList<>(allowedMethods) : null); + if (!CollectionUtils.isEmpty(allowedMethods)) { + this.resolvedMethods = new ArrayList<>(allowedMethods.size()); + for (String method : allowedMethods) { + if (ALL.equals(method)) { + this.resolvedMethods = null; + break; + } + this.resolvedMethods.add(HttpMethod.resolve(method)); + } + } + else { + this.resolvedMethods = DEFAULT_METHODS; + } + } + + /** + * Return the allowed HTTP methods, or {@code null} in which case + * only {@code "GET"} and {@code "HEAD"} allowed. + * @see #addAllowedMethod(HttpMethod) + * @see #addAllowedMethod(String) + * @see #setAllowedMethods(List) + */ + @Nullable + public List getAllowedMethods() { + return this.allowedMethods; + } + + /** + * Add an HTTP method to allow. + */ + public void addAllowedMethod(HttpMethod method) { + addAllowedMethod(method.name()); + } + + /** + * Add an HTTP method to allow. + */ + public void addAllowedMethod(String method) { + if (StringUtils.hasText(method)) { + if (this.allowedMethods == null) { + this.allowedMethods = new ArrayList<>(4); + this.resolvedMethods = new ArrayList<>(4); + } + else if (this.allowedMethods == DEFAULT_PERMIT_METHODS) { + setAllowedMethods(DEFAULT_PERMIT_METHODS); + } + this.allowedMethods.add(method); + if (ALL.equals(method)) { + this.resolvedMethods = null; + } + else if (this.resolvedMethods != null) { + this.resolvedMethods.add(HttpMethod.resolve(method)); + } + } + } + + /** + * Set the list of headers that a pre-flight request can list as allowed + * for use during an actual request. + *

The special value {@code "*"} allows actual requests to send any + * header. + *

A header name is not required to be listed if it is one of: + * {@code Cache-Control}, {@code Content-Language}, {@code Expires}, + * {@code Last-Modified}, or {@code Pragma}. + *

By default this is not set. + */ + public void setAllowedHeaders(@Nullable List allowedHeaders) { + this.allowedHeaders = (allowedHeaders != null ? new ArrayList<>(allowedHeaders) : null); + } + + /** + * Return the allowed actual request headers, or {@code null} if none. + * @see #addAllowedHeader(String) + * @see #setAllowedHeaders(List) + */ + @Nullable + public List getAllowedHeaders() { + return this.allowedHeaders; + } + + /** + * Add an actual request header to allow. + */ + public void addAllowedHeader(String allowedHeader) { + if (this.allowedHeaders == null) { + this.allowedHeaders = new ArrayList<>(4); + } + else if (this.allowedHeaders == DEFAULT_PERMIT_ALL) { + setAllowedHeaders(DEFAULT_PERMIT_ALL); + } + this.allowedHeaders.add(allowedHeader); + } + + /** + * Set the list of response headers other than simple headers (i.e. + * {@code Cache-Control}, {@code Content-Language}, {@code Content-Type}, + * {@code Expires}, {@code Last-Modified}, or {@code Pragma}) that an + * actual response might have and can be exposed. + *

The special value {@code "*"} allows all headers to be exposed for + * non-credentialed requests. + *

By default this is not set. + */ + public void setExposedHeaders(@Nullable List exposedHeaders) { + this.exposedHeaders = (exposedHeaders != null ? new ArrayList<>(exposedHeaders) : null); + } + + /** + * Return the configured response headers to expose, or {@code null} if none. + * @see #addExposedHeader(String) + * @see #setExposedHeaders(List) + */ + @Nullable + public List getExposedHeaders() { + return this.exposedHeaders; + } + + /** + * Add a response header to expose. + *

The special value {@code "*"} allows all headers to be exposed for + * non-credentialed requests. + */ + public void addExposedHeader(String exposedHeader) { + if (this.exposedHeaders == null) { + this.exposedHeaders = new ArrayList<>(4); + } + this.exposedHeaders.add(exposedHeader); + } + + /** + * Whether user credentials are supported. + *

By default this is not set (i.e. user credentials are not supported). + */ + public void setAllowCredentials(@Nullable Boolean allowCredentials) { + this.allowCredentials = allowCredentials; + } + + /** + * Return the configured {@code allowCredentials} flag, or {@code null} if none. + * @see #setAllowCredentials(Boolean) + */ + @Nullable + public Boolean getAllowCredentials() { + return this.allowCredentials; + } + + /** + * Configure how long, in seconds, the response from a pre-flight request + * can be cached by clients. + *

By default this is not set. + */ + public void setMaxAge(@Nullable Long maxAge) { + this.maxAge = maxAge; + } + + /** + * Return the configured {@code maxAge} value, or {@code null} if none. + * @see #setMaxAge(Long) + */ + @Nullable + public Long getMaxAge() { + return this.maxAge; + } + + + /** + * By default a newly created {@code CorsConfiguration} does not permit any + * cross-origin requests and must be configured explicitly to indicate what + * should be allowed. + *

Use this method to flip the initialization model to start with open + * defaults that permit all cross-origin requests for GET, HEAD, and POST + * requests. Note however that this method will not override any existing + * values already set. + *

The following defaults are applied if not already set: + *

    + *
  • Allow all origins.
  • + *
  • Allow "simple" methods {@code GET}, {@code HEAD} and {@code POST}.
  • + *
  • Allow all headers.
  • + *
  • Set max age to 1800 seconds (30 minutes).
  • + *
+ */ + public CorsConfiguration applyPermitDefaultValues() { + if (this.allowedOrigins == null) { + this.allowedOrigins = DEFAULT_PERMIT_ALL; + } + if (this.allowedMethods == null) { + this.allowedMethods = DEFAULT_PERMIT_METHODS; + this.resolvedMethods = DEFAULT_PERMIT_METHODS + .stream().map(HttpMethod::resolve).collect(Collectors.toList()); + } + if (this.allowedHeaders == null) { + this.allowedHeaders = DEFAULT_PERMIT_ALL; + } + if (this.maxAge == null) { + this.maxAge = 1800L; + } + return this; + } + + /** + * Combine the non-null properties of the supplied + * {@code CorsConfiguration} with this one. + *

When combining single values like {@code allowCredentials} or + * {@code maxAge}, {@code this} properties are overridden by non-null + * {@code other} properties if any. + *

Combining lists like {@code allowedOrigins}, {@code allowedMethods}, + * {@code allowedHeaders} or {@code exposedHeaders} is done in an additive + * way. For example, combining {@code ["GET", "POST"]} with + * {@code ["PATCH"]} results in {@code ["GET", "POST", "PATCH"]}, but keep + * in mind that combining {@code ["GET", "POST"]} with {@code ["*"]} + * results in {@code ["*"]}. + *

Notice that default permit values set by + * {@link CorsConfiguration#applyPermitDefaultValues()} are overridden by + * any value explicitly defined. + * @return the combined {@code CorsConfiguration}, or {@code this} + * configuration if the supplied configuration is {@code null} + */ + @Nullable + public CorsConfiguration combine(@Nullable CorsConfiguration other) { + if (other == null) { + return this; + } + CorsConfiguration config = new CorsConfiguration(this); + config.setAllowedOrigins(combine(getAllowedOrigins(), other.getAllowedOrigins())); + config.setAllowedMethods(combine(getAllowedMethods(), other.getAllowedMethods())); + config.setAllowedHeaders(combine(getAllowedHeaders(), other.getAllowedHeaders())); + config.setExposedHeaders(combine(getExposedHeaders(), other.getExposedHeaders())); + Boolean allowCredentials = other.getAllowCredentials(); + if (allowCredentials != null) { + config.setAllowCredentials(allowCredentials); + } + Long maxAge = other.getMaxAge(); + if (maxAge != null) { + config.setMaxAge(maxAge); + } + return config; + } + + private List combine(@Nullable List source, @Nullable List other) { + if (other == null) { + return (source != null ? source : Collections.emptyList()); + } + if (source == null) { + return other; + } + if (source == DEFAULT_PERMIT_ALL || source == DEFAULT_PERMIT_METHODS) { + return other; + } + if (other == DEFAULT_PERMIT_ALL || other == DEFAULT_PERMIT_METHODS) { + return source; + } + if (source.contains(ALL) || other.contains(ALL)) { + return new ArrayList<>(Collections.singletonList(ALL)); + } + Set combined = new LinkedHashSet<>(source); + combined.addAll(other); + return new ArrayList<>(combined); + } + + /** + * Check the origin of the request against the configured allowed origins. + * @param requestOrigin the origin to check + * @return the origin to use for the response, or {@code null} which + * means the request origin is not allowed + */ + @Nullable + public String checkOrigin(@Nullable String requestOrigin) { + if (!StringUtils.hasText(requestOrigin)) { + return null; + } + if (ObjectUtils.isEmpty(this.allowedOrigins)) { + return null; + } + + if (this.allowedOrigins.contains(ALL)) { + if (this.allowCredentials != Boolean.TRUE) { + return ALL; + } + else { + return requestOrigin; + } + } + for (String allowedOrigin : this.allowedOrigins) { + if (requestOrigin.equalsIgnoreCase(allowedOrigin)) { + return requestOrigin; + } + } + + return null; + } + + /** + * Check the HTTP request method (or the method from the + * {@code Access-Control-Request-Method} header on a pre-flight request) + * against the configured allowed methods. + * @param requestMethod the HTTP request method to check + * @return the list of HTTP methods to list in the response of a pre-flight + * request, or {@code null} if the supplied {@code requestMethod} is not allowed + */ + @Nullable + public List checkHttpMethod(@Nullable HttpMethod requestMethod) { + if (requestMethod == null) { + return null; + } + if (this.resolvedMethods == null) { + return Collections.singletonList(requestMethod); + } + return (this.resolvedMethods.contains(requestMethod) ? this.resolvedMethods : null); + } + + /** + * Check the supplied request headers (or the headers listed in the + * {@code Access-Control-Request-Headers} of a pre-flight request) against + * the configured allowed headers. + * @param requestHeaders the request headers to check + * @return the list of allowed headers to list in the response of a pre-flight + * request, or {@code null} if none of the supplied request headers is allowed + */ + @Nullable + public List checkHeaders(@Nullable List requestHeaders) { + if (requestHeaders == null) { + return null; + } + if (requestHeaders.isEmpty()) { + return Collections.emptyList(); + } + if (ObjectUtils.isEmpty(this.allowedHeaders)) { + return null; + } + + boolean allowAnyHeader = this.allowedHeaders.contains(ALL); + List result = new ArrayList<>(requestHeaders.size()); + for (String requestHeader : requestHeaders) { + if (StringUtils.hasText(requestHeader)) { + requestHeader = requestHeader.trim(); + if (allowAnyHeader) { + result.add(requestHeader); + } + else { + for (String allowedHeader : this.allowedHeaders) { + if (requestHeader.equalsIgnoreCase(allowedHeader)) { + result.add(requestHeader); + break; + } + } + } + } + } + return (result.isEmpty() ? null : result); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/CorsConfigurationSource.java new file mode 100644 index 0000000000000000000000000000000000000000..27934a126f5e4e17ee6eeac4d099e868116be81f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsConfigurationSource.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.lang.Nullable; + +/** + * Interface to be implemented by classes (usually HTTP request handlers) that + * provides a {@link CorsConfiguration} instance based on the provided request. + * + * @author Sebastien Deleuze + * @since 4.2 + */ +public interface CorsConfigurationSource { + + /** + * Return a {@link CorsConfiguration} based on the incoming request. + * @return the associated {@link CorsConfiguration}, or {@code null} if none + */ + @Nullable + CorsConfiguration getCorsConfiguration(HttpServletRequest request); + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..e162d77ebf7ccae198b06cba810158d82b24ecb2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import java.io.IOException; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.lang.Nullable; + +/** + * A strategy that takes a request and a {@link CorsConfiguration} and updates + * the response. + * + *

This component is not concerned with how a {@code CorsConfiguration} is + * selected but rather takes follow-up actions such as applying CORS validation + * checks and either rejecting the response or adding CORS headers to the + * response. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 4.2 + * @see CORS W3C recommendation + * @see org.springframework.web.servlet.handler.AbstractHandlerMapping#setCorsProcessor + */ +public interface CorsProcessor { + + /** + * Process a request given a {@code CorsConfiguration}. + * @param configuration the applicable CORS configuration (possibly {@code null}) + * @param request the current request + * @param response the current response + * @return {@code false} if the request is rejected, {@code true} otherwise + */ + boolean processRequest(@Nullable CorsConfiguration configuration, HttpServletRequest request, + HttpServletResponse response) throws IOException; + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..2e31588101c2bdeb41bf9f68f819194e3a0dc80a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; + +/** + * Utility class for CORS request handling based on the + * CORS W3C recommendation. + * + * @author Sebastien Deleuze + * @since 4.2 + */ +public abstract class CorsUtils { + + /** + * Returns {@code true} if the request is a valid CORS one. + */ + public static boolean isCorsRequest(HttpServletRequest request) { + return (request.getHeader(HttpHeaders.ORIGIN) != null); + } + + /** + * Returns {@code true} if the request is a valid CORS pre-flight one. + */ + public static boolean isPreFlightRequest(HttpServletRequest request) { + return (isCorsRequest(request) && HttpMethod.OPTIONS.matches(request.getMethod()) && + request.getHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..b674d9ee1615fa87700ef086714745b6242737f3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.web.util.WebUtils; + +/** + * The default implementation of {@link CorsProcessor}, as defined by the + * CORS W3C recommendation. + * + *

Note that when input {@link CorsConfiguration} is {@code null}, this + * implementation does not reject simple or actual requests outright but simply + * avoid adding CORS headers to the response. CORS processing is also skipped + * if the response already contains CORS headers, or if the request is detected + * as a same-origin one. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 4.2 + */ +public class DefaultCorsProcessor implements CorsProcessor { + + private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class); + + + @Override + @SuppressWarnings("resource") + public boolean processRequest(@Nullable CorsConfiguration config, HttpServletRequest request, + HttpServletResponse response) throws IOException { + + if (!CorsUtils.isCorsRequest(request)) { + return true; + } + + ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); + if (responseHasCors(serverResponse)) { + logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\""); + return true; + } + + ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + if (WebUtils.isSameOrigin(serverRequest)) { + logger.trace("Skip: request is from same origin"); + return true; + } + + boolean preFlightRequest = CorsUtils.isPreFlightRequest(request); + if (config == null) { + if (preFlightRequest) { + rejectRequest(serverResponse); + return false; + } + else { + return true; + } + } + + return handleInternal(serverRequest, serverResponse, config, preFlightRequest); + } + + private boolean responseHasCors(ServerHttpResponse response) { + try { + return (response.getHeaders().getAccessControlAllowOrigin() != null); + } + catch (NullPointerException npe) { + // SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 + return false; + } + } + + /** + * Invoked when one of the CORS checks failed. + * The default implementation sets the response status to 403 and writes + * "Invalid CORS request" to the response. + */ + protected void rejectRequest(ServerHttpResponse response) throws IOException { + response.setStatusCode(HttpStatus.FORBIDDEN); + response.getBody().write("Invalid CORS request".getBytes(StandardCharsets.UTF_8)); + } + + /** + * Handle the given request. + */ + protected boolean handleInternal(ServerHttpRequest request, ServerHttpResponse response, + CorsConfiguration config, boolean preFlightRequest) throws IOException { + + String requestOrigin = request.getHeaders().getOrigin(); + String allowOrigin = checkOrigin(config, requestOrigin); + HttpHeaders responseHeaders = response.getHeaders(); + + responseHeaders.addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + + if (allowOrigin == null) { + logger.debug("Reject: '" + requestOrigin + "' origin is not allowed"); + rejectRequest(response); + return false; + } + + HttpMethod requestMethod = getMethodToUse(request, preFlightRequest); + List allowMethods = checkMethods(config, requestMethod); + if (allowMethods == null) { + logger.debug("Reject: HTTP '" + requestMethod + "' is not allowed"); + rejectRequest(response); + return false; + } + + List requestHeaders = getHeadersToUse(request, preFlightRequest); + List allowHeaders = checkHeaders(config, requestHeaders); + if (preFlightRequest && allowHeaders == null) { + logger.debug("Reject: headers '" + requestHeaders + "' are not allowed"); + rejectRequest(response); + return false; + } + + responseHeaders.setAccessControlAllowOrigin(allowOrigin); + + if (preFlightRequest) { + responseHeaders.setAccessControlAllowMethods(allowMethods); + } + + if (preFlightRequest && !allowHeaders.isEmpty()) { + responseHeaders.setAccessControlAllowHeaders(allowHeaders); + } + + if (!CollectionUtils.isEmpty(config.getExposedHeaders())) { + responseHeaders.setAccessControlExposeHeaders(config.getExposedHeaders()); + } + + if (Boolean.TRUE.equals(config.getAllowCredentials())) { + responseHeaders.setAccessControlAllowCredentials(true); + } + + if (preFlightRequest && config.getMaxAge() != null) { + responseHeaders.setAccessControlMaxAge(config.getMaxAge()); + } + + response.flush(); + return true; + } + + /** + * Check the origin and determine the origin for the response. The default + * implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)}. + */ + @Nullable + protected String checkOrigin(CorsConfiguration config, @Nullable String requestOrigin) { + return config.checkOrigin(requestOrigin); + } + + /** + * Check the HTTP method and determine the methods for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkHttpMethod(HttpMethod)}. + */ + @Nullable + protected List checkMethods(CorsConfiguration config, @Nullable HttpMethod requestMethod) { + return config.checkHttpMethod(requestMethod); + } + + @Nullable + private HttpMethod getMethodToUse(ServerHttpRequest request, boolean isPreFlight) { + return (isPreFlight ? request.getHeaders().getAccessControlRequestMethod() : request.getMethod()); + } + + /** + * Check the headers and determine the headers for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)}. + */ + @Nullable + protected List checkHeaders(CorsConfiguration config, List requestHeaders) { + return config.checkHeaders(requestHeaders); + } + + private List getHeadersToUse(ServerHttpRequest request, boolean isPreFlight) { + HttpHeaders headers = request.getHeaders(); + return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList<>(headers.keySet())); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/UrlBasedCorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/UrlBasedCorsConfigurationSource.java new file mode 100644 index 0000000000000000000000000000000000000000..38516b9b773ac2e3ce17daa0b1898f0061d3fab1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/UrlBasedCorsConfigurationSource.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.lang.Nullable; +import org.springframework.util.AntPathMatcher; +import org.springframework.util.Assert; +import org.springframework.util.PathMatcher; +import org.springframework.web.util.UrlPathHelper; + +/** + * Provide a per request {@link CorsConfiguration} instance based on a + * collection of {@link CorsConfiguration} mapped on path patterns. + * + *

Exact path mapping URIs (such as {@code "/admin"}) are supported + * as well as Ant-style path patterns (such as {@code "/admin/**"}). + * + * @author Sebastien Deleuze + * @since 4.2 + */ +public class UrlBasedCorsConfigurationSource implements CorsConfigurationSource { + + private final Map corsConfigurations = new LinkedHashMap<>(); + + private PathMatcher pathMatcher = new AntPathMatcher(); + + private UrlPathHelper urlPathHelper = new UrlPathHelper(); + + + /** + * Set the PathMatcher implementation to use for matching URL paths + * against registered URL patterns. Default is AntPathMatcher. + * @see org.springframework.util.AntPathMatcher + */ + public void setPathMatcher(PathMatcher pathMatcher) { + Assert.notNull(pathMatcher, "PathMatcher must not be null"); + this.pathMatcher = pathMatcher; + } + + /** + * Shortcut to same property on underlying {@link #setUrlPathHelper UrlPathHelper}. + * @see org.springframework.web.util.UrlPathHelper#setAlwaysUseFullPath + */ + public void setAlwaysUseFullPath(boolean alwaysUseFullPath) { + this.urlPathHelper.setAlwaysUseFullPath(alwaysUseFullPath); + } + + /** + * Shortcut to same property on underlying {@link #setUrlPathHelper UrlPathHelper}. + * @see org.springframework.web.util.UrlPathHelper#setUrlDecode + */ + public void setUrlDecode(boolean urlDecode) { + this.urlPathHelper.setUrlDecode(urlDecode); + } + + /** + * Shortcut to same property on underlying {@link #setUrlPathHelper UrlPathHelper}. + * @see org.springframework.web.util.UrlPathHelper#setRemoveSemicolonContent(boolean) + */ + public void setRemoveSemicolonContent(boolean removeSemicolonContent) { + this.urlPathHelper.setRemoveSemicolonContent(removeSemicolonContent); + } + + /** + * Set the UrlPathHelper to use for resolution of lookup paths. + *

Use this to override the default UrlPathHelper with a custom subclass. + */ + public void setUrlPathHelper(UrlPathHelper urlPathHelper) { + Assert.notNull(urlPathHelper, "UrlPathHelper must not be null"); + this.urlPathHelper = urlPathHelper; + } + + /** + * Set CORS configuration based on URL patterns. + */ + public void setCorsConfigurations(@Nullable Map corsConfigurations) { + this.corsConfigurations.clear(); + if (corsConfigurations != null) { + this.corsConfigurations.putAll(corsConfigurations); + } + } + + /** + * Get the CORS configuration. + */ + public Map getCorsConfigurations() { + return Collections.unmodifiableMap(this.corsConfigurations); + } + + /** + * Register a {@link CorsConfiguration} for the specified path pattern. + */ + public void registerCorsConfiguration(String path, CorsConfiguration config) { + this.corsConfigurations.put(path, config); + } + + + @Override + @Nullable + public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { + String lookupPath = this.urlPathHelper.getLookupPathForRequest(request); + for (Map.Entry entry : this.corsConfigurations.entrySet()) { + if (this.pathMatcher.match(entry.getKey(), lookupPath)) { + return entry.getValue(); + } + } + return null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/package-info.java b/spring-web/src/main/java/org/springframework/web/cors/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..3c92402aed248f9eb2f19bceb1badbc954c6b42c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/package-info.java @@ -0,0 +1,10 @@ +/** + * Support for CORS (Cross-Origin Resource Sharing), + * based on a common {@code CorsProcessor} strategy. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.cors; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java new file mode 100644 index 0000000000000000000000000000000000000000..98a31159ef2b9240dda4bb40ea8d9321c53c901f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import org.springframework.lang.Nullable; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; + +/** + * Interface to be implemented by classes (usually HTTP request handlers) that + * provides a {@link CorsConfiguration} instance based on the provided reactive request. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface CorsConfigurationSource { + + /** + * Return a {@link CorsConfiguration} based on the incoming request. + * @return the associated {@link CorsConfiguration}, or {@code null} if none + */ + @Nullable + CorsConfiguration getCorsConfiguration(ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..c27a06575c9d05465a160aa71c3fd2d2850c93f3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import org.springframework.lang.Nullable; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; + +/** + * A strategy to apply CORS validation checks and updates to a + * {@link ServerWebExchange}, either rejecting through the response or adding + * CORS related headers, based on a pre-selected {@link CorsConfiguration}. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see CORS W3C recommendation + */ +public interface CorsProcessor { + + /** + * Process a request using the given {@code CorsConfiguration}. + * @param configuration the CORS configuration to use; possibly {@code null} + * in which case pre-flight requests are rejected, but all others allowed. + * @param exchange the current exchange + * @return {@code false} if the request was rejected, {@code true} otherwise + */ + boolean process(@Nullable CorsConfiguration configuration, ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..73107bb1dc7fb697b5a1c3b7b247032eed6b0c9a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java @@ -0,0 +1,97 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import java.net.URI; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Utility class for CORS reactive request handling based on the + * CORS W3C recommendation. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public abstract class CorsUtils { + + /** + * Returns {@code true} if the request is a valid CORS one. + */ + public static boolean isCorsRequest(ServerHttpRequest request) { + return (request.getHeaders().get(HttpHeaders.ORIGIN) != null); + } + + /** + * Returns {@code true} if the request is a valid CORS pre-flight one. + */ + public static boolean isPreFlightRequest(ServerHttpRequest request) { + return (request.getMethod() == HttpMethod.OPTIONS && isCorsRequest(request) && + request.getHeaders().get(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null); + } + + /** + * Check if the request is a same-origin one, based on {@code Origin}, and + * {@code Host} headers. + * + *

Note: as of 5.1 this method ignores + * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the + * client-originated address. Consider using the {@code ForwardedHeaderFilter} + * to extract and use, or to discard such headers. + * + * @return {@code true} if the request is a same-origin one, {@code false} in case + * of a cross-origin request + */ + public static boolean isSameOrigin(ServerHttpRequest request) { + String origin = request.getHeaders().getOrigin(); + if (origin == null) { + return true; + } + + URI uri = request.getURI(); + String actualScheme = uri.getScheme(); + String actualHost = uri.getHost(); + int actualPort = getPort(uri.getScheme(), uri.getPort()); + Assert.notNull(actualScheme, "Actual request scheme must not be null"); + Assert.notNull(actualHost, "Actual request host must not be null"); + Assert.isTrue(actualPort != -1, "Actual request port must not be undefined"); + + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); + return (actualScheme.equals(originUrl.getScheme()) && + actualHost.equals(originUrl.getHost()) && + actualPort == getPort(originUrl.getScheme(), originUrl.getPort())); + } + + private static int getPort(@Nullable String scheme, int port) { + if (port == -1) { + if ("http".equals(scheme) || "ws".equals(scheme)) { + port = 80; + } + else if ("https".equals(scheme) || "wss".equals(scheme)) { + port = 443; + } + } + return port; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..4938d7842e2184f1b6ae5c1224d446ad9d7763d6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import reactor.core.publisher.Mono; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; + + +/** + * {@link WebFilter} that handles CORS preflight requests and intercepts + * CORS simple and actual requests thanks to a {@link CorsProcessor} implementation + * ({@link DefaultCorsProcessor} by default) in order to add the relevant CORS + * response headers (like {@code Access-Control-Allow-Origin}) using the provided + * {@link CorsConfigurationSource} (for example an {@link UrlBasedCorsConfigurationSource} + * instance. + * + *

This is an alternative to Spring WebFlux Java config CORS configuration, + * mostly useful for applications using the functional API. + * + * @author Sebastien Deleuze + * @since 5.0 + * @see CORS W3C recommendation + */ +public class CorsWebFilter implements WebFilter { + + private final CorsConfigurationSource configSource; + + private final CorsProcessor processor; + + + /** + * Constructor accepting a {@link CorsConfigurationSource} used by the filter + * to find the {@link CorsConfiguration} to use for each incoming request. + * @see UrlBasedCorsConfigurationSource + */ + public CorsWebFilter(CorsConfigurationSource configSource) { + this(configSource, new DefaultCorsProcessor()); + } + + /** + * Constructor accepting a {@link CorsConfigurationSource} used by the filter + * to find the {@link CorsConfiguration} to use for each incoming request and a + * custom {@link CorsProcessor} to use to apply the matched + * {@link CorsConfiguration} for a request. + * @see UrlBasedCorsConfigurationSource + */ + public CorsWebFilter(CorsConfigurationSource configSource, CorsProcessor processor) { + Assert.notNull(configSource, "CorsConfigurationSource must not be null"); + Assert.notNull(processor, "CorsProcessor must not be null"); + this.configSource = configSource; + this.processor = processor; + } + + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + ServerHttpRequest request = exchange.getRequest(); + if (CorsUtils.isCorsRequest(request)) { + CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange); + if (corsConfiguration != null) { + boolean isValid = this.processor.process(corsConfiguration, exchange); + if (!isValid || CorsUtils.isPreFlightRequest(request)) { + return Mono.empty(); + } + } + } + return chain.filter(exchange); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..f52d9be190e7e1cc3b7ea0c8fb1badd5fa135dc3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java @@ -0,0 +1,203 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; + +/** + * The default implementation of {@link CorsProcessor}, + * as defined by the CORS W3C recommendation. + * + *

Note that when input {@link CorsConfiguration} is {@code null}, this + * implementation does not reject simple or actual requests outright but simply + * avoid adding CORS headers to the response. CORS processing is also skipped + * if the response already contains CORS headers, or if the request is detected + * as a same-origin one. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class DefaultCorsProcessor implements CorsProcessor { + + private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class); + + + @Override + public boolean process(@Nullable CorsConfiguration config, ServerWebExchange exchange) { + + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + + if (!CorsUtils.isCorsRequest(request)) { + return true; + } + + if (responseHasCors(response)) { + logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\""); + return true; + } + + if (CorsUtils.isSameOrigin(request)) { + logger.trace("Skip: request is from same origin"); + return true; + } + + boolean preFlightRequest = CorsUtils.isPreFlightRequest(request); + if (config == null) { + if (preFlightRequest) { + rejectRequest(response); + return false; + } + else { + return true; + } + } + + return handleInternal(exchange, config, preFlightRequest); + } + + private boolean responseHasCors(ServerHttpResponse response) { + return response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null; + } + + /** + * Invoked when one of the CORS checks failed. + */ + protected void rejectRequest(ServerHttpResponse response) { + response.setStatusCode(HttpStatus.FORBIDDEN); + } + + /** + * Handle the given request. + */ + protected boolean handleInternal(ServerWebExchange exchange, + CorsConfiguration config, boolean preFlightRequest) { + + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + HttpHeaders responseHeaders = response.getHeaders(); + + response.getHeaders().addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + + String requestOrigin = request.getHeaders().getOrigin(); + String allowOrigin = checkOrigin(config, requestOrigin); + if (allowOrigin == null) { + logger.debug("Reject: '" + requestOrigin + "' origin is not allowed"); + rejectRequest(response); + return false; + } + + HttpMethod requestMethod = getMethodToUse(request, preFlightRequest); + List allowMethods = checkMethods(config, requestMethod); + if (allowMethods == null) { + logger.debug("Reject: HTTP '" + requestMethod + "' is not allowed"); + rejectRequest(response); + return false; + } + + List requestHeaders = getHeadersToUse(request, preFlightRequest); + List allowHeaders = checkHeaders(config, requestHeaders); + if (preFlightRequest && allowHeaders == null) { + logger.debug("Reject: headers '" + requestHeaders + "' are not allowed"); + rejectRequest(response); + return false; + } + + responseHeaders.setAccessControlAllowOrigin(allowOrigin); + + if (preFlightRequest) { + responseHeaders.setAccessControlAllowMethods(allowMethods); + } + + if (preFlightRequest && !allowHeaders.isEmpty()) { + responseHeaders.setAccessControlAllowHeaders(allowHeaders); + } + + if (!CollectionUtils.isEmpty(config.getExposedHeaders())) { + responseHeaders.setAccessControlExposeHeaders(config.getExposedHeaders()); + } + + if (Boolean.TRUE.equals(config.getAllowCredentials())) { + responseHeaders.setAccessControlAllowCredentials(true); + } + + if (preFlightRequest && config.getMaxAge() != null) { + responseHeaders.setAccessControlMaxAge(config.getMaxAge()); + } + + return true; + } + + /** + * Check the origin and determine the origin for the response. The default + * implementation simply delegates to + * {@link CorsConfiguration#checkOrigin(String)}. + */ + @Nullable + protected String checkOrigin(CorsConfiguration config, @Nullable String requestOrigin) { + return config.checkOrigin(requestOrigin); + } + + /** + * Check the HTTP method and determine the methods for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link CorsConfiguration#checkOrigin(String)}. + */ + @Nullable + protected List checkMethods(CorsConfiguration config, @Nullable HttpMethod requestMethod) { + return config.checkHttpMethod(requestMethod); + } + + @Nullable + private HttpMethod getMethodToUse(ServerHttpRequest request, boolean isPreFlight) { + return (isPreFlight ? request.getHeaders().getAccessControlRequestMethod() : request.getMethod()); + } + + /** + * Check the headers and determine the headers for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link CorsConfiguration#checkOrigin(String)}. + */ + @Nullable + + protected List checkHeaders(CorsConfiguration config, List requestHeaders) { + return config.checkHeaders(requestHeaders); + } + + private List getHeadersToUse(ServerHttpRequest request, boolean isPreFlight) { + HttpHeaders headers = request.getHeaders(); + return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList<>(headers.keySet())); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java new file mode 100644 index 0000000000000000000000000000000000000000..2b97a1c7a866bbc73e5ae0f0b7f59b190cb53885 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.http.server.PathContainer; +import org.springframework.lang.Nullable; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.pattern.PathPattern; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * Provide a per reactive request {@link CorsConfiguration} instance based on a + * collection of {@link CorsConfiguration} mapped on path patterns. + * + *

Exact path mapping URIs (such as {@code "/admin"}) are supported + * as well as Ant-style path patterns (such as {@code "/admin/**"}). + * + * @author Sebastien Deleuze + * @author Brian Clozel + * @since 5.0 + */ +public class UrlBasedCorsConfigurationSource implements CorsConfigurationSource { + + private final Map corsConfigurations; + + private final PathPatternParser patternParser; + + + /** + * Construct a new {@code UrlBasedCorsConfigurationSource} instance with default + * {@code PathPatternParser}. + * @since 5.0.6 + */ + public UrlBasedCorsConfigurationSource() { + this(new PathPatternParser()); + } + + /** + * Construct a new {@code UrlBasedCorsConfigurationSource} instance from the supplied + * {@code PathPatternParser}. + */ + public UrlBasedCorsConfigurationSource(PathPatternParser patternParser) { + this.corsConfigurations = new LinkedHashMap<>(); + this.patternParser = patternParser; + } + + + /** + * Set CORS configuration based on URL patterns. + */ + public void setCorsConfigurations(@Nullable Map corsConfigurations) { + this.corsConfigurations.clear(); + if (corsConfigurations != null) { + corsConfigurations.forEach(this::registerCorsConfiguration); + } + } + + /** + * Register a {@link CorsConfiguration} for the specified path pattern. + */ + public void registerCorsConfiguration(String path, CorsConfiguration config) { + this.corsConfigurations.put(this.patternParser.parse(path), config); + } + + @Override + @Nullable + public CorsConfiguration getCorsConfiguration(ServerWebExchange exchange) { + PathContainer lookupPath = exchange.getRequest().getPath().pathWithinApplication(); + return this.corsConfigurations.entrySet().stream() + .filter(entry -> entry.getKey().matches(lookupPath)) + .map(Map.Entry::getValue) + .findFirst() + .orElse(null); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/package-info.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..0cca26c2a03db38af20a533f06d0abdedc95b543 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/package-info.java @@ -0,0 +1,10 @@ +/** + * Reactive support for CORS (Cross-Origin Resource Sharing), + * based on a common {@code CorsProcessor} strategy. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.cors.reactive; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..f7c9a1d3101c378763cf3995506fb901a7e1056a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java @@ -0,0 +1,395 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.ContentCachingRequestWrapper; +import org.springframework.web.util.WebUtils; + +/** + * Base class for {@code Filter}s that perform logging operations before and after a request + * is processed. + * + *

Subclasses should override the {@code beforeRequest(HttpServletRequest, String)} and + * {@code afterRequest(HttpServletRequest, String)} methods to perform the actual logging + * around the request. + * + *

Subclasses are passed the message to write to the log in the {@code beforeRequest} and + * {@code afterRequest} methods. By default, only the URI of the request is logged. However, + * setting the {@code includeQueryString} property to {@code true} will cause the query string of + * the request to be included also; this can be further extended through {@code includeClientInfo} + * and {@code includeHeaders}. The payload (body content) of the request can be logged via the + * {@code includePayload} flag: Note that this will only log the part of the payload which has + * actually been read, not necessarily the entire body of the request. + * + *

Prefixes and suffixes for the before and after messages can be configured using the + * {@code beforeMessagePrefix}, {@code afterMessagePrefix}, {@code beforeMessageSuffix} and + * {@code afterMessageSuffix} properties. + * + * @author Rob Harrop + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 1.2.5 + * @see #beforeRequest + * @see #afterRequest + */ +public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter { + + /** + * The default value prepended to the log message written before a request is + * processed. + */ + public static final String DEFAULT_BEFORE_MESSAGE_PREFIX = "Before request ["; + + /** + * The default value appended to the log message written before a request is + * processed. + */ + public static final String DEFAULT_BEFORE_MESSAGE_SUFFIX = "]"; + + /** + * The default value prepended to the log message written after a request is + * processed. + */ + public static final String DEFAULT_AFTER_MESSAGE_PREFIX = "After request ["; + + /** + * The default value appended to the log message written after a request is + * processed. + */ + public static final String DEFAULT_AFTER_MESSAGE_SUFFIX = "]"; + + private static final int DEFAULT_MAX_PAYLOAD_LENGTH = 50; + + + private boolean includeQueryString = false; + + private boolean includeClientInfo = false; + + private boolean includeHeaders = false; + + private boolean includePayload = false; + + private int maxPayloadLength = DEFAULT_MAX_PAYLOAD_LENGTH; + + private String beforeMessagePrefix = DEFAULT_BEFORE_MESSAGE_PREFIX; + + private String beforeMessageSuffix = DEFAULT_BEFORE_MESSAGE_SUFFIX; + + private String afterMessagePrefix = DEFAULT_AFTER_MESSAGE_PREFIX; + + private String afterMessageSuffix = DEFAULT_AFTER_MESSAGE_SUFFIX; + + + /** + * Set whether the query string should be included in the log message. + *

Should be configured using an {@code } for parameter name + * "includeQueryString" in the filter definition in {@code web.xml}. + */ + public void setIncludeQueryString(boolean includeQueryString) { + this.includeQueryString = includeQueryString; + } + + /** + * Return whether the query string should be included in the log message. + */ + protected boolean isIncludeQueryString() { + return this.includeQueryString; + } + + /** + * Set whether the client address and session id should be included in the + * log message. + *

Should be configured using an {@code } for parameter name + * "includeClientInfo" in the filter definition in {@code web.xml}. + */ + public void setIncludeClientInfo(boolean includeClientInfo) { + this.includeClientInfo = includeClientInfo; + } + + /** + * Return whether the client address and session id should be included in the + * log message. + */ + protected boolean isIncludeClientInfo() { + return this.includeClientInfo; + } + + /** + * Set whether the request headers should be included in the log message. + *

Should be configured using an {@code } for parameter name + * "includeHeaders" in the filter definition in {@code web.xml}. + * @since 4.3 + */ + public void setIncludeHeaders(boolean includeHeaders) { + this.includeHeaders = includeHeaders; + } + + /** + * Return whether the request headers should be included in the log message. + * @since 4.3 + */ + protected boolean isIncludeHeaders() { + return this.includeHeaders; + } + + /** + * Set whether the request payload (body) should be included in the log message. + *

Should be configured using an {@code } for parameter name + * "includePayload" in the filter definition in {@code web.xml}. + * @since 3.0 + */ + public void setIncludePayload(boolean includePayload) { + this.includePayload = includePayload; + } + + /** + * Return whether the request payload (body) should be included in the log message. + * @since 3.0 + */ + protected boolean isIncludePayload() { + return this.includePayload; + } + + /** + * Set the maximum length of the payload body to be included in the log message. + * Default is 50 characters. + * @since 3.0 + */ + public void setMaxPayloadLength(int maxPayloadLength) { + Assert.isTrue(maxPayloadLength >= 0, "'maxPayloadLength' should be larger than or equal to 0"); + this.maxPayloadLength = maxPayloadLength; + } + + /** + * Return the maximum length of the payload body to be included in the log message. + * @since 3.0 + */ + protected int getMaxPayloadLength() { + return this.maxPayloadLength; + } + + /** + * Set the value that should be prepended to the log message written + * before a request is processed. + */ + public void setBeforeMessagePrefix(String beforeMessagePrefix) { + this.beforeMessagePrefix = beforeMessagePrefix; + } + + /** + * Set the value that should be appended to the log message written + * before a request is processed. + */ + public void setBeforeMessageSuffix(String beforeMessageSuffix) { + this.beforeMessageSuffix = beforeMessageSuffix; + } + + /** + * Set the value that should be prepended to the log message written + * after a request is processed. + */ + public void setAfterMessagePrefix(String afterMessagePrefix) { + this.afterMessagePrefix = afterMessagePrefix; + } + + /** + * Set the value that should be appended to the log message written + * after a request is processed. + */ + public void setAfterMessageSuffix(String afterMessageSuffix) { + this.afterMessageSuffix = afterMessageSuffix; + } + + + /** + * The default value is "false" so that the filter may log a "before" message + * at the start of request processing and an "after" message at the end from + * when the last asynchronously dispatched thread is exiting. + */ + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + + /** + * Forwards the request to the next filter in the chain and delegates down to the subclasses + * to perform the actual request logging both before and after the request is processed. + * @see #beforeRequest + * @see #afterRequest + */ + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + boolean isFirstRequest = !isAsyncDispatch(request); + HttpServletRequest requestToUse = request; + + if (isIncludePayload() && isFirstRequest && !(request instanceof ContentCachingRequestWrapper)) { + requestToUse = new ContentCachingRequestWrapper(request, getMaxPayloadLength()); + } + + boolean shouldLog = shouldLog(requestToUse); + if (shouldLog && isFirstRequest) { + beforeRequest(requestToUse, getBeforeMessage(requestToUse)); + } + try { + filterChain.doFilter(requestToUse, response); + } + finally { + if (shouldLog && !isAsyncStarted(requestToUse)) { + afterRequest(requestToUse, getAfterMessage(requestToUse)); + } + } + } + + /** + * Get the message to write to the log before the request. + * @see #createMessage + */ + private String getBeforeMessage(HttpServletRequest request) { + return createMessage(request, this.beforeMessagePrefix, this.beforeMessageSuffix); + } + + /** + * Get the message to write to the log after the request. + * @see #createMessage + */ + private String getAfterMessage(HttpServletRequest request) { + return createMessage(request, this.afterMessagePrefix, this.afterMessageSuffix); + } + + /** + * Create a log message for the given request, prefix and suffix. + *

If {@code includeQueryString} is {@code true}, then the inner part + * of the log message will take the form {@code request_uri?query_string}; + * otherwise the message will simply be of the form {@code request_uri}. + *

The final message is composed of the inner part as described and + * the supplied prefix and suffix. + */ + protected String createMessage(HttpServletRequest request, String prefix, String suffix) { + StringBuilder msg = new StringBuilder(); + msg.append(prefix); + msg.append("uri=").append(request.getRequestURI()); + + if (isIncludeQueryString()) { + String queryString = request.getQueryString(); + if (queryString != null) { + msg.append('?').append(queryString); + } + } + + if (isIncludeClientInfo()) { + String client = request.getRemoteAddr(); + if (StringUtils.hasLength(client)) { + msg.append(";client=").append(client); + } + HttpSession session = request.getSession(false); + if (session != null) { + msg.append(";session=").append(session.getId()); + } + String user = request.getRemoteUser(); + if (user != null) { + msg.append(";user=").append(user); + } + } + + if (isIncludeHeaders()) { + msg.append(";headers=").append(new ServletServerHttpRequest(request).getHeaders()); + } + + if (isIncludePayload()) { + String payload = getMessagePayload(request); + if (payload != null) { + msg.append(";payload=").append(payload); + } + } + + msg.append(suffix); + return msg.toString(); + } + + /** + * Extracts the message payload portion of the message created by + * {@link #createMessage(HttpServletRequest, String, String)} when + * {@link #isIncludePayload()} returns true. + * @since 5.0.3 + */ + @Nullable + protected String getMessagePayload(HttpServletRequest request) { + ContentCachingRequestWrapper wrapper = + WebUtils.getNativeRequest(request, ContentCachingRequestWrapper.class); + if (wrapper != null) { + byte[] buf = wrapper.getContentAsByteArray(); + if (buf.length > 0) { + int length = Math.min(buf.length, getMaxPayloadLength()); + try { + return new String(buf, 0, length, wrapper.getCharacterEncoding()); + } + catch (UnsupportedEncodingException ex) { + return "[unknown]"; + } + } + } + return null; + } + + + /** + * Determine whether to call the {@link #beforeRequest}/{@link #afterRequest} + * methods for the current request, i.e. whether logging is currently active + * (and the log message is worth building). + *

The default implementation always returns {@code true}. Subclasses may + * override this with a log level check. + * @param request current HTTP request + * @return {@code true} if the before/after method should get called; + * {@code false} otherwise + * @since 4.1.5 + */ + protected boolean shouldLog(HttpServletRequest request) { + return true; + } + + /** + * Concrete subclasses should implement this method to write a log message + * before the request is processed. + * @param request current HTTP request + * @param message the message to log + */ + protected abstract void beforeRequest(HttpServletRequest request, String message); + + /** + * Concrete subclasses should implement this method to write a log message + * after the request is processed. + * @param request current HTTP request + * @param message the message to log + */ + protected abstract void afterRequest(HttpServletRequest request, String message); + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/CharacterEncodingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/CharacterEncodingFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..be97ae84aa9f161e1b09113100570ea2aa0d01b2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/CharacterEncodingFilter.java @@ -0,0 +1,204 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Servlet Filter that allows one to specify a character encoding for requests. + * This is useful because current browsers typically do not set a character + * encoding even if specified in the HTML page or form. + * + *

This filter can either apply its encoding if the request does not already + * specify an encoding, or enforce this filter's encoding in any case + * ("forceEncoding"="true"). In the latter case, the encoding will also be + * applied as default response encoding (although this will usually be overridden + * by a full content type set in the view). + * + * @author Juergen Hoeller + * @since 15.03.2004 + * @see #setEncoding + * @see #setForceEncoding + * @see javax.servlet.http.HttpServletRequest#setCharacterEncoding + * @see javax.servlet.http.HttpServletResponse#setCharacterEncoding + */ +public class CharacterEncodingFilter extends OncePerRequestFilter { + + @Nullable + private String encoding; + + private boolean forceRequestEncoding = false; + + private boolean forceResponseEncoding = false; + + + /** + * Create a default {@code CharacterEncodingFilter}, + * with the encoding to be set via {@link #setEncoding}. + * @see #setEncoding + */ + public CharacterEncodingFilter() { + } + + /** + * Create a {@code CharacterEncodingFilter} for the given encoding. + * @param encoding the encoding to apply + * @since 4.2.3 + * @see #setEncoding + */ + public CharacterEncodingFilter(String encoding) { + this(encoding, false); + } + + /** + * Create a {@code CharacterEncodingFilter} for the given encoding. + * @param encoding the encoding to apply + * @param forceEncoding whether the specified encoding is supposed to + * override existing request and response encodings + * @since 4.2.3 + * @see #setEncoding + * @see #setForceEncoding + */ + public CharacterEncodingFilter(String encoding, boolean forceEncoding) { + this(encoding, forceEncoding, forceEncoding); + } + + /** + * Create a {@code CharacterEncodingFilter} for the given encoding. + * @param encoding the encoding to apply + * @param forceRequestEncoding whether the specified encoding is supposed to + * override existing request encodings + * @param forceResponseEncoding whether the specified encoding is supposed to + * override existing response encodings + * @since 4.3 + * @see #setEncoding + * @see #setForceRequestEncoding(boolean) + * @see #setForceResponseEncoding(boolean) + */ + public CharacterEncodingFilter(String encoding, boolean forceRequestEncoding, boolean forceResponseEncoding) { + Assert.hasLength(encoding, "Encoding must not be empty"); + this.encoding = encoding; + this.forceRequestEncoding = forceRequestEncoding; + this.forceResponseEncoding = forceResponseEncoding; + } + + + /** + * Set the encoding to use for requests. This encoding will be passed into a + * {@link javax.servlet.http.HttpServletRequest#setCharacterEncoding} call. + *

Whether this encoding will override existing request encodings + * (and whether it will be applied as default response encoding as well) + * depends on the {@link #setForceEncoding "forceEncoding"} flag. + */ + public void setEncoding(@Nullable String encoding) { + this.encoding = encoding; + } + + /** + * Return the configured encoding for requests and/or responses. + * @since 4.3 + */ + @Nullable + public String getEncoding() { + return this.encoding; + } + + /** + * Set whether the configured {@link #setEncoding encoding} of this filter + * is supposed to override existing request and response encodings. + *

Default is "false", i.e. do not modify the encoding if + * {@link javax.servlet.http.HttpServletRequest#getCharacterEncoding()} + * returns a non-null value. Switch this to "true" to enforce the specified + * encoding in any case, applying it as default response encoding as well. + *

This is the equivalent to setting both {@link #setForceRequestEncoding(boolean)} + * and {@link #setForceResponseEncoding(boolean)}. + * @see #setForceRequestEncoding(boolean) + * @see #setForceResponseEncoding(boolean) + */ + public void setForceEncoding(boolean forceEncoding) { + this.forceRequestEncoding = forceEncoding; + this.forceResponseEncoding = forceEncoding; + } + + /** + * Set whether the configured {@link #setEncoding encoding} of this filter + * is supposed to override existing request encodings. + *

Default is "false", i.e. do not modify the encoding if + * {@link javax.servlet.http.HttpServletRequest#getCharacterEncoding()} + * returns a non-null value. Switch this to "true" to enforce the specified + * encoding in any case. + * @since 4.3 + */ + public void setForceRequestEncoding(boolean forceRequestEncoding) { + this.forceRequestEncoding = forceRequestEncoding; + } + + /** + * Return whether the encoding should be forced on requests. + * @since 4.3 + */ + public boolean isForceRequestEncoding() { + return this.forceRequestEncoding; + } + + /** + * Set whether the configured {@link #setEncoding encoding} of this filter + * is supposed to override existing response encodings. + *

Default is "false", i.e. do not modify the encoding. + * Switch this to "true" to enforce the specified encoding + * for responses in any case. + * @since 4.3 + */ + public void setForceResponseEncoding(boolean forceResponseEncoding) { + this.forceResponseEncoding = forceResponseEncoding; + } + + /** + * Return whether the encoding should be forced on responses. + * @since 4.3 + */ + public boolean isForceResponseEncoding() { + return this.forceResponseEncoding; + } + + + @Override + protected void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + String encoding = getEncoding(); + if (encoding != null) { + if (isForceRequestEncoding() || request.getCharacterEncoding() == null) { + request.setCharacterEncoding(encoding); + } + if (isForceResponseEncoding()) { + response.setCharacterEncoding(encoding); + } + } + filterChain.doFilter(request, response); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/CommonsRequestLoggingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/CommonsRequestLoggingFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..564a47040bef7ca06f63b62e09313b7579a4b5a2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/CommonsRequestLoggingFilter.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import javax.servlet.http.HttpServletRequest; + +/** + * Simple request logging filter that writes the request URI + * (and optionally the query string) to the Commons Log. + * + * @author Rob Harrop + * @author Juergen Hoeller + * @since 1.2.5 + * @see #setIncludeQueryString + * @see #setBeforeMessagePrefix + * @see #setBeforeMessageSuffix + * @see #setAfterMessagePrefix + * @see #setAfterMessageSuffix + * @see org.apache.commons.logging.Log#debug(Object) + */ +public class CommonsRequestLoggingFilter extends AbstractRequestLoggingFilter { + + @Override + protected boolean shouldLog(HttpServletRequest request) { + return logger.isDebugEnabled(); + } + + /** + * Writes a log message before the request is processed. + */ + @Override + protected void beforeRequest(HttpServletRequest request, String message) { + logger.debug(message); + } + + /** + * Writes a log message after the request is processed. + */ + @Override + protected void afterRequest(HttpServletRequest request, String message) { + logger.debug(message); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/CompositeFilter.java b/spring-web/src/main/java/org/springframework/web/filter/CompositeFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..f8ec4e37a0bd3307cdd451256228158107d20e04 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/CompositeFilter.java @@ -0,0 +1,118 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +/** + * A generic composite servlet {@link Filter} that just delegates its behavior + * to a chain (list) of user-supplied filters, achieving the functionality of a + * {@link FilterChain}, but conveniently using only {@link Filter} instances. + * + *

This is useful for filters that require dependency injection, and can + * therefore be set up in a Spring application context. Typically, this + * composite would be used in conjunction with {@link DelegatingFilterProxy}, + * so that it can be declared in Spring but applied to a servlet context. + * + * @author Dave Syer + * @since 3.1 + */ +public class CompositeFilter implements Filter { + + private List filters = new ArrayList<>(); + + + public void setFilters(List filters) { + this.filters = new ArrayList<>(filters); + } + + + /** + * Initialize all the filters, calling each one's init method in turn in the order supplied. + * @see Filter#init(FilterConfig) + */ + @Override + public void init(FilterConfig config) throws ServletException { + for (Filter filter : this.filters) { + filter.init(config); + } + } + + /** + * Forms a temporary chain from the list of delegate filters supplied ({@link #setFilters}) + * and executes them in order. Each filter delegates to the next one in the list, achieving + * the normal behavior of a {@link FilterChain}, despite the fact that this is a {@link Filter}. + * @see Filter#doFilter(ServletRequest, ServletResponse, FilterChain) + */ + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + + new VirtualFilterChain(chain, this.filters).doFilter(request, response); + } + + /** + * Clean up all the filters supplied, calling each one's destroy method in turn, but in reverse order. + * @see Filter#init(FilterConfig) + */ + @Override + public void destroy() { + for (int i = this.filters.size(); i-- > 0;) { + Filter filter = this.filters.get(i); + filter.destroy(); + } + } + + + private static class VirtualFilterChain implements FilterChain { + + private final FilterChain originalChain; + + private final List additionalFilters; + + private int currentPosition = 0; + + public VirtualFilterChain(FilterChain chain, List additionalFilters) { + this.originalChain = chain; + this.additionalFilters = additionalFilters; + } + + @Override + public void doFilter(final ServletRequest request, final ServletResponse response) + throws IOException, ServletException { + + if (this.currentPosition == this.additionalFilters.size()) { + this.originalChain.doFilter(request, response); + } + else { + this.currentPosition++; + Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1); + nextFilter.doFilter(request, response, this); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/CorsFilter.java b/spring-web/src/main/java/org/springframework/web/filter/CorsFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..2b857985bf73f946765ddc442b90c9b2e4fd45f9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/CorsFilter.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.util.Assert; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.CorsConfigurationSource; +import org.springframework.web.cors.CorsProcessor; +import org.springframework.web.cors.CorsUtils; +import org.springframework.web.cors.DefaultCorsProcessor; +import org.springframework.web.cors.UrlBasedCorsConfigurationSource; + +/** + * {@link javax.servlet.Filter} that handles CORS preflight requests and intercepts + * CORS simple and actual requests thanks to a {@link CorsProcessor} implementation + * ({@link DefaultCorsProcessor} by default) in order to add the relevant CORS + * response headers (like {@code Access-Control-Allow-Origin}) using the provided + * {@link CorsConfigurationSource} (for example an {@link UrlBasedCorsConfigurationSource} + * instance. + * + *

This is an alternative to Spring MVC Java config and XML namespace CORS configuration, + * useful for applications depending only on spring-web (not on spring-webmvc) or for + * security constraints requiring CORS checks to be performed at {@link javax.servlet.Filter} + * level. + * + *

This filter could be used in conjunction with {@link DelegatingFilterProxy} in order + * to help with its initialization. + * + * @author Sebastien Deleuze + * @since 4.2 + * @see CORS W3C recommendation + */ +public class CorsFilter extends OncePerRequestFilter { + + private final CorsConfigurationSource configSource; + + private CorsProcessor processor = new DefaultCorsProcessor(); + + + /** + * Constructor accepting a {@link CorsConfigurationSource} used by the filter + * to find the {@link CorsConfiguration} to use for each incoming request. + * @see UrlBasedCorsConfigurationSource + */ + public CorsFilter(CorsConfigurationSource configSource) { + Assert.notNull(configSource, "CorsConfigurationSource must not be null"); + this.configSource = configSource; + } + + + /** + * Configure a custom {@link CorsProcessor} to use to apply the matched + * {@link CorsConfiguration} for a request. + *

By default {@link DefaultCorsProcessor} is used. + */ + public void setCorsProcessor(CorsProcessor processor) { + Assert.notNull(processor, "CorsProcessor must not be null"); + this.processor = processor; + } + + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + if (CorsUtils.isCorsRequest(request)) { + CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(request); + if (corsConfiguration != null) { + boolean isValid = this.processor.processRequest(corsConfiguration, request, response); + if (!isValid || CorsUtils.isPreFlightRequest(request)) { + return; + } + } + } + + filterChain.doFilter(request, response); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/DelegatingFilterProxy.java b/spring-web/src/main/java/org/springframework/web/filter/DelegatingFilterProxy.java new file mode 100644 index 0000000000000000000000000000000000000000..bfbe62739dcec307d37ee8b1a91bb44a1fcdccb7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/DelegatingFilterProxy.java @@ -0,0 +1,374 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; + +/** + * Proxy for a standard Servlet Filter, delegating to a Spring-managed bean that + * implements the Filter interface. Supports a "targetBeanName" filter init-param + * in {@code web.xml}, specifying the name of the target bean in the Spring + * application context. + * + *

{@code web.xml} will usually contain a {@code DelegatingFilterProxy} definition, + * with the specified {@code filter-name} corresponding to a bean name in + * Spring's root application context. All calls to the filter proxy will then + * be delegated to that bean in the Spring context, which is required to implement + * the standard Servlet Filter interface. + * + *

This approach is particularly useful for Filter implementation with complex + * setup needs, allowing to apply the full Spring bean definition machinery to + * Filter instances. Alternatively, consider standard Filter setup in combination + * with looking up service beans from the Spring root application context. + * + *

NOTE: The lifecycle methods defined by the Servlet Filter interface + * will by default not be delegated to the target bean, relying on the + * Spring application context to manage the lifecycle of that bean. Specifying + * the "targetFilterLifecycle" filter init-param as "true" will enforce invocation + * of the {@code Filter.init} and {@code Filter.destroy} lifecycle methods + * on the target bean, letting the servlet container manage the filter lifecycle. + * + *

As of Spring 3.1, {@code DelegatingFilterProxy} has been updated to optionally accept + * constructor parameters when using Servlet 3.0's instance-based filter registration + * methods, usually in conjunction with Spring 3.1's + * {@link org.springframework.web.WebApplicationInitializer} SPI. These constructors allow + * for providing the delegate Filter bean directly, or providing the application context + * and bean name to fetch, avoiding the need to look up the application context from the + * ServletContext. + * + *

This class was originally inspired by Spring Security's {@code FilterToBeanProxy} + * class, written by Ben Alex. + * + * @author Juergen Hoeller + * @author Sam Brannen + * @author Chris Beams + * @since 1.2 + * @see #setTargetBeanName + * @see #setTargetFilterLifecycle + * @see javax.servlet.Filter#doFilter + * @see javax.servlet.Filter#init + * @see javax.servlet.Filter#destroy + * @see #DelegatingFilterProxy(Filter) + * @see #DelegatingFilterProxy(String) + * @see #DelegatingFilterProxy(String, WebApplicationContext) + * @see javax.servlet.ServletContext#addFilter(String, Filter) + * @see org.springframework.web.WebApplicationInitializer + */ +public class DelegatingFilterProxy extends GenericFilterBean { + + @Nullable + private String contextAttribute; + + @Nullable + private WebApplicationContext webApplicationContext; + + @Nullable + private String targetBeanName; + + private boolean targetFilterLifecycle = false; + + @Nullable + private volatile Filter delegate; + + private final Object delegateMonitor = new Object(); + + + /** + * Create a new {@code DelegatingFilterProxy}. For traditional (pre-Servlet 3.0) use + * in {@code web.xml}. + * @see #setTargetBeanName(String) + */ + public DelegatingFilterProxy() { + } + + /** + * Create a new {@code DelegatingFilterProxy} with the given {@link Filter} delegate. + * Bypasses entirely the need for interacting with a Spring application context, + * specifying the {@linkplain #setTargetBeanName target bean name}, etc. + *

For use in Servlet 3.0+ environments where instance-based registration of + * filters is supported. + * @param delegate the {@code Filter} instance that this proxy will delegate to and + * manage the lifecycle for (must not be {@code null}). + * @see #doFilter(ServletRequest, ServletResponse, FilterChain) + * @see #invokeDelegate(Filter, ServletRequest, ServletResponse, FilterChain) + * @see #destroy() + * @see #setEnvironment(org.springframework.core.env.Environment) + */ + public DelegatingFilterProxy(Filter delegate) { + Assert.notNull(delegate, "Delegate Filter must not be null"); + this.delegate = delegate; + } + + /** + * Create a new {@code DelegatingFilterProxy} that will retrieve the named target + * bean from the Spring {@code WebApplicationContext} found in the {@code ServletContext} + * (either the 'root' application context or the context named by + * {@link #setContextAttribute}). + *

For use in Servlet 3.0+ environments where instance-based registration of + * filters is supported. + *

The target bean must implement the standard Servlet Filter. + * @param targetBeanName name of the target filter bean to look up in the Spring + * application context (must not be {@code null}). + * @see #findWebApplicationContext() + * @see #setEnvironment(org.springframework.core.env.Environment) + */ + public DelegatingFilterProxy(String targetBeanName) { + this(targetBeanName, null); + } + + /** + * Create a new {@code DelegatingFilterProxy} that will retrieve the named target + * bean from the given Spring {@code WebApplicationContext}. + *

For use in Servlet 3.0+ environments where instance-based registration of + * filters is supported. + *

The target bean must implement the standard Servlet Filter interface. + *

The given {@code WebApplicationContext} may or may not be refreshed when passed + * in. If it has not, and if the context implements {@link ConfigurableApplicationContext}, + * a {@link ConfigurableApplicationContext#refresh() refresh()} will be attempted before + * retrieving the named target bean. + *

This proxy's {@code Environment} will be inherited from the given + * {@code WebApplicationContext}. + * @param targetBeanName name of the target filter bean in the Spring application + * context (must not be {@code null}). + * @param wac the application context from which the target filter will be retrieved; + * if {@code null}, an application context will be looked up from {@code ServletContext} + * as a fallback. + * @see #findWebApplicationContext() + * @see #setEnvironment(org.springframework.core.env.Environment) + */ + public DelegatingFilterProxy(String targetBeanName, @Nullable WebApplicationContext wac) { + Assert.hasText(targetBeanName, "Target Filter bean name must not be null or empty"); + this.setTargetBeanName(targetBeanName); + this.webApplicationContext = wac; + if (wac != null) { + this.setEnvironment(wac.getEnvironment()); + } + } + + /** + * Set the name of the ServletContext attribute which should be used to retrieve the + * {@link WebApplicationContext} from which to load the delegate {@link Filter} bean. + */ + public void setContextAttribute(@Nullable String contextAttribute) { + this.contextAttribute = contextAttribute; + } + + /** + * Return the name of the ServletContext attribute which should be used to retrieve the + * {@link WebApplicationContext} from which to load the delegate {@link Filter} bean. + */ + @Nullable + public String getContextAttribute() { + return this.contextAttribute; + } + + /** + * Set the name of the target bean in the Spring application context. + * The target bean must implement the standard Servlet Filter interface. + *

By default, the {@code filter-name} as specified for the + * DelegatingFilterProxy in {@code web.xml} will be used. + */ + public void setTargetBeanName(@Nullable String targetBeanName) { + this.targetBeanName = targetBeanName; + } + + /** + * Return the name of the target bean in the Spring application context. + */ + @Nullable + protected String getTargetBeanName() { + return this.targetBeanName; + } + + /** + * Set whether to invoke the {@code Filter.init} and + * {@code Filter.destroy} lifecycle methods on the target bean. + *

Default is "false"; target beans usually rely on the Spring application + * context for managing their lifecycle. Setting this flag to "true" means + * that the servlet container will control the lifecycle of the target + * Filter, with this proxy delegating the corresponding calls. + */ + public void setTargetFilterLifecycle(boolean targetFilterLifecycle) { + this.targetFilterLifecycle = targetFilterLifecycle; + } + + /** + * Return whether to invoke the {@code Filter.init} and + * {@code Filter.destroy} lifecycle methods on the target bean. + */ + protected boolean isTargetFilterLifecycle() { + return this.targetFilterLifecycle; + } + + + @Override + protected void initFilterBean() throws ServletException { + synchronized (this.delegateMonitor) { + if (this.delegate == null) { + // If no target bean name specified, use filter name. + if (this.targetBeanName == null) { + this.targetBeanName = getFilterName(); + } + // Fetch Spring root application context and initialize the delegate early, + // if possible. If the root application context will be started after this + // filter proxy, we'll have to resort to lazy initialization. + WebApplicationContext wac = findWebApplicationContext(); + if (wac != null) { + this.delegate = initDelegate(wac); + } + } + } + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + // Lazily initialize the delegate if necessary. + Filter delegateToUse = this.delegate; + if (delegateToUse == null) { + synchronized (this.delegateMonitor) { + delegateToUse = this.delegate; + if (delegateToUse == null) { + WebApplicationContext wac = findWebApplicationContext(); + if (wac == null) { + throw new IllegalStateException("No WebApplicationContext found: " + + "no ContextLoaderListener or DispatcherServlet registered?"); + } + delegateToUse = initDelegate(wac); + } + this.delegate = delegateToUse; + } + } + + // Let the delegate perform the actual doFilter operation. + invokeDelegate(delegateToUse, request, response, filterChain); + } + + @Override + public void destroy() { + Filter delegateToUse = this.delegate; + if (delegateToUse != null) { + destroyDelegate(delegateToUse); + } + } + + + /** + * Return the {@code WebApplicationContext} passed in at construction time, if available. + * Otherwise, attempt to retrieve a {@code WebApplicationContext} from the + * {@code ServletContext} attribute with the {@linkplain #setContextAttribute + * configured name} if set. Otherwise look up a {@code WebApplicationContext} under + * the well-known "root" application context attribute. The + * {@code WebApplicationContext} must have already been loaded and stored in the + * {@code ServletContext} before this filter gets initialized (or invoked). + *

Subclasses may override this method to provide a different + * {@code WebApplicationContext} retrieval strategy. + * @return the {@code WebApplicationContext} for this proxy, or {@code null} if not found + * @see #DelegatingFilterProxy(String, WebApplicationContext) + * @see #getContextAttribute() + * @see WebApplicationContextUtils#getWebApplicationContext(javax.servlet.ServletContext) + * @see WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE + */ + @Nullable + protected WebApplicationContext findWebApplicationContext() { + if (this.webApplicationContext != null) { + // The user has injected a context at construction time -> use it... + if (this.webApplicationContext instanceof ConfigurableApplicationContext) { + ConfigurableApplicationContext cac = (ConfigurableApplicationContext) this.webApplicationContext; + if (!cac.isActive()) { + // The context has not yet been refreshed -> do so before returning it... + cac.refresh(); + } + } + return this.webApplicationContext; + } + String attrName = getContextAttribute(); + if (attrName != null) { + return WebApplicationContextUtils.getWebApplicationContext(getServletContext(), attrName); + } + else { + return WebApplicationContextUtils.findWebApplicationContext(getServletContext()); + } + } + + /** + * Initialize the Filter delegate, defined as bean the given Spring + * application context. + *

The default implementation fetches the bean from the application context + * and calls the standard {@code Filter.init} method on it, passing + * in the FilterConfig of this Filter proxy. + * @param wac the root application context + * @return the initialized delegate Filter + * @throws ServletException if thrown by the Filter + * @see #getTargetBeanName() + * @see #isTargetFilterLifecycle() + * @see #getFilterConfig() + * @see javax.servlet.Filter#init(javax.servlet.FilterConfig) + */ + protected Filter initDelegate(WebApplicationContext wac) throws ServletException { + String targetBeanName = getTargetBeanName(); + Assert.state(targetBeanName != null, "No target bean name set"); + Filter delegate = wac.getBean(targetBeanName, Filter.class); + if (isTargetFilterLifecycle()) { + delegate.init(getFilterConfig()); + } + return delegate; + } + + /** + * Actually invoke the delegate Filter with the given request and response. + * @param delegate the delegate Filter + * @param request the current HTTP request + * @param response the current HTTP response + * @param filterChain the current FilterChain + * @throws ServletException if thrown by the Filter + * @throws IOException if thrown by the Filter + */ + protected void invokeDelegate( + Filter delegate, ServletRequest request, ServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + delegate.doFilter(request, response, filterChain); + } + + /** + * Destroy the Filter delegate. + * Default implementation simply calls {@code Filter.destroy} on it. + * @param delegate the Filter delegate (never {@code null}) + * @see #isTargetFilterLifecycle() + * @see javax.servlet.Filter#destroy() + */ + protected void destroyDelegate(Filter delegate) { + if (isTargetFilterLifecycle()) { + delegate.destroy(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/FormContentFilter.java b/spring-web/src/main/java/org/springframework/web/filter/FormContentFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..41eb89b10e60cba9e4c10d1581e814fbd7c8b5d9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/FormContentFilter.java @@ -0,0 +1,182 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * {@code Filter} that parses form data for HTTP PUT, PATCH, and DELETE requests + * and exposes it as Servlet request parameters. By default the Servlet spec + * only requires this for HTTP POST. + * + * @author Rossen Stoyanchev + * @since 5.1 + */ +public class FormContentFilter extends OncePerRequestFilter { + + private static final List HTTP_METHODS = Arrays.asList("PUT", "PATCH", "DELETE"); + + private FormHttpMessageConverter formConverter = new AllEncompassingFormHttpMessageConverter(); + + + /** + * Set the converter to use for parsing form content. + *

By default this is an instance of {@link AllEncompassingFormHttpMessageConverter}. + */ + public void setFormConverter(FormHttpMessageConverter converter) { + Assert.notNull(converter, "FormHttpMessageConverter is required"); + this.formConverter = converter; + } + + /** + * The default character set to use for reading form data. + * This is a shortcut for:
+ * {@code getFormConverter.setCharset(charset)}. + */ + public void setCharset(Charset charset) { + this.formConverter.setCharset(charset); + } + + + @Override + protected void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + MultiValueMap params = parseIfNecessary(request); + if (!CollectionUtils.isEmpty(params)) { + filterChain.doFilter(new FormContentRequestWrapper(request, params), response); + } + else { + filterChain.doFilter(request, response); + } + } + + @Nullable + private MultiValueMap parseIfNecessary(HttpServletRequest request) throws IOException { + if (!shouldParse(request)) { + return null; + } + + HttpInputMessage inputMessage = new ServletServerHttpRequest(request) { + @Override + public InputStream getBody() throws IOException { + return request.getInputStream(); + } + }; + return this.formConverter.read(null, inputMessage); + } + + private boolean shouldParse(HttpServletRequest request) { + if (!HTTP_METHODS.contains(request.getMethod())) { + return false; + } + try { + MediaType mediaType = MediaType.parseMediaType(request.getContentType()); + return MediaType.APPLICATION_FORM_URLENCODED.includes(mediaType); + } + catch (IllegalArgumentException ex) { + return false; + } + } + + + private static class FormContentRequestWrapper extends HttpServletRequestWrapper { + + private MultiValueMap formParams; + + public FormContentRequestWrapper(HttpServletRequest request, MultiValueMap params) { + super(request); + this.formParams = params; + } + + @Override + @Nullable + public String getParameter(String name) { + String queryStringValue = super.getParameter(name); + String formValue = this.formParams.getFirst(name); + return (queryStringValue != null ? queryStringValue : formValue); + } + + @Override + public Map getParameterMap() { + Map result = new LinkedHashMap<>(); + Enumeration names = getParameterNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + result.put(name, getParameterValues(name)); + } + return result; + } + + @Override + public Enumeration getParameterNames() { + Set names = new LinkedHashSet<>(); + names.addAll(Collections.list(super.getParameterNames())); + names.addAll(this.formParams.keySet()); + return Collections.enumeration(names); + } + + @Override + @Nullable + public String[] getParameterValues(String name) { + String[] parameterValues = super.getParameterValues(name); + List formParam = this.formParams.get(name); + if (formParam == null) { + return parameterValues; + } + if (parameterValues == null || getQueryString() == null) { + return StringUtils.toStringArray(formParam); + } + else { + List result = new ArrayList<>(parameterValues.length + formParam.size()); + result.addAll(Arrays.asList(parameterValues)); + result.addAll(formParam); + return StringUtils.toStringArray(result); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..b9c4a4c3014fb9aa56edb793e02b7478a35882bc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -0,0 +1,430 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UrlPathHelper; + +/** + * Extract values from "Forwarded" and "X-Forwarded-*" headers, wrap the request + * and response, and make they reflect the client-originated protocol and + * address in the following methods: + *

    + *
  • {@link HttpServletRequest#getServerName() getServerName()} + *
  • {@link HttpServletRequest#getServerPort() getServerPort()} + *
  • {@link HttpServletRequest#getScheme() getScheme()} + *
  • {@link HttpServletRequest#isSecure() isSecure()} + *
  • {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}. + *
+ * + *

This filter can also be used in a {@link #setRemoveOnly removeOnly} mode + * where "Forwarded" and "X-Forwarded-*" headers are eliminated, and not used. + * + * @author Rossen Stoyanchev + * @author Eddú Meléndez + * @author Rob Winch + * @since 4.3 + * @see https://tools.ietf.org/html/rfc7239 + */ +public class ForwardedHeaderFilter extends OncePerRequestFilter { + + private static final Set FORWARDED_HEADER_NAMES = + Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(6, Locale.ENGLISH)); + + static { + FORWARDED_HEADER_NAMES.add("Forwarded"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Ssl"); + } + + + private boolean removeOnly; + + private boolean relativeRedirects; + + + /** + * Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are + * removed only and the information in them ignored. + * @param removeOnly whether to discard and ignore forwarded headers + * @since 4.3.9 + */ + public void setRemoveOnly(boolean removeOnly) { + this.removeOnly = removeOnly; + } + + /** + * Use this property to enable relative redirects as explained in + * {@link RelativeRedirectFilter}, and also using the same response wrapper + * as that filter does, or if both are configured, only one will wrap. + *

By default, if this property is set to false, in which case calls to + * {@link HttpServletResponse#sendRedirect(String)} are overridden in order + * to turn relative into absolute URLs, also taking into account forwarded + * headers. + * @param relativeRedirects whether to use relative redirects + * @since 4.3.10 + */ + public void setRelativeRedirects(boolean relativeRedirects) { + this.relativeRedirects = relativeRedirects; + } + + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) { + for (String headerName : FORWARDED_HEADER_NAMES) { + if (request.getHeader(headerName) != null) { + return false; + } + } + return true; + } + + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + + @Override + protected boolean shouldNotFilterErrorDispatch() { + return false; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + if (this.removeOnly) { + ForwardedHeaderRemovingRequest wrappedRequest = new ForwardedHeaderRemovingRequest(request); + filterChain.doFilter(wrappedRequest, response); + } + else { + HttpServletRequest wrappedRequest = + new ForwardedHeaderExtractingRequest(request); + + HttpServletResponse wrappedResponse = this.relativeRedirects ? + RelativeRedirectResponseWrapper.wrapIfNecessary(response, HttpStatus.SEE_OTHER) : + new ForwardedHeaderExtractingResponse(response, wrappedRequest); + + filterChain.doFilter(wrappedRequest, wrappedResponse); + } + } + + @Override + protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + doFilterInternal(request, response, filterChain); + } + + /** + * Hide "Forwarded" or "X-Forwarded-*" headers. + */ + private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper { + + private final Map> headers; + + public ForwardedHeaderRemovingRequest(HttpServletRequest request) { + super(request); + this.headers = initHeaders(request); + } + + private static Map> initHeaders(HttpServletRequest request) { + Map> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH); + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if (!FORWARDED_HEADER_NAMES.contains(name)) { + headers.put(name, Collections.list(request.getHeaders(name))); + } + } + return headers; + } + + // Override header accessors to not expose forwarded headers + + @Override + @Nullable + public String getHeader(String name) { + List value = this.headers.get(name); + return (CollectionUtils.isEmpty(value) ? null : value.get(0)); + } + + @Override + public Enumeration getHeaders(String name) { + List value = this.headers.get(name); + return (Collections.enumeration(value != null ? value : Collections.emptySet())); + } + + @Override + public Enumeration getHeaderNames() { + return Collections.enumeration(this.headers.keySet()); + } + } + + + /** + * Extract and use "Forwarded" or "X-Forwarded-*" headers. + */ + private static class ForwardedHeaderExtractingRequest extends ForwardedHeaderRemovingRequest { + + @Nullable + private final String scheme; + + private final boolean secure; + + @Nullable + private final String host; + + private final int port; + + private final ForwardedPrefixExtractor forwardedPrefixExtractor; + + + ForwardedHeaderExtractingRequest(HttpServletRequest request) { + super(request); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents uriComponents = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + int port = uriComponents.getPort(); + + this.scheme = uriComponents.getScheme(); + this.secure = "https".equals(this.scheme); + this.host = uriComponents.getHost(); + this.port = (port == -1 ? (this.secure ? 443 : 80) : port); + + String baseUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port); + Supplier delegateRequest = () -> (HttpServletRequest) getRequest(); + this.forwardedPrefixExtractor = new ForwardedPrefixExtractor(delegateRequest, baseUrl); + } + + + @Override + @Nullable + public String getScheme() { + return this.scheme; + } + + @Override + @Nullable + public String getServerName() { + return this.host; + } + + @Override + public int getServerPort() { + return this.port; + } + + @Override + public boolean isSecure() { + return this.secure; + } + + @Override + public String getContextPath() { + return this.forwardedPrefixExtractor.getContextPath(); + } + + @Override + public String getRequestURI() { + return this.forwardedPrefixExtractor.getRequestUri(); + } + + @Override + public StringBuffer getRequestURL() { + return this.forwardedPrefixExtractor.getRequestUrl(); + } + } + + + /** + * Responsible for the contextPath, requestURI, and requestURL with forwarded + * headers in mind, and also taking into account changes to the path of the + * underlying delegate request (e.g. on a Servlet FORWARD). + */ + private static class ForwardedPrefixExtractor { + + private final Supplier delegate; + + private final String baseUrl; + + private String actualRequestUri; + + @Nullable + private final String forwardedPrefix; + + @Nullable + private String requestUri; + + private String requestUrl; + + + /** + * Constructor with required information. + * @param delegateRequest supplier for the current + * {@link HttpServletRequestWrapper#getRequest() delegate request} which + * may change during a forward (e.g. Tomcat. + * @param baseUrl the host, scheme, and port based on forwarded headers + */ + public ForwardedPrefixExtractor(Supplier delegateRequest, String baseUrl) { + this.delegate = delegateRequest; + this.baseUrl = baseUrl; + this.actualRequestUri = delegateRequest.get().getRequestURI(); + + this.forwardedPrefix = initForwardedPrefix(delegateRequest.get()); + this.requestUri = initRequestUri(); + this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri + } + + @Nullable + private static String initForwardedPrefix(HttpServletRequest request) { + String result = null; + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) { + result = request.getHeader(name); + } + } + if (result != null) { + while (result.endsWith("/")) { + result = result.substring(0, result.length() - 1); + } + } + return result; + } + + @Nullable + private String initRequestUri() { + if (this.forwardedPrefix != null) { + return this.forwardedPrefix + + UrlPathHelper.rawPathInstance.getPathWithinApplication(this.delegate.get()); + } + return null; + } + + private String initRequestUrl() { + return this.baseUrl + (this.requestUri != null ? this.requestUri : this.delegate.get().getRequestURI()); + } + + + public String getContextPath() { + return this.forwardedPrefix == null ? this.delegate.get().getContextPath() : this.forwardedPrefix; + } + + public String getRequestUri() { + if (this.requestUri == null) { + return this.delegate.get().getRequestURI(); + } + recalculatePathsIfNecessary(); + return this.requestUri; + } + + public StringBuffer getRequestUrl() { + recalculatePathsIfNecessary(); + return new StringBuffer(this.requestUrl); + } + + private void recalculatePathsIfNecessary() { + if (!this.actualRequestUri.equals(this.delegate.get().getRequestURI())) { + // Underlying path change (e.g. Servlet FORWARD). + this.actualRequestUri = this.delegate.get().getRequestURI(); + this.requestUri = initRequestUri(); + this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri + } + } + } + + + private static class ForwardedHeaderExtractingResponse extends HttpServletResponseWrapper { + + private static final String FOLDER_SEPARATOR = "/"; + + private final HttpServletRequest request; + + + ForwardedHeaderExtractingResponse(HttpServletResponse response, HttpServletRequest request) { + super(response); + this.request = request; + } + + + @Override + public void sendRedirect(String location) throws IOException { + + UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(location); + UriComponents uriComponents = builder.build(); + + // Absolute location + if (uriComponents.getScheme() != null) { + super.sendRedirect(location); + return; + } + + // Network-path reference + if (location.startsWith("//")) { + String scheme = this.request.getScheme(); + super.sendRedirect(builder.scheme(scheme).toUriString()); + return; + } + + String path = uriComponents.getPath(); + if (path != null) { + // Relative to Servlet container root or to current request + path = (path.startsWith(FOLDER_SEPARATOR) ? path : + StringUtils.applyRelativePath(this.request.getRequestURI(), path)); + } + + String result = UriComponentsBuilder + .fromHttpRequest(new ServletServerHttpRequest(this.request)) + .replacePath(path) + .replaceQuery(uriComponents.getQuery()) + .fragment(uriComponents.getFragment()) + .build().normalize().toUriString(); + + super.sendRedirect(result); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/GenericFilterBean.java b/spring-web/src/main/java/org/springframework/web/filter/GenericFilterBean.java new file mode 100644 index 0000000000000000000000000000000000000000..da36fb123b32c7b000cb37666dde485426c6b60e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/GenericFilterBean.java @@ -0,0 +1,365 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Set; + +import javax.servlet.Filter; +import javax.servlet.FilterConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.BeanWrapper; +import org.springframework.beans.BeansException; +import org.springframework.beans.MutablePropertyValues; +import org.springframework.beans.PropertyAccessorFactory; +import org.springframework.beans.PropertyValue; +import org.springframework.beans.PropertyValues; +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.EnvironmentAware; +import org.springframework.core.env.Environment; +import org.springframework.core.env.EnvironmentCapable; +import org.springframework.core.io.Resource; +import org.springframework.core.io.ResourceEditor; +import org.springframework.core.io.ResourceLoader; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.context.ServletContextAware; +import org.springframework.web.context.support.ServletContextResourceLoader; +import org.springframework.web.context.support.StandardServletEnvironment; +import org.springframework.web.util.NestedServletException; + +/** + * Simple base implementation of {@link javax.servlet.Filter} which treats + * its config parameters ({@code init-param} entries within the + * {@code filter} tag in {@code web.xml}) as bean properties. + * + *

A handy superclass for any type of filter. Type conversion of config + * parameters is automatic, with the corresponding setter method getting + * invoked with the converted value. It is also possible for subclasses to + * specify required properties. Parameters without matching bean property + * setter will simply be ignored. + * + *

This filter leaves actual filtering to subclasses, which have to + * implement the {@link javax.servlet.Filter#doFilter} method. + * + *

This generic filter base class has no dependency on the Spring + * {@link org.springframework.context.ApplicationContext} concept. + * Filters usually don't load their own context but rather access service + * beans from the Spring root application context, accessible via the + * filter's {@link #getServletContext() ServletContext} (see + * {@link org.springframework.web.context.support.WebApplicationContextUtils}). + * + * @author Juergen Hoeller + * @since 06.12.2003 + * @see #addRequiredProperty + * @see #initFilterBean + * @see #doFilter + */ +public abstract class GenericFilterBean implements Filter, BeanNameAware, EnvironmentAware, + EnvironmentCapable, ServletContextAware, InitializingBean, DisposableBean { + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private String beanName; + + @Nullable + private Environment environment; + + @Nullable + private ServletContext servletContext; + + @Nullable + private FilterConfig filterConfig; + + private final Set requiredProperties = new HashSet<>(4); + + + /** + * Stores the bean name as defined in the Spring bean factory. + *

Only relevant in case of initialization as bean, to have a name as + * fallback to the filter name usually provided by a FilterConfig instance. + * @see org.springframework.beans.factory.BeanNameAware + * @see #getFilterName() + */ + @Override + public void setBeanName(String beanName) { + this.beanName = beanName; + } + + /** + * Set the {@code Environment} that this filter runs in. + *

Any environment set here overrides the {@link StandardServletEnvironment} + * provided by default. + *

This {@code Environment} object is used only for resolving placeholders in + * resource paths passed into init-parameters for this filter. If no init-params are + * used, this {@code Environment} can be essentially ignored. + */ + @Override + public void setEnvironment(Environment environment) { + this.environment = environment; + } + + /** + * Return the {@link Environment} associated with this filter. + *

If none specified, a default environment will be initialized via + * {@link #createEnvironment()}. + * @since 4.3.9 + */ + @Override + public Environment getEnvironment() { + if (this.environment == null) { + this.environment = createEnvironment(); + } + return this.environment; + } + + /** + * Create and return a new {@link StandardServletEnvironment}. + *

Subclasses may override this in order to configure the environment or + * specialize the environment type returned. + * @since 4.3.9 + */ + protected Environment createEnvironment() { + return new StandardServletEnvironment(); + } + + /** + * Stores the ServletContext that the bean factory runs in. + *

Only relevant in case of initialization as bean, to have a ServletContext + * as fallback to the context usually provided by a FilterConfig instance. + * @see org.springframework.web.context.ServletContextAware + * @see #getServletContext() + */ + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + + /** + * Calls the {@code initFilterBean()} method that might + * contain custom initialization of a subclass. + *

Only relevant in case of initialization as bean, where the + * standard {@code init(FilterConfig)} method won't be called. + * @see #initFilterBean() + * @see #init(javax.servlet.FilterConfig) + */ + @Override + public void afterPropertiesSet() throws ServletException { + initFilterBean(); + } + + /** + * Subclasses may override this to perform custom filter shutdown. + *

Note: This method will be called from standard filter destruction + * as well as filter bean destruction in a Spring application context. + *

This default implementation is empty. + */ + @Override + public void destroy() { + } + + + /** + * Subclasses can invoke this method to specify that this property + * (which must match a JavaBean property they expose) is mandatory, + * and must be supplied as a config parameter. This should be called + * from the constructor of a subclass. + *

This method is only relevant in case of traditional initialization + * driven by a FilterConfig instance. + * @param property name of the required property + */ + protected final void addRequiredProperty(String property) { + this.requiredProperties.add(property); + } + + /** + * Standard way of initializing this filter. + * Map config parameters onto bean properties of this filter, and + * invoke subclass initialization. + * @param filterConfig the configuration for this filter + * @throws ServletException if bean properties are invalid (or required + * properties are missing), or if subclass initialization fails. + * @see #initFilterBean + */ + @Override + public final void init(FilterConfig filterConfig) throws ServletException { + Assert.notNull(filterConfig, "FilterConfig must not be null"); + + this.filterConfig = filterConfig; + + // Set bean properties from init parameters. + PropertyValues pvs = new FilterConfigPropertyValues(filterConfig, this.requiredProperties); + if (!pvs.isEmpty()) { + try { + BeanWrapper bw = PropertyAccessorFactory.forBeanPropertyAccess(this); + ResourceLoader resourceLoader = new ServletContextResourceLoader(filterConfig.getServletContext()); + Environment env = this.environment; + if (env == null) { + env = new StandardServletEnvironment(); + } + bw.registerCustomEditor(Resource.class, new ResourceEditor(resourceLoader, env)); + initBeanWrapper(bw); + bw.setPropertyValues(pvs, true); + } + catch (BeansException ex) { + String msg = "Failed to set bean properties on filter '" + + filterConfig.getFilterName() + "': " + ex.getMessage(); + logger.error(msg, ex); + throw new NestedServletException(msg, ex); + } + } + + // Let subclasses do whatever initialization they like. + initFilterBean(); + + if (logger.isDebugEnabled()) { + logger.debug("Filter '" + filterConfig.getFilterName() + "' configured for use"); + } + } + + /** + * Initialize the BeanWrapper for this GenericFilterBean, + * possibly with custom editors. + *

This default implementation is empty. + * @param bw the BeanWrapper to initialize + * @throws BeansException if thrown by BeanWrapper methods + * @see org.springframework.beans.BeanWrapper#registerCustomEditor + */ + protected void initBeanWrapper(BeanWrapper bw) throws BeansException { + } + + /** + * Subclasses may override this to perform custom initialization. + * All bean properties of this filter will have been set before this + * method is invoked. + *

Note: This method will be called from standard filter initialization + * as well as filter bean initialization in a Spring application context. + * Filter name and ServletContext will be available in both cases. + *

This default implementation is empty. + * @throws ServletException if subclass initialization fails + * @see #getFilterName() + * @see #getServletContext() + */ + protected void initFilterBean() throws ServletException { + } + + /** + * Make the FilterConfig of this filter available, if any. + * Analogous to GenericServlet's {@code getServletConfig()}. + *

Public to resemble the {@code getFilterConfig()} method + * of the Servlet Filter version that shipped with WebLogic 6.1. + * @return the FilterConfig instance, or {@code null} if none available + * @see javax.servlet.GenericServlet#getServletConfig() + */ + @Nullable + public FilterConfig getFilterConfig() { + return this.filterConfig; + } + + /** + * Make the name of this filter available to subclasses. + * Analogous to GenericServlet's {@code getServletName()}. + *

Takes the FilterConfig's filter name by default. + * If initialized as bean in a Spring application context, + * it falls back to the bean name as defined in the bean factory. + * @return the filter name, or {@code null} if none available + * @see javax.servlet.GenericServlet#getServletName() + * @see javax.servlet.FilterConfig#getFilterName() + * @see #setBeanName + */ + @Nullable + protected String getFilterName() { + return (this.filterConfig != null ? this.filterConfig.getFilterName() : this.beanName); + } + + /** + * Make the ServletContext of this filter available to subclasses. + * Analogous to GenericServlet's {@code getServletContext()}. + *

Takes the FilterConfig's ServletContext by default. + * If initialized as bean in a Spring application context, + * it falls back to the ServletContext that the bean factory runs in. + * @return the ServletContext instance + * @throws IllegalStateException if no ServletContext is available + * @see javax.servlet.GenericServlet#getServletContext() + * @see javax.servlet.FilterConfig#getServletContext() + * @see #setServletContext + */ + protected ServletContext getServletContext() { + if (this.filterConfig != null) { + return this.filterConfig.getServletContext(); + } + else if (this.servletContext != null) { + return this.servletContext; + } + else { + throw new IllegalStateException("No ServletContext"); + } + } + + + /** + * PropertyValues implementation created from FilterConfig init parameters. + */ + @SuppressWarnings("serial") + private static class FilterConfigPropertyValues extends MutablePropertyValues { + + /** + * Create new FilterConfigPropertyValues. + * @param config the FilterConfig we'll use to take PropertyValues from + * @param requiredProperties set of property names we need, where + * we can't accept default values + * @throws ServletException if any required properties are missing + */ + public FilterConfigPropertyValues(FilterConfig config, Set requiredProperties) + throws ServletException { + + Set missingProps = (!CollectionUtils.isEmpty(requiredProperties) ? + new HashSet<>(requiredProperties) : null); + + Enumeration paramNames = config.getInitParameterNames(); + while (paramNames.hasMoreElements()) { + String property = paramNames.nextElement(); + Object value = config.getInitParameter(property); + addPropertyValue(new PropertyValue(property, value)); + if (missingProps != null) { + missingProps.remove(property); + } + } + + // Fail if we are still missing properties. + if (!CollectionUtils.isEmpty(missingProps)) { + throw new ServletException( + "Initialization from FilterConfig for filter '" + config.getFilterName() + + "' failed; the following required properties were missing: " + + StringUtils.collectionToDelimitedString(missingProps, ", ")); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/HiddenHttpMethodFilter.java b/spring-web/src/main/java/org/springframework/web/filter/HiddenHttpMethodFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..0ed01daa40cbec894e7dbf4e39ca3421f12a2758 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/HiddenHttpMethodFilter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * {@link javax.servlet.Filter} that converts posted method parameters into HTTP methods, + * retrievable via {@link HttpServletRequest#getMethod()}. Since browsers currently only + * support GET and POST, a common technique - used by the Prototype library, for instance - + * is to use a normal POST with an additional hidden form field ({@code _method}) + * to pass the "real" HTTP method along. This filter reads that parameter and changes + * the {@link HttpServletRequestWrapper#getMethod()} return value accordingly. + * Only {@code "PUT"}, {@code "DELETE"} and {@code "PATCH"} HTTP methods are allowed. + * + *

The name of the request parameter defaults to {@code _method}, but can be + * adapted via the {@link #setMethodParam(String) methodParam} property. + * + *

NOTE: This filter needs to run after multipart processing in case of a multipart + * POST request, due to its inherent need for checking a POST body parameter. + * So typically, put a Spring {@link org.springframework.web.multipart.support.MultipartFilter} + * before this HiddenHttpMethodFilter in your {@code web.xml} filter chain. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 3.0 + */ +public class HiddenHttpMethodFilter extends OncePerRequestFilter { + + private static final List ALLOWED_METHODS = + Collections.unmodifiableList(Arrays.asList(HttpMethod.PUT.name(), + HttpMethod.DELETE.name(), HttpMethod.PATCH.name())); + + /** Default method parameter: {@code _method}. */ + public static final String DEFAULT_METHOD_PARAM = "_method"; + + private String methodParam = DEFAULT_METHOD_PARAM; + + + /** + * Set the parameter name to look for HTTP methods. + * @see #DEFAULT_METHOD_PARAM + */ + public void setMethodParam(String methodParam) { + Assert.hasText(methodParam, "'methodParam' must not be empty"); + this.methodParam = methodParam; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + HttpServletRequest requestToUse = request; + + if ("POST".equals(request.getMethod()) && request.getAttribute(WebUtils.ERROR_EXCEPTION_ATTRIBUTE) == null) { + String paramValue = request.getParameter(this.methodParam); + if (StringUtils.hasLength(paramValue)) { + String method = paramValue.toUpperCase(Locale.ENGLISH); + if (ALLOWED_METHODS.contains(method)) { + requestToUse = new HttpMethodRequestWrapper(request, method); + } + } + } + + filterChain.doFilter(requestToUse, response); + } + + + /** + * Simple {@link HttpServletRequest} wrapper that returns the supplied method for + * {@link HttpServletRequest#getMethod()}. + */ + private static class HttpMethodRequestWrapper extends HttpServletRequestWrapper { + + private final String method; + + public HttpMethodRequestWrapper(HttpServletRequest request, String method) { + super(request); + this.method = method; + } + + @Override + public String getMethod() { + return this.method; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/HttpPutFormContentFilter.java b/spring-web/src/main/java/org/springframework/web/filter/HttpPutFormContentFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..2e6c04aa264a15db9ca7243fd1f3bcb931cafe2e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/HttpPutFormContentFilter.java @@ -0,0 +1,189 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * {@link javax.servlet.Filter} that makes form encoded data available through + * the {@code ServletRequest.getParameter*()} family of methods during HTTP PUT + * or PATCH requests. + * + *

The Servlet spec requires form data to be available for HTTP POST but + * not for HTTP PUT or PATCH requests. This filter intercepts HTTP PUT and PATCH + * requests where content type is {@code 'application/x-www-form-urlencoded'}, + * reads form encoded content from the body of the request, and wraps the ServletRequest + * in order to make the form data available as request parameters just like + * it is for HTTP POST requests. + * + * @author Rossen Stoyanchev + * @since 3.1 + * @deprecated as of 5.1 in favor of {@link FormContentFilter} which is the same + * but also handles DELETE. + */ +@Deprecated +public class HttpPutFormContentFilter extends OncePerRequestFilter { + + private FormHttpMessageConverter formConverter = new AllEncompassingFormHttpMessageConverter(); + + + /** + * Set the converter to use for parsing form content. + *

By default this is an instance of {@link AllEncompassingFormHttpMessageConverter}. + */ + public void setFormConverter(FormHttpMessageConverter converter) { + Assert.notNull(converter, "FormHttpMessageConverter is required."); + this.formConverter = converter; + } + + public FormHttpMessageConverter getFormConverter() { + return this.formConverter; + } + + /** + * The default character set to use for reading form data. + * This is a shortcut for:
+ * {@code getFormConverter.setCharset(charset)}. + */ + public void setCharset(Charset charset) { + this.formConverter.setCharset(charset); + } + + + @Override + protected void doFilterInternal(final HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + if (("PUT".equals(request.getMethod()) || "PATCH".equals(request.getMethod())) && isFormContentType(request)) { + HttpInputMessage inputMessage = new ServletServerHttpRequest(request) { + @Override + public InputStream getBody() throws IOException { + return request.getInputStream(); + } + }; + MultiValueMap formParameters = this.formConverter.read(null, inputMessage); + if (!formParameters.isEmpty()) { + HttpServletRequest wrapper = new HttpPutFormContentRequestWrapper(request, formParameters); + filterChain.doFilter(wrapper, response); + return; + } + } + + filterChain.doFilter(request, response); + } + + private boolean isFormContentType(HttpServletRequest request) { + String contentType = request.getContentType(); + if (contentType != null) { + try { + MediaType mediaType = MediaType.parseMediaType(contentType); + return (MediaType.APPLICATION_FORM_URLENCODED.includes(mediaType)); + } + catch (IllegalArgumentException ex) { + return false; + } + } + else { + return false; + } + } + + + private static class HttpPutFormContentRequestWrapper extends HttpServletRequestWrapper { + + private MultiValueMap formParameters; + + public HttpPutFormContentRequestWrapper(HttpServletRequest request, MultiValueMap parameters) { + super(request); + this.formParameters = parameters; + } + + @Override + @Nullable + public String getParameter(String name) { + String queryStringValue = super.getParameter(name); + String formValue = this.formParameters.getFirst(name); + return (queryStringValue != null ? queryStringValue : formValue); + } + + @Override + public Map getParameterMap() { + Map result = new LinkedHashMap<>(); + Enumeration names = getParameterNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + result.put(name, getParameterValues(name)); + } + return result; + } + + @Override + public Enumeration getParameterNames() { + Set names = new LinkedHashSet<>(); + names.addAll(Collections.list(super.getParameterNames())); + names.addAll(this.formParameters.keySet()); + return Collections.enumeration(names); + } + + @Override + @Nullable + public String[] getParameterValues(String name) { + String[] parameterValues = super.getParameterValues(name); + List formParam = this.formParameters.get(name); + if (formParam == null) { + return parameterValues; + } + if (parameterValues == null || getQueryString() == null) { + return StringUtils.toStringArray(formParam); + } + else { + List result = new ArrayList<>(parameterValues.length + formParam.size()); + result.addAll(Arrays.asList(parameterValues)); + result.addAll(formParam); + return StringUtils.toStringArray(result); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..a775b8b504e4546f8092523675c500faf8f49017 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java @@ -0,0 +1,254 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.DispatcherType; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; +import org.springframework.web.util.WebUtils; + +/** + * Filter base class that aims to guarantee a single execution per request + * dispatch, on any servlet container. It provides a {@link #doFilterInternal} + * method with HttpServletRequest and HttpServletResponse arguments. + * + *

As of Servlet 3.0, a filter may be invoked as part of a + * {@link javax.servlet.DispatcherType#REQUEST REQUEST} or + * {@link javax.servlet.DispatcherType#ASYNC ASYNC} dispatches that occur in + * separate threads. A filter can be configured in {@code web.xml} whether it + * should be involved in async dispatches. However, in some cases servlet + * containers assume different default configuration. Therefore sub-classes can + * override the method {@link #shouldNotFilterAsyncDispatch()} to declare + * statically if they should indeed be invoked, once, during both types + * of dispatches in order to provide thread initialization, logging, security, + * and so on. This mechanism complements and does not replace the need to + * configure a filter in {@code web.xml} with dispatcher types. + * + *

Subclasses may use {@link #isAsyncDispatch(HttpServletRequest)} to + * determine when a filter is invoked as part of an async dispatch, and use + * {@link #isAsyncStarted(HttpServletRequest)} to determine when the request + * has been placed in async mode and therefore the current dispatch won't be + * the last one for the given request. + * + *

Yet another dispatch type that also occurs in its own thread is + * {@link javax.servlet.DispatcherType#ERROR ERROR}. Subclasses can override + * {@link #shouldNotFilterErrorDispatch()} if they wish to declare statically + * if they should be invoked once during error dispatches. + * + *

The {@link #getAlreadyFilteredAttributeName} method determines how to + * identify that a request is already filtered. The default implementation is + * based on the configured name of the concrete filter instance. + * + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 06.12.2003 + */ +public abstract class OncePerRequestFilter extends GenericFilterBean { + + /** + * Suffix that gets appended to the filter name for the + * "already filtered" request attribute. + * @see #getAlreadyFilteredAttributeName + */ + public static final String ALREADY_FILTERED_SUFFIX = ".FILTERED"; + + + /** + * This {@code doFilter} implementation stores a request attribute for + * "already filtered", proceeding without filtering again if the + * attribute is already there. + * @see #getAlreadyFilteredAttributeName + * @see #shouldNotFilter + * @see #doFilterInternal + */ + @Override + public final void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) { + throw new ServletException("OncePerRequestFilter just supports HTTP requests"); + } + HttpServletRequest httpRequest = (HttpServletRequest) request; + HttpServletResponse httpResponse = (HttpServletResponse) response; + + String alreadyFilteredAttributeName = getAlreadyFilteredAttributeName(); + boolean hasAlreadyFilteredAttribute = request.getAttribute(alreadyFilteredAttributeName) != null; + + if (skipDispatch(httpRequest) || shouldNotFilter(httpRequest)) { + + // Proceed without invoking this filter... + filterChain.doFilter(request, response); + } + else if (hasAlreadyFilteredAttribute) { + + if (DispatcherType.ERROR.equals(request.getDispatcherType())) { + doFilterNestedErrorDispatch(httpRequest, httpResponse, filterChain); + return; + } + + // Proceed without invoking this filter... + filterChain.doFilter(request, response); + } + else { + // Do invoke this filter... + request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE); + try { + doFilterInternal(httpRequest, httpResponse, filterChain); + } + finally { + // Remove the "already filtered" request attribute for this request. + request.removeAttribute(alreadyFilteredAttributeName); + } + } + } + + private boolean skipDispatch(HttpServletRequest request) { + if (isAsyncDispatch(request) && shouldNotFilterAsyncDispatch()) { + return true; + } + if (request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE) != null && shouldNotFilterErrorDispatch()) { + return true; + } + return false; + } + + /** + * The dispatcher type {@code javax.servlet.DispatcherType.ASYNC} introduced + * in Servlet 3.0 means a filter can be invoked in more than one thread over + * the course of a single request. This method returns {@code true} if the + * filter is currently executing within an asynchronous dispatch. + * @param request the current request + * @since 3.2 + * @see WebAsyncManager#hasConcurrentResult() + */ + protected boolean isAsyncDispatch(HttpServletRequest request) { + return WebAsyncUtils.getAsyncManager(request).hasConcurrentResult(); + } + + /** + * Whether request processing is in asynchronous mode meaning that the + * response will not be committed after the current thread is exited. + * @param request the current request + * @since 3.2 + * @see WebAsyncManager#isConcurrentHandlingStarted() + */ + protected boolean isAsyncStarted(HttpServletRequest request) { + return WebAsyncUtils.getAsyncManager(request).isConcurrentHandlingStarted(); + } + + /** + * Return the name of the request attribute that identifies that a request + * is already filtered. + *

The default implementation takes the configured name of the concrete filter + * instance and appends ".FILTERED". If the filter is not fully initialized, + * it falls back to its class name. + * @see #getFilterName + * @see #ALREADY_FILTERED_SUFFIX + */ + protected String getAlreadyFilteredAttributeName() { + String name = getFilterName(); + if (name == null) { + name = getClass().getName(); + } + return name + ALREADY_FILTERED_SUFFIX; + } + + /** + * Can be overridden in subclasses for custom filtering control, + * returning {@code true} to avoid filtering of the given request. + *

The default implementation always returns {@code false}. + * @param request current HTTP request + * @return whether the given request should not be filtered + * @throws ServletException in case of errors + */ + protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + return false; + } + + /** + * The dispatcher type {@code javax.servlet.DispatcherType.ASYNC} introduced + * in Servlet 3.0 means a filter can be invoked in more than one thread + * over the course of a single request. Some filters only need to filter + * the initial thread (e.g. request wrapping) while others may need + * to be invoked at least once in each additional thread for example for + * setting up thread locals or to perform final processing at the very end. + *

Note that although a filter can be mapped to handle specific dispatcher + * types via {@code web.xml} or in Java through the {@code ServletContext}, + * servlet containers may enforce different defaults with regards to + * dispatcher types. This flag enforces the design intent of the filter. + *

The default return value is "true", which means the filter will not be + * invoked during subsequent async dispatches. If "false", the filter will + * be invoked during async dispatches with the same guarantees of being + * invoked only once during a request within a single thread. + * @since 3.2 + */ + protected boolean shouldNotFilterAsyncDispatch() { + return true; + } + + /** + * Whether to filter error dispatches such as when the servlet container + * processes and error mapped in {@code web.xml}. The default return value + * is "true", which means the filter will not be invoked in case of an error + * dispatch. + * @since 3.2 + */ + protected boolean shouldNotFilterErrorDispatch() { + return true; + } + + + /** + * Same contract as for {@code doFilter}, but guaranteed to be + * just invoked once per request within a single request thread. + * See {@link #shouldNotFilterAsyncDispatch()} for details. + *

Provides HttpServletRequest and HttpServletResponse arguments instead of the + * default ServletRequest and ServletResponse ones. + */ + protected abstract void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException; + + /** + * Typically an ERROR dispatch happens after the REQUEST dispatch completes, + * and the filter chain starts anew. On some servers however the ERROR + * dispatch may be nested within the REQUEST dispatch, e.g. as a result of + * calling {@code sendError} on the response. In that case we are still in + * the filter chain, on the same thread, but the request and response have + * been switched to the original, unwrapped ones. + *

Sub-classes may use this method to filter such nested ERROR dispatches + * and re-apply wrapping on the request or response. {@code ThreadLocal} + * context, if any, should still be active as we are still nested within + * the filter chain. + * @since 5.1.9 + */ + protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + filterChain.doFilter(request, response); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/RelativeRedirectFilter.java b/spring-web/src/main/java/org/springframework/web/filter/RelativeRedirectFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..25df286a0c4809e5c73f71447fbcba3824255ff4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/RelativeRedirectFilter.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpStatus; +import org.springframework.util.Assert; + +/** + * Overrides {@link HttpServletResponse#sendRedirect(String)} and handles it by + * setting the HTTP status and "Location" headers, which keeps the Servlet + * container from re-writing relative redirect URLs into absolute ones. + * Servlet containers are required to do that but against the recommendation of + * RFC 7231 Section 7.1.2, + * and furthermore not necessarily taking into account "X-Forwarded" headers. + * + *

Note: While relative redirects are recommended in the + * RFC, under some configurations with reverse proxies they may not work. + * + * @author Rob Winch + * @author Rossen Stoyanchev + * @since 4.3.10 + */ +public class RelativeRedirectFilter extends OncePerRequestFilter { + + private HttpStatus redirectStatus = HttpStatus.SEE_OTHER; + + + /** + * Set the default HTTP Status to use for redirects. + *

By default this is {@link HttpStatus#SEE_OTHER}. + * @param status the 3xx redirect status to use + */ + public void setRedirectStatus(HttpStatus status) { + Assert.notNull(status, "Property 'redirectStatus' is required"); + Assert.isTrue(status.is3xxRedirection(), "Not a redirect status code"); + this.redirectStatus = status; + } + + /** + * Return the configured redirect status. + */ + public HttpStatus getRedirectStatus() { + return this.redirectStatus; + } + + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + response = RelativeRedirectResponseWrapper.wrapIfNecessary(response, this.redirectStatus); + filterChain.doFilter(request, response); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/RelativeRedirectResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/filter/RelativeRedirectResponseWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..baf4ab53ce9b80de51c6e2e51b83a4c4bffdb741 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/RelativeRedirectResponseWrapper.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.filter; + +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.util.Assert; +import org.springframework.web.util.WebUtils; + +/** + * A response wrapper used for the implementation of + * {@link RelativeRedirectFilter} also shared with {@link ForwardedHeaderFilter}. + * + * @author Rossen Stoyanchev + * @since 4.3.10 + */ +final class RelativeRedirectResponseWrapper extends HttpServletResponseWrapper { + + private final HttpStatus redirectStatus; + + + private RelativeRedirectResponseWrapper(HttpServletResponse response, HttpStatus redirectStatus) { + super(response); + Assert.notNull(redirectStatus, "'redirectStatus' is required"); + this.redirectStatus = redirectStatus; + } + + + @Override + public void sendRedirect(String location) { + setStatus(this.redirectStatus.value()); + setHeader(HttpHeaders.LOCATION, location); + } + + + public static HttpServletResponse wrapIfNecessary(HttpServletResponse response, + HttpStatus redirectStatus) { + + RelativeRedirectResponseWrapper wrapper = + WebUtils.getNativeResponse(response, RelativeRedirectResponseWrapper.class); + + return (wrapper != null ? response : + new RelativeRedirectResponseWrapper(response, redirectStatus)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java b/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..8bb6a821bb2ed0d76f2c136a297cfbf5f444899e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java @@ -0,0 +1,124 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.context.i18n.LocaleContextHolder; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +/** + * Servlet Filter that exposes the request to the current thread, + * through both {@link org.springframework.context.i18n.LocaleContextHolder} and + * {@link RequestContextHolder}. To be registered as filter in {@code web.xml}. + * + *

Alternatively, Spring's {@link org.springframework.web.context.request.RequestContextListener} + * and Spring's {@link org.springframework.web.servlet.DispatcherServlet} also expose + * the same request context to the current thread. + * + *

This filter is mainly for use with third-party servlets, e.g. the JSF FacesServlet. + * Within Spring's own web support, DispatcherServlet's processing is perfectly sufficient. + * + * @author Juergen Hoeller + * @author Rod Johnson + * @author Rossen Stoyanchev + * @since 2.0 + * @see org.springframework.context.i18n.LocaleContextHolder + * @see org.springframework.web.context.request.RequestContextHolder + * @see org.springframework.web.context.request.RequestContextListener + * @see org.springframework.web.servlet.DispatcherServlet + */ +public class RequestContextFilter extends OncePerRequestFilter { + + private boolean threadContextInheritable = false; + + + /** + * Set whether to expose the LocaleContext and RequestAttributes as inheritable + * for child threads (using an {@link java.lang.InheritableThreadLocal}). + *

Default is "false", to avoid side effects on spawned background threads. + * Switch this to "true" to enable inheritance for custom child threads which + * are spawned during request processing and only used for this request + * (that is, ending after their initial task, without reuse of the thread). + *

WARNING: Do not use inheritance for child threads if you are + * accessing a thread pool which is configured to potentially add new threads + * on demand (e.g. a JDK {@link java.util.concurrent.ThreadPoolExecutor}), + * since this will expose the inherited context to such a pooled thread. + */ + public void setThreadContextInheritable(boolean threadContextInheritable) { + this.threadContextInheritable = threadContextInheritable; + } + + + /** + * Returns "false" so that the filter may set up the request context in each + * asynchronously dispatched thread. + */ + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + + /** + * Returns "false" so that the filter may set up the request context in an + * error dispatch. + */ + @Override + protected boolean shouldNotFilterErrorDispatch() { + return false; + } + + @Override + protected void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + ServletRequestAttributes attributes = new ServletRequestAttributes(request, response); + initContextHolders(request, attributes); + + try { + filterChain.doFilter(request, response); + } + finally { + resetContextHolders(); + if (logger.isTraceEnabled()) { + logger.trace("Cleared thread-bound request context: " + request); + } + attributes.requestCompleted(); + } + } + + private void initContextHolders(HttpServletRequest request, ServletRequestAttributes requestAttributes) { + LocaleContextHolder.setLocale(request.getLocale(), this.threadContextInheritable); + RequestContextHolder.setRequestAttributes(requestAttributes, this.threadContextInheritable); + if (logger.isTraceEnabled()) { + logger.trace("Bound request context to thread: " + request); + } + } + + private void resetContextHolders() { + LocaleContextHolder.resetLocaleContext(); + RequestContextHolder.resetRequestAttributes(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/ServletContextRequestLoggingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ServletContextRequestLoggingFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..40cf7681e20d56b3d779d965d275b13d931c8821 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/ServletContextRequestLoggingFilter.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2005 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import javax.servlet.http.HttpServletRequest; + +/** + * Simple request logging filter that writes the request URI + * (and optionally the query string) to the ServletContext log. + * + * @author Juergen Hoeller + * @since 1.2.5 + * @see #setIncludeQueryString + * @see #setBeforeMessagePrefix + * @see #setBeforeMessageSuffix + * @see #setAfterMessagePrefix + * @see #setAfterMessageSuffix + * @see javax.servlet.ServletContext#log(String) + */ +public class ServletContextRequestLoggingFilter extends AbstractRequestLoggingFilter { + + /** + * Writes a log message before the request is processed. + */ + @Override + protected void beforeRequest(HttpServletRequest request, String message) { + getServletContext().log(message); + } + + /** + * Writes a log message after the request is processed. + */ + @Override + protected void afterRequest(HttpServletRequest request, String message) { + getServletContext().log(message); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..0dfcd6e6e5de3e8c33182d4cfae9f80f3950355c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java @@ -0,0 +1,239 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletOutputStream; +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; +import org.springframework.util.DigestUtils; +import org.springframework.web.util.ContentCachingResponseWrapper; +import org.springframework.web.util.WebUtils; + +/** + * {@link javax.servlet.Filter} that generates an {@code ETag} value based on the + * content on the response. This ETag is compared to the {@code If-None-Match} + * header of the request. If these headers are equal, the response content is + * not sent, but rather a {@code 304 "Not Modified"} status instead. + * + *

Since the ETag is based on the response content, the response + * (e.g. a {@link org.springframework.web.servlet.View}) is still rendered. + * As such, this filter only saves bandwidth, not server performance. + * + *

NOTE: As of Spring Framework 5.0, this filter uses request/response + * decorators built on the Servlet 3.1 API. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Brian Clozel + * @author Juergen Hoeller + * @since 3.0 + */ +public class ShallowEtagHeaderFilter extends OncePerRequestFilter { + + private static final String HEADER_ETAG = "ETag"; + + private static final String HEADER_IF_NONE_MATCH = "If-None-Match"; + + private static final String HEADER_CACHE_CONTROL = "Cache-Control"; + + private static final String DIRECTIVE_NO_STORE = "no-store"; + + private static final String STREAMING_ATTRIBUTE = ShallowEtagHeaderFilter.class.getName() + ".STREAMING"; + + + private boolean writeWeakETag = false; + + + /** + * Set whether the ETag value written to the response should be weak, as per RFC 7232. + *

Should be configured using an {@code } for parameter name + * "writeWeakETag" in the filter definition in {@code web.xml}. + * @since 4.3 + * @see RFC 7232 section 2.3 + */ + public void setWriteWeakETag(boolean writeWeakETag) { + this.writeWeakETag = writeWeakETag; + } + + /** + * Return whether the ETag value written to the response should be weak, as per RFC 7232. + * @since 4.3 + */ + public boolean isWriteWeakETag() { + return this.writeWeakETag; + } + + + /** + * The default value is {@code false} so that the filter may delay the generation + * of an ETag until the last asynchronously dispatched thread. + */ + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + HttpServletResponse responseToUse = response; + if (!isAsyncDispatch(request) && !(response instanceof ContentCachingResponseWrapper)) { + responseToUse = new HttpStreamingAwareContentCachingResponseWrapper(response, request); + } + + filterChain.doFilter(request, responseToUse); + + if (!isAsyncStarted(request) && !isContentCachingDisabled(request)) { + updateResponse(request, responseToUse); + } + } + + private void updateResponse(HttpServletRequest request, HttpServletResponse response) throws IOException { + ContentCachingResponseWrapper responseWrapper = + WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class); + Assert.notNull(responseWrapper, "ContentCachingResponseWrapper not found"); + HttpServletResponse rawResponse = (HttpServletResponse) responseWrapper.getResponse(); + int statusCode = responseWrapper.getStatusCode(); + + if (rawResponse.isCommitted()) { + responseWrapper.copyBodyToResponse(); + } + else if (isEligibleForEtag(request, responseWrapper, statusCode, responseWrapper.getContentInputStream())) { + String responseETag = generateETagHeaderValue(responseWrapper.getContentInputStream(), this.writeWeakETag); + rawResponse.setHeader(HEADER_ETAG, responseETag); + String requestETag = request.getHeader(HEADER_IF_NONE_MATCH); + if (requestETag != null && ("*".equals(requestETag) || compareETagHeaderValue(requestETag, responseETag))) { + rawResponse.setStatus(HttpServletResponse.SC_NOT_MODIFIED); + } + else { + responseWrapper.copyBodyToResponse(); + } + } + else { + responseWrapper.copyBodyToResponse(); + } + } + + /** + * Indicates whether the given request and response are eligible for ETag generation. + *

The default implementation returns {@code true} if all conditions match: + *

    + *
  • response status codes in the {@code 2xx} series
  • + *
  • request method is a GET
  • + *
  • response Cache-Control header is not set or does not contain a "no-store" directive
  • + *
+ * @param request the HTTP request + * @param response the HTTP response + * @param responseStatusCode the HTTP response status code + * @param inputStream the response body + * @return {@code true} if eligible for ETag generation, {@code false} otherwise + */ + protected boolean isEligibleForEtag(HttpServletRequest request, HttpServletResponse response, + int responseStatusCode, InputStream inputStream) { + + String method = request.getMethod(); + if (responseStatusCode >= 200 && responseStatusCode < 300 && HttpMethod.GET.matches(method)) { + String cacheControl = response.getHeader(HEADER_CACHE_CONTROL); + return (cacheControl == null || !cacheControl.contains(DIRECTIVE_NO_STORE)); + } + return false; + } + + /** + * Generate the ETag header value from the given response body byte array. + *

The default implementation generates an MD5 hash. + * @param inputStream the response body as an InputStream + * @param isWeak whether the generated ETag should be weak + * @return the ETag header value + * @see org.springframework.util.DigestUtils + */ + protected String generateETagHeaderValue(InputStream inputStream, boolean isWeak) throws IOException { + // length of W/ + " + 0 + 32bits md5 hash + " + StringBuilder builder = new StringBuilder(37); + if (isWeak) { + builder.append("W/"); + } + builder.append("\"0"); + DigestUtils.appendMd5DigestAsHex(inputStream, builder); + builder.append('"'); + return builder.toString(); + } + + private boolean compareETagHeaderValue(String requestETag, String responseETag) { + if (requestETag.startsWith("W/")) { + requestETag = requestETag.substring(2); + } + if (responseETag.startsWith("W/")) { + responseETag = responseETag.substring(2); + } + return requestETag.equals(responseETag); + } + + + /** + * This method can be used to disable the content caching response wrapper + * of the ShallowEtagHeaderFilter. This can be done before the start of HTTP + * streaming for example where the response will be written to asynchronously + * and not in the context of a Servlet container thread. + * @since 4.2 + */ + public static void disableContentCaching(ServletRequest request) { + Assert.notNull(request, "ServletRequest must not be null"); + request.setAttribute(STREAMING_ATTRIBUTE, true); + } + + private static boolean isContentCachingDisabled(HttpServletRequest request) { + return (request.getAttribute(STREAMING_ATTRIBUTE) != null); + } + + + private static class HttpStreamingAwareContentCachingResponseWrapper extends ContentCachingResponseWrapper { + + private final HttpServletRequest request; + + public HttpStreamingAwareContentCachingResponseWrapper(HttpServletResponse response, HttpServletRequest request) { + super(response); + this.request = request; + } + + @Override + public ServletOutputStream getOutputStream() throws IOException { + return (useRawResponse() ? getResponse().getOutputStream() : super.getOutputStream()); + } + + @Override + public PrintWriter getWriter() throws IOException { + return (useRawResponse() ? getResponse().getWriter() : super.getWriter()); + } + + private boolean useRawResponse() { + return isContentCachingDisabled(this.request); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/package-info.java b/spring-web/src/main/java/org/springframework/web/filter/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..d54e7e3d3fd810390d43767a598539ed9a298e91 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides generic filter base classes allowing for bean-style configuration. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.filter; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..27f97d85a201888c931e31c1c82c3445f11cc9bd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter.reactive; + +import reactor.core.publisher.Mono; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.adapter.ForwardedHeaderTransformer; + +/** + * Extract values from "Forwarded" and "X-Forwarded-*" headers to override the + * request URI (i.e. {@link ServerHttpRequest#getURI()}) so it reflects the + * client-originated protocol and address. + * + *

Alternatively if {@link #setRemoveOnly removeOnly} is set to "true", then + * "Forwarded" and "X-Forwarded-*" headers are only removed, and not used. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @deprecated as of 5.1 this filter is deprecated in favor of using + * {@link ForwardedHeaderTransformer} which can be declared as a bean with the + * name "forwardedHeaderTransformer" or registered explicitly in + * {@link org.springframework.web.server.adapter.WebHttpHandlerBuilder + * WebHttpHandlerBuilder}. + * @since 5.0 + * @see https://tools.ietf.org/html/rfc7239 + */ +@Deprecated +public class ForwardedHeaderFilter extends ForwardedHeaderTransformer implements WebFilter { + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + ServerHttpRequest request = exchange.getRequest(); + if (hasForwardedHeaders(request)) { + exchange = exchange.mutate().request(apply(request)).build(); + } + return chain.filter(exchange); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..90d18497169c047392d074981dd08a48fcec6bdc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilter.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter.reactive; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; + +/** + * Reactive {@link WebFilter} that converts posted method parameters into HTTP methods, + * retrievable via {@link ServerHttpRequest#getMethod()}. Since browsers currently only + * support GET and POST, a common technique is to use a normal POST with an additional + * hidden form field ({@code _method}) to pass the "real" HTTP method along. + * This filter reads that parameter and changes the {@link ServerHttpRequest#getMethod()} + * return value using {@link ServerWebExchange#mutate()}. + * + *

The name of the request parameter defaults to {@code _method}, but can be + * adapted via the {@link #setMethodParamName(String) methodParamName} property. + * + * @author Greg Turnquist + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class HiddenHttpMethodFilter implements WebFilter { + + private static final List ALLOWED_METHODS = + Collections.unmodifiableList(Arrays.asList(HttpMethod.PUT, + HttpMethod.DELETE, HttpMethod.PATCH)); + + /** Default name of the form parameter with the HTTP method to use. */ + public static final String DEFAULT_METHOD_PARAMETER_NAME = "_method"; + + + private String methodParamName = DEFAULT_METHOD_PARAMETER_NAME; + + + /** + * Set the name of the form parameter with the HTTP method to use. + *

By default this is set to {@code "_method"}. + */ + public void setMethodParamName(String methodParamName) { + Assert.hasText(methodParamName, "'methodParamName' must not be empty"); + this.methodParamName = methodParamName; + } + + + /** + * Transform an HTTP POST into another method based on {@code methodParamName}. + * @param exchange the current server exchange + * @param chain provides a way to delegate to the next filter + * @return {@code Mono} to indicate when request processing is complete + */ + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + + if (exchange.getRequest().getMethod() != HttpMethod.POST) { + return chain.filter(exchange); + } + + return exchange.getFormData() + .map(formData -> { + String method = formData.getFirst(this.methodParamName); + return StringUtils.hasLength(method) ? mapExchange(exchange, method) : exchange; + }) + .flatMap(chain::filter); + } + + private ServerWebExchange mapExchange(ServerWebExchange exchange, String methodParamValue) { + HttpMethod httpMethod = HttpMethod.resolve(methodParamValue.toUpperCase(Locale.ENGLISH)); + Assert.notNull(httpMethod, () -> "HttpMethod '" + methodParamValue + "' not supported"); + if (ALLOWED_METHODS.contains(httpMethod)) { + return exchange.mutate().request(builder -> builder.method(httpMethod)).build(); + } + else { + return exchange; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/package-info.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..23a7beb13a2466bceb8d5273c29960c059b3f8c7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/package-info.java @@ -0,0 +1,10 @@ +/** + * {@link org.springframework.web.server.WebFilter} implementations for use in + * reactive web applications. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.filter.reactive; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/jsf/DecoratingNavigationHandler.java b/spring-web/src/main/java/org/springframework/web/jsf/DecoratingNavigationHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..7d8f082586f6fbaa705a5b55ad3013654daa20cd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/DecoratingNavigationHandler.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import javax.faces.application.NavigationHandler; +import javax.faces.context.FacesContext; + +import org.springframework.lang.Nullable; + +/** + * Base class for JSF NavigationHandler implementations that want + * to be capable of decorating an original NavigationHandler. + * + *

Supports the standard JSF style of decoration (through a constructor argument) + * as well as an overloaded {@code handleNavigation} method with explicit + * NavigationHandler argument (passing in the original NavigationHandler). Subclasses + * are forced to implement this overloaded {@code handleNavigation} method. + * Standard JSF invocations will automatically delegate to the overloaded method, + * with the constructor-injected NavigationHandler as argument. + * + * @author Juergen Hoeller + * @since 1.2.7 + * @see #handleNavigation(javax.faces.context.FacesContext, String, String, NavigationHandler) + * @see DelegatingNavigationHandlerProxy + */ +public abstract class DecoratingNavigationHandler extends NavigationHandler { + + @Nullable + private NavigationHandler decoratedNavigationHandler; + + + /** + * Create a DecoratingNavigationHandler without fixed original NavigationHandler. + */ + protected DecoratingNavigationHandler() { + } + + /** + * Create a DecoratingNavigationHandler with fixed original NavigationHandler. + * @param originalNavigationHandler the original NavigationHandler to decorate + */ + protected DecoratingNavigationHandler(NavigationHandler originalNavigationHandler) { + this.decoratedNavigationHandler = originalNavigationHandler; + } + + /** + * Return the fixed original NavigationHandler decorated by this handler, if any + * (that is, if passed in through the constructor). + */ + @Nullable + public final NavigationHandler getDecoratedNavigationHandler() { + return this.decoratedNavigationHandler; + } + + + /** + * This implementation of the standard JSF {@code handleNavigation} method + * delegates to the overloaded variant, passing in constructor-injected + * NavigationHandler as argument. + * @see #handleNavigation(javax.faces.context.FacesContext, String, String, javax.faces.application.NavigationHandler) + */ + @Override + public final void handleNavigation(FacesContext facesContext, String fromAction, String outcome) { + handleNavigation(facesContext, fromAction, outcome, this.decoratedNavigationHandler); + } + + /** + * Special {@code handleNavigation} variant with explicit NavigationHandler + * argument. Either called directly, by code with an explicit original handler, + * or called from the standard {@code handleNavigation} method, as + * plain JSF-defined NavigationHandler. + *

Implementations should invoke {@code callNextHandlerInChain} to + * delegate to the next handler in the chain. This will always call the most + * appropriate next handler (see {@code callNextHandlerInChain} javadoc). + * Alternatively, the decorated NavigationHandler or the passed-in original + * NavigationHandler can also be called directly; however, this is not as + * flexible in terms of reacting to potential positions in the chain. + * @param facesContext the current JSF context + * @param fromAction the action binding expression that was evaluated to retrieve the + * specified outcome, or {@code null} if the outcome was acquired by some other means + * @param outcome the logical outcome returned by a previous invoked application action + * (which may be {@code null}) + * @param originalNavigationHandler the original NavigationHandler, + * or {@code null} if none + * @see #callNextHandlerInChain + */ + public abstract void handleNavigation(FacesContext facesContext, @Nullable String fromAction, + @Nullable String outcome, @Nullable NavigationHandler originalNavigationHandler); + + + /** + * Method to be called by subclasses when intending to delegate to the next + * handler in the NavigationHandler chain. Will always call the most + * appropriate next handler, either the decorated NavigationHandler passed + * in as constructor argument or the original NavigationHandler as passed + * into this method - according to the position of this instance in the chain. + *

Will call the decorated NavigationHandler specified as constructor + * argument, if any. In case of a DecoratingNavigationHandler as target, the + * original NavigationHandler as passed into this method will be passed on to + * the next element in the chain: This ensures propagation of the original + * handler that the last element in the handler chain might delegate back to. + * In case of a standard NavigationHandler as target, the original handler + * will simply not get passed on; no delegating back to the original is + * possible further down the chain in that scenario. + *

If no decorated NavigationHandler specified as constructor argument, + * this instance is the last element in the chain. Hence, this method will + * call the original NavigationHandler as passed into this method. If no + * original NavigationHandler has been passed in (for example if this + * instance is the last element in a chain with standard NavigationHandlers + * as earlier elements), this method corresponds to a no-op. + * @param facesContext the current JSF context + * @param fromAction the action binding expression that was evaluated to retrieve the + * specified outcome, or {@code null} if the outcome was acquired by some other means + * @param outcome the logical outcome returned by a previous invoked application action + * (which may be {@code null}) + * @param originalNavigationHandler the original NavigationHandler, + * or {@code null} if none + */ + protected final void callNextHandlerInChain(FacesContext facesContext, @Nullable String fromAction, + @Nullable String outcome, @Nullable NavigationHandler originalNavigationHandler) { + + NavigationHandler decoratedNavigationHandler = getDecoratedNavigationHandler(); + + if (decoratedNavigationHandler instanceof DecoratingNavigationHandler) { + // DecoratingNavigationHandler specified through constructor argument: + // Call it with original NavigationHandler passed in. + DecoratingNavigationHandler decHandler = (DecoratingNavigationHandler) decoratedNavigationHandler; + decHandler.handleNavigation(facesContext, fromAction, outcome, originalNavigationHandler); + } + else if (decoratedNavigationHandler != null) { + // Standard NavigationHandler specified through constructor argument: + // Call it through standard API, without original NavigationHandler passed in. + // The called handler will not be able to redirect to the original handler. + decoratedNavigationHandler.handleNavigation(facesContext, fromAction, outcome); + } + else if (originalNavigationHandler != null) { + // No NavigationHandler specified through constructor argument: + // Call original handler, marking the end of this chain. + originalNavigationHandler.handleNavigation(facesContext, fromAction, outcome); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/jsf/DelegatingNavigationHandlerProxy.java b/spring-web/src/main/java/org/springframework/web/jsf/DelegatingNavigationHandlerProxy.java new file mode 100644 index 0000000000000000000000000000000000000000..a6eafaace42c7182aa038ddfbd0c13e9c1840ed8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/DelegatingNavigationHandlerProxy.java @@ -0,0 +1,170 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import javax.faces.application.NavigationHandler; +import javax.faces.context.FacesContext; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.lang.Nullable; +import org.springframework.web.context.WebApplicationContext; + +/** + * JSF NavigationHandler implementation that delegates to a NavigationHandler + * bean obtained from the Spring root WebApplicationContext. + * + *

Configure this handler proxy in your {@code faces-config.xml} file + * as follows: + * + *

+ * <application>
+ *   ...
+ *   <navigation-handler>
+ * 	   org.springframework.web.jsf.DelegatingNavigationHandlerProxy
+ *   </navigation-handler>
+ *   ...
+ * </application>
+ * + * By default, the Spring ApplicationContext will be searched for the NavigationHandler + * under the bean name "jsfNavigationHandler". In the simplest case, this is a plain + * Spring bean definition like the following. However, all of Spring's bean configuration + * power can be applied to such a bean, in particular all flavors of dependency injection. + * + *
+ * <bean name="jsfNavigationHandler" class="mypackage.MyNavigationHandler">
+ *   <property name="myProperty" ref="myOtherBean"/>
+ * </bean>
+ * + * The target NavigationHandler bean will typically extend the standard JSF + * NavigationHandler class. However, note that decorating the original + * NavigationHandler (the JSF provider's default handler) is not supported + * in such a scenario, since we can't inject the original handler in standard + * JSF style (that is, as constructor argument). + * + *

For decorating the original NavigationHandler, make sure that your + * target bean extends Spring's DecoratingNavigationHandler class. This + * allows to pass in the original handler as method argument, which this proxy + * automatically detects. Note that a DecoratingNavigationHandler subclass + * will still work as standard JSF NavigationHandler as well! + * + *

This proxy may be subclassed to change the bean name used to search for the + * navigation handler, change the strategy used to obtain the target handler, + * or change the strategy used to access the ApplicationContext (normally obtained + * via {@link FacesContextUtils#getWebApplicationContext(FacesContext)}). + * + * @author Juergen Hoeller + * @author Colin Sampaleanu + * @since 1.2.7 + * @see DecoratingNavigationHandler + */ +public class DelegatingNavigationHandlerProxy extends NavigationHandler { + + /** + * Default name of the target bean in the Spring application context: + * "jsfNavigationHandler". + */ + public static final String DEFAULT_TARGET_BEAN_NAME = "jsfNavigationHandler"; + + @Nullable + private NavigationHandler originalNavigationHandler; + + + /** + * Create a new DelegatingNavigationHandlerProxy. + */ + public DelegatingNavigationHandlerProxy() { + } + + /** + * Create a new DelegatingNavigationHandlerProxy. + * @param originalNavigationHandler the original NavigationHandler + */ + public DelegatingNavigationHandlerProxy(NavigationHandler originalNavigationHandler) { + this.originalNavigationHandler = originalNavigationHandler; + } + + + /** + * Handle the navigation request implied by the specified parameters, + * through delegating to the target bean in the Spring application context. + *

The target bean needs to extend the JSF NavigationHandler class. + * If it extends Spring's DecoratingNavigationHandler, the overloaded + * {@code handleNavigation} method with the original NavigationHandler + * as argument will be used. Else, the standard {@code handleNavigation} + * method will be called. + */ + @Override + public void handleNavigation(FacesContext facesContext, String fromAction, String outcome) { + NavigationHandler handler = getDelegate(facesContext); + if (handler instanceof DecoratingNavigationHandler) { + ((DecoratingNavigationHandler) handler).handleNavigation( + facesContext, fromAction, outcome, this.originalNavigationHandler); + } + else { + handler.handleNavigation(facesContext, fromAction, outcome); + } + } + + /** + * Return the target NavigationHandler to delegate to. + *

By default, a bean with the name "jsfNavigationHandler" is obtained + * from the Spring root WebApplicationContext, for every invocation. + * @param facesContext the current JSF context + * @return the target NavigationHandler to delegate to + * @see #getTargetBeanName + * @see #getBeanFactory + */ + protected NavigationHandler getDelegate(FacesContext facesContext) { + String targetBeanName = getTargetBeanName(facesContext); + return getBeanFactory(facesContext).getBean(targetBeanName, NavigationHandler.class); + } + + /** + * Return the name of the target NavigationHandler bean in the BeanFactory. + * Default is "jsfNavigationHandler". + * @param facesContext the current JSF context + * @return the name of the target bean + */ + protected String getTargetBeanName(FacesContext facesContext) { + return DEFAULT_TARGET_BEAN_NAME; + } + + /** + * Retrieve the Spring BeanFactory to delegate bean name resolution to. + *

Default implementation delegates to {@code getWebApplicationContext}. + * Can be overridden to provide an arbitrary BeanFactory reference to resolve + * against; usually, this will be a full Spring ApplicationContext. + * @param facesContext the current JSF context + * @return the Spring BeanFactory (never {@code null}) + * @see #getWebApplicationContext + */ + protected BeanFactory getBeanFactory(FacesContext facesContext) { + return getWebApplicationContext(facesContext); + } + + /** + * Retrieve the web application context to delegate bean name resolution to. + *

Default implementation delegates to FacesContextUtils. + * @param facesContext the current JSF context + * @return the Spring web application context (never {@code null}) + * @see FacesContextUtils#getRequiredWebApplicationContext + */ + protected WebApplicationContext getWebApplicationContext(FacesContext facesContext) { + return FacesContextUtils.getRequiredWebApplicationContext(facesContext); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/jsf/DelegatingPhaseListenerMulticaster.java b/spring-web/src/main/java/org/springframework/web/jsf/DelegatingPhaseListenerMulticaster.java new file mode 100644 index 0000000000000000000000000000000000000000..261d9baac1b3d25da566d51fbb278ff5f1b89d8a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/DelegatingPhaseListenerMulticaster.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import java.util.Collection; + +import javax.faces.context.FacesContext; +import javax.faces.event.PhaseEvent; +import javax.faces.event.PhaseId; +import javax.faces.event.PhaseListener; + +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.web.context.WebApplicationContext; + +/** + * JSF PhaseListener implementation that delegates to one or more Spring-managed + * PhaseListener beans coming from the Spring root WebApplicationContext. + * + *

Configure this listener multicaster in your {@code faces-config.xml} file + * as follows: + * + *

+ * <application>
+ *   ...
+ *   <phase-listener>
+ *     org.springframework.web.jsf.DelegatingPhaseListenerMulticaster
+ *   </phase-listener>
+ *   ...
+ * </application>
+ * + * The multicaster will delegate all {@code beforePhase} and {@code afterPhase} + * events to all target PhaseListener beans. By default, those will simply be obtained + * by type: All beans in the Spring root WebApplicationContext that implement the + * PhaseListener interface will be fetched and invoked. + * + *

Note: This multicaster's {@code getPhaseId()} method will always return + * {@code ANY_PHASE}. The phase id exposed by the target listener beans + * will be ignored; all events will be propagated to all listeners. + * + *

This multicaster may be subclassed to change the strategy used to obtain + * the listener beans, or to change the strategy used to access the ApplicationContext + * (normally obtained via {@link FacesContextUtils#getWebApplicationContext(FacesContext)}). + * + * @author Juergen Hoeller + * @author Colin Sampaleanu + * @since 1.2.7 + */ +@SuppressWarnings("serial") +public class DelegatingPhaseListenerMulticaster implements PhaseListener { + + @Override + public PhaseId getPhaseId() { + return PhaseId.ANY_PHASE; + } + + @Override + public void beforePhase(PhaseEvent event) { + for (PhaseListener listener : getDelegates(event.getFacesContext())) { + listener.beforePhase(event); + } + } + + @Override + public void afterPhase(PhaseEvent event) { + for (PhaseListener listener : getDelegates(event.getFacesContext())) { + listener.afterPhase(event); + } + } + + + /** + * Obtain the delegate PhaseListener beans from the Spring root WebApplicationContext. + * @param facesContext the current JSF context + * @return a Collection of PhaseListener objects + * @see #getBeanFactory + * @see org.springframework.beans.factory.ListableBeanFactory#getBeansOfType(Class) + */ + protected Collection getDelegates(FacesContext facesContext) { + ListableBeanFactory bf = getBeanFactory(facesContext); + return BeanFactoryUtils.beansOfTypeIncludingAncestors(bf, PhaseListener.class, true, false).values(); + } + + /** + * Retrieve the Spring BeanFactory to delegate bean name resolution to. + *

The default implementation delegates to {@code getWebApplicationContext}. + * Can be overridden to provide an arbitrary ListableBeanFactory reference to + * resolve against; usually, this will be a full Spring ApplicationContext. + * @param facesContext the current JSF context + * @return the Spring ListableBeanFactory (never {@code null}) + * @see #getWebApplicationContext + */ + protected ListableBeanFactory getBeanFactory(FacesContext facesContext) { + return getWebApplicationContext(facesContext); + } + + /** + * Retrieve the web application context to delegate bean name resolution to. + *

The default implementation delegates to FacesContextUtils. + * @param facesContext the current JSF context + * @return the Spring web application context (never {@code null}) + * @see FacesContextUtils#getRequiredWebApplicationContext + */ + protected WebApplicationContext getWebApplicationContext(FacesContext facesContext) { + return FacesContextUtils.getRequiredWebApplicationContext(facesContext); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/jsf/FacesContextUtils.java b/spring-web/src/main/java/org/springframework/web/jsf/FacesContextUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..1444460314790417cfabcbc9afbc9cff36503d26 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/FacesContextUtils.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import javax.faces.context.ExternalContext; +import javax.faces.context.FacesContext; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.util.WebUtils; + +/** + * Convenience methods to retrieve Spring's root {@link WebApplicationContext} + * for a given JSF {@link FacesContext}. This is useful for accessing a + * Spring application context from custom JSF-based code. + * + *

Analogous to Spring's WebApplicationContextUtils for the ServletContext. + * + * @author Juergen Hoeller + * @since 1.1 + * @see org.springframework.web.context.ContextLoader + * @see org.springframework.web.context.support.WebApplicationContextUtils + */ +public abstract class FacesContextUtils { + + /** + * Find the root {@link WebApplicationContext} for this web app, typically + * loaded via {@link org.springframework.web.context.ContextLoaderListener}. + *

Will rethrow an exception that happened on root context startup, + * to differentiate between a failed context startup and no context at all. + * @param fc the FacesContext to find the web application context for + * @return the root WebApplicationContext for this web app, or {@code null} if none + * @see org.springframework.web.context.WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE + */ + @Nullable + public static WebApplicationContext getWebApplicationContext(FacesContext fc) { + Assert.notNull(fc, "FacesContext must not be null"); + Object attr = fc.getExternalContext().getApplicationMap().get( + WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + if (attr == null) { + return null; + } + if (attr instanceof RuntimeException) { + throw (RuntimeException) attr; + } + if (attr instanceof Error) { + throw (Error) attr; + } + if (!(attr instanceof WebApplicationContext)) { + throw new IllegalStateException("Root context attribute is not of type WebApplicationContext: " + attr); + } + return (WebApplicationContext) attr; + } + + /** + * Find the root {@link WebApplicationContext} for this web app, typically + * loaded via {@link org.springframework.web.context.ContextLoaderListener}. + *

Will rethrow an exception that happened on root context startup, + * to differentiate between a failed context startup and no context at all. + * @param fc the FacesContext to find the web application context for + * @return the root WebApplicationContext for this web app + * @throws IllegalStateException if the root WebApplicationContext could not be found + * @see org.springframework.web.context.WebApplicationContext#ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE + */ + public static WebApplicationContext getRequiredWebApplicationContext(FacesContext fc) throws IllegalStateException { + WebApplicationContext wac = getWebApplicationContext(fc); + if (wac == null) { + throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?"); + } + return wac; + } + + /** + * Return the best available mutex for the given session: + * that is, an object to synchronize on for the given session. + *

Returns the session mutex attribute if available; usually, + * this means that the HttpSessionMutexListener needs to be defined + * in {@code web.xml}. Falls back to the Session reference itself + * if no mutex attribute found. + *

The session mutex is guaranteed to be the same object during + * the entire lifetime of the session, available under the key defined + * by the {@code SESSION_MUTEX_ATTRIBUTE} constant. It serves as a + * safe reference to synchronize on for locking on the current session. + *

In many cases, the Session reference itself is a safe mutex + * as well, since it will always be the same object reference for the + * same active logical session. However, this is not guaranteed across + * different servlet containers; the only 100% safe way is a session mutex. + * @param fc the FacesContext to find the session mutex for + * @return the mutex object (never {@code null}) + * @see org.springframework.web.util.WebUtils#SESSION_MUTEX_ATTRIBUTE + * @see org.springframework.web.util.HttpSessionMutexListener + */ + @Nullable + public static Object getSessionMutex(FacesContext fc) { + Assert.notNull(fc, "FacesContext must not be null"); + ExternalContext ec = fc.getExternalContext(); + Object mutex = ec.getSessionMap().get(WebUtils.SESSION_MUTEX_ATTRIBUTE); + if (mutex == null) { + mutex = ec.getSession(true); + } + return mutex; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/jsf/el/SpringBeanFacesELResolver.java b/spring-web/src/main/java/org/springframework/web/jsf/el/SpringBeanFacesELResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..922174d830d01eaa391e0c03d39120e6353e23d0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/el/SpringBeanFacesELResolver.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf.el; + +import java.beans.FeatureDescriptor; +import java.util.Iterator; + +import javax.el.ELContext; +import javax.el.ELException; +import javax.el.ELResolver; +import javax.el.PropertyNotWritableException; +import javax.faces.context.FacesContext; + +import org.springframework.lang.Nullable; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.jsf.FacesContextUtils; + +/** + * JSF {@code ELResolver} that delegates to the Spring root {@code WebApplicationContext}, + * resolving name references to Spring-defined beans. + * + *

Configure this resolver in your {@code faces-config.xml} file as follows: + * + *

+ * <application>
+ *   ...
+ *   <el-resolver>org.springframework.web.jsf.el.SpringBeanFacesELResolver</el-resolver>
+ * </application>
+ * + * All your JSF expressions can then implicitly refer to the names of + * Spring-managed service layer beans, for example in property values of + * JSF-managed beans: + * + *
+ * <managed-bean>
+ *   <managed-bean-name>myJsfManagedBean</managed-bean-name>
+ *   <managed-bean-class>example.MyJsfManagedBean</managed-bean-class>
+ *   <managed-bean-scope>session</managed-bean-scope>
+ *   <managed-property>
+ *     <property-name>mySpringManagedBusinessObject</property-name>
+ *     <value>#{mySpringManagedBusinessObject}</value>
+ *   </managed-property>
+ * </managed-bean>
+ * + * with "mySpringManagedBusinessObject" defined as Spring bean in + * applicationContext.xml: + * + *
+ * <bean id="mySpringManagedBusinessObject" class="example.MySpringManagedBusinessObject">
+ *   ...
+ * </bean>
+ * + * @author Juergen Hoeller + * @since 2.5 + * @see WebApplicationContextFacesELResolver + * @see org.springframework.web.jsf.FacesContextUtils#getRequiredWebApplicationContext + */ +public class SpringBeanFacesELResolver extends ELResolver { + + @Override + @Nullable + public Object getValue(ELContext elContext, @Nullable Object base, Object property) throws ELException { + if (base == null) { + String beanName = property.toString(); + WebApplicationContext wac = getWebApplicationContext(elContext); + if (wac.containsBean(beanName)) { + elContext.setPropertyResolved(true); + return wac.getBean(beanName); + } + } + return null; + } + + @Override + @Nullable + public Class getType(ELContext elContext, @Nullable Object base, Object property) throws ELException { + if (base == null) { + String beanName = property.toString(); + WebApplicationContext wac = getWebApplicationContext(elContext); + if (wac.containsBean(beanName)) { + elContext.setPropertyResolved(true); + return wac.getType(beanName); + } + } + return null; + } + + @Override + public void setValue(ELContext elContext, @Nullable Object base, Object property, Object value) throws ELException { + if (base == null) { + String beanName = property.toString(); + WebApplicationContext wac = getWebApplicationContext(elContext); + if (wac.containsBean(beanName)) { + if (value == wac.getBean(beanName)) { + // Setting the bean reference to the same value is alright - can simply be ignored... + elContext.setPropertyResolved(true); + } + else { + throw new PropertyNotWritableException( + "Variable '" + beanName + "' refers to a Spring bean which by definition is not writable"); + } + } + } + } + + @Override + public boolean isReadOnly(ELContext elContext, @Nullable Object base, Object property) throws ELException { + if (base == null) { + String beanName = property.toString(); + WebApplicationContext wac = getWebApplicationContext(elContext); + if (wac.containsBean(beanName)) { + return true; + } + } + return false; + } + + @Override + @Nullable + public Iterator getFeatureDescriptors(ELContext elContext, @Nullable Object base) { + return null; + } + + @Override + public Class getCommonPropertyType(ELContext elContext, @Nullable Object base) { + return Object.class; + } + + /** + * Retrieve the web application context to delegate bean name resolution to. + *

The default implementation delegates to FacesContextUtils. + * @param elContext the current JSF ELContext + * @return the Spring web application context (never {@code null}) + * @see org.springframework.web.jsf.FacesContextUtils#getRequiredWebApplicationContext + */ + protected WebApplicationContext getWebApplicationContext(ELContext elContext) { + FacesContext facesContext = FacesContext.getCurrentInstance(); + return FacesContextUtils.getRequiredWebApplicationContext(facesContext); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/jsf/el/WebApplicationContextFacesELResolver.java b/spring-web/src/main/java/org/springframework/web/jsf/el/WebApplicationContextFacesELResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..d729187464d3cecce0157bb6dcd80a8cab133512 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/el/WebApplicationContextFacesELResolver.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf.el; + +import java.beans.FeatureDescriptor; +import java.util.Iterator; + +import javax.el.ELContext; +import javax.el.ELException; +import javax.el.ELResolver; +import javax.faces.context.FacesContext; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.BeansException; +import org.springframework.lang.Nullable; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.jsf.FacesContextUtils; + +/** + * Special JSF {@code ELResolver} that exposes the Spring {@code WebApplicationContext} + * instance under a variable named "webApplicationContext". + * + *

In contrast to {@link SpringBeanFacesELResolver}, this ELResolver variant + * does not resolve JSF variable names as Spring bean names. It rather + * exposes Spring's root WebApplicationContext itself under a special name, + * and is able to resolve "webApplicationContext.mySpringManagedBusinessObject" + * dereferences to Spring-defined beans in that application context. + * + *

Configure this resolver in your {@code faces-config.xml} file as follows: + * + *

+ * <application>
+ *   ...
+ *   <el-resolver>org.springframework.web.jsf.el.WebApplicationContextFacesELResolver</el-resolver>
+ * </application>
+ * + * @author Juergen Hoeller + * @since 2.5 + * @see SpringBeanFacesELResolver + * @see org.springframework.web.jsf.FacesContextUtils#getWebApplicationContext + */ +public class WebApplicationContextFacesELResolver extends ELResolver { + + /** + * Name of the exposed WebApplicationContext variable: "webApplicationContext". + */ + public static final String WEB_APPLICATION_CONTEXT_VARIABLE_NAME = "webApplicationContext"; + + + /** Logger available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + + @Override + @Nullable + public Object getValue(ELContext elContext, @Nullable Object base, Object property) throws ELException { + if (base != null) { + if (base instanceof WebApplicationContext) { + WebApplicationContext wac = (WebApplicationContext) base; + String beanName = property.toString(); + if (logger.isTraceEnabled()) { + logger.trace("Attempting to resolve property '" + beanName + "' in root WebApplicationContext"); + } + if (wac.containsBean(beanName)) { + if (logger.isDebugEnabled()) { + logger.debug("Successfully resolved property '" + beanName + "' in root WebApplicationContext"); + } + elContext.setPropertyResolved(true); + try { + return wac.getBean(beanName); + } + catch (BeansException ex) { + throw new ELException(ex); + } + } + else { + // Mimic standard JSF/JSP behavior when base is a Map by returning null. + return null; + } + } + } + else { + if (WEB_APPLICATION_CONTEXT_VARIABLE_NAME.equals(property)) { + elContext.setPropertyResolved(true); + return getWebApplicationContext(elContext); + } + } + + return null; + } + + @Override + @Nullable + public Class getType(ELContext elContext, @Nullable Object base, Object property) throws ELException { + if (base != null) { + if (base instanceof WebApplicationContext) { + WebApplicationContext wac = (WebApplicationContext) base; + String beanName = property.toString(); + if (logger.isDebugEnabled()) { + logger.debug("Attempting to resolve property '" + beanName + "' in root WebApplicationContext"); + } + if (wac.containsBean(beanName)) { + if (logger.isDebugEnabled()) { + logger.debug("Successfully resolved property '" + beanName + "' in root WebApplicationContext"); + } + elContext.setPropertyResolved(true); + try { + return wac.getType(beanName); + } + catch (BeansException ex) { + throw new ELException(ex); + } + } + else { + // Mimic standard JSF/JSP behavior when base is a Map by returning null. + return null; + } + } + } + else { + if (WEB_APPLICATION_CONTEXT_VARIABLE_NAME.equals(property)) { + elContext.setPropertyResolved(true); + return WebApplicationContext.class; + } + } + + return null; + } + + @Override + public void setValue(ELContext elContext, Object base, Object property, Object value) throws ELException { + } + + @Override + public boolean isReadOnly(ELContext elContext, Object base, Object property) throws ELException { + if (base instanceof WebApplicationContext) { + elContext.setPropertyResolved(true); + return true; + } + return false; + } + + @Override + @Nullable + public Iterator getFeatureDescriptors(ELContext elContext, Object base) { + return null; + } + + @Override + public Class getCommonPropertyType(ELContext elContext, Object base) { + return Object.class; + } + + + /** + * Retrieve the {@link WebApplicationContext} reference to expose. + *

The default implementation delegates to {@link FacesContextUtils}, + * returning {@code null} if no {@code WebApplicationContext} found. + * @param elContext the current JSF ELContext + * @return the Spring web application context + * @see org.springframework.web.jsf.FacesContextUtils#getWebApplicationContext + */ + @Nullable + protected WebApplicationContext getWebApplicationContext(ELContext elContext) { + FacesContext facesContext = FacesContext.getCurrentInstance(); + return FacesContextUtils.getRequiredWebApplicationContext(facesContext); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/jsf/el/package-info.java b/spring-web/src/main/java/org/springframework/web/jsf/el/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..2ce617b9ae81abcd98f8cdcbc1208e1c2250995b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/el/package-info.java @@ -0,0 +1,10 @@ +/** + * ELResolvers for integrating a JSF web layer with a Spring service layer + * which is hosted in a Spring root WebApplicationContext. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.jsf.el; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/jsf/package-info.java b/spring-web/src/main/java/org/springframework/web/jsf/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..53bfeb3e1af43e67715efe23f1f54ee8c5cc62c6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/jsf/package-info.java @@ -0,0 +1,13 @@ +/** + * Support classes for integrating a JSF web layer with a Spring service layer + * which is hosted in a Spring root WebApplicationContext. + * + *

Supports easy access to beans in the Spring root WebApplicationContext + * from JSF EL expressions, for example in property values of JSF-managed beans. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.jsf; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/method/ControllerAdviceBean.java b/spring-web/src/main/java/org/springframework/web/method/ControllerAdviceBean.java new file mode 100644 index 0000000000000000000000000000000000000000..ddf67af093bab1ac27c00e7be74d649a12677619 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/ControllerAdviceBean.java @@ -0,0 +1,210 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.springframework.aop.scope.ScopedProxyUtils; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.context.ApplicationContext; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.OrderUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.web.bind.annotation.ControllerAdvice; + +/** + * Encapsulates information about an {@linkplain ControllerAdvice @ControllerAdvice} + * Spring-managed bean without necessarily requiring it to be instantiated. + * + *

The {@link #findAnnotatedBeans(ApplicationContext)} method can be used to + * discover such beans. However, a {@code ControllerAdviceBean} may be created + * from any object, including ones without an {@code @ControllerAdvice}. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @author Juergen Hoeller + * @author Sam Brannen + * @since 3.2 + */ +public class ControllerAdviceBean implements Ordered { + + private final Object bean; + + @Nullable + private final BeanFactory beanFactory; + + private final int order; + + private final HandlerTypePredicate beanTypePredicate; + + + /** + * Create a {@code ControllerAdviceBean} using the given bean instance. + * @param bean the bean instance + */ + public ControllerAdviceBean(Object bean) { + this(bean, null); + } + + /** + * Create a {@code ControllerAdviceBean} using the given bean name. + * @param beanName the name of the bean + * @param beanFactory a BeanFactory that can be used later to resolve the bean + */ + public ControllerAdviceBean(String beanName, @Nullable BeanFactory beanFactory) { + this((Object) beanName, beanFactory); + } + + private ControllerAdviceBean(Object bean, @Nullable BeanFactory beanFactory) { + this.bean = bean; + this.beanFactory = beanFactory; + Class beanType; + + if (bean instanceof String) { + String beanName = (String) bean; + Assert.hasText(beanName, "Bean name must not be null"); + Assert.notNull(beanFactory, "BeanFactory must not be null"); + if (!beanFactory.containsBean(beanName)) { + throw new IllegalArgumentException("BeanFactory [" + beanFactory + + "] does not contain specified controller advice bean '" + beanName + "'"); + } + beanType = this.beanFactory.getType(beanName); + this.order = initOrderFromBeanType(beanType); + } + else { + Assert.notNull(bean, "Bean must not be null"); + beanType = bean.getClass(); + this.order = initOrderFromBean(bean); + } + + ControllerAdvice annotation = (beanType != null ? + AnnotatedElementUtils.findMergedAnnotation(beanType, ControllerAdvice.class) : null); + + if (annotation != null) { + this.beanTypePredicate = HandlerTypePredicate.builder() + .basePackage(annotation.basePackages()) + .basePackageClass(annotation.basePackageClasses()) + .assignableType(annotation.assignableTypes()) + .annotation(annotation.annotations()) + .build(); + } + else { + this.beanTypePredicate = HandlerTypePredicate.forAnyHandlerType(); + } + } + + + /** + * Returns the order value extracted from the {@link ControllerAdvice} + * annotation, or {@link Ordered#LOWEST_PRECEDENCE} otherwise. + */ + @Override + public int getOrder() { + return this.order; + } + + /** + * Return the type of the contained bean. + *

If the bean type is a CGLIB-generated class, the original + * user-defined class is returned. + */ + @Nullable + public Class getBeanType() { + Class beanType = (this.bean instanceof String ? + obtainBeanFactory().getType((String) this.bean) : this.bean.getClass()); + return (beanType != null ? ClassUtils.getUserClass(beanType) : null); + } + + /** + * Return a bean instance if necessary resolving the bean name through the BeanFactory. + */ + public Object resolveBean() { + return (this.bean instanceof String ? obtainBeanFactory().getBean((String) this.bean) : this.bean); + } + + private BeanFactory obtainBeanFactory() { + Assert.state(this.beanFactory != null, "No BeanFactory set"); + return this.beanFactory; + } + + /** + * Check whether the given bean type should be assisted by this + * {@code @ControllerAdvice} instance. + * @param beanType the type of the bean to check + * @since 4.0 + * @see org.springframework.web.bind.annotation.ControllerAdvice + */ + public boolean isApplicableToBeanType(@Nullable Class beanType) { + return this.beanTypePredicate.test(beanType); + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ControllerAdviceBean)) { + return false; + } + ControllerAdviceBean otherAdvice = (ControllerAdviceBean) other; + return (this.bean.equals(otherAdvice.bean) && this.beanFactory == otherAdvice.beanFactory); + } + + @Override + public int hashCode() { + return this.bean.hashCode(); + } + + @Override + public String toString() { + return this.bean.toString(); + } + + + /** + * Find the names of beans annotated with + * {@linkplain ControllerAdvice @ControllerAdvice} in the given + * ApplicationContext and wrap them as {@code ControllerAdviceBean} instances. + */ + public static List findAnnotatedBeans(ApplicationContext context) { + return Arrays.stream(BeanFactoryUtils.beanNamesForTypeIncludingAncestors(context, Object.class)) + .filter(name -> !ScopedProxyUtils.isScopedTarget(name)) + .filter(name -> context.findAnnotationOnBean(name, ControllerAdvice.class) != null) + .map(name -> new ControllerAdviceBean(name, context)) + .collect(Collectors.toList()); + } + + private static int initOrderFromBean(Object bean) { + return (bean instanceof Ordered ? ((Ordered) bean).getOrder() : initOrderFromBeanType(bean.getClass())); + } + + private static int initOrderFromBeanType(@Nullable Class beanType) { + Integer order = null; + if (beanType != null) { + order = OrderUtils.getOrder(beanType); + } + return (order != null ? order : Ordered.LOWEST_PRECEDENCE); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/HandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/HandlerMethod.java new file mode 100644 index 0000000000000000000000000000000000000000..a8feb6e4050755917104cf8ebfe68c5f13199fa7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/HandlerMethod.java @@ -0,0 +1,549 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.core.BridgeMethodResolver; +import org.springframework.core.GenericTypeResolver; +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.bind.annotation.ResponseStatus; + +/** + * Encapsulates information about a handler method consisting of a + * {@linkplain #getMethod() method} and a {@linkplain #getBean() bean}. + * Provides convenient access to method parameters, the method return value, + * method annotations, etc. + * + *

The class may be created with a bean instance or with a bean name + * (e.g. lazy-init bean, prototype bean). Use {@link #createWithResolvedBean()} + * to obtain a {@code HandlerMethod} instance with a bean instance resolved + * through the associated {@link BeanFactory}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Sam Brannen + * @since 3.1 + */ +public class HandlerMethod { + + /** Logger that is available to subclasses. */ + protected final Log logger = LogFactory.getLog(getClass()); + + private final Object bean; + + @Nullable + private final BeanFactory beanFactory; + + private final Class beanType; + + private final Method method; + + private final Method bridgedMethod; + + private final MethodParameter[] parameters; + + @Nullable + private HttpStatus responseStatus; + + @Nullable + private String responseStatusReason; + + @Nullable + private HandlerMethod resolvedFromHandlerMethod; + + @Nullable + private volatile List interfaceParameterAnnotations; + + + /** + * Create an instance from a bean instance and a method. + */ + public HandlerMethod(Object bean, Method method) { + Assert.notNull(bean, "Bean is required"); + Assert.notNull(method, "Method is required"); + this.bean = bean; + this.beanFactory = null; + this.beanType = ClassUtils.getUserClass(bean); + this.method = method; + this.bridgedMethod = BridgeMethodResolver.findBridgedMethod(method); + this.parameters = initMethodParameters(); + evaluateResponseStatus(); + } + + /** + * Create an instance from a bean instance, method name, and parameter types. + * @throws NoSuchMethodException when the method cannot be found + */ + public HandlerMethod(Object bean, String methodName, Class... parameterTypes) throws NoSuchMethodException { + Assert.notNull(bean, "Bean is required"); + Assert.notNull(methodName, "Method name is required"); + this.bean = bean; + this.beanFactory = null; + this.beanType = ClassUtils.getUserClass(bean); + this.method = bean.getClass().getMethod(methodName, parameterTypes); + this.bridgedMethod = BridgeMethodResolver.findBridgedMethod(this.method); + this.parameters = initMethodParameters(); + evaluateResponseStatus(); + } + + /** + * Create an instance from a bean name, a method, and a {@code BeanFactory}. + * The method {@link #createWithResolvedBean()} may be used later to + * re-create the {@code HandlerMethod} with an initialized bean. + */ + public HandlerMethod(String beanName, BeanFactory beanFactory, Method method) { + Assert.hasText(beanName, "Bean name is required"); + Assert.notNull(beanFactory, "BeanFactory is required"); + Assert.notNull(method, "Method is required"); + this.bean = beanName; + this.beanFactory = beanFactory; + Class beanType = beanFactory.getType(beanName); + if (beanType == null) { + throw new IllegalStateException("Cannot resolve bean type for bean with name '" + beanName + "'"); + } + this.beanType = ClassUtils.getUserClass(beanType); + this.method = method; + this.bridgedMethod = BridgeMethodResolver.findBridgedMethod(method); + this.parameters = initMethodParameters(); + evaluateResponseStatus(); + } + + /** + * Copy constructor for use in subclasses. + */ + protected HandlerMethod(HandlerMethod handlerMethod) { + Assert.notNull(handlerMethod, "HandlerMethod is required"); + this.bean = handlerMethod.bean; + this.beanFactory = handlerMethod.beanFactory; + this.beanType = handlerMethod.beanType; + this.method = handlerMethod.method; + this.bridgedMethod = handlerMethod.bridgedMethod; + this.parameters = handlerMethod.parameters; + this.responseStatus = handlerMethod.responseStatus; + this.responseStatusReason = handlerMethod.responseStatusReason; + this.resolvedFromHandlerMethod = handlerMethod.resolvedFromHandlerMethod; + } + + /** + * Re-create HandlerMethod with the resolved handler. + */ + private HandlerMethod(HandlerMethod handlerMethod, Object handler) { + Assert.notNull(handlerMethod, "HandlerMethod is required"); + Assert.notNull(handler, "Handler object is required"); + this.bean = handler; + this.beanFactory = handlerMethod.beanFactory; + this.beanType = handlerMethod.beanType; + this.method = handlerMethod.method; + this.bridgedMethod = handlerMethod.bridgedMethod; + this.parameters = handlerMethod.parameters; + this.responseStatus = handlerMethod.responseStatus; + this.responseStatusReason = handlerMethod.responseStatusReason; + this.resolvedFromHandlerMethod = handlerMethod; + } + + private MethodParameter[] initMethodParameters() { + int count = this.bridgedMethod.getParameterCount(); + MethodParameter[] result = new MethodParameter[count]; + for (int i = 0; i < count; i++) { + HandlerMethodParameter parameter = new HandlerMethodParameter(i); + GenericTypeResolver.resolveParameterType(parameter, this.beanType); + result[i] = parameter; + } + return result; + } + + private void evaluateResponseStatus() { + ResponseStatus annotation = getMethodAnnotation(ResponseStatus.class); + if (annotation == null) { + annotation = AnnotatedElementUtils.findMergedAnnotation(getBeanType(), ResponseStatus.class); + } + if (annotation != null) { + this.responseStatus = annotation.code(); + this.responseStatusReason = annotation.reason(); + } + } + + + /** + * Return the bean for this handler method. + */ + public Object getBean() { + return this.bean; + } + + /** + * Return the method for this handler method. + */ + public Method getMethod() { + return this.method; + } + + /** + * This method returns the type of the handler for this handler method. + *

Note that if the bean type is a CGLIB-generated class, the original + * user-defined class is returned. + */ + public Class getBeanType() { + return this.beanType; + } + + /** + * If the bean method is a bridge method, this method returns the bridged + * (user-defined) method. Otherwise it returns the same method as {@link #getMethod()}. + */ + protected Method getBridgedMethod() { + return this.bridgedMethod; + } + + /** + * Return the method parameters for this handler method. + */ + public MethodParameter[] getMethodParameters() { + return this.parameters; + } + + /** + * Return the specified response status, if any. + * @since 4.3.8 + * @see ResponseStatus#code() + */ + @Nullable + protected HttpStatus getResponseStatus() { + return this.responseStatus; + } + + /** + * Return the associated response status reason, if any. + * @since 4.3.8 + * @see ResponseStatus#reason() + */ + @Nullable + protected String getResponseStatusReason() { + return this.responseStatusReason; + } + + /** + * Return the HandlerMethod return type. + */ + public MethodParameter getReturnType() { + return new HandlerMethodParameter(-1); + } + + /** + * Return the actual return value type. + */ + public MethodParameter getReturnValueType(@Nullable Object returnValue) { + return new ReturnValueMethodParameter(returnValue); + } + + /** + * Return {@code true} if the method return type is void, {@code false} otherwise. + */ + public boolean isVoid() { + return Void.TYPE.equals(getReturnType().getParameterType()); + } + + /** + * Return a single annotation on the underlying method traversing its super methods + * if no annotation can be found on the given method itself. + *

Also supports merged composed annotations with attribute + * overrides as of Spring Framework 4.2.2. + * @param annotationType the type of annotation to introspect the method for + * @return the annotation, or {@code null} if none found + * @see AnnotatedElementUtils#findMergedAnnotation + */ + @Nullable + public A getMethodAnnotation(Class annotationType) { + return AnnotatedElementUtils.findMergedAnnotation(this.method, annotationType); + } + + /** + * Return whether the parameter is declared with the given annotation type. + * @param annotationType the annotation type to look for + * @since 4.3 + * @see AnnotatedElementUtils#hasAnnotation + */ + public boolean hasMethodAnnotation(Class annotationType) { + return AnnotatedElementUtils.hasAnnotation(this.method, annotationType); + } + + /** + * Return the HandlerMethod from which this HandlerMethod instance was + * resolved via {@link #createWithResolvedBean()}. + */ + @Nullable + public HandlerMethod getResolvedFromHandlerMethod() { + return this.resolvedFromHandlerMethod; + } + + /** + * If the provided instance contains a bean name rather than an object instance, + * the bean name is resolved before a {@link HandlerMethod} is created and returned. + */ + public HandlerMethod createWithResolvedBean() { + Object handler = this.bean; + if (this.bean instanceof String) { + Assert.state(this.beanFactory != null, "Cannot resolve bean name without BeanFactory"); + String beanName = (String) this.bean; + handler = this.beanFactory.getBean(beanName); + } + return new HandlerMethod(this, handler); + } + + /** + * Return a short representation of this handler method for log message purposes. + * @since 4.3 + */ + public String getShortLogMessage() { + return getBeanType().getName() + "#" + this.method.getName() + + "[" + this.method.getParameterCount() + " args]"; + } + + + private List getInterfaceParameterAnnotations() { + List parameterAnnotations = this.interfaceParameterAnnotations; + if (parameterAnnotations == null) { + parameterAnnotations = new ArrayList<>(); + for (Class ifc : ClassUtils.getAllInterfacesForClassAsSet(this.method.getDeclaringClass())) { + for (Method candidate : ifc.getMethods()) { + if (isOverrideFor(candidate)) { + parameterAnnotations.add(candidate.getParameterAnnotations()); + } + } + } + this.interfaceParameterAnnotations = parameterAnnotations; + } + return parameterAnnotations; + } + + private boolean isOverrideFor(Method candidate) { + if (!candidate.getName().equals(this.method.getName()) || + candidate.getParameterCount() != this.method.getParameterCount()) { + return false; + } + Class[] paramTypes = this.method.getParameterTypes(); + if (Arrays.equals(candidate.getParameterTypes(), paramTypes)) { + return true; + } + for (int i = 0; i < paramTypes.length; i++) { + if (paramTypes[i] != + ResolvableType.forMethodParameter(candidate, i, this.method.getDeclaringClass()).resolve()) { + return false; + } + } + return true; + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof HandlerMethod)) { + return false; + } + HandlerMethod otherMethod = (HandlerMethod) other; + return (this.bean.equals(otherMethod.bean) && this.method.equals(otherMethod.method)); + } + + @Override + public int hashCode() { + return (this.bean.hashCode() * 31 + this.method.hashCode()); + } + + @Override + public String toString() { + return this.method.toGenericString(); + } + + + // Support methods for use in "InvocableHandlerMethod" sub-class variants.. + + @Nullable + protected static Object findProvidedArgument(MethodParameter parameter, @Nullable Object... providedArgs) { + if (!ObjectUtils.isEmpty(providedArgs)) { + for (Object providedArg : providedArgs) { + if (parameter.getParameterType().isInstance(providedArg)) { + return providedArg; + } + } + } + return null; + } + + protected static String formatArgumentError(MethodParameter param, String message) { + return "Could not resolve parameter [" + param.getParameterIndex() + "] in " + + param.getExecutable().toGenericString() + (StringUtils.hasText(message) ? ": " + message : ""); + } + + /** + * Assert that the target bean class is an instance of the class where the given + * method is declared. In some cases the actual controller instance at request- + * processing time may be a JDK dynamic proxy (lazy initialization, prototype + * beans, and others). {@code @Controller}'s that require proxying should prefer + * class-based proxy mechanisms. + */ + protected void assertTargetBean(Method method, Object targetBean, Object[] args) { + Class methodDeclaringClass = method.getDeclaringClass(); + Class targetBeanClass = targetBean.getClass(); + if (!methodDeclaringClass.isAssignableFrom(targetBeanClass)) { + String text = "The mapped handler method class '" + methodDeclaringClass.getName() + + "' is not an instance of the actual controller bean class '" + + targetBeanClass.getName() + "'. If the controller requires proxying " + + "(e.g. due to @Transactional), please use class-based proxying."; + throw new IllegalStateException(formatInvokeError(text, args)); + } + } + + protected String formatInvokeError(String text, Object[] args) { + String formattedArgs = IntStream.range(0, args.length) + .mapToObj(i -> (args[i] != null ? + "[" + i + "] [type=" + args[i].getClass().getName() + "] [value=" + args[i] + "]" : + "[" + i + "] [null]")) + .collect(Collectors.joining(",\n", " ", " ")); + return text + "\n" + + "Controller [" + getBeanType().getName() + "]\n" + + "Method [" + getBridgedMethod().toGenericString() + "] " + + "with argument values:\n" + formattedArgs; + } + + + /** + * A MethodParameter with HandlerMethod-specific behavior. + */ + protected class HandlerMethodParameter extends SynthesizingMethodParameter { + + @Nullable + private volatile Annotation[] combinedAnnotations; + + public HandlerMethodParameter(int index) { + super(HandlerMethod.this.bridgedMethod, index); + } + + protected HandlerMethodParameter(HandlerMethodParameter original) { + super(original); + } + + @Override + public Class getContainingClass() { + return HandlerMethod.this.getBeanType(); + } + + @Override + public T getMethodAnnotation(Class annotationType) { + return HandlerMethod.this.getMethodAnnotation(annotationType); + } + + @Override + public boolean hasMethodAnnotation(Class annotationType) { + return HandlerMethod.this.hasMethodAnnotation(annotationType); + } + + @Override + public Annotation[] getParameterAnnotations() { + Annotation[] anns = this.combinedAnnotations; + if (anns == null) { + anns = super.getParameterAnnotations(); + int index = getParameterIndex(); + if (index >= 0) { + for (Annotation[][] ifcAnns : getInterfaceParameterAnnotations()) { + if (index < ifcAnns.length) { + Annotation[] paramAnns = ifcAnns[index]; + if (paramAnns.length > 0) { + List merged = new ArrayList<>(anns.length + paramAnns.length); + merged.addAll(Arrays.asList(anns)); + for (Annotation paramAnn : paramAnns) { + boolean existingType = false; + for (Annotation ann : anns) { + if (ann.annotationType() == paramAnn.annotationType()) { + existingType = true; + break; + } + } + if (!existingType) { + merged.add(adaptAnnotation(paramAnn)); + } + } + anns = merged.toArray(new Annotation[0]); + } + } + } + } + this.combinedAnnotations = anns; + } + return anns; + } + + @Override + public HandlerMethodParameter clone() { + return new HandlerMethodParameter(this); + } + } + + + /** + * A MethodParameter for a HandlerMethod return type based on an actual return value. + */ + private class ReturnValueMethodParameter extends HandlerMethodParameter { + + @Nullable + private final Object returnValue; + + public ReturnValueMethodParameter(@Nullable Object returnValue) { + super(-1); + this.returnValue = returnValue; + } + + protected ReturnValueMethodParameter(ReturnValueMethodParameter original) { + super(original); + this.returnValue = original.returnValue; + } + + @Override + public Class getParameterType() { + return (this.returnValue != null ? this.returnValue.getClass() : super.getParameterType()); + } + + @Override + public ReturnValueMethodParameter clone() { + return new ReturnValueMethodParameter(this); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/HandlerTypePredicate.java b/spring-web/src/main/java/org/springframework/web/method/HandlerTypePredicate.java new file mode 100644 index 0000000000000000000000000000000000000000..b40e0483f3e1df902bcd8d2a3173e806ba26d5cb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/HandlerTypePredicate.java @@ -0,0 +1,211 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Predicate; + +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; + +/** + * A {@code Predicate} to match request handling component types if + * any of the following selectors match: + *

+ *

Composability methods on {@link Predicate} can be used : + *

+ * Predicate<Class<?>> predicate =
+ * 		HandlerTypePredicate.forAnnotation(RestController.class)
+ * 				.and(HandlerTypePredicate.forBasePackage("org.example"));
+ * 
+ * + * @author Rossen Stoyanchev + * @since 5.1 + */ +public final class HandlerTypePredicate implements Predicate> { + + private final Set basePackages; + + private final List> assignableTypes; + + private final List> annotations; + + + /** + * Private constructor. See static factory methods. + */ + private HandlerTypePredicate(Set basePackages, List> assignableTypes, + List> annotations) { + + this.basePackages = Collections.unmodifiableSet(basePackages); + this.assignableTypes = Collections.unmodifiableList(assignableTypes); + this.annotations = Collections.unmodifiableList(annotations); + } + + + @Override + public boolean test(Class controllerType) { + if (!hasSelectors()) { + return true; + } + else if (controllerType != null) { + for (String basePackage : this.basePackages) { + if (controllerType.getName().startsWith(basePackage)) { + return true; + } + } + for (Class clazz : this.assignableTypes) { + if (ClassUtils.isAssignable(clazz, controllerType)) { + return true; + } + } + for (Class annotationClass : this.annotations) { + if (AnnotationUtils.findAnnotation(controllerType, annotationClass) != null) { + return true; + } + } + } + return false; + } + + private boolean hasSelectors() { + return (!this.basePackages.isEmpty() || !this.assignableTypes.isEmpty() || !this.annotations.isEmpty()); + } + + + // Static factory methods + + /** + * {@code Predicate} that applies to any handlers. + */ + public static HandlerTypePredicate forAnyHandlerType() { + return new HandlerTypePredicate( + Collections.emptySet(), Collections.emptyList(), Collections.emptyList()); + } + + /** + * Match handlers declared under a base package, e.g. "org.example". + * @param packages one or more base package names + */ + public static HandlerTypePredicate forBasePackage(String... packages) { + return new Builder().basePackage(packages).build(); + } + + /** + * Type-safe alternative to {@link #forBasePackage(String...)} to specify a + * base package through a class. + * @param packageClasses one or more base package classes + */ + public static HandlerTypePredicate forBasePackageClass(Class... packageClasses) { + return new Builder().basePackageClass(packageClasses).build(); + } + + /** + * Match handlers that are assignable to a given type. + * @param types one or more handler super types + */ + public static HandlerTypePredicate forAssignableType(Class... types) { + return new Builder().assignableType(types).build(); + } + + /** + * Match handlers annotated with a specific annotation. + * @param annotations one or more annotations to check for + */ + @SafeVarargs + public static HandlerTypePredicate forAnnotation(Class... annotations) { + return new Builder().annotation(annotations).build(); + } + + /** + * Return a builder for a {@code HandlerTypePredicate}. + */ + public static Builder builder() { + return new Builder(); + } + + + /** + * A {@link HandlerTypePredicate} builder. + */ + public static class Builder { + + private final Set basePackages = new LinkedHashSet<>(); + + private final List> assignableTypes = new ArrayList<>(); + + private final List> annotations = new ArrayList<>(); + + /** + * Match handlers declared under a base package, e.g. "org.example". + * @param packages one or more base package classes + */ + public Builder basePackage(String... packages) { + Arrays.stream(packages).filter(StringUtils::hasText).forEach(this::addBasePackage); + return this; + } + + /** + * Type-safe alternative to {@link #forBasePackage(String...)} to specify a + * base package through a class. + * @param packageClasses one or more base package names + */ + public Builder basePackageClass(Class... packageClasses) { + Arrays.stream(packageClasses).forEach(clazz -> addBasePackage(ClassUtils.getPackageName(clazz))); + return this; + } + + private void addBasePackage(String basePackage) { + this.basePackages.add(basePackage.endsWith(".") ? basePackage : basePackage + "."); + } + + /** + * Match handlers that are assignable to a given type. + * @param types one or more handler super types + */ + public Builder assignableType(Class... types) { + this.assignableTypes.addAll(Arrays.asList(types)); + return this; + } + + /** + * Match types that are annotated with one of the given annotations. + * @param annotations one or more annotations to check for + */ + @SuppressWarnings("unchecked") + public final Builder annotation(Class... annotations) { + this.annotations.addAll(Arrays.asList(annotations)); + return this; + } + + public HandlerTypePredicate build() { + return new HandlerTypePredicate(this.basePackages, this.assignableTypes, this.annotations); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractCookieValueMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractCookieValueMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..99103719f53ddd38e63800ee393e07ab76169cda --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractCookieValueMethodArgumentResolver.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.bind.MissingRequestCookieException; +import org.springframework.web.bind.ServletRequestBindingException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.CookieValue; + +/** + * A base abstract class to resolve method arguments annotated with + * {@code @CookieValue}. Subclasses extract the cookie value from the request. + * + *

An {@code @CookieValue} is a named value that is resolved from a cookie. + * It has a required flag and a default value to fall back on when the cookie + * does not exist. + * + *

A {@link WebDataBinder} may be invoked to apply type conversion to the + * resolved cookie value. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.1 + */ +public abstract class AbstractCookieValueMethodArgumentResolver extends AbstractNamedValueMethodArgumentResolver { + + /** + * Crate a new {@link AbstractCookieValueMethodArgumentResolver} instance. + * @param beanFactory a bean factory to use for resolving ${...} + * placeholder and #{...} SpEL expressions in default values; + * or {@code null} if default values are not expected to contain expressions + */ + public AbstractCookieValueMethodArgumentResolver(@Nullable ConfigurableBeanFactory beanFactory) { + super(beanFactory); + } + + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return parameter.hasParameterAnnotation(CookieValue.class); + } + + @Override + protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) { + CookieValue annotation = parameter.getParameterAnnotation(CookieValue.class); + Assert.state(annotation != null, "No CookieValue annotation"); + return new CookieValueNamedValueInfo(annotation); + } + + @Override + protected void handleMissingValue(String name, MethodParameter parameter) throws ServletRequestBindingException { + throw new MissingRequestCookieException(name, parameter); + } + + + private static final class CookieValueNamedValueInfo extends NamedValueInfo { + + private CookieValueNamedValueInfo(CookieValue annotation) { + super(annotation.name(), annotation.required(), annotation.defaultValue()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractNamedValueMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractNamedValueMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..55ea38da9ef0c22718d91156676bc080b68945fc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractNamedValueMethodArgumentResolver.java @@ -0,0 +1,285 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.servlet.ServletException; + +import org.springframework.beans.ConversionNotSupportedException; +import org.springframework.beans.TypeMismatchException; +import org.springframework.beans.factory.config.BeanExpressionContext; +import org.springframework.beans.factory.config.BeanExpressionResolver; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.web.bind.ServletRequestBindingException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.ValueConstants; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.RequestScope; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Abstract base class for resolving method arguments from a named value. + * Request parameters, request headers, and path variables are examples of named + * values. Each may have a name, a required flag, and a default value. + * + *

Subclasses define how to do the following: + *

    + *
  • Obtain named value information for a method parameter + *
  • Resolve names into argument values + *
  • Handle missing argument values when argument values are required + *
  • Optionally handle a resolved value + *
+ * + *

A default value string can contain ${...} placeholders and Spring Expression + * Language #{...} expressions. For this to work a + * {@link ConfigurableBeanFactory} must be supplied to the class constructor. + * + *

A {@link WebDataBinder} is created to apply type conversion to the resolved + * argument value if it doesn't match the method parameter type. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public abstract class AbstractNamedValueMethodArgumentResolver implements HandlerMethodArgumentResolver { + + @Nullable + private final ConfigurableBeanFactory configurableBeanFactory; + + @Nullable + private final BeanExpressionContext expressionContext; + + private final Map namedValueInfoCache = new ConcurrentHashMap<>(256); + + + public AbstractNamedValueMethodArgumentResolver() { + this.configurableBeanFactory = null; + this.expressionContext = null; + } + + /** + * Create a new {@link AbstractNamedValueMethodArgumentResolver} instance. + * @param beanFactory a bean factory to use for resolving ${...} placeholder + * and #{...} SpEL expressions in default values, or {@code null} if default + * values are not expected to contain expressions + */ + public AbstractNamedValueMethodArgumentResolver(@Nullable ConfigurableBeanFactory beanFactory) { + this.configurableBeanFactory = beanFactory; + this.expressionContext = + (beanFactory != null ? new BeanExpressionContext(beanFactory, new RequestScope()) : null); + } + + + @Override + @Nullable + public final Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + NamedValueInfo namedValueInfo = getNamedValueInfo(parameter); + MethodParameter nestedParameter = parameter.nestedIfOptional(); + + Object resolvedName = resolveStringValue(namedValueInfo.name); + if (resolvedName == null) { + throw new IllegalArgumentException( + "Specified name must not resolve to null: [" + namedValueInfo.name + "]"); + } + + Object arg = resolveName(resolvedName.toString(), nestedParameter, webRequest); + if (arg == null) { + if (namedValueInfo.defaultValue != null) { + arg = resolveStringValue(namedValueInfo.defaultValue); + } + else if (namedValueInfo.required && !nestedParameter.isOptional()) { + handleMissingValue(namedValueInfo.name, nestedParameter, webRequest); + } + arg = handleNullValue(namedValueInfo.name, arg, nestedParameter.getNestedParameterType()); + } + else if ("".equals(arg) && namedValueInfo.defaultValue != null) { + arg = resolveStringValue(namedValueInfo.defaultValue); + } + + if (binderFactory != null) { + WebDataBinder binder = binderFactory.createBinder(webRequest, null, namedValueInfo.name); + try { + arg = binder.convertIfNecessary(arg, parameter.getParameterType(), parameter); + } + catch (ConversionNotSupportedException ex) { + throw new MethodArgumentConversionNotSupportedException(arg, ex.getRequiredType(), + namedValueInfo.name, parameter, ex.getCause()); + } + catch (TypeMismatchException ex) { + throw new MethodArgumentTypeMismatchException(arg, ex.getRequiredType(), + namedValueInfo.name, parameter, ex.getCause()); + } + } + + handleResolvedValue(arg, namedValueInfo.name, parameter, mavContainer, webRequest); + + return arg; + } + + /** + * Obtain the named value for the given method parameter. + */ + private NamedValueInfo getNamedValueInfo(MethodParameter parameter) { + NamedValueInfo namedValueInfo = this.namedValueInfoCache.get(parameter); + if (namedValueInfo == null) { + namedValueInfo = createNamedValueInfo(parameter); + namedValueInfo = updateNamedValueInfo(parameter, namedValueInfo); + this.namedValueInfoCache.put(parameter, namedValueInfo); + } + return namedValueInfo; + } + + /** + * Create the {@link NamedValueInfo} object for the given method parameter. Implementations typically + * retrieve the method annotation by means of {@link MethodParameter#getParameterAnnotation(Class)}. + * @param parameter the method parameter + * @return the named value information + */ + protected abstract NamedValueInfo createNamedValueInfo(MethodParameter parameter); + + /** + * Create a new NamedValueInfo based on the given NamedValueInfo with sanitized values. + */ + private NamedValueInfo updateNamedValueInfo(MethodParameter parameter, NamedValueInfo info) { + String name = info.name; + if (info.name.isEmpty()) { + name = parameter.getParameterName(); + if (name == null) { + throw new IllegalArgumentException( + "Name for argument of type [" + parameter.getNestedParameterType().getName() + + "] not specified, and parameter name information not found in class file either."); + } + } + String defaultValue = (ValueConstants.DEFAULT_NONE.equals(info.defaultValue) ? null : info.defaultValue); + return new NamedValueInfo(name, info.required, defaultValue); + } + + /** + * Resolve the given annotation-specified value, + * potentially containing placeholders and expressions. + */ + @Nullable + private Object resolveStringValue(String value) { + if (this.configurableBeanFactory == null || this.expressionContext == null) { + return value; + } + String placeholdersResolved = this.configurableBeanFactory.resolveEmbeddedValue(value); + BeanExpressionResolver exprResolver = this.configurableBeanFactory.getBeanExpressionResolver(); + if (exprResolver == null) { + return value; + } + return exprResolver.evaluate(placeholdersResolved, this.expressionContext); + } + + /** + * Resolve the given parameter type and value name into an argument value. + * @param name the name of the value being resolved + * @param parameter the method parameter to resolve to an argument value + * (pre-nested in case of a {@link java.util.Optional} declaration) + * @param request the current request + * @return the resolved argument (may be {@code null}) + * @throws Exception in case of errors + */ + @Nullable + protected abstract Object resolveName(String name, MethodParameter parameter, NativeWebRequest request) + throws Exception; + + /** + * Invoked when a named value is required, but {@link #resolveName(String, MethodParameter, NativeWebRequest)} + * returned {@code null} and there is no default value. Subclasses typically throw an exception in this case. + * @param name the name for the value + * @param parameter the method parameter + * @param request the current request + * @since 4.3 + */ + protected void handleMissingValue(String name, MethodParameter parameter, NativeWebRequest request) + throws Exception { + + handleMissingValue(name, parameter); + } + + /** + * Invoked when a named value is required, but {@link #resolveName(String, MethodParameter, NativeWebRequest)} + * returned {@code null} and there is no default value. Subclasses typically throw an exception in this case. + * @param name the name for the value + * @param parameter the method parameter + */ + protected void handleMissingValue(String name, MethodParameter parameter) throws ServletException { + throw new ServletRequestBindingException("Missing argument '" + name + + "' for method parameter of type " + parameter.getNestedParameterType().getSimpleName()); + } + + /** + * A {@code null} results in a {@code false} value for {@code boolean}s or an exception for other primitives. + */ + @Nullable + private Object handleNullValue(String name, @Nullable Object value, Class paramType) { + if (value == null) { + if (Boolean.TYPE.equals(paramType)) { + return Boolean.FALSE; + } + else if (paramType.isPrimitive()) { + throw new IllegalStateException("Optional " + paramType.getSimpleName() + " parameter '" + name + + "' is present but cannot be translated into a null value due to being declared as a " + + "primitive type. Consider declaring it as object wrapper for the corresponding primitive type."); + } + } + return value; + } + + /** + * Invoked after a value is resolved. + * @param arg the resolved argument value + * @param name the argument name + * @param parameter the argument parameter type + * @param mavContainer the {@link ModelAndViewContainer} (may be {@code null}) + * @param webRequest the current request + */ + protected void handleResolvedValue(@Nullable Object arg, String name, MethodParameter parameter, + @Nullable ModelAndViewContainer mavContainer, NativeWebRequest webRequest) { + } + + + /** + * Represents the information about a named value, including name, whether it's required and a default value. + */ + protected static class NamedValueInfo { + + private final String name; + + private final boolean required; + + @Nullable + private final String defaultValue; + + public NamedValueInfo(String name, boolean required, @Nullable String defaultValue) { + this.name = name; + this.required = required; + this.defaultValue = defaultValue; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractWebArgumentResolverAdapter.java b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractWebArgumentResolverAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..0d0272a373e6a78af2bee634de0e52624a38eca0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractWebArgumentResolverAdapter.java @@ -0,0 +1,118 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.web.bind.support.WebArgumentResolver; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * An abstract base class adapting a {@link WebArgumentResolver} to the + * {@link HandlerMethodArgumentResolver} contract. + * + *

Note: This class is provided for backwards compatibility. + * However it is recommended to re-write a {@code WebArgumentResolver} as + * {@code HandlerMethodArgumentResolver}. Since {@link #supportsParameter} + * can only be implemented by actually resolving the value and then checking + * the result is not {@code WebArgumentResolver#UNRESOLVED} any exceptions + * raised must be absorbed and ignored since it's not clear whether the adapter + * doesn't support the parameter or whether it failed for an internal reason. + * The {@code HandlerMethodArgumentResolver} contract also provides access to + * model attributes and to {@code WebDataBinderFactory} (for type conversion). + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.1 + */ +public abstract class AbstractWebArgumentResolverAdapter implements HandlerMethodArgumentResolver { + + private final Log logger = LogFactory.getLog(getClass()); + + private final WebArgumentResolver adaptee; + + + /** + * Create a new instance. + */ + public AbstractWebArgumentResolverAdapter(WebArgumentResolver adaptee) { + Assert.notNull(adaptee, "'adaptee' must not be null"); + this.adaptee = adaptee; + } + + + /** + * Actually resolve the value and check the resolved value is not + * {@link WebArgumentResolver#UNRESOLVED} absorbing _any_ exceptions. + */ + @Override + public boolean supportsParameter(MethodParameter parameter) { + try { + NativeWebRequest webRequest = getWebRequest(); + Object result = this.adaptee.resolveArgument(parameter, webRequest); + if (result == WebArgumentResolver.UNRESOLVED) { + return false; + } + else { + return ClassUtils.isAssignableValue(parameter.getParameterType(), result); + } + } + catch (Exception ex) { + // ignore (see class-level doc) + if (logger.isDebugEnabled()) { + logger.debug("Error in checking support for parameter [" + parameter + "]: " + ex.getMessage()); + } + return false; + } + } + + /** + * Delegate to the {@link WebArgumentResolver} instance. + * @throws IllegalStateException if the resolved value is not assignable + * to the method parameter. + */ + @Override + @Nullable + public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + Class paramType = parameter.getParameterType(); + Object result = this.adaptee.resolveArgument(parameter, webRequest); + if (result == WebArgumentResolver.UNRESOLVED || !ClassUtils.isAssignableValue(paramType, result)) { + throw new IllegalStateException( + "Standard argument type [" + paramType.getName() + "] in method " + parameter.getMethod() + + "resolved to incompatible value of type [" + (result != null ? result.getClass() : null) + + "]. Consider declaring the argument type in a less specific fashion."); + } + return result; + } + + + /** + * Required for access to NativeWebRequest in {@link #supportsParameter}. + */ + protected abstract NativeWebRequest getWebRequest(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/ErrorsMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/ErrorsMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..2e727c7975f5eb1a6f078bfd309b55684cdbad7c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/ErrorsMethodArgumentResolver.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.ui.ModelMap; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.validation.BindingResult; +import org.springframework.validation.Errors; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Resolves {@link Errors} method arguments. + * + *

An {@code Errors} method argument is expected to appear immediately after + * the model attribute in the method signature. It is resolved by expecting the + * last two attributes added to the model to be the model attribute and its + * {@link BindingResult}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public class ErrorsMethodArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + Class paramType = parameter.getParameterType(); + return Errors.class.isAssignableFrom(paramType); + } + + @Override + @Nullable + public Object resolveArgument(MethodParameter parameter, + @Nullable ModelAndViewContainer mavContainer, NativeWebRequest webRequest, + @Nullable WebDataBinderFactory binderFactory) throws Exception { + + Assert.state(mavContainer != null, + "Errors/BindingResult argument only supported on regular handler methods"); + + ModelMap model = mavContainer.getModel(); + String lastKey = CollectionUtils.lastElement(model.keySet()); + if (lastKey != null && lastKey.startsWith(BindingResult.MODEL_KEY_PREFIX)) { + return model.get(lastKey); + } + + throw new IllegalStateException( + "An Errors/BindingResult argument is expected to be declared immediately after " + + "the model attribute, the @RequestBody or the @RequestPart arguments " + + "to which they apply: " + parameter.getMethod()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..c247262b25b052e80647824b2eb15d1653d831f2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolver.java @@ -0,0 +1,179 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.core.ExceptionDepthComparator; +import org.springframework.core.MethodIntrospector; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ConcurrentReferenceHashMap; +import org.springframework.util.ReflectionUtils.MethodFilter; +import org.springframework.web.bind.annotation.ExceptionHandler; + +/** + * Discovers {@linkplain ExceptionHandler @ExceptionHandler} methods in a given class, + * including all of its superclasses, and helps to resolve a given {@link Exception} + * to the exception types supported by a given {@link Method}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public class ExceptionHandlerMethodResolver { + + /** + * A filter for selecting {@code @ExceptionHandler} methods. + */ + public static final MethodFilter EXCEPTION_HANDLER_METHODS = method -> + AnnotatedElementUtils.hasAnnotation(method, ExceptionHandler.class); + + + private final Map, Method> mappedMethods = new HashMap<>(16); + + private final Map, Method> exceptionLookupCache = new ConcurrentReferenceHashMap<>(16); + + + /** + * A constructor that finds {@link ExceptionHandler} methods in the given type. + * @param handlerType the type to introspect + */ + public ExceptionHandlerMethodResolver(Class handlerType) { + for (Method method : MethodIntrospector.selectMethods(handlerType, EXCEPTION_HANDLER_METHODS)) { + for (Class exceptionType : detectExceptionMappings(method)) { + addExceptionMapping(exceptionType, method); + } + } + } + + + /** + * Extract exception mappings from the {@code @ExceptionHandler} annotation first, + * and then as a fallback from the method signature itself. + */ + @SuppressWarnings("unchecked") + private List> detectExceptionMappings(Method method) { + List> result = new ArrayList<>(); + detectAnnotationExceptionMappings(method, result); + if (result.isEmpty()) { + for (Class paramType : method.getParameterTypes()) { + if (Throwable.class.isAssignableFrom(paramType)) { + result.add((Class) paramType); + } + } + } + if (result.isEmpty()) { + throw new IllegalStateException("No exception types mapped to " + method); + } + return result; + } + + private void detectAnnotationExceptionMappings(Method method, List> result) { + ExceptionHandler ann = AnnotatedElementUtils.findMergedAnnotation(method, ExceptionHandler.class); + Assert.state(ann != null, "No ExceptionHandler annotation"); + result.addAll(Arrays.asList(ann.value())); + } + + private void addExceptionMapping(Class exceptionType, Method method) { + Method oldMethod = this.mappedMethods.put(exceptionType, method); + if (oldMethod != null && !oldMethod.equals(method)) { + throw new IllegalStateException("Ambiguous @ExceptionHandler method mapped for [" + + exceptionType + "]: {" + oldMethod + ", " + method + "}"); + } + } + + /** + * Whether the contained type has any exception mappings. + */ + public boolean hasExceptionMappings() { + return !this.mappedMethods.isEmpty(); + } + + /** + * Find a {@link Method} to handle the given exception. + * Use {@link ExceptionDepthComparator} if more than one match is found. + * @param exception the exception + * @return a Method to handle the exception, or {@code null} if none found + */ + @Nullable + public Method resolveMethod(Exception exception) { + return resolveMethodByThrowable(exception); + } + + /** + * Find a {@link Method} to handle the given Throwable. + * Use {@link ExceptionDepthComparator} if more than one match is found. + * @param exception the exception + * @return a Method to handle the exception, or {@code null} if none found + * @since 5.0 + */ + @Nullable + public Method resolveMethodByThrowable(Throwable exception) { + Method method = resolveMethodByExceptionType(exception.getClass()); + if (method == null) { + Throwable cause = exception.getCause(); + if (cause != null) { + method = resolveMethodByExceptionType(cause.getClass()); + } + } + return method; + } + + /** + * Find a {@link Method} to handle the given exception type. This can be + * useful if an {@link Exception} instance is not available (e.g. for tools). + * @param exceptionType the exception type + * @return a Method to handle the exception, or {@code null} if none found + */ + @Nullable + public Method resolveMethodByExceptionType(Class exceptionType) { + Method method = this.exceptionLookupCache.get(exceptionType); + if (method == null) { + method = getMappedMethod(exceptionType); + this.exceptionLookupCache.put(exceptionType, method); + } + return method; + } + + /** + * Return the {@link Method} mapped to the given exception type, or {@code null} if none. + */ + @Nullable + private Method getMappedMethod(Class exceptionType) { + List> matches = new ArrayList<>(); + for (Class mappedException : this.mappedMethods.keySet()) { + if (mappedException.isAssignableFrom(exceptionType)) { + matches.add(mappedException); + } + } + if (!matches.isEmpty()) { + matches.sort(new ExceptionDepthComparator(exceptionType)); + return this.mappedMethods.get(matches.get(0)); + } + else { + return null; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/ExpressionValueMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/ExpressionValueMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..18c60dee2b753858d921fd6260ef1216701c4d21 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/ExpressionValueMethodArgumentResolver.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import javax.servlet.ServletException; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Resolves method arguments annotated with {@code @Value}. + * + *

An {@code @Value} does not have a name but gets resolved from the default + * value string, which may contain ${...} placeholder or Spring Expression + * Language #{...} expressions. + * + *

A {@link WebDataBinder} may be invoked to apply type conversion to + * resolved argument value. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class ExpressionValueMethodArgumentResolver extends AbstractNamedValueMethodArgumentResolver { + + /** + * Create a new {@link ExpressionValueMethodArgumentResolver} instance. + * @param beanFactory a bean factory to use for resolving ${...} + * placeholder and #{...} SpEL expressions in default values; + * or {@code null} if default values are not expected to contain expressions + */ + public ExpressionValueMethodArgumentResolver(@Nullable ConfigurableBeanFactory beanFactory) { + super(beanFactory); + } + + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return parameter.hasParameterAnnotation(Value.class); + } + + @Override + protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) { + Value ann = parameter.getParameterAnnotation(Value.class); + Assert.state(ann != null, "No Value annotation"); + return new ExpressionValueNamedValueInfo(ann); + } + + @Override + @Nullable + protected Object resolveName(String name, MethodParameter parameter, NativeWebRequest webRequest) throws Exception { + // No name to resolve + return null; + } + + @Override + protected void handleMissingValue(String name, MethodParameter parameter) throws ServletException { + throw new UnsupportedOperationException("@Value is never required: " + parameter.getMethod()); + } + + + private static final class ExpressionValueNamedValueInfo extends NamedValueInfo { + + private ExpressionValueNamedValueInfo(Value annotation) { + super("@Value", false, annotation.value()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/InitBinderDataBinderFactory.java b/spring-web/src/main/java/org/springframework/web/method/annotation/InitBinderDataBinderFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..d8446a353ac02a2e6275fc4aa14337d01eab74cb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/InitBinderDataBinderFactory.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Collections; +import java.util.List; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.InitBinder; +import org.springframework.web.bind.support.DefaultDataBinderFactory; +import org.springframework.web.bind.support.WebBindingInitializer; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.method.support.InvocableHandlerMethod; + +/** + * Adds initialization to a WebDataBinder via {@code @InitBinder} methods. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class InitBinderDataBinderFactory extends DefaultDataBinderFactory { + + private final List binderMethods; + + + /** + * Create a new InitBinderDataBinderFactory instance. + * @param binderMethods {@code @InitBinder} methods + * @param initializer for global data binder initialization + */ + public InitBinderDataBinderFactory(@Nullable List binderMethods, + @Nullable WebBindingInitializer initializer) { + + super(initializer); + this.binderMethods = (binderMethods != null ? binderMethods : Collections.emptyList()); + } + + + /** + * Initialize a WebDataBinder with {@code @InitBinder} methods. + *

If the {@code @InitBinder} annotation specifies attributes names, + * it is invoked only if the names include the target object name. + * @throws Exception if one of the invoked @{@link InitBinder} methods fails + * @see #isBinderMethodApplicable + */ + @Override + public void initBinder(WebDataBinder dataBinder, NativeWebRequest request) throws Exception { + for (InvocableHandlerMethod binderMethod : this.binderMethods) { + if (isBinderMethodApplicable(binderMethod, dataBinder)) { + Object returnValue = binderMethod.invokeForRequest(request, null, dataBinder); + if (returnValue != null) { + throw new IllegalStateException( + "@InitBinder methods must not return a value (should be void): " + binderMethod); + } + } + } + } + + /** + * Determine whether the given {@code @InitBinder} method should be used + * to initialize the given {@link WebDataBinder} instance. By default we + * check the specified attribute names in the annotation value, if any. + */ + protected boolean isBinderMethodApplicable(HandlerMethod initBinderMethod, WebDataBinder dataBinder) { + InitBinder ann = initBinderMethod.getMethodAnnotation(InitBinder.class); + Assert.state(ann != null, "No InitBinder annotation"); + String[] names = ann.value(); + return (ObjectUtils.isEmpty(names) || ObjectUtils.containsElement(names, dataBinder.getObjectName())); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/MapMethodProcessor.java b/spring-web/src/main/java/org/springframework/web/method/annotation/MapMethodProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..ea505d276bb2107b99de8d232014825fa9fd91d4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/MapMethodProcessor.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Map; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.HandlerMethodReturnValueHandler; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Resolves {@link Map} method arguments and handles {@link Map} return values. + * + *

A Map return value can be interpreted in more than one ways depending + * on the presence of annotations like {@code @ModelAttribute} or + * {@code @ResponseBody}. Therefore this handler should be configured after + * the handlers that support these annotations. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class MapMethodProcessor implements HandlerMethodArgumentResolver, HandlerMethodReturnValueHandler { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return Map.class.isAssignableFrom(parameter.getParameterType()); + } + + @Override + @Nullable + public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + Assert.state(mavContainer != null, "ModelAndViewContainer is required for model exposure"); + return mavContainer.getModel(); + } + + @Override + public boolean supportsReturnType(MethodParameter returnType) { + return Map.class.isAssignableFrom(returnType.getParameterType()); + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public void handleReturnValue(@Nullable Object returnValue, MethodParameter returnType, + ModelAndViewContainer mavContainer, NativeWebRequest webRequest) throws Exception { + + if (returnValue instanceof Map){ + mavContainer.addAllAttributes((Map) returnValue); + } + else if (returnValue != null) { + // should not happen + throw new UnsupportedOperationException("Unexpected return type: " + + returnType.getParameterType().getName() + " in method: " + returnType.getMethod()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/MethodArgumentConversionNotSupportedException.java b/spring-web/src/main/java/org/springframework/web/method/annotation/MethodArgumentConversionNotSupportedException.java new file mode 100644 index 0000000000000000000000000000000000000000..a30602dcdd498c4b6655e123c471eca57ab07d84 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/MethodArgumentConversionNotSupportedException.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.springframework.beans.ConversionNotSupportedException; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; + +/** + * A ConversionNotSupportedException raised while resolving a method argument. + * Provides access to the target {@link org.springframework.core.MethodParameter + * MethodParameter}. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +@SuppressWarnings("serial") +public class MethodArgumentConversionNotSupportedException extends ConversionNotSupportedException { + + private final String name; + + private final MethodParameter parameter; + + + public MethodArgumentConversionNotSupportedException(@Nullable Object value, + @Nullable Class requiredType, String name, MethodParameter param, Throwable cause) { + + super(value, requiredType, cause); + this.name = name; + this.parameter = param; + } + + + /** + * Return the name of the method argument. + */ + public String getName() { + return this.name; + } + + /** + * Return the target method parameter. + */ + public MethodParameter getParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/MethodArgumentTypeMismatchException.java b/spring-web/src/main/java/org/springframework/web/method/annotation/MethodArgumentTypeMismatchException.java new file mode 100644 index 0000000000000000000000000000000000000000..2f91434a29c4be3129269793d655301babfd63b6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/MethodArgumentTypeMismatchException.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.springframework.beans.TypeMismatchException; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; + +/** + * A TypeMismatchException raised while resolving a controller method argument. + * Provides access to the target {@link org.springframework.core.MethodParameter + * MethodParameter}. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +@SuppressWarnings("serial") +public class MethodArgumentTypeMismatchException extends TypeMismatchException { + + private final String name; + + private final MethodParameter parameter; + + + public MethodArgumentTypeMismatchException(@Nullable Object value, + @Nullable Class requiredType, String name, MethodParameter param, Throwable cause) { + + super(value, requiredType, cause); + this.name = name; + this.parameter = param; + } + + + /** + * Return the name of the method argument. + */ + public String getName() { + return this.name; + } + + /** + * Return the target method parameter. + */ + public MethodParameter getParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/ModelAttributeMethodProcessor.java b/spring-web/src/main/java/org/springframework/web/method/annotation/ModelAttributeMethodProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..340a039df6e07b7cfa111343a2ef75f3b0cf29a0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/ModelAttributeMethodProcessor.java @@ -0,0 +1,528 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.beans.ConstructorProperties; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.TypeMismatchException; +import org.springframework.core.DefaultParameterNameDiscoverer; +import org.springframework.core.MethodParameter; +import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.validation.BindException; +import org.springframework.validation.BindingResult; +import org.springframework.validation.Errors; +import org.springframework.validation.SmartValidator; +import org.springframework.validation.Validator; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.bind.support.WebRequestDataBinder; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.HandlerMethodReturnValueHandler; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Resolve {@code @ModelAttribute} annotated method arguments and handle + * return values from {@code @ModelAttribute} annotated methods. + * + *

Model attributes are obtained from the model or created with a default + * constructor (and then added to the model). Once created the attribute is + * populated via data binding to Servlet request parameters. Validation may be + * applied if the argument is annotated with {@code @javax.validation.Valid}. + * or Spring's own {@code @org.springframework.validation.annotation.Validated}. + * + *

When this handler is created with {@code annotationNotRequired=true} + * any non-simple type argument and return value is regarded as a model + * attribute with or without the presence of an {@code @ModelAttribute}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 3.1 + */ +public class ModelAttributeMethodProcessor implements HandlerMethodArgumentResolver, HandlerMethodReturnValueHandler { + + private static final ParameterNameDiscoverer parameterNameDiscoverer = new DefaultParameterNameDiscoverer(); + + protected final Log logger = LogFactory.getLog(getClass()); + + private final boolean annotationNotRequired; + + + /** + * Class constructor. + * @param annotationNotRequired if "true", non-simple method arguments and + * return values are considered model attributes with or without a + * {@code @ModelAttribute} annotation + */ + public ModelAttributeMethodProcessor(boolean annotationNotRequired) { + this.annotationNotRequired = annotationNotRequired; + } + + + /** + * Returns {@code true} if the parameter is annotated with + * {@link ModelAttribute} or, if in default resolution mode, for any + * method parameter that is not a simple type. + */ + @Override + public boolean supportsParameter(MethodParameter parameter) { + return (parameter.hasParameterAnnotation(ModelAttribute.class) || + (this.annotationNotRequired && !BeanUtils.isSimpleProperty(parameter.getParameterType()))); + } + + /** + * Resolve the argument from the model or if not found instantiate it with + * its default if it is available. The model attribute is then populated + * with request values via data binding and optionally validated + * if {@code @java.validation.Valid} is present on the argument. + * @throws BindException if data binding and validation result in an error + * and the next method parameter is not of type {@link Errors} + * @throws Exception if WebDataBinder initialization fails + */ + @Override + @Nullable + public final Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + Assert.state(mavContainer != null, "ModelAttributeMethodProcessor requires ModelAndViewContainer"); + Assert.state(binderFactory != null, "ModelAttributeMethodProcessor requires WebDataBinderFactory"); + + String name = ModelFactory.getNameForParameter(parameter); + ModelAttribute ann = parameter.getParameterAnnotation(ModelAttribute.class); + if (ann != null) { + mavContainer.setBinding(name, ann.binding()); + } + + Object attribute = null; + BindingResult bindingResult = null; + + if (mavContainer.containsAttribute(name)) { + attribute = mavContainer.getModel().get(name); + } + else { + // Create attribute instance + try { + attribute = createAttribute(name, parameter, binderFactory, webRequest); + } + catch (BindException ex) { + if (isBindExceptionRequired(parameter)) { + // No BindingResult parameter -> fail with BindException + throw ex; + } + // Otherwise, expose null/empty value and associated BindingResult + if (parameter.getParameterType() == Optional.class) { + attribute = Optional.empty(); + } + bindingResult = ex.getBindingResult(); + } + } + + if (bindingResult == null) { + // Bean property binding and validation; + // skipped in case of binding failure on construction. + WebDataBinder binder = binderFactory.createBinder(webRequest, attribute, name); + if (binder.getTarget() != null) { + if (!mavContainer.isBindingDisabled(name)) { + bindRequestParameters(binder, webRequest); + } + validateIfApplicable(binder, parameter); + if (binder.getBindingResult().hasErrors() && isBindExceptionRequired(binder, parameter)) { + throw new BindException(binder.getBindingResult()); + } + } + // Value type adaptation, also covering java.util.Optional + if (!parameter.getParameterType().isInstance(attribute)) { + attribute = binder.convertIfNecessary(binder.getTarget(), parameter.getParameterType(), parameter); + } + bindingResult = binder.getBindingResult(); + } + + // Add resolved attribute and BindingResult at the end of the model + Map bindingResultModel = bindingResult.getModel(); + mavContainer.removeAttributes(bindingResultModel); + mavContainer.addAllAttributes(bindingResultModel); + + return attribute; + } + + /** + * Extension point to create the model attribute if not found in the model, + * with subsequent parameter binding through bean properties (unless suppressed). + *

The default implementation typically uses the unique public no-arg constructor + * if available but also handles a "primary constructor" approach for data classes: + * It understands the JavaBeans {@link ConstructorProperties} annotation as well as + * runtime-retained parameter names in the bytecode, associating request parameters + * with constructor arguments by name. If no such constructor is found, the default + * constructor will be used (even if not public), assuming subsequent bean property + * bindings through setter methods. + * @param attributeName the name of the attribute (never {@code null}) + * @param parameter the method parameter declaration + * @param binderFactory for creating WebDataBinder instance + * @param webRequest the current request + * @return the created model attribute (never {@code null}) + * @throws BindException in case of constructor argument binding failure + * @throws Exception in case of constructor invocation failure + * @see #constructAttribute(Constructor, String, MethodParameter, WebDataBinderFactory, NativeWebRequest) + * @see BeanUtils#findPrimaryConstructor(Class) + */ + protected Object createAttribute(String attributeName, MethodParameter parameter, + WebDataBinderFactory binderFactory, NativeWebRequest webRequest) throws Exception { + + MethodParameter nestedParameter = parameter.nestedIfOptional(); + Class clazz = nestedParameter.getNestedParameterType(); + + Constructor ctor = BeanUtils.findPrimaryConstructor(clazz); + if (ctor == null) { + Constructor[] ctors = clazz.getConstructors(); + if (ctors.length == 1) { + ctor = ctors[0]; + } + else { + try { + ctor = clazz.getDeclaredConstructor(); + } + catch (NoSuchMethodException ex) { + throw new IllegalStateException("No primary or default constructor found for " + clazz, ex); + } + } + } + + Object attribute = constructAttribute(ctor, attributeName, parameter, binderFactory, webRequest); + if (parameter != nestedParameter) { + attribute = Optional.of(attribute); + } + return attribute; + } + + /** + * Construct a new attribute instance with the given constructor. + *

Called from + * {@link #createAttribute(String, MethodParameter, WebDataBinderFactory, NativeWebRequest)} + * after constructor resolution. + * @param ctor the constructor to use + * @param attributeName the name of the attribute (never {@code null}) + * @param binderFactory for creating WebDataBinder instance + * @param webRequest the current request + * @return the created model attribute (never {@code null}) + * @throws BindException in case of constructor argument binding failure + * @throws Exception in case of constructor invocation failure + * @since 5.1 + */ + @SuppressWarnings("deprecation") + protected Object constructAttribute(Constructor ctor, String attributeName, MethodParameter parameter, + WebDataBinderFactory binderFactory, NativeWebRequest webRequest) throws Exception { + + Object constructed = constructAttribute(ctor, attributeName, binderFactory, webRequest); + if (constructed != null) { + return constructed; + } + + if (ctor.getParameterCount() == 0) { + // A single default constructor -> clearly a standard JavaBeans arrangement. + return BeanUtils.instantiateClass(ctor); + } + + // A single data class constructor -> resolve constructor arguments from request parameters. + ConstructorProperties cp = ctor.getAnnotation(ConstructorProperties.class); + String[] paramNames = (cp != null ? cp.value() : parameterNameDiscoverer.getParameterNames(ctor)); + Assert.state(paramNames != null, () -> "Cannot resolve parameter names for constructor " + ctor); + Class[] paramTypes = ctor.getParameterTypes(); + Assert.state(paramNames.length == paramTypes.length, + () -> "Invalid number of parameter names: " + paramNames.length + " for constructor " + ctor); + + Object[] args = new Object[paramTypes.length]; + WebDataBinder binder = binderFactory.createBinder(webRequest, null, attributeName); + String fieldDefaultPrefix = binder.getFieldDefaultPrefix(); + String fieldMarkerPrefix = binder.getFieldMarkerPrefix(); + boolean bindingFailure = false; + Set failedParams = new HashSet<>(4); + + for (int i = 0; i < paramNames.length; i++) { + String paramName = paramNames[i]; + Class paramType = paramTypes[i]; + Object value = webRequest.getParameterValues(paramName); + if (value == null) { + if (fieldDefaultPrefix != null) { + value = webRequest.getParameter(fieldDefaultPrefix + paramName); + } + if (value == null && fieldMarkerPrefix != null) { + if (webRequest.getParameter(fieldMarkerPrefix + paramName) != null) { + value = binder.getEmptyValue(paramType); + } + } + } + try { + MethodParameter methodParam = new FieldAwareConstructorParameter(ctor, i, paramName); + if (value == null && methodParam.isOptional()) { + args[i] = (methodParam.getParameterType() == Optional.class ? Optional.empty() : null); + } + else { + args[i] = binder.convertIfNecessary(value, paramType, methodParam); + } + } + catch (TypeMismatchException ex) { + ex.initPropertyName(paramName); + args[i] = value; + failedParams.add(paramName); + binder.getBindingResult().recordFieldValue(paramName, paramType, value); + binder.getBindingErrorProcessor().processPropertyAccessException(ex, binder.getBindingResult()); + bindingFailure = true; + } + } + + if (bindingFailure) { + BindingResult result = binder.getBindingResult(); + for (int i = 0; i < paramNames.length; i++) { + String paramName = paramNames[i]; + if (!failedParams.contains(paramName)) { + Object value = args[i]; + result.recordFieldValue(paramName, paramTypes[i], value); + validateValueIfApplicable(binder, parameter, ctor.getDeclaringClass(), paramName, value); + } + } + throw new BindException(result); + } + + return BeanUtils.instantiateClass(ctor, args); + } + + /** + * Construct a new attribute instance with the given constructor. + * @since 5.0 + * @deprecated as of 5.1, in favor of + * {@link #constructAttribute(Constructor, String, MethodParameter, WebDataBinderFactory, NativeWebRequest)} + */ + @Deprecated + @Nullable + protected Object constructAttribute(Constructor ctor, String attributeName, + WebDataBinderFactory binderFactory, NativeWebRequest webRequest) throws Exception { + + return null; + } + + /** + * Extension point to bind the request to the target object. + * @param binder the data binder instance to use for the binding + * @param request the current request + */ + protected void bindRequestParameters(WebDataBinder binder, NativeWebRequest request) { + ((WebRequestDataBinder) binder).bind(request); + } + + /** + * Validate the model attribute if applicable. + *

The default implementation checks for {@code @javax.validation.Valid}, + * Spring's {@link org.springframework.validation.annotation.Validated}, + * and custom annotations whose name starts with "Valid". + * @param binder the DataBinder to be used + * @param parameter the method parameter declaration + * @see WebDataBinder#validate(Object...) + * @see SmartValidator#validate(Object, Errors, Object...) + */ + protected void validateIfApplicable(WebDataBinder binder, MethodParameter parameter) { + for (Annotation ann : parameter.getParameterAnnotations()) { + Object[] validationHints = determineValidationHints(ann); + if (validationHints != null) { + binder.validate(validationHints); + break; + } + } + } + + /** + * Validate the specified candidate value if applicable. + *

The default implementation checks for {@code @javax.validation.Valid}, + * Spring's {@link org.springframework.validation.annotation.Validated}, + * and custom annotations whose name starts with "Valid". + * @param binder the DataBinder to be used + * @param parameter the method parameter declaration + * @param targetType the target type + * @param fieldName the name of the field + * @param value the candidate value + * @since 5.1 + * @see #validateIfApplicable(WebDataBinder, MethodParameter) + * @see SmartValidator#validateValue(Class, String, Object, Errors, Object...) + */ + protected void validateValueIfApplicable(WebDataBinder binder, MethodParameter parameter, + Class targetType, String fieldName, @Nullable Object value) { + + for (Annotation ann : parameter.getParameterAnnotations()) { + Object[] validationHints = determineValidationHints(ann); + if (validationHints != null) { + for (Validator validator : binder.getValidators()) { + if (validator instanceof SmartValidator) { + try { + ((SmartValidator) validator).validateValue(targetType, fieldName, value, + binder.getBindingResult(), validationHints); + } + catch (IllegalArgumentException ex) { + // No corresponding field on the target class... + } + } + } + break; + } + } + } + + /** + * Determine any validation triggered by the given annotation. + * @param ann the annotation (potentially a validation annotation) + * @return the validation hints to apply (possibly an empty array), + * or {@code null} if this annotation does not trigger any validation + * @since 5.1 + */ + @Nullable + private Object[] determineValidationHints(Annotation ann) { + Validated validatedAnn = AnnotationUtils.getAnnotation(ann, Validated.class); + if (validatedAnn != null || ann.annotationType().getSimpleName().startsWith("Valid")) { + Object hints = (validatedAnn != null ? validatedAnn.value() : AnnotationUtils.getValue(ann)); + if (hints == null) { + return new Object[0]; + } + return (hints instanceof Object[] ? (Object[]) hints : new Object[] {hints}); + } + return null; + } + + /** + * Whether to raise a fatal bind exception on validation errors. + *

The default implementation delegates to {@link #isBindExceptionRequired(MethodParameter)}. + * @param binder the data binder used to perform data binding + * @param parameter the method parameter declaration + * @return {@code true} if the next method parameter is not of type {@link Errors} + * @see #isBindExceptionRequired(MethodParameter) + */ + protected boolean isBindExceptionRequired(WebDataBinder binder, MethodParameter parameter) { + return isBindExceptionRequired(parameter); + } + + /** + * Whether to raise a fatal bind exception on validation errors. + * @param parameter the method parameter declaration + * @return {@code true} if the next method parameter is not of type {@link Errors} + * @since 5.0 + */ + protected boolean isBindExceptionRequired(MethodParameter parameter) { + int i = parameter.getParameterIndex(); + Class[] paramTypes = parameter.getExecutable().getParameterTypes(); + boolean hasBindingResult = (paramTypes.length > (i + 1) && Errors.class.isAssignableFrom(paramTypes[i + 1])); + return !hasBindingResult; + } + + /** + * Return {@code true} if there is a method-level {@code @ModelAttribute} + * or, in default resolution mode, for any return value type that is not + * a simple type. + */ + @Override + public boolean supportsReturnType(MethodParameter returnType) { + return (returnType.hasMethodAnnotation(ModelAttribute.class) || + (this.annotationNotRequired && !BeanUtils.isSimpleProperty(returnType.getParameterType()))); + } + + /** + * Add non-null return values to the {@link ModelAndViewContainer}. + */ + @Override + public void handleReturnValue(@Nullable Object returnValue, MethodParameter returnType, + ModelAndViewContainer mavContainer, NativeWebRequest webRequest) throws Exception { + + if (returnValue != null) { + String name = ModelFactory.getNameForReturnValue(returnValue, returnType); + mavContainer.addAttribute(name, returnValue); + } + } + + + /** + * {@link MethodParameter} subclass which detects field annotations as well. + * @since 5.1 + */ + private static class FieldAwareConstructorParameter extends MethodParameter { + + private final String parameterName; + + @Nullable + private volatile Annotation[] combinedAnnotations; + + public FieldAwareConstructorParameter(Constructor constructor, int parameterIndex, String parameterName) { + super(constructor, parameterIndex); + this.parameterName = parameterName; + } + + @Override + public Annotation[] getParameterAnnotations() { + Annotation[] anns = this.combinedAnnotations; + if (anns == null) { + anns = super.getParameterAnnotations(); + try { + Field field = getDeclaringClass().getDeclaredField(this.parameterName); + Annotation[] fieldAnns = field.getAnnotations(); + if (fieldAnns.length > 0) { + List merged = new ArrayList<>(anns.length + fieldAnns.length); + merged.addAll(Arrays.asList(anns)); + for (Annotation fieldAnn : fieldAnns) { + boolean existingType = false; + for (Annotation ann : anns) { + if (ann.annotationType() == fieldAnn.annotationType()) { + existingType = true; + break; + } + } + if (!existingType) { + merged.add(fieldAnn); + } + } + anns = merged.toArray(new Annotation[0]); + } + } + catch (NoSuchFieldException | SecurityException ex) { + // ignore + } + this.combinedAnnotations = anns; + } + return anns; + } + + @Override + public String getParameterName() { + return this.parameterName; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/ModelFactory.java b/spring-web/src/main/java/org/springframework/web/method/annotation/ModelFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..5c83393b2901dd0af5e88a6441e677a160e6cb99 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/ModelFactory.java @@ -0,0 +1,323 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.BeanUtils; +import org.springframework.core.Conventions; +import org.springframework.core.GenericTypeResolver; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.ui.Model; +import org.springframework.ui.ModelMap; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.validation.BindingResult; +import org.springframework.web.HttpSessionRequiredException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.method.support.InvocableHandlerMethod; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Assist with initialization of the {@link Model} before controller method + * invocation and with updates to it after the invocation. + * + *

On initialization the model is populated with attributes temporarily stored + * in the session and through the invocation of {@code @ModelAttribute} methods. + * + *

On update model attributes are synchronized with the session and also + * {@link BindingResult} attributes are added if missing. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public final class ModelFactory { + + private static final Log logger = LogFactory.getLog(ModelFactory.class); + + private final List modelMethods = new ArrayList<>(); + + private final WebDataBinderFactory dataBinderFactory; + + private final SessionAttributesHandler sessionAttributesHandler; + + + /** + * Create a new instance with the given {@code @ModelAttribute} methods. + * @param handlerMethods the {@code @ModelAttribute} methods to invoke + * @param binderFactory for preparation of {@link BindingResult} attributes + * @param attributeHandler for access to session attributes + */ + public ModelFactory(@Nullable List handlerMethods, + WebDataBinderFactory binderFactory, SessionAttributesHandler attributeHandler) { + + if (handlerMethods != null) { + for (InvocableHandlerMethod handlerMethod : handlerMethods) { + this.modelMethods.add(new ModelMethod(handlerMethod)); + } + } + this.dataBinderFactory = binderFactory; + this.sessionAttributesHandler = attributeHandler; + } + + + /** + * Populate the model in the following order: + *

    + *
  1. Retrieve "known" session attributes listed as {@code @SessionAttributes}. + *
  2. Invoke {@code @ModelAttribute} methods + *
  3. Find {@code @ModelAttribute} method arguments also listed as + * {@code @SessionAttributes} and ensure they're present in the model raising + * an exception if necessary. + *
+ * @param request the current request + * @param container a container with the model to be initialized + * @param handlerMethod the method for which the model is initialized + * @throws Exception may arise from {@code @ModelAttribute} methods + */ + public void initModel(NativeWebRequest request, ModelAndViewContainer container, HandlerMethod handlerMethod) + throws Exception { + + Map sessionAttributes = this.sessionAttributesHandler.retrieveAttributes(request); + container.mergeAttributes(sessionAttributes); + invokeModelAttributeMethods(request, container); + + for (String name : findSessionAttributeArguments(handlerMethod)) { + if (!container.containsAttribute(name)) { + Object value = this.sessionAttributesHandler.retrieveAttribute(request, name); + if (value == null) { + throw new HttpSessionRequiredException("Expected session attribute '" + name + "'", name); + } + container.addAttribute(name, value); + } + } + } + + /** + * Invoke model attribute methods to populate the model. + * Attributes are added only if not already present in the model. + */ + private void invokeModelAttributeMethods(NativeWebRequest request, ModelAndViewContainer container) + throws Exception { + + while (!this.modelMethods.isEmpty()) { + InvocableHandlerMethod modelMethod = getNextModelMethod(container).getHandlerMethod(); + ModelAttribute ann = modelMethod.getMethodAnnotation(ModelAttribute.class); + Assert.state(ann != null, "No ModelAttribute annotation"); + if (container.containsAttribute(ann.name())) { + if (!ann.binding()) { + container.setBindingDisabled(ann.name()); + } + continue; + } + + Object returnValue = modelMethod.invokeForRequest(request, container); + if (!modelMethod.isVoid()){ + String returnValueName = getNameForReturnValue(returnValue, modelMethod.getReturnType()); + if (!ann.binding()) { + container.setBindingDisabled(returnValueName); + } + if (!container.containsAttribute(returnValueName)) { + container.addAttribute(returnValueName, returnValue); + } + } + } + } + + private ModelMethod getNextModelMethod(ModelAndViewContainer container) { + for (ModelMethod modelMethod : this.modelMethods) { + if (modelMethod.checkDependencies(container)) { + this.modelMethods.remove(modelMethod); + return modelMethod; + } + } + ModelMethod modelMethod = this.modelMethods.get(0); + this.modelMethods.remove(modelMethod); + return modelMethod; + } + + /** + * Find {@code @ModelAttribute} arguments also listed as {@code @SessionAttributes}. + */ + private List findSessionAttributeArguments(HandlerMethod handlerMethod) { + List result = new ArrayList<>(); + for (MethodParameter parameter : handlerMethod.getMethodParameters()) { + if (parameter.hasParameterAnnotation(ModelAttribute.class)) { + String name = getNameForParameter(parameter); + Class paramType = parameter.getParameterType(); + if (this.sessionAttributesHandler.isHandlerSessionAttribute(name, paramType)) { + result.add(name); + } + } + } + return result; + } + + /** + * Promote model attributes listed as {@code @SessionAttributes} to the session. + * Add {@link BindingResult} attributes where necessary. + * @param request the current request + * @param container contains the model to update + * @throws Exception if creating BindingResult attributes fails + */ + public void updateModel(NativeWebRequest request, ModelAndViewContainer container) throws Exception { + ModelMap defaultModel = container.getDefaultModel(); + if (container.getSessionStatus().isComplete()){ + this.sessionAttributesHandler.cleanupAttributes(request); + } + else { + this.sessionAttributesHandler.storeAttributes(request, defaultModel); + } + if (!container.isRequestHandled() && container.getModel() == defaultModel) { + updateBindingResult(request, defaultModel); + } + } + + /** + * Add {@link BindingResult} attributes to the model for attributes that require it. + */ + private void updateBindingResult(NativeWebRequest request, ModelMap model) throws Exception { + List keyNames = new ArrayList<>(model.keySet()); + for (String name : keyNames) { + Object value = model.get(name); + if (value != null && isBindingCandidate(name, value)) { + String bindingResultKey = BindingResult.MODEL_KEY_PREFIX + name; + if (!model.containsAttribute(bindingResultKey)) { + WebDataBinder dataBinder = this.dataBinderFactory.createBinder(request, value, name); + model.put(bindingResultKey, dataBinder.getBindingResult()); + } + } + } + } + + /** + * Whether the given attribute requires a {@link BindingResult} in the model. + */ + private boolean isBindingCandidate(String attributeName, Object value) { + if (attributeName.startsWith(BindingResult.MODEL_KEY_PREFIX)) { + return false; + } + + if (this.sessionAttributesHandler.isHandlerSessionAttribute(attributeName, value.getClass())) { + return true; + } + + return (!value.getClass().isArray() && !(value instanceof Collection) && + !(value instanceof Map) && !BeanUtils.isSimpleValueType(value.getClass())); + } + + + /** + * Derive the model attribute name for the given method parameter based on + * a {@code @ModelAttribute} parameter annotation (if present) or falling + * back on parameter type based conventions. + * @param parameter a descriptor for the method parameter + * @return the derived name + * @see Conventions#getVariableNameForParameter(MethodParameter) + */ + public static String getNameForParameter(MethodParameter parameter) { + ModelAttribute ann = parameter.getParameterAnnotation(ModelAttribute.class); + String name = (ann != null ? ann.value() : null); + return (StringUtils.hasText(name) ? name : Conventions.getVariableNameForParameter(parameter)); + } + + /** + * Derive the model attribute name for the given return value. Results will be + * based on: + *
    + *
  1. the method {@code ModelAttribute} annotation value + *
  2. the declared return type if it is more specific than {@code Object} + *
  3. the actual return value type + *
+ * @param returnValue the value returned from a method invocation + * @param returnType a descriptor for the return type of the method + * @return the derived name (never {@code null} or empty String) + */ + public static String getNameForReturnValue(@Nullable Object returnValue, MethodParameter returnType) { + ModelAttribute ann = returnType.getMethodAnnotation(ModelAttribute.class); + if (ann != null && StringUtils.hasText(ann.value())) { + return ann.value(); + } + else { + Method method = returnType.getMethod(); + Assert.state(method != null, "No handler method"); + Class containingClass = returnType.getContainingClass(); + Class resolvedType = GenericTypeResolver.resolveReturnType(method, containingClass); + return Conventions.getVariableNameForReturnType(method, resolvedType, returnValue); + } + } + + + private static class ModelMethod { + + private final InvocableHandlerMethod handlerMethod; + + private final Set dependencies = new HashSet<>(); + + public ModelMethod(InvocableHandlerMethod handlerMethod) { + this.handlerMethod = handlerMethod; + for (MethodParameter parameter : handlerMethod.getMethodParameters()) { + if (parameter.hasParameterAnnotation(ModelAttribute.class)) { + this.dependencies.add(getNameForParameter(parameter)); + } + } + } + + public InvocableHandlerMethod getHandlerMethod() { + return this.handlerMethod; + } + + public boolean checkDependencies(ModelAndViewContainer mavContainer) { + for (String name : this.dependencies) { + if (!mavContainer.containsAttribute(name)) { + return false; + } + } + return true; + } + + public List getUnresolvedDependencies(ModelAndViewContainer mavContainer) { + List result = new ArrayList<>(this.dependencies.size()); + for (String name : this.dependencies) { + if (!mavContainer.containsAttribute(name)) { + result.add(name); + } + } + return result; + } + + @Override + public String toString() { + return this.handlerMethod.getMethod().toGenericString(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/ModelMethodProcessor.java b/spring-web/src/main/java/org/springframework/web/method/annotation/ModelMethodProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..52a2756ac1859bb82defaa3e6e1f5c9651d03652 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/ModelMethodProcessor.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.ui.Model; +import org.springframework.util.Assert; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.HandlerMethodReturnValueHandler; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Resolves {@link Model} arguments and handles {@link Model} return values. + * + *

A {@link Model} return type has a set purpose. Therefore this handler + * should be configured ahead of handlers that support any return value type + * annotated with {@code @ModelAttribute} or {@code @ResponseBody} to ensure + * they don't take over. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class ModelMethodProcessor implements HandlerMethodArgumentResolver, HandlerMethodReturnValueHandler { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return Model.class.isAssignableFrom(parameter.getParameterType()); + } + + @Override + @Nullable + public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + Assert.state(mavContainer != null, "ModelAndViewContainer is required for model exposure"); + return mavContainer.getModel(); + } + + @Override + public boolean supportsReturnType(MethodParameter returnType) { + return Model.class.isAssignableFrom(returnType.getParameterType()); + } + + @Override + public void handleReturnValue(@Nullable Object returnValue, MethodParameter returnType, + ModelAndViewContainer mavContainer, NativeWebRequest webRequest) throws Exception { + + if (returnValue == null) { + return; + } + else if (returnValue instanceof Model) { + mavContainer.addAllAttributes(((Model) returnValue).asMap()); + } + else { + // should not happen + throw new UnsupportedOperationException("Unexpected return type: " + + returnType.getParameterType().getName() + " in method: " + returnType.getMethod()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/RequestHeaderMapMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestHeaderMapMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..558d2ea06eb970f70a142b0535b05793dceffe4b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestHeaderMapMethodArgumentResolver.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.core.MethodParameter; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Resolves {@link Map} method arguments annotated with {@code @RequestHeader}. + * For individual header values annotated with {@code @RequestHeader} see + * {@link RequestHeaderMethodArgumentResolver} instead. + * + *

The created {@link Map} contains all request header name/value pairs. + * The method parameter type may be a {@link MultiValueMap} to receive all + * values for a header, not only the first one. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class RequestHeaderMapMethodArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return (parameter.hasParameterAnnotation(RequestHeader.class) && + Map.class.isAssignableFrom(parameter.getParameterType())); + } + + @Override + public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + Class paramType = parameter.getParameterType(); + if (MultiValueMap.class.isAssignableFrom(paramType)) { + MultiValueMap result; + if (HttpHeaders.class.isAssignableFrom(paramType)) { + result = new HttpHeaders(); + } + else { + result = new LinkedMultiValueMap<>(); + } + for (Iterator iterator = webRequest.getHeaderNames(); iterator.hasNext();) { + String headerName = iterator.next(); + String[] headerValues = webRequest.getHeaderValues(headerName); + if (headerValues != null) { + for (String headerValue : headerValues) { + result.add(headerName, headerValue); + } + } + } + return result; + } + else { + Map result = new LinkedHashMap<>(); + for (Iterator iterator = webRequest.getHeaderNames(); iterator.hasNext();) { + String headerName = iterator.next(); + String headerValue = webRequest.getHeader(headerName); + if (headerValue != null) { + result.put(headerName, headerValue); + } + } + return result; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/RequestHeaderMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestHeaderMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..85b856bf43c2079bc782f7257cadf147c2a4327a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestHeaderMethodArgumentResolver.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Map; + +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.bind.MissingRequestHeaderException; +import org.springframework.web.bind.ServletRequestBindingException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Resolves method arguments annotated with {@code @RequestHeader} except for + * {@link Map} arguments. See {@link RequestHeaderMapMethodArgumentResolver} for + * details on {@link Map} arguments annotated with {@code @RequestHeader}. + * + *

An {@code @RequestHeader} is a named value resolved from a request header. + * It has a required flag and a default value to fall back on when the request + * header does not exist. + * + *

A {@link WebDataBinder} is invoked to apply type conversion to resolved + * request header values that don't yet match the method parameter type. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class RequestHeaderMethodArgumentResolver extends AbstractNamedValueMethodArgumentResolver { + + /** + * Create a new {@link RequestHeaderMethodArgumentResolver} instance. + * @param beanFactory a bean factory to use for resolving ${...} + * placeholder and #{...} SpEL expressions in default values; + * or {@code null} if default values are not expected to have expressions + */ + public RequestHeaderMethodArgumentResolver(@Nullable ConfigurableBeanFactory beanFactory) { + super(beanFactory); + } + + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return (parameter.hasParameterAnnotation(RequestHeader.class) && + !Map.class.isAssignableFrom(parameter.nestedIfOptional().getNestedParameterType())); + } + + @Override + protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) { + RequestHeader ann = parameter.getParameterAnnotation(RequestHeader.class); + Assert.state(ann != null, "No RequestHeader annotation"); + return new RequestHeaderNamedValueInfo(ann); + } + + @Override + @Nullable + protected Object resolveName(String name, MethodParameter parameter, NativeWebRequest request) throws Exception { + String[] headerValues = request.getHeaderValues(name); + if (headerValues != null) { + return (headerValues.length == 1 ? headerValues[0] : headerValues); + } + else { + return null; + } + } + + @Override + protected void handleMissingValue(String name, MethodParameter parameter) throws ServletRequestBindingException { + throw new MissingRequestHeaderException(name, parameter); + } + + + private static final class RequestHeaderNamedValueInfo extends NamedValueInfo { + + private RequestHeaderNamedValueInfo(RequestHeader annotation) { + super(annotation.name(), annotation.required(), annotation.defaultValue()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMapMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMapMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..e4934100b869092adb6f798f186dc8be28083c3d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMapMethodArgumentResolver.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.ModelAndViewContainer; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartRequest; +import org.springframework.web.multipart.support.MultipartResolutionDelegate; + +/** + * Resolves {@link Map} method arguments annotated with an @{@link RequestParam} + * where the annotation does not specify a request parameter name. + * + *

The created {@link Map} contains all request parameter name/value pairs, + * or all multipart files for a given parameter name if specifically declared + * with {@link MultipartFile} as the value type. If the method parameter type is + * {@link MultiValueMap} instead, the created map contains all request parameters + * and all their values for cases where request parameters have multiple values + * (or multiple multipart files of the same name). + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + * @see RequestParamMethodArgumentResolver + * @see HttpServletRequest#getParameterMap() + * @see MultipartRequest#getMultiFileMap() + * @see MultipartRequest#getFileMap() + */ +public class RequestParamMapMethodArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + RequestParam requestParam = parameter.getParameterAnnotation(RequestParam.class); + return (requestParam != null && Map.class.isAssignableFrom(parameter.getParameterType()) && + !StringUtils.hasText(requestParam.name())); + } + + @Override + public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + + if (MultiValueMap.class.isAssignableFrom(parameter.getParameterType())) { + // MultiValueMap + Class valueType = resolvableType.as(MultiValueMap.class).getGeneric(1).resolve(); + if (valueType == MultipartFile.class) { + MultipartRequest multipartRequest = MultipartResolutionDelegate.resolveMultipartRequest(webRequest); + return (multipartRequest != null ? multipartRequest.getMultiFileMap() : new LinkedMultiValueMap<>(0)); + } + else if (valueType == Part.class) { + HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); + if (servletRequest != null && MultipartResolutionDelegate.isMultipartRequest(servletRequest)) { + Collection parts = servletRequest.getParts(); + LinkedMultiValueMap result = new LinkedMultiValueMap<>(parts.size()); + for (Part part : parts) { + result.add(part.getName(), part); + } + return result; + } + return new LinkedMultiValueMap<>(0); + } + else { + Map parameterMap = webRequest.getParameterMap(); + MultiValueMap result = new LinkedMultiValueMap<>(parameterMap.size()); + parameterMap.forEach((key, values) -> { + for (String value : values) { + result.add(key, value); + } + }); + return result; + } + } + + else { + // Regular Map + Class valueType = resolvableType.asMap().getGeneric(1).resolve(); + if (valueType == MultipartFile.class) { + MultipartRequest multipartRequest = MultipartResolutionDelegate.resolveMultipartRequest(webRequest); + return (multipartRequest != null ? multipartRequest.getFileMap() : new LinkedHashMap<>(0)); + } + else if (valueType == Part.class) { + HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); + if (servletRequest != null && MultipartResolutionDelegate.isMultipartRequest(servletRequest)) { + Collection parts = servletRequest.getParts(); + LinkedHashMap result = new LinkedHashMap<>(parts.size()); + for (Part part : parts) { + if (!result.containsKey(part.getName())) { + result.put(part.getName(), part); + } + } + return result; + } + return new LinkedHashMap<>(0); + } + else { + Map parameterMap = webRequest.getParameterMap(); + Map result = new LinkedHashMap<>(parameterMap.size()); + parameterMap.forEach((key, values) -> { + if (values.length > 0) { + result.put(key, values[0]); + } + }); + return result; + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..ae80c2ae2630f59b83ed09d4f320e3e1e8486d14 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolver.java @@ -0,0 +1,269 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.beans.PropertyEditor; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.core.MethodParameter; +import org.springframework.core.convert.ConversionService; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.Converter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.bind.MissingServletRequestParameterException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.bind.annotation.ValueConstants; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.UriComponentsContributor; +import org.springframework.web.multipart.MultipartException; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartRequest; +import org.springframework.web.multipart.MultipartResolver; +import org.springframework.web.multipart.support.MissingServletRequestPartException; +import org.springframework.web.multipart.support.MultipartResolutionDelegate; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Resolves method arguments annotated with @{@link RequestParam}, arguments of + * type {@link MultipartFile} in conjunction with Spring's {@link MultipartResolver} + * abstraction, and arguments of type {@code javax.servlet.http.Part} in conjunction + * with Servlet 3.0 multipart requests. This resolver can also be created in default + * resolution mode in which simple types (int, long, etc.) not annotated with + * {@link RequestParam @RequestParam} are also treated as request parameters with + * the parameter name derived from the argument name. + * + *

If the method parameter type is {@link Map}, the name specified in the + * annotation is used to resolve the request parameter String value. The value is + * then converted to a {@link Map} via type conversion assuming a suitable + * {@link Converter} or {@link PropertyEditor} has been registered. + * Or if a request parameter name is not specified the + * {@link RequestParamMapMethodArgumentResolver} is used instead to provide + * access to all request parameters in the form of a map. + * + *

A {@link WebDataBinder} is invoked to apply type conversion to resolved request + * header values that don't yet match the method parameter type. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 3.1 + * @see RequestParamMapMethodArgumentResolver + */ +public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethodArgumentResolver + implements UriComponentsContributor { + + private static final TypeDescriptor STRING_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(String.class); + + private final boolean useDefaultResolution; + + + /** + * Create a new {@link RequestParamMethodArgumentResolver} instance. + * @param useDefaultResolution in default resolution mode a method argument + * that is a simple type, as defined in {@link BeanUtils#isSimpleProperty}, + * is treated as a request parameter even if it isn't annotated, the + * request parameter name is derived from the method parameter name. + */ + public RequestParamMethodArgumentResolver(boolean useDefaultResolution) { + this.useDefaultResolution = useDefaultResolution; + } + + /** + * Create a new {@link RequestParamMethodArgumentResolver} instance. + * @param beanFactory a bean factory used for resolving ${...} placeholder + * and #{...} SpEL expressions in default values, or {@code null} if default + * values are not expected to contain expressions + * @param useDefaultResolution in default resolution mode a method argument + * that is a simple type, as defined in {@link BeanUtils#isSimpleProperty}, + * is treated as a request parameter even if it isn't annotated, the + * request parameter name is derived from the method parameter name. + */ + public RequestParamMethodArgumentResolver(@Nullable ConfigurableBeanFactory beanFactory, + boolean useDefaultResolution) { + + super(beanFactory); + this.useDefaultResolution = useDefaultResolution; + } + + + /** + * Supports the following: + *

    + *
  • @RequestParam-annotated method arguments. + * This excludes {@link Map} params where the annotation does not specify a name. + * See {@link RequestParamMapMethodArgumentResolver} instead for such params. + *
  • Arguments of type {@link MultipartFile} unless annotated with @{@link RequestPart}. + *
  • Arguments of type {@code Part} unless annotated with @{@link RequestPart}. + *
  • In default resolution mode, simple type arguments even if not with @{@link RequestParam}. + *
+ */ + @Override + public boolean supportsParameter(MethodParameter parameter) { + if (parameter.hasParameterAnnotation(RequestParam.class)) { + if (Map.class.isAssignableFrom(parameter.nestedIfOptional().getNestedParameterType())) { + RequestParam requestParam = parameter.getParameterAnnotation(RequestParam.class); + return (requestParam != null && StringUtils.hasText(requestParam.name())); + } + else { + return true; + } + } + else { + if (parameter.hasParameterAnnotation(RequestPart.class)) { + return false; + } + parameter = parameter.nestedIfOptional(); + if (MultipartResolutionDelegate.isMultipartArgument(parameter)) { + return true; + } + else if (this.useDefaultResolution) { + return BeanUtils.isSimpleProperty(parameter.getNestedParameterType()); + } + else { + return false; + } + } + } + + @Override + protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) { + RequestParam ann = parameter.getParameterAnnotation(RequestParam.class); + return (ann != null ? new RequestParamNamedValueInfo(ann) : new RequestParamNamedValueInfo()); + } + + @Override + @Nullable + protected Object resolveName(String name, MethodParameter parameter, NativeWebRequest request) throws Exception { + HttpServletRequest servletRequest = request.getNativeRequest(HttpServletRequest.class); + + if (servletRequest != null) { + Object mpArg = MultipartResolutionDelegate.resolveMultipartArgument(name, parameter, servletRequest); + if (mpArg != MultipartResolutionDelegate.UNRESOLVABLE) { + return mpArg; + } + } + + Object arg = null; + MultipartRequest multipartRequest = request.getNativeRequest(MultipartRequest.class); + if (multipartRequest != null) { + List files = multipartRequest.getFiles(name); + if (!files.isEmpty()) { + arg = (files.size() == 1 ? files.get(0) : files); + } + } + if (arg == null) { + String[] paramValues = request.getParameterValues(name); + if (paramValues != null) { + arg = (paramValues.length == 1 ? paramValues[0] : paramValues); + } + } + return arg; + } + + @Override + protected void handleMissingValue(String name, MethodParameter parameter, NativeWebRequest request) + throws Exception { + + HttpServletRequest servletRequest = request.getNativeRequest(HttpServletRequest.class); + if (MultipartResolutionDelegate.isMultipartArgument(parameter)) { + if (servletRequest == null || !MultipartResolutionDelegate.isMultipartRequest(servletRequest)) { + throw new MultipartException("Current request is not a multipart request"); + } + else { + throw new MissingServletRequestPartException(name); + } + } + else { + throw new MissingServletRequestParameterException(name, + parameter.getNestedParameterType().getSimpleName()); + } + } + + @Override + public void contributeMethodArgument(MethodParameter parameter, @Nullable Object value, + UriComponentsBuilder builder, Map uriVariables, ConversionService conversionService) { + + Class paramType = parameter.getNestedParameterType(); + if (Map.class.isAssignableFrom(paramType) || MultipartFile.class == paramType || Part.class == paramType) { + return; + } + + RequestParam requestParam = parameter.getParameterAnnotation(RequestParam.class); + String name = (requestParam != null && StringUtils.hasLength(requestParam.name()) ? + requestParam.name() : parameter.getParameterName()); + Assert.state(name != null, "Unresolvable parameter name"); + + if (value == null) { + if (requestParam != null && + (!requestParam.required() || !requestParam.defaultValue().equals(ValueConstants.DEFAULT_NONE))) { + return; + } + builder.queryParam(name); + } + else if (value instanceof Collection) { + for (Object element : (Collection) value) { + element = formatUriValue(conversionService, TypeDescriptor.nested(parameter, 1), element); + builder.queryParam(name, element); + } + } + else { + builder.queryParam(name, formatUriValue(conversionService, new TypeDescriptor(parameter), value)); + } + } + + @Nullable + protected String formatUriValue( + @Nullable ConversionService cs, @Nullable TypeDescriptor sourceType, @Nullable Object value) { + + if (value == null) { + return null; + } + else if (value instanceof String) { + return (String) value; + } + else if (cs != null) { + return (String) cs.convert(value, sourceType, STRING_TYPE_DESCRIPTOR); + } + else { + return value.toString(); + } + } + + + private static class RequestParamNamedValueInfo extends NamedValueInfo { + + public RequestParamNamedValueInfo() { + super("", false, ValueConstants.DEFAULT_NONE); + } + + public RequestParamNamedValueInfo(RequestParam annotation) { + super(annotation.name(), annotation.required(), annotation.defaultValue()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/SessionAttributesHandler.java b/spring-web/src/main/java/org/springframework/web/method/annotation/SessionAttributesHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..ca0f8edf9ada794dd2729f01ef2348378cf60eca --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/SessionAttributesHandler.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.bind.support.SessionAttributeStore; +import org.springframework.web.bind.support.SessionStatus; +import org.springframework.web.context.request.WebRequest; + +/** + * Manages controller-specific session attributes declared via + * {@link SessionAttributes @SessionAttributes}. Actual storage is + * delegated to a {@link SessionAttributeStore} instance. + * + *

When a controller annotated with {@code @SessionAttributes} adds + * attributes to its model, those attributes are checked against names and + * types specified via {@code @SessionAttributes}. Matching model attributes + * are saved in the HTTP session and remain there until the controller calls + * {@link SessionStatus#setComplete()}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public class SessionAttributesHandler { + + private final Set attributeNames = new HashSet<>(); + + private final Set> attributeTypes = new HashSet<>(); + + private final Set knownAttributeNames = Collections.newSetFromMap(new ConcurrentHashMap<>(4)); + + private final SessionAttributeStore sessionAttributeStore; + + + /** + * Create a new session attributes handler. Session attribute names and types + * are extracted from the {@code @SessionAttributes} annotation, if present, + * on the given type. + * @param handlerType the controller type + * @param sessionAttributeStore used for session access + */ + public SessionAttributesHandler(Class handlerType, SessionAttributeStore sessionAttributeStore) { + Assert.notNull(sessionAttributeStore, "SessionAttributeStore may not be null"); + this.sessionAttributeStore = sessionAttributeStore; + + SessionAttributes ann = AnnotatedElementUtils.findMergedAnnotation(handlerType, SessionAttributes.class); + if (ann != null) { + Collections.addAll(this.attributeNames, ann.names()); + Collections.addAll(this.attributeTypes, ann.types()); + } + this.knownAttributeNames.addAll(this.attributeNames); + } + + + /** + * Whether the controller represented by this instance has declared any + * session attributes through an {@link SessionAttributes} annotation. + */ + public boolean hasSessionAttributes() { + return (!this.attributeNames.isEmpty() || !this.attributeTypes.isEmpty()); + } + + /** + * Whether the attribute name or type match the names and types specified + * via {@code @SessionAttributes} on the underlying controller. + *

Attributes successfully resolved through this method are "remembered" + * and subsequently used in {@link #retrieveAttributes(WebRequest)} and + * {@link #cleanupAttributes(WebRequest)}. + * @param attributeName the attribute name to check + * @param attributeType the type for the attribute + */ + public boolean isHandlerSessionAttribute(String attributeName, Class attributeType) { + Assert.notNull(attributeName, "Attribute name must not be null"); + if (this.attributeNames.contains(attributeName) || this.attributeTypes.contains(attributeType)) { + this.knownAttributeNames.add(attributeName); + return true; + } + else { + return false; + } + } + + /** + * Store a subset of the given attributes in the session. Attributes not + * declared as session attributes via {@code @SessionAttributes} are ignored. + * @param request the current request + * @param attributes candidate attributes for session storage + */ + public void storeAttributes(WebRequest request, Map attributes) { + attributes.forEach((name, value) -> { + if (value != null && isHandlerSessionAttribute(name, value.getClass())) { + this.sessionAttributeStore.storeAttribute(request, name, value); + } + }); + } + + /** + * Retrieve "known" attributes from the session, i.e. attributes listed + * by name in {@code @SessionAttributes} or attributes previously stored + * in the model that matched by type. + * @param request the current request + * @return a map with handler session attributes, possibly empty + */ + public Map retrieveAttributes(WebRequest request) { + Map attributes = new HashMap<>(); + for (String name : this.knownAttributeNames) { + Object value = this.sessionAttributeStore.retrieveAttribute(request, name); + if (value != null) { + attributes.put(name, value); + } + } + return attributes; + } + + /** + * Remove "known" attributes from the session, i.e. attributes listed + * by name in {@code @SessionAttributes} or attributes previously stored + * in the model that matched by type. + * @param request the current request + */ + public void cleanupAttributes(WebRequest request) { + for (String attributeName : this.knownAttributeNames) { + this.sessionAttributeStore.cleanupAttribute(request, attributeName); + } + } + + /** + * A pass-through call to the underlying {@link SessionAttributeStore}. + * @param request the current request + * @param attributeName the name of the attribute of interest + * @return the attribute value, or {@code null} if none + */ + @Nullable + Object retrieveAttribute(WebRequest request, String attributeName) { + return this.sessionAttributeStore.retrieveAttribute(request, attributeName); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/SessionStatusMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/SessionStatusMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..5676fa24be1b038962cafd6c98783d18382ddb14 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/SessionStatusMethodArgumentResolver.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.bind.support.SessionStatus; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.ModelAndViewContainer; + +/** + * Resolves a {@link SessionStatus} argument by obtaining it from + * the {@link ModelAndViewContainer}. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class SessionStatusMethodArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return SessionStatus.class == parameter.getParameterType(); + } + + @Override + public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + Assert.state(mavContainer != null, "ModelAndViewContainer is required for session status exposure"); + return mavContainer.getSessionStatus(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/package-info.java b/spring-web/src/main/java/org/springframework/web/method/annotation/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..0a15dd473fb2baafe7770fb5784b42622dae612c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/package-info.java @@ -0,0 +1,9 @@ +/** + * Support classes for annotation-based handler method processing. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.method.annotation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/method/package-info.java b/spring-web/src/main/java/org/springframework/web/method/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..772e1ec8a093cf14ff51f1bdd8d1548827fba1a0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/package-info.java @@ -0,0 +1,10 @@ +/** + * Common infrastructure for handler method processing, as used by + * Spring MVC's {@code org.springframework.web.servlet.mvc.method} package. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.method; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/method/support/AsyncHandlerMethodReturnValueHandler.java b/spring-web/src/main/java/org/springframework/web/method/support/AsyncHandlerMethodReturnValueHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..c43de97316c8fc75f542da39fdeaef2b84688965 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/AsyncHandlerMethodReturnValueHandler.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; + +/** + * A return value handler that supports async types. Such return value types + * need to be handled with priority so the async value can be "unwrapped". + * + *

Note: implementing this contract is not required but it + * should be implemented when the handler needs to be prioritized ahead of others. + * For example custom (async) handlers, by default ordered after built-in + * handlers, should take precedence over {@code @ResponseBody} or + * {@code @ModelAttribute} handling, which should occur once the async value is + * ready. By contrast, built-in (async) handlers are already ordered ahead of + * sync handlers. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public interface AsyncHandlerMethodReturnValueHandler extends HandlerMethodReturnValueHandler { + + /** + * Whether the given return value represents asynchronous computation. + * @param returnValue the value returned from the handler method + * @param returnType the return type + * @return {@code true} if the return value type represents an async value + */ + boolean isAsyncReturnValue(@Nullable Object returnValue, MethodParameter returnType); + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/CompositeUriComponentsContributor.java b/spring-web/src/main/java/org/springframework/web/method/support/CompositeUriComponentsContributor.java new file mode 100644 index 0000000000000000000000000000000000000000..e47ef428b3f973c6f4ecd39edfeeb9a889d59e00 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/CompositeUriComponentsContributor.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import org.springframework.core.MethodParameter; +import org.springframework.core.convert.ConversionService; +import org.springframework.format.support.DefaultFormattingConversionService; +import org.springframework.lang.Nullable; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * A {@link UriComponentsContributor} containing a list of other contributors + * to delegate and also encapsulating a specific {@link ConversionService} to + * use for formatting method argument values to Strings. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class CompositeUriComponentsContributor implements UriComponentsContributor { + + private final List contributors = new LinkedList<>(); + + private final ConversionService conversionService; + + + /** + * Create an instance from a collection of {@link UriComponentsContributor UriComponentsContributors} or + * {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. Since both of these tend to be implemented + * by the same class, the most convenient option is to obtain the configured + * {@code HandlerMethodArgumentResolvers} in {@code RequestMappingHandlerAdapter} + * and provide that to this constructor. + * @param contributors a collection of {@link UriComponentsContributor} + * or {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. + */ + public CompositeUriComponentsContributor(UriComponentsContributor... contributors) { + Collections.addAll(this.contributors, contributors); + this.conversionService = new DefaultFormattingConversionService(); + } + + /** + * Create an instance from a collection of {@link UriComponentsContributor UriComponentsContributors} or + * {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. Since both of these tend to be implemented + * by the same class, the most convenient option is to obtain the configured + * {@code HandlerMethodArgumentResolvers} in {@code RequestMappingHandlerAdapter} + * and provide that to this constructor. + * @param contributors a collection of {@link UriComponentsContributor} + * or {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. + */ + public CompositeUriComponentsContributor(Collection contributors) { + this(contributors, null); + } + + /** + * Create an instance from a collection of {@link UriComponentsContributor UriComponentsContributors} or + * {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. Since both of these tend to be implemented + * by the same class, the most convenient option is to obtain the configured + * {@code HandlerMethodArgumentResolvers} in the {@code RequestMappingHandlerAdapter} + * and provide that to this constructor. + *

If the {@link ConversionService} argument is {@code null}, + * {@link org.springframework.format.support.DefaultFormattingConversionService} + * will be used by default. + * @param contributors a collection of {@link UriComponentsContributor} + * or {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. + * @param cs a ConversionService to use when method argument values + * need to be formatted as Strings before being added to the URI + */ + public CompositeUriComponentsContributor(@Nullable Collection contributors, @Nullable ConversionService cs) { + if (contributors != null) { + this.contributors.addAll(contributors); + } + this.conversionService = (cs != null ? cs : new DefaultFormattingConversionService()); + } + + + public boolean hasContributors() { + return this.contributors.isEmpty(); + } + + @Override + public boolean supportsParameter(MethodParameter parameter) { + for (Object contributor : this.contributors) { + if (contributor instanceof UriComponentsContributor) { + if (((UriComponentsContributor) contributor).supportsParameter(parameter)) { + return true; + } + } + else if (contributor instanceof HandlerMethodArgumentResolver) { + if (((HandlerMethodArgumentResolver) contributor).supportsParameter(parameter)) { + return false; + } + } + } + return false; + } + + @Override + public void contributeMethodArgument(MethodParameter parameter, Object value, + UriComponentsBuilder builder, Map uriVariables, ConversionService conversionService) { + + for (Object contributor : this.contributors) { + if (contributor instanceof UriComponentsContributor) { + UriComponentsContributor ucc = (UriComponentsContributor) contributor; + if (ucc.supportsParameter(parameter)) { + ucc.contributeMethodArgument(parameter, value, builder, uriVariables, conversionService); + break; + } + } + else if (contributor instanceof HandlerMethodArgumentResolver) { + if (((HandlerMethodArgumentResolver) contributor).supportsParameter(parameter)) { + break; + } + } + } + } + + /** + * An overloaded method that uses the ConversionService created at construction. + */ + public void contributeMethodArgument(MethodParameter parameter, Object value, UriComponentsBuilder builder, + Map uriVariables) { + + this.contributeMethodArgument(parameter, value, builder, uriVariables, this.conversionService); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..72b1972e1c63e98f6b7eebba1b88e4ef320788a2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodArgumentResolver.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Strategy interface for resolving method parameters into argument values in + * the context of a given request. + * + * @author Arjen Poutsma + * @since 3.1 + * @see HandlerMethodReturnValueHandler + */ +public interface HandlerMethodArgumentResolver { + + /** + * Whether the given {@linkplain MethodParameter method parameter} is + * supported by this resolver. + * @param parameter the method parameter to check + * @return {@code true} if this resolver supports the supplied parameter; + * {@code false} otherwise + */ + boolean supportsParameter(MethodParameter parameter); + + /** + * Resolves a method parameter into an argument value from a given request. + * A {@link ModelAndViewContainer} provides access to the model for the + * request. A {@link WebDataBinderFactory} provides a way to create + * a {@link WebDataBinder} instance when needed for data binding and + * type conversion purposes. + * @param parameter the method parameter to resolve. This parameter must + * have previously been passed to {@link #supportsParameter} which must + * have returned {@code true}. + * @param mavContainer the ModelAndViewContainer for the current request + * @param webRequest the current request + * @param binderFactory a factory for creating {@link WebDataBinder} instances + * @return the resolved argument value, or {@code null} if not resolvable + * @throws Exception in case of errors with the preparation of argument values + */ + @Nullable + Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception; + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodArgumentResolverComposite.java b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodArgumentResolverComposite.java new file mode 100644 index 0000000000000000000000000000000000000000..ea6de1a45afa33c96d9a5572248df6a2d7c09bc1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodArgumentResolverComposite.java @@ -0,0 +1,149 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Resolves method parameters by delegating to a list of registered + * {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. + * Previously resolved method parameters are cached for faster lookups. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public class HandlerMethodArgumentResolverComposite implements HandlerMethodArgumentResolver { + + @Deprecated + protected final Log logger = LogFactory.getLog(getClass()); + + private final List argumentResolvers = new LinkedList<>(); + + private final Map argumentResolverCache = + new ConcurrentHashMap<>(256); + + + /** + * Add the given {@link HandlerMethodArgumentResolver}. + */ + public HandlerMethodArgumentResolverComposite addResolver(HandlerMethodArgumentResolver resolver) { + this.argumentResolvers.add(resolver); + return this; + } + + /** + * Add the given {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. + * @since 4.3 + */ + public HandlerMethodArgumentResolverComposite addResolvers( + @Nullable HandlerMethodArgumentResolver... resolvers) { + + if (resolvers != null) { + Collections.addAll(this.argumentResolvers, resolvers); + } + return this; + } + + /** + * Add the given {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. + */ + public HandlerMethodArgumentResolverComposite addResolvers( + @Nullable List resolvers) { + + if (resolvers != null) { + this.argumentResolvers.addAll(resolvers); + } + return this; + } + + /** + * Return a read-only list with the contained resolvers, or an empty list. + */ + public List getResolvers() { + return Collections.unmodifiableList(this.argumentResolvers); + } + + /** + * Clear the list of configured resolvers. + * @since 4.3 + */ + public void clear() { + this.argumentResolvers.clear(); + } + + + /** + * Whether the given {@linkplain MethodParameter method parameter} is + * supported by any registered {@link HandlerMethodArgumentResolver}. + */ + @Override + public boolean supportsParameter(MethodParameter parameter) { + return getArgumentResolver(parameter) != null; + } + + /** + * Iterate over registered + * {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers} + * and invoke the one that supports it. + * @throws IllegalArgumentException if no suitable argument resolver is found + */ + @Override + @Nullable + public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception { + + HandlerMethodArgumentResolver resolver = getArgumentResolver(parameter); + if (resolver == null) { + throw new IllegalArgumentException("Unsupported parameter type [" + + parameter.getParameterType().getName() + "]. supportsParameter should be called first."); + } + return resolver.resolveArgument(parameter, mavContainer, webRequest, binderFactory); + } + + /** + * Find a registered {@link HandlerMethodArgumentResolver} that supports + * the given method parameter. + */ + @Nullable + private HandlerMethodArgumentResolver getArgumentResolver(MethodParameter parameter) { + HandlerMethodArgumentResolver result = this.argumentResolverCache.get(parameter); + if (result == null) { + for (HandlerMethodArgumentResolver resolver : this.argumentResolvers) { + if (resolver.supportsParameter(parameter)) { + result = resolver; + this.argumentResolverCache.put(parameter, result); + break; + } + } + } + return result; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodReturnValueHandler.java b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodReturnValueHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..6317abf679ac30512eddca39b232cd8ad3052af5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodReturnValueHandler.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Strategy interface to handle the value returned from the invocation of a + * handler method . + * + * @author Arjen Poutsma + * @since 3.1 + * @see HandlerMethodArgumentResolver + */ +public interface HandlerMethodReturnValueHandler { + + /** + * Whether the given {@linkplain MethodParameter method return type} is + * supported by this handler. + * @param returnType the method return type to check + * @return {@code true} if this handler supports the supplied return type; + * {@code false} otherwise + */ + boolean supportsReturnType(MethodParameter returnType); + + /** + * Handle the given return value by adding attributes to the model and + * setting a view or setting the + * {@link ModelAndViewContainer#setRequestHandled} flag to {@code true} + * to indicate the response has been handled directly. + * @param returnValue the value returned from the handler method + * @param returnType the type of the return value. This type must have + * previously been passed to {@link #supportsReturnType} which must + * have returned {@code true}. + * @param mavContainer the ModelAndViewContainer for the current request + * @param webRequest the current request + * @throws Exception if the return value handling results in an error + */ + void handleReturnValue(@Nullable Object returnValue, MethodParameter returnType, + ModelAndViewContainer mavContainer, NativeWebRequest webRequest) throws Exception; + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodReturnValueHandlerComposite.java b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodReturnValueHandlerComposite.java new file mode 100644 index 0000000000000000000000000000000000000000..2e7f3a9478d80abec225af3f583e7db98f601434 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/HandlerMethodReturnValueHandlerComposite.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Handles method return values by delegating to a list of registered {@link HandlerMethodReturnValueHandler HandlerMethodReturnValueHandlers}. + * Previously resolved return types are cached for faster lookups. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class HandlerMethodReturnValueHandlerComposite implements HandlerMethodReturnValueHandler { + + protected final Log logger = LogFactory.getLog(getClass()); + + private final List returnValueHandlers = new ArrayList<>(); + + + /** + * Return a read-only list with the registered handlers, or an empty list. + */ + public List getHandlers() { + return Collections.unmodifiableList(this.returnValueHandlers); + } + + /** + * Whether the given {@linkplain MethodParameter method return type} is supported by any registered + * {@link HandlerMethodReturnValueHandler}. + */ + @Override + public boolean supportsReturnType(MethodParameter returnType) { + return getReturnValueHandler(returnType) != null; + } + + @Nullable + private HandlerMethodReturnValueHandler getReturnValueHandler(MethodParameter returnType) { + for (HandlerMethodReturnValueHandler handler : this.returnValueHandlers) { + if (handler.supportsReturnType(returnType)) { + return handler; + } + } + return null; + } + + /** + * Iterate over registered {@link HandlerMethodReturnValueHandler HandlerMethodReturnValueHandlers} and invoke the one that supports it. + * @throws IllegalStateException if no suitable {@link HandlerMethodReturnValueHandler} is found. + */ + @Override + public void handleReturnValue(@Nullable Object returnValue, MethodParameter returnType, + ModelAndViewContainer mavContainer, NativeWebRequest webRequest) throws Exception { + + HandlerMethodReturnValueHandler handler = selectHandler(returnValue, returnType); + if (handler == null) { + throw new IllegalArgumentException("Unknown return value type: " + returnType.getParameterType().getName()); + } + handler.handleReturnValue(returnValue, returnType, mavContainer, webRequest); + } + + @Nullable + private HandlerMethodReturnValueHandler selectHandler(@Nullable Object value, MethodParameter returnType) { + boolean isAsyncValue = isAsyncReturnValue(value, returnType); + for (HandlerMethodReturnValueHandler handler : this.returnValueHandlers) { + if (isAsyncValue && !(handler instanceof AsyncHandlerMethodReturnValueHandler)) { + continue; + } + if (handler.supportsReturnType(returnType)) { + return handler; + } + } + return null; + } + + private boolean isAsyncReturnValue(@Nullable Object value, MethodParameter returnType) { + for (HandlerMethodReturnValueHandler handler : this.returnValueHandlers) { + if (handler instanceof AsyncHandlerMethodReturnValueHandler && + ((AsyncHandlerMethodReturnValueHandler) handler).isAsyncReturnValue(value, returnType)) { + return true; + } + } + return false; + } + + /** + * Add the given {@link HandlerMethodReturnValueHandler}. + */ + public HandlerMethodReturnValueHandlerComposite addHandler(HandlerMethodReturnValueHandler handler) { + this.returnValueHandlers.add(handler); + return this; + } + + /** + * Add the given {@link HandlerMethodReturnValueHandler HandlerMethodReturnValueHandlers}. + */ + public HandlerMethodReturnValueHandlerComposite addHandlers( + @Nullable List handlers) { + + if (handlers != null) { + this.returnValueHandlers.addAll(handlers); + } + return this; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java new file mode 100644 index 0000000000000000000000000000000000000000..817a4f2935198a807d4e15e8f849b9bb8543435c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java @@ -0,0 +1,215 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Arrays; + +import org.springframework.core.DefaultParameterNameDiscoverer; +import org.springframework.core.MethodParameter; +import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; +import org.springframework.util.ReflectionUtils; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.support.SessionStatus; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.HandlerMethod; + +/** + * Extension of {@link HandlerMethod} that invokes the underlying method with + * argument values resolved from the current HTTP request through a list of + * {@link HandlerMethodArgumentResolver}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public class InvocableHandlerMethod extends HandlerMethod { + + private static final Object[] EMPTY_ARGS = new Object[0]; + + + private HandlerMethodArgumentResolverComposite resolvers = new HandlerMethodArgumentResolverComposite(); + + private ParameterNameDiscoverer parameterNameDiscoverer = new DefaultParameterNameDiscoverer(); + + @Nullable + private WebDataBinderFactory dataBinderFactory; + + + /** + * Create an instance from a {@code HandlerMethod}. + */ + public InvocableHandlerMethod(HandlerMethod handlerMethod) { + super(handlerMethod); + } + + /** + * Create an instance from a bean instance and a method. + */ + public InvocableHandlerMethod(Object bean, Method method) { + super(bean, method); + } + + /** + * Construct a new handler method with the given bean instance, method name and parameters. + * @param bean the object bean + * @param methodName the method name + * @param parameterTypes the method parameter types + * @throws NoSuchMethodException when the method cannot be found + */ + public InvocableHandlerMethod(Object bean, String methodName, Class... parameterTypes) + throws NoSuchMethodException { + + super(bean, methodName, parameterTypes); + } + + + /** + * Set {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers} + * to use for resolving method argument values. + */ + public void setHandlerMethodArgumentResolvers(HandlerMethodArgumentResolverComposite argumentResolvers) { + this.resolvers = argumentResolvers; + } + + /** + * Set the ParameterNameDiscoverer for resolving parameter names when needed + * (e.g. default request attribute name). + *

Default is a {@link org.springframework.core.DefaultParameterNameDiscoverer}. + */ + public void setParameterNameDiscoverer(ParameterNameDiscoverer parameterNameDiscoverer) { + this.parameterNameDiscoverer = parameterNameDiscoverer; + } + + /** + * Set the {@link WebDataBinderFactory} to be passed to argument resolvers allowing them + * to create a {@link WebDataBinder} for data binding and type conversion purposes. + */ + public void setDataBinderFactory(WebDataBinderFactory dataBinderFactory) { + this.dataBinderFactory = dataBinderFactory; + } + + + /** + * Invoke the method after resolving its argument values in the context of the given request. + *

Argument values are commonly resolved through + * {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers}. + * The {@code providedArgs} parameter however may supply argument values to be used directly, + * i.e. without argument resolution. Examples of provided argument values include a + * {@link WebDataBinder}, a {@link SessionStatus}, or a thrown exception instance. + * Provided argument values are checked before argument resolvers. + *

Delegates to {@link #getMethodArgumentValues} and calls {@link #doInvoke} with the + * resolved arguments. + * @param request the current request + * @param mavContainer the ModelAndViewContainer for this request + * @param providedArgs "given" arguments matched by type, not resolved + * @return the raw value returned by the invoked method + * @throws Exception raised if no suitable argument resolver can be found, + * or if the method raised an exception + * @see #getMethodArgumentValues + * @see #doInvoke + */ + @Nullable + public Object invokeForRequest(NativeWebRequest request, @Nullable ModelAndViewContainer mavContainer, + Object... providedArgs) throws Exception { + + Object[] args = getMethodArgumentValues(request, mavContainer, providedArgs); + if (logger.isTraceEnabled()) { + logger.trace("Arguments: " + Arrays.toString(args)); + } + return doInvoke(args); + } + + /** + * Get the method argument values for the current request, checking the provided + * argument values and falling back to the configured argument resolvers. + *

The resulting array will be passed into {@link #doInvoke}. + * @since 5.1.2 + */ + protected Object[] getMethodArgumentValues(NativeWebRequest request, @Nullable ModelAndViewContainer mavContainer, + Object... providedArgs) throws Exception { + + MethodParameter[] parameters = getMethodParameters(); + if (ObjectUtils.isEmpty(parameters)) { + return EMPTY_ARGS; + } + + Object[] args = new Object[parameters.length]; + for (int i = 0; i < parameters.length; i++) { + MethodParameter parameter = parameters[i]; + parameter.initParameterNameDiscovery(this.parameterNameDiscoverer); + args[i] = findProvidedArgument(parameter, providedArgs); + if (args[i] != null) { + continue; + } + if (!this.resolvers.supportsParameter(parameter)) { + throw new IllegalStateException(formatArgumentError(parameter, "No suitable resolver")); + } + try { + args[i] = this.resolvers.resolveArgument(parameter, mavContainer, request, this.dataBinderFactory); + } + catch (Exception ex) { + // Leave stack trace for later, exception may actually be resolved and handled... + if (logger.isDebugEnabled()) { + String exMsg = ex.getMessage(); + if (exMsg != null && !exMsg.contains(parameter.getExecutable().toGenericString())) { + logger.debug(formatArgumentError(parameter, exMsg)); + } + } + throw ex; + } + } + return args; + } + + /** + * Invoke the handler method with the given argument values. + */ + @Nullable + protected Object doInvoke(Object... args) throws Exception { + ReflectionUtils.makeAccessible(getBridgedMethod()); + try { + return getBridgedMethod().invoke(getBean(), args); + } + catch (IllegalArgumentException ex) { + assertTargetBean(getBridgedMethod(), getBean(), args); + String text = (ex.getMessage() != null ? ex.getMessage() : "Illegal argument"); + throw new IllegalStateException(formatInvokeError(text, args), ex); + } + catch (InvocationTargetException ex) { + // Unwrap for HandlerExceptionResolvers ... + Throwable targetException = ex.getTargetException(); + if (targetException instanceof RuntimeException) { + throw (RuntimeException) targetException; + } + else if (targetException instanceof Error) { + throw (Error) targetException; + } + else if (targetException instanceof Exception) { + throw (Exception) targetException; + } + else { + throw new IllegalStateException(formatInvokeError("Invocation failure", args), targetException); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/ModelAndViewContainer.java b/spring-web/src/main/java/org/springframework/web/method/support/ModelAndViewContainer.java new file mode 100644 index 0000000000000000000000000000000000000000..44b6be8b54ac7dceec91b72f82ccb5e7fe9b111b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/ModelAndViewContainer.java @@ -0,0 +1,355 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.ui.Model; +import org.springframework.ui.ModelMap; +import org.springframework.validation.support.BindingAwareModelMap; +import org.springframework.web.bind.support.SessionStatus; +import org.springframework.web.bind.support.SimpleSessionStatus; + +/** + * Records model and view related decisions made by + * {@link HandlerMethodArgumentResolver HandlerMethodArgumentResolvers} and + * {@link HandlerMethodReturnValueHandler HandlerMethodReturnValueHandlers} during the course of invocation of + * a controller method. + * + *

The {@link #setRequestHandled} flag can be used to indicate the request + * has been handled directly and view resolution is not required. + * + *

A default {@link Model} is automatically created at instantiation. + * An alternate model instance may be provided via {@link #setRedirectModel} + * for use in a redirect scenario. When {@link #setRedirectModelScenario} is set + * to {@code true} signalling a redirect scenario, the {@link #getModel()} + * returns the redirect model instead of the default model. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public class ModelAndViewContainer { + + private boolean ignoreDefaultModelOnRedirect = false; + + @Nullable + private Object view; + + private final ModelMap defaultModel = new BindingAwareModelMap(); + + @Nullable + private ModelMap redirectModel; + + private boolean redirectModelScenario = false; + + @Nullable + private HttpStatus status; + + private final Set noBinding = new HashSet<>(4); + + private final Set bindingDisabled = new HashSet<>(4); + + private final SessionStatus sessionStatus = new SimpleSessionStatus(); + + private boolean requestHandled = false; + + + /** + * By default the content of the "default" model is used both during + * rendering and redirect scenarios. Alternatively controller methods + * can declare an argument of type {@code RedirectAttributes} and use + * it to provide attributes to prepare the redirect URL. + *

Setting this flag to {@code true} guarantees the "default" model is + * never used in a redirect scenario even if a RedirectAttributes argument + * is not declared. Setting it to {@code false} means the "default" model + * may be used in a redirect if the controller method doesn't declare a + * RedirectAttributes argument. + *

The default setting is {@code false}. + */ + public void setIgnoreDefaultModelOnRedirect(boolean ignoreDefaultModelOnRedirect) { + this.ignoreDefaultModelOnRedirect = ignoreDefaultModelOnRedirect; + } + + /** + * Set a view name to be resolved by the DispatcherServlet via a ViewResolver. + * Will override any pre-existing view name or View. + */ + public void setViewName(@Nullable String viewName) { + this.view = viewName; + } + + /** + * Return the view name to be resolved by the DispatcherServlet via a + * ViewResolver, or {@code null} if a View object is set. + */ + @Nullable + public String getViewName() { + return (this.view instanceof String ? (String) this.view : null); + } + + /** + * Set a View object to be used by the DispatcherServlet. + * Will override any pre-existing view name or View. + */ + public void setView(@Nullable Object view) { + this.view = view; + } + + /** + * Return the View object, or {@code null} if we using a view name + * to be resolved by the DispatcherServlet via a ViewResolver. + */ + @Nullable + public Object getView() { + return this.view; + } + + /** + * Whether the view is a view reference specified via a name to be + * resolved by the DispatcherServlet via a ViewResolver. + */ + public boolean isViewReference() { + return (this.view instanceof String); + } + + /** + * Return the model to use -- either the "default" or the "redirect" model. + * The default model is used if {@code redirectModelScenario=false} or + * there is no redirect model (i.e. RedirectAttributes was not declared as + * a method argument) and {@code ignoreDefaultModelOnRedirect=false}. + */ + public ModelMap getModel() { + if (useDefaultModel()) { + return this.defaultModel; + } + else { + if (this.redirectModel == null) { + this.redirectModel = new ModelMap(); + } + return this.redirectModel; + } + } + + /** + * Whether to use the default model or the redirect model. + */ + private boolean useDefaultModel() { + return (!this.redirectModelScenario || (this.redirectModel == null && !this.ignoreDefaultModelOnRedirect)); + } + + /** + * Return the "default" model created at instantiation. + *

In general it is recommended to use {@link #getModel()} instead which + * returns either the "default" model (template rendering) or the "redirect" + * model (redirect URL preparation). Use of this method may be needed for + * advanced cases when access to the "default" model is needed regardless, + * e.g. to save model attributes specified via {@code @SessionAttributes}. + * @return the default model (never {@code null}) + * @since 4.1.4 + */ + public ModelMap getDefaultModel() { + return this.defaultModel; + } + + /** + * Provide a separate model instance to use in a redirect scenario. + *

The provided additional model however is not used unless + * {@link #setRedirectModelScenario} gets set to {@code true} + * to signal an actual redirect scenario. + */ + public void setRedirectModel(ModelMap redirectModel) { + this.redirectModel = redirectModel; + } + + /** + * Whether the controller has returned a redirect instruction, e.g. a + * "redirect:" prefixed view name, a RedirectView instance, etc. + */ + public void setRedirectModelScenario(boolean redirectModelScenario) { + this.redirectModelScenario = redirectModelScenario; + } + + /** + * Provide an HTTP status that will be passed on to with the + * {@code ModelAndView} used for view rendering purposes. + * @since 4.3 + */ + public void setStatus(@Nullable HttpStatus status) { + this.status = status; + } + + /** + * Return the configured HTTP status, if any. + * @since 4.3 + */ + @Nullable + public HttpStatus getStatus() { + return this.status; + } + + /** + * Programmatically register an attribute for which data binding should not occur, + * not even for a subsequent {@code @ModelAttribute} declaration. + * @param attributeName the name of the attribute + * @since 4.3 + */ + public void setBindingDisabled(String attributeName) { + this.bindingDisabled.add(attributeName); + } + + /** + * Whether binding is disabled for the given model attribute. + * @since 4.3 + */ + public boolean isBindingDisabled(String name) { + return (this.bindingDisabled.contains(name) || this.noBinding.contains(name)); + } + + /** + * Register whether data binding should occur for a corresponding model attribute, + * corresponding to an {@code @ModelAttribute(binding=true/false)} declaration. + *

Note: While this flag will be taken into account by {@link #isBindingDisabled}, + * a hard {@link #setBindingDisabled} declaration will always override it. + * @param attributeName the name of the attribute + * @since 4.3.13 + */ + public void setBinding(String attributeName, boolean enabled) { + if (!enabled) { + this.noBinding.add(attributeName); + } + else { + this.noBinding.remove(attributeName); + } + } + + /** + * Return the {@link SessionStatus} instance to use that can be used to + * signal that session processing is complete. + */ + public SessionStatus getSessionStatus() { + return this.sessionStatus; + } + + /** + * Whether the request has been handled fully within the handler, e.g. + * {@code @ResponseBody} method, and therefore view resolution is not + * necessary. This flag can also be set when controller methods declare an + * argument of type {@code ServletResponse} or {@code OutputStream}). + *

The default value is {@code false}. + */ + public void setRequestHandled(boolean requestHandled) { + this.requestHandled = requestHandled; + } + + /** + * Whether the request has been handled fully within the handler. + */ + public boolean isRequestHandled() { + return this.requestHandled; + } + + /** + * Add the supplied attribute to the underlying model. + * A shortcut for {@code getModel().addAttribute(String, Object)}. + */ + public ModelAndViewContainer addAttribute(String name, @Nullable Object value) { + getModel().addAttribute(name, value); + return this; + } + + /** + * Add the supplied attribute to the underlying model. + * A shortcut for {@code getModel().addAttribute(Object)}. + */ + public ModelAndViewContainer addAttribute(Object value) { + getModel().addAttribute(value); + return this; + } + + /** + * Copy all attributes to the underlying model. + * A shortcut for {@code getModel().addAllAttributes(Map)}. + */ + public ModelAndViewContainer addAllAttributes(@Nullable Map attributes) { + getModel().addAllAttributes(attributes); + return this; + } + + /** + * Copy attributes in the supplied {@code Map} with existing objects of + * the same name taking precedence (i.e. not getting replaced). + * A shortcut for {@code getModel().mergeAttributes(Map)}. + */ + public ModelAndViewContainer mergeAttributes(@Nullable Map attributes) { + getModel().mergeAttributes(attributes); + return this; + } + + /** + * Remove the given attributes from the model. + */ + public ModelAndViewContainer removeAttributes(@Nullable Map attributes) { + if (attributes != null) { + for (String key : attributes.keySet()) { + getModel().remove(key); + } + } + return this; + } + + /** + * Whether the underlying model contains the given attribute name. + * A shortcut for {@code getModel().containsAttribute(String)}. + */ + public boolean containsAttribute(String name) { + return getModel().containsAttribute(name); + } + + + /** + * Return diagnostic information. + */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder("ModelAndViewContainer: "); + if (!isRequestHandled()) { + if (isViewReference()) { + sb.append("reference to view with name '").append(this.view).append("'"); + } + else { + sb.append("View is [").append(this.view).append(']'); + } + if (useDefaultModel()) { + sb.append("; default model "); + } + else { + sb.append("; redirect model "); + } + sb.append(getModel()); + } + else { + sb.append("Request handled directly"); + } + return sb.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/UriComponentsContributor.java b/spring-web/src/main/java/org/springframework/web/method/support/UriComponentsContributor.java new file mode 100644 index 0000000000000000000000000000000000000000..368ba65668afef3754183218e793226572b7cfd5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/UriComponentsContributor.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.util.Map; + +import org.springframework.core.MethodParameter; +import org.springframework.core.convert.ConversionService; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Strategy for contributing to the building of a {@link UriComponents} by + * looking at a method parameter and an argument value and deciding what + * part of the target URL should be updated. + * + * @author Oliver Gierke + * @author Rossen Stoyanchev + * @since 4.0 + */ +public interface UriComponentsContributor { + + /** + * Whether this contributor supports the given method parameter. + */ + boolean supportsParameter(MethodParameter parameter); + + /** + * Process the given method argument and either update the + * {@link UriComponentsBuilder} or add to the map with URI variables + * to use to expand the URI after all arguments are processed. + * @param parameter the controller method parameter (never {@code null}) + * @param value the argument value (possibly {@code null}) + * @param builder the builder to update (never {@code null}) + * @param uriVariables a map to add URI variables to (never {@code null}) + * @param conversionService a ConversionService to format values as Strings + */ + void contributeMethodArgument(MethodParameter parameter, Object value, UriComponentsBuilder builder, + Map uriVariables, ConversionService conversionService); + +} diff --git a/spring-web/src/main/java/org/springframework/web/method/support/package-info.java b/spring-web/src/main/java/org/springframework/web/method/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..4d4de4a0572ebd79b072bad9c808f66ce3158103 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/method/support/package-info.java @@ -0,0 +1,9 @@ +/** + * Generic support classes for handler method processing. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.method.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MaxUploadSizeExceededException.java b/spring-web/src/main/java/org/springframework/web/multipart/MaxUploadSizeExceededException.java new file mode 100644 index 0000000000000000000000000000000000000000..9d8c4f18cfe7a5852ebdebe49770de42edcafd74 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/MaxUploadSizeExceededException.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart; + +import org.springframework.lang.Nullable; + +/** + * MultipartException subclass thrown when an upload exceeds the + * maximum upload size allowed. + * + * @author Juergen Hoeller + * @since 1.0.1 + */ +@SuppressWarnings("serial") +public class MaxUploadSizeExceededException extends MultipartException { + + private final long maxUploadSize; + + + /** + * Constructor for MaxUploadSizeExceededException. + * @param maxUploadSize the maximum upload size allowed, + * or -1 if the size limit isn't known + */ + public MaxUploadSizeExceededException(long maxUploadSize) { + this(maxUploadSize, null); + } + + /** + * Constructor for MaxUploadSizeExceededException. + * @param maxUploadSize the maximum upload size allowed, + * or -1 if the size limit isn't known + * @param ex root cause from multipart parsing API in use + */ + public MaxUploadSizeExceededException(long maxUploadSize, @Nullable Throwable ex) { + super("Maximum upload size " + (maxUploadSize >= 0 ? "of " + maxUploadSize + " bytes " : "") + "exceeded", ex); + this.maxUploadSize = maxUploadSize; + } + + + /** + * Return the maximum upload size allowed, + * or -1 if the size limit isn't known. + */ + public long getMaxUploadSize() { + return this.maxUploadSize; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartException.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartException.java new file mode 100644 index 0000000000000000000000000000000000000000..401cd572bd6c5eb732e33f90c80009884208aa7f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart; + +import org.springframework.core.NestedRuntimeException; +import org.springframework.lang.Nullable; + +/** + * Exception thrown when multipart resolution fails. + * + * @author Trevor D. Cook + * @author Juergen Hoeller + * @since 29.09.2003 + * @see MultipartResolver#resolveMultipart + * @see org.springframework.web.multipart.support.MultipartFilter + */ +@SuppressWarnings("serial") +public class MultipartException extends NestedRuntimeException { + + /** + * Constructor for MultipartException. + * @param msg the detail message + */ + public MultipartException(String msg) { + super(msg); + } + + /** + * Constructor for MultipartException. + * @param msg the detail message + * @param cause the root cause from the multipart parsing API in use + */ + public MultipartException(String msg, @Nullable Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartFile.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartFile.java new file mode 100644 index 0000000000000000000000000000000000000000..f5deb9ecffb2b6e057c409b1f5a6b2eb1bbdc7af --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartFile.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.springframework.core.io.InputStreamSource; +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.FileCopyUtils; + +/** + * A representation of an uploaded file received in a multipart request. + * + *

The file contents are either stored in memory or temporarily on disk. + * In either case, the user is responsible for copying file contents to a + * session-level or persistent store as and if desired. The temporary storage + * will be cleared at the end of request processing. + * + * @author Juergen Hoeller + * @author Trevor D. Cook + * @since 29.09.2003 + * @see org.springframework.web.multipart.MultipartHttpServletRequest + * @see org.springframework.web.multipart.MultipartResolver + */ +public interface MultipartFile extends InputStreamSource { + + /** + * Return the name of the parameter in the multipart form. + * @return the name of the parameter (never {@code null} or empty) + */ + String getName(); + + /** + * Return the original filename in the client's filesystem. + *

This may contain path information depending on the browser used, + * but it typically will not with any other than Opera. + * @return the original filename, or the empty String if no file has been chosen + * in the multipart form, or {@code null} if not defined or not available + * @see org.apache.commons.fileupload.FileItem#getName() + * @see org.springframework.web.multipart.commons.CommonsMultipartFile#setPreserveFilename + */ + @Nullable + String getOriginalFilename(); + + /** + * Return the content type of the file. + * @return the content type, or {@code null} if not defined + * (or no file has been chosen in the multipart form) + */ + @Nullable + String getContentType(); + + /** + * Return whether the uploaded file is empty, that is, either no file has + * been chosen in the multipart form or the chosen file has no content. + */ + boolean isEmpty(); + + /** + * Return the size of the file in bytes. + * @return the size of the file, or 0 if empty + */ + long getSize(); + + /** + * Return the contents of the file as an array of bytes. + * @return the contents of the file as bytes, or an empty byte array if empty + * @throws IOException in case of access errors (if the temporary store fails) + */ + byte[] getBytes() throws IOException; + + /** + * Return an InputStream to read the contents of the file from. + *

The user is responsible for closing the returned stream. + * @return the contents of the file as stream, or an empty stream if empty + * @throws IOException in case of access errors (if the temporary store fails) + */ + @Override + InputStream getInputStream() throws IOException; + + /** + * Return a Resource representation of this MultipartFile. This can be used + * as input to the {@code RestTemplate} or the {@code WebClient} to expose + * content length and the filename along with the InputStream. + * @return this MultipartFile adapted to the Resource contract + * @since 5.1 + */ + default Resource getResource() { + return new MultipartFileResource(this); + } + + /** + * Transfer the received file to the given destination file. + *

This may either move the file in the filesystem, copy the file in the + * filesystem, or save memory-held contents to the destination file. If the + * destination file already exists, it will be deleted first. + *

If the target file has been moved in the filesystem, this operation + * cannot be invoked again afterwards. Therefore, call this method just once + * in order to work with any storage mechanism. + *

NOTE: Depending on the underlying provider, temporary storage + * may be container-dependent, including the base directory for relative + * destinations specified here (e.g. with Servlet 3.0 multipart handling). + * For absolute destinations, the target file may get renamed/moved from its + * temporary location or newly copied, even if a temporary copy already exists. + * @param dest the destination file (typically absolute) + * @throws IOException in case of reading or writing errors + * @throws IllegalStateException if the file has already been moved + * in the filesystem and is not available anymore for another transfer + * @see org.apache.commons.fileupload.FileItem#write(File) + * @see javax.servlet.http.Part#write(String) + */ + void transferTo(File dest) throws IOException, IllegalStateException; + + /** + * Transfer the received file to the given destination file. + *

The default implementation simply copies the file input stream. + * @since 5.1 + * @see #getInputStream() + * @see #transferTo(File) + */ + default void transferTo(Path dest) throws IOException, IllegalStateException { + FileCopyUtils.copy(getInputStream(), Files.newOutputStream(dest)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartFileResource.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartFileResource.java new file mode 100644 index 0000000000000000000000000000000000000000..1cfe7326b7ae89b0781dbece06b34fa86393d402 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartFileResource.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart; + +import java.io.IOException; +import java.io.InputStream; + +import org.springframework.core.io.AbstractResource; +import org.springframework.util.Assert; + +/** + * Adapt {@link MultipartFile} to {@link org.springframework.core.io.Resource}, + * exposing the content as {@code InputStream} and also overriding + * {@link #contentLength()} as well as {@link #getFilename()}. + * + * @author Rossen Stoyanchev + * @since 5.1 + */ +class MultipartFileResource extends AbstractResource { + + private final MultipartFile multipartFile; + + + public MultipartFileResource(MultipartFile multipartFile) { + Assert.notNull(multipartFile, "MultipartFile must not be null"); + this.multipartFile = multipartFile; + } + + + /** + * This implementation always returns {@code true}. + */ + @Override + public boolean exists() { + return true; + } + + /** + * This implementation always returns {@code true}. + */ + @Override + public boolean isOpen() { + return true; + } + + @Override + public long contentLength() { + return this.multipartFile.getSize(); + } + + @Override + public String getFilename() { + return this.multipartFile.getOriginalFilename(); + } + + /** + * This implementation throws IllegalStateException if attempting to + * read the underlying stream multiple times. + */ + @Override + public InputStream getInputStream() throws IOException, IllegalStateException { + return this.multipartFile.getInputStream(); + } + + /** + * This implementation returns a description that has the Multipart name. + */ + @Override + public String getDescription() { + return "MultipartFile resource [" + this.multipartFile.getName() + "]"; + } + + + @Override + public boolean equals(Object other) { + return (this == other || (other instanceof MultipartFileResource && + ((MultipartFileResource) other).multipartFile.equals(this.multipartFile))); + } + + @Override + public int hashCode() { + return this.multipartFile.hashCode(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..c083b42fdc6f199084a7a2e5311b36bab41a4fdd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; + +/** + * Provides additional methods for dealing with multipart content within a + * servlet request, allowing to access uploaded files. + * Implementations also need to override the standard + * {@link javax.servlet.ServletRequest} methods for parameter access, making + * multipart parameters available. + * + *

A concrete implementation is + * {@link org.springframework.web.multipart.support.DefaultMultipartHttpServletRequest}. + * As an intermediate step, + * {@link org.springframework.web.multipart.support.AbstractMultipartHttpServletRequest} + * can be subclassed. + * + * @author Juergen Hoeller + * @author Trevor D. Cook + * @since 29.09.2003 + * @see MultipartResolver + * @see MultipartFile + * @see javax.servlet.http.HttpServletRequest#getParameter + * @see javax.servlet.http.HttpServletRequest#getParameterNames + * @see javax.servlet.http.HttpServletRequest#getParameterMap + * @see org.springframework.web.multipart.support.DefaultMultipartHttpServletRequest + * @see org.springframework.web.multipart.support.AbstractMultipartHttpServletRequest + */ +public interface MultipartHttpServletRequest extends HttpServletRequest, MultipartRequest { + + /** + * Return this request's method as a convenient HttpMethod instance. + */ + @Nullable + HttpMethod getRequestMethod(); + + /** + * Return this request's headers as a convenient HttpHeaders instance. + */ + HttpHeaders getRequestHeaders(); + + /** + * Return the headers associated with the specified part of the multipart request. + *

If the underlying implementation supports access to headers, then all headers are returned. + * Otherwise, the returned headers will include a 'Content-Type' header at the very least. + */ + @Nullable + HttpHeaders getMultipartHeaders(String paramOrFileName); + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..f5339b4716671219fc7bdbf2516db1ff2d1eded6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartRequest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * This interface defines the multipart request access operations that are exposed + * for actual multipart requests. It is extended by {@link MultipartHttpServletRequest}. + * + * @author Juergen Hoeller + * @author Arjen Poutsma + * @since 2.5.2 + */ +public interface MultipartRequest { + + /** + * Return an {@link java.util.Iterator} of String objects containing the + * parameter names of the multipart files contained in this request. These + * are the field names of the form (like with normal parameters), not the + * original file names. + * @return the names of the files + */ + Iterator getFileNames(); + + /** + * Return the contents plus description of an uploaded file in this request, + * or {@code null} if it does not exist. + * @param name a String specifying the parameter name of the multipart file + * @return the uploaded content in the form of a {@link MultipartFile} object + */ + @Nullable + MultipartFile getFile(String name); + + /** + * Return the contents plus description of uploaded files in this request, + * or an empty list if it does not exist. + * @param name a String specifying the parameter name of the multipart file + * @return the uploaded content in the form of a {@link MultipartFile} list + * @since 3.0 + */ + List getFiles(String name); + + /** + * Return a {@link java.util.Map} of the multipart files contained in this request. + * @return a map containing the parameter names as keys, and the + * {@link MultipartFile} objects as values + */ + Map getFileMap(); + + /** + * Return a {@link MultiValueMap} of the multipart files contained in this request. + * @return a map containing the parameter names as keys, and a list of + * {@link MultipartFile} objects as values + * @since 3.0 + */ + MultiValueMap getMultiFileMap(); + + /** + * Determine the content type of the specified request part. + * @param paramOrFileName the name of the part + * @return the associated content type, or {@code null} if not defined + * @since 3.1 + */ + @Nullable + String getMultipartContentType(String paramOrFileName); + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartResolver.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..c8d7b642391169ee42ae73cca792db0246df3009 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartResolver.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart; + +import javax.servlet.http.HttpServletRequest; + +/** + * A strategy interface for multipart file upload resolution in accordance + * with RFC 1867. + * Implementations are typically usable both within an application context + * and standalone. + * + *

There are two concrete implementations included in Spring, as of Spring 3.1: + *

    + *
  • {@link org.springframework.web.multipart.commons.CommonsMultipartResolver} + * for Apache Commons FileUpload + *
  • {@link org.springframework.web.multipart.support.StandardServletMultipartResolver} + * for the Servlet 3.0+ Part API + *
+ * + *

There is no default resolver implementation used for Spring + * {@link org.springframework.web.servlet.DispatcherServlet DispatcherServlets}, + * as an application might choose to parse its multipart requests itself. To define + * an implementation, create a bean with the id "multipartResolver" in a + * {@link org.springframework.web.servlet.DispatcherServlet DispatcherServlet's} + * application context. Such a resolver gets applied to all requests handled + * by that {@link org.springframework.web.servlet.DispatcherServlet}. + * + *

If a {@link org.springframework.web.servlet.DispatcherServlet} detects a + * multipart request, it will resolve it via the configured {@link MultipartResolver} + * and pass on a wrapped {@link javax.servlet.http.HttpServletRequest}. Controllers + * can then cast their given request to the {@link MultipartHttpServletRequest} + * interface, which allows for access to any {@link MultipartFile MultipartFiles}. + * Note that this cast is only supported in case of an actual multipart request. + * + *

+ * public ModelAndView handleRequest(HttpServletRequest request, HttpServletResponse response) {
+ *   MultipartHttpServletRequest multipartRequest = (MultipartHttpServletRequest) request;
+ *   MultipartFile multipartFile = multipartRequest.getFile("image");
+ *   ...
+ * }
+ * + * Instead of direct access, command or form controllers can register a + * {@link org.springframework.web.multipart.support.ByteArrayMultipartFileEditor} + * or {@link org.springframework.web.multipart.support.StringMultipartFileEditor} + * with their data binder, to automatically apply multipart content to form + * bean properties. + * + *

As an alternative to using a {@link MultipartResolver} with a + * {@link org.springframework.web.servlet.DispatcherServlet}, + * a {@link org.springframework.web.multipart.support.MultipartFilter} can be + * registered in {@code web.xml}. It will delegate to a corresponding + * {@link MultipartResolver} bean in the root application context. This is mainly + * intended for applications that do not use Spring's own web MVC framework. + * + *

Note: There is hardly ever a need to access the {@link MultipartResolver} + * itself from application code. It will simply do its work behind the scenes, + * making {@link MultipartHttpServletRequest MultipartHttpServletRequests} + * available to controllers. + * + * @author Juergen Hoeller + * @author Trevor D. Cook + * @since 29.09.2003 + * @see MultipartHttpServletRequest + * @see MultipartFile + * @see org.springframework.web.multipart.commons.CommonsMultipartResolver + * @see org.springframework.web.multipart.support.ByteArrayMultipartFileEditor + * @see org.springframework.web.multipart.support.StringMultipartFileEditor + * @see org.springframework.web.servlet.DispatcherServlet + */ +public interface MultipartResolver { + + /** + * Determine if the given request contains multipart content. + *

Will typically check for content type "multipart/form-data", but the actually + * accepted requests might depend on the capabilities of the resolver implementation. + * @param request the servlet request to be evaluated + * @return whether the request contains multipart content + */ + boolean isMultipart(HttpServletRequest request); + + /** + * Parse the given HTTP request into multipart files and parameters, + * and wrap the request inside a + * {@link org.springframework.web.multipart.MultipartHttpServletRequest} + * object that provides access to file descriptors and makes contained + * parameters accessible via the standard ServletRequest methods. + * @param request the servlet request to wrap (must be of a multipart content type) + * @return the wrapped servlet request + * @throws MultipartException if the servlet request is not multipart, or if + * implementation-specific problems are encountered (such as exceeding file size limits) + * @see MultipartHttpServletRequest#getFile + * @see MultipartHttpServletRequest#getFileNames + * @see MultipartHttpServletRequest#getFileMap + * @see javax.servlet.http.HttpServletRequest#getParameter + * @see javax.servlet.http.HttpServletRequest#getParameterNames + * @see javax.servlet.http.HttpServletRequest#getParameterMap + */ + MultipartHttpServletRequest resolveMultipart(HttpServletRequest request) throws MultipartException; + + /** + * Cleanup any resources used for the multipart handling, + * like a storage for the uploaded files. + * @param request the request to cleanup resources for + */ + void cleanupMultipart(MultipartHttpServletRequest request); + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsFileUploadSupport.java b/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsFileUploadSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..99fcee5f8c4816051bb70a2e1a5ae81a8484b46d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsFileUploadSupport.java @@ -0,0 +1,374 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.commons; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.nio.charset.Charset; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.fileupload.FileItem; +import org.apache.commons.fileupload.FileItemFactory; +import org.apache.commons.fileupload.FileUpload; +import org.apache.commons.fileupload.disk.DiskFileItemFactory; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.io.Resource; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.util.WebUtils; + +/** + * Base class for multipart resolvers that use Apache Commons FileUpload + * 1.2 or above. + * + *

Provides common configuration properties and parsing functionality + * for multipart requests, using a Map of Spring CommonsMultipartFile instances + * as representation of uploaded files and a String-based parameter Map as + * representation of uploaded form fields. + * + * @author Juergen Hoeller + * @since 2.0 + * @see CommonsMultipartFile + * @see CommonsMultipartResolver + */ +public abstract class CommonsFileUploadSupport { + + protected final Log logger = LogFactory.getLog(getClass()); + + private final DiskFileItemFactory fileItemFactory; + + private final FileUpload fileUpload; + + private boolean uploadTempDirSpecified = false; + + private boolean preserveFilename = false; + + + /** + * Instantiate a new CommonsFileUploadSupport with its + * corresponding FileItemFactory and FileUpload instances. + * @see #newFileItemFactory + * @see #newFileUpload + */ + public CommonsFileUploadSupport() { + this.fileItemFactory = newFileItemFactory(); + this.fileUpload = newFileUpload(getFileItemFactory()); + } + + + /** + * Return the underlying {@code org.apache.commons.fileupload.disk.DiskFileItemFactory} + * instance. There is hardly any need to access this. + * @return the underlying DiskFileItemFactory instance + */ + public DiskFileItemFactory getFileItemFactory() { + return this.fileItemFactory; + } + + /** + * Return the underlying {@code org.apache.commons.fileupload.FileUpload} + * instance. There is hardly any need to access this. + * @return the underlying FileUpload instance + */ + public FileUpload getFileUpload() { + return this.fileUpload; + } + + /** + * Set the maximum allowed size (in bytes) before an upload gets rejected. + * -1 indicates no limit (the default). + * @param maxUploadSize the maximum upload size allowed + * @see org.apache.commons.fileupload.FileUploadBase#setSizeMax + */ + public void setMaxUploadSize(long maxUploadSize) { + this.fileUpload.setSizeMax(maxUploadSize); + } + + /** + * Set the maximum allowed size (in bytes) for each individual file before + * an upload gets rejected. -1 indicates no limit (the default). + * @param maxUploadSizePerFile the maximum upload size per file + * @since 4.2 + * @see org.apache.commons.fileupload.FileUploadBase#setFileSizeMax + */ + public void setMaxUploadSizePerFile(long maxUploadSizePerFile) { + this.fileUpload.setFileSizeMax(maxUploadSizePerFile); + } + + /** + * Set the maximum allowed size (in bytes) before uploads are written to disk. + * Uploaded files will still be received past this amount, but they will not be + * stored in memory. Default is 10240, according to Commons FileUpload. + * @param maxInMemorySize the maximum in memory size allowed + * @see org.apache.commons.fileupload.disk.DiskFileItemFactory#setSizeThreshold + */ + public void setMaxInMemorySize(int maxInMemorySize) { + this.fileItemFactory.setSizeThreshold(maxInMemorySize); + } + + /** + * Set the default character encoding to use for parsing requests, + * to be applied to headers of individual parts and to form fields. + * Default is ISO-8859-1, according to the Servlet spec. + *

If the request specifies a character encoding itself, the request + * encoding will override this setting. This also allows for generically + * overriding the character encoding in a filter that invokes the + * {@code ServletRequest.setCharacterEncoding} method. + * @param defaultEncoding the character encoding to use + * @see javax.servlet.ServletRequest#getCharacterEncoding + * @see javax.servlet.ServletRequest#setCharacterEncoding + * @see WebUtils#DEFAULT_CHARACTER_ENCODING + * @see org.apache.commons.fileupload.FileUploadBase#setHeaderEncoding + */ + public void setDefaultEncoding(String defaultEncoding) { + this.fileUpload.setHeaderEncoding(defaultEncoding); + } + + /** + * Determine the default encoding to use for parsing requests. + * @see #setDefaultEncoding + */ + protected String getDefaultEncoding() { + String encoding = getFileUpload().getHeaderEncoding(); + if (encoding == null) { + encoding = WebUtils.DEFAULT_CHARACTER_ENCODING; + } + return encoding; + } + + /** + * Set the temporary directory where uploaded files get stored. + * Default is the servlet container's temporary directory for the web application. + * @see org.springframework.web.util.WebUtils#TEMP_DIR_CONTEXT_ATTRIBUTE + */ + public void setUploadTempDir(Resource uploadTempDir) throws IOException { + if (!uploadTempDir.exists() && !uploadTempDir.getFile().mkdirs()) { + throw new IllegalArgumentException("Given uploadTempDir [" + uploadTempDir + "] could not be created"); + } + this.fileItemFactory.setRepository(uploadTempDir.getFile()); + this.uploadTempDirSpecified = true; + } + + /** + * Return the temporary directory where uploaded files get stored. + * @see #setUploadTempDir + */ + protected boolean isUploadTempDirSpecified() { + return this.uploadTempDirSpecified; + } + + /** + * Set whether to preserve the filename as sent by the client, not stripping off + * path information in {@link CommonsMultipartFile#getOriginalFilename()}. + *

Default is "false", stripping off path information that may prefix the + * actual filename e.g. from Opera. Switch this to "true" for preserving the + * client-specified filename as-is, including potential path separators. + * @since 4.3.5 + * @see MultipartFile#getOriginalFilename() + * @see CommonsMultipartFile#setPreserveFilename(boolean) + */ + public void setPreserveFilename(boolean preserveFilename) { + this.preserveFilename = preserveFilename; + } + + + /** + * Factory method for a Commons DiskFileItemFactory instance. + *

Default implementation returns a standard DiskFileItemFactory. + * Can be overridden to use a custom subclass, e.g. for testing purposes. + * @return the new DiskFileItemFactory instance + */ + protected DiskFileItemFactory newFileItemFactory() { + return new DiskFileItemFactory(); + } + + /** + * Factory method for a Commons FileUpload instance. + *

To be implemented by subclasses. + * @param fileItemFactory the Commons FileItemFactory to build upon + * @return the Commons FileUpload instance + */ + protected abstract FileUpload newFileUpload(FileItemFactory fileItemFactory); + + + /** + * Determine an appropriate FileUpload instance for the given encoding. + *

Default implementation returns the shared FileUpload instance + * if the encoding matches, else creates a new FileUpload instance + * with the same configuration other than the desired encoding. + * @param encoding the character encoding to use + * @return an appropriate FileUpload instance. + */ + protected FileUpload prepareFileUpload(@Nullable String encoding) { + FileUpload fileUpload = getFileUpload(); + FileUpload actualFileUpload = fileUpload; + + // Use new temporary FileUpload instance if the request specifies + // its own encoding that does not match the default encoding. + if (encoding != null && !encoding.equals(fileUpload.getHeaderEncoding())) { + actualFileUpload = newFileUpload(getFileItemFactory()); + actualFileUpload.setSizeMax(fileUpload.getSizeMax()); + actualFileUpload.setFileSizeMax(fileUpload.getFileSizeMax()); + actualFileUpload.setHeaderEncoding(encoding); + } + + return actualFileUpload; + } + + /** + * Parse the given List of Commons FileItems into a Spring MultipartParsingResult, + * containing Spring MultipartFile instances and a Map of multipart parameter. + * @param fileItems the Commons FileItems to parse + * @param encoding the encoding to use for form fields + * @return the Spring MultipartParsingResult + * @see CommonsMultipartFile#CommonsMultipartFile(org.apache.commons.fileupload.FileItem) + */ + protected MultipartParsingResult parseFileItems(List fileItems, String encoding) { + MultiValueMap multipartFiles = new LinkedMultiValueMap<>(); + Map multipartParameters = new HashMap<>(); + Map multipartParameterContentTypes = new HashMap<>(); + + // Extract multipart files and multipart parameters. + for (FileItem fileItem : fileItems) { + if (fileItem.isFormField()) { + String value; + String partEncoding = determineEncoding(fileItem.getContentType(), encoding); + try { + value = fileItem.getString(partEncoding); + } + catch (UnsupportedEncodingException ex) { + if (logger.isWarnEnabled()) { + logger.warn("Could not decode multipart item '" + fileItem.getFieldName() + + "' with encoding '" + partEncoding + "': using platform default"); + } + value = fileItem.getString(); + } + String[] curParam = multipartParameters.get(fileItem.getFieldName()); + if (curParam == null) { + // simple form field + multipartParameters.put(fileItem.getFieldName(), new String[] {value}); + } + else { + // array of simple form fields + String[] newParam = StringUtils.addStringToArray(curParam, value); + multipartParameters.put(fileItem.getFieldName(), newParam); + } + multipartParameterContentTypes.put(fileItem.getFieldName(), fileItem.getContentType()); + } + else { + // multipart file field + CommonsMultipartFile file = createMultipartFile(fileItem); + multipartFiles.add(file.getName(), file); + LogFormatUtils.traceDebug(logger, traceOn -> + "Part '" + file.getName() + "', size " + file.getSize() + + " bytes, filename='" + file.getOriginalFilename() + "'" + + (traceOn ? ", storage=" + file.getStorageDescription() : "") + ); + } + } + return new MultipartParsingResult(multipartFiles, multipartParameters, multipartParameterContentTypes); + } + + /** + * Create a {@link CommonsMultipartFile} wrapper for the given Commons {@link FileItem}. + * @param fileItem the Commons FileItem to wrap + * @return the corresponding CommonsMultipartFile (potentially a custom subclass) + * @since 4.3.5 + * @see #setPreserveFilename(boolean) + * @see CommonsMultipartFile#setPreserveFilename(boolean) + */ + protected CommonsMultipartFile createMultipartFile(FileItem fileItem) { + CommonsMultipartFile multipartFile = new CommonsMultipartFile(fileItem); + multipartFile.setPreserveFilename(this.preserveFilename); + return multipartFile; + } + + /** + * Cleanup the Spring MultipartFiles created during multipart parsing, + * potentially holding temporary data on disk. + *

Deletes the underlying Commons FileItem instances. + * @param multipartFiles a Collection of MultipartFile instances + * @see org.apache.commons.fileupload.FileItem#delete() + */ + protected void cleanupFileItems(MultiValueMap multipartFiles) { + for (List files : multipartFiles.values()) { + for (MultipartFile file : files) { + if (file instanceof CommonsMultipartFile) { + CommonsMultipartFile cmf = (CommonsMultipartFile) file; + cmf.getFileItem().delete(); + LogFormatUtils.traceDebug(logger, traceOn -> + "Cleaning up part '" + cmf.getName() + + "', filename '" + cmf.getOriginalFilename() + "'" + + (traceOn ? ", stored " + cmf.getStorageDescription() : "")); + } + } + } + } + + private String determineEncoding(String contentTypeHeader, String defaultEncoding) { + if (!StringUtils.hasText(contentTypeHeader)) { + return defaultEncoding; + } + MediaType contentType = MediaType.parseMediaType(contentTypeHeader); + Charset charset = contentType.getCharset(); + return (charset != null ? charset.name() : defaultEncoding); + } + + + /** + * Holder for a Map of Spring MultipartFiles and a Map of + * multipart parameters. + */ + protected static class MultipartParsingResult { + + private final MultiValueMap multipartFiles; + + private final Map multipartParameters; + + private final Map multipartParameterContentTypes; + + public MultipartParsingResult(MultiValueMap mpFiles, + Map mpParams, Map mpParamContentTypes) { + + this.multipartFiles = mpFiles; + this.multipartParameters = mpParams; + this.multipartParameterContentTypes = mpParamContentTypes; + } + + public MultiValueMap getMultipartFiles() { + return this.multipartFiles; + } + + public Map getMultipartParameters() { + return this.multipartParameters; + } + + public Map getMultipartParameterContentTypes() { + return this.multipartParameterContentTypes; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsMultipartFile.java b/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsMultipartFile.java new file mode 100644 index 0000000000000000000000000000000000000000..abdafd425d16bb9d91b64a0a8505d753ef808484 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsMultipartFile.java @@ -0,0 +1,235 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.commons; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.Serializable; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.apache.commons.fileupload.FileItem; +import org.apache.commons.fileupload.FileUploadException; +import org.apache.commons.fileupload.disk.DiskFileItem; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogFormatUtils; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.StreamUtils; +import org.springframework.web.multipart.MultipartFile; + +/** + * {@link MultipartFile} implementation for Apache Commons FileUpload. + * + * @author Trevor D. Cook + * @author Juergen Hoeller + * @since 29.09.2003 + * @see CommonsMultipartResolver + */ +@SuppressWarnings("serial") +public class CommonsMultipartFile implements MultipartFile, Serializable { + + protected static final Log logger = LogFactory.getLog(CommonsMultipartFile.class); + + private final FileItem fileItem; + + private final long size; + + private boolean preserveFilename = false; + + + /** + * Create an instance wrapping the given FileItem. + * @param fileItem the FileItem to wrap + */ + public CommonsMultipartFile(FileItem fileItem) { + this.fileItem = fileItem; + this.size = this.fileItem.getSize(); + } + + + /** + * Return the underlying {@code org.apache.commons.fileupload.FileItem} + * instance. There is hardly any need to access this. + */ + public final FileItem getFileItem() { + return this.fileItem; + } + + /** + * Set whether to preserve the filename as sent by the client, not stripping off + * path information in {@link CommonsMultipartFile#getOriginalFilename()}. + *

Default is "false", stripping off path information that may prefix the + * actual filename e.g. from Opera. Switch this to "true" for preserving the + * client-specified filename as-is, including potential path separators. + * @since 4.3.5 + * @see #getOriginalFilename() + * @see CommonsMultipartResolver#setPreserveFilename(boolean) + */ + public void setPreserveFilename(boolean preserveFilename) { + this.preserveFilename = preserveFilename; + } + + + @Override + public String getName() { + return this.fileItem.getFieldName(); + } + + @Override + public String getOriginalFilename() { + String filename = this.fileItem.getName(); + if (filename == null) { + // Should never happen. + return ""; + } + if (this.preserveFilename) { + // Do not try to strip off a path... + return filename; + } + + // Check for Unix-style path + int unixSep = filename.lastIndexOf('/'); + // Check for Windows-style path + int winSep = filename.lastIndexOf('\\'); + // Cut off at latest possible point + int pos = (winSep > unixSep ? winSep : unixSep); + if (pos != -1) { + // Any sort of path separator found... + return filename.substring(pos + 1); + } + else { + // A plain name + return filename; + } + } + + @Override + public String getContentType() { + return this.fileItem.getContentType(); + } + + @Override + public boolean isEmpty() { + return (this.size == 0); + } + + @Override + public long getSize() { + return this.size; + } + + @Override + public byte[] getBytes() { + if (!isAvailable()) { + throw new IllegalStateException("File has been moved - cannot be read again"); + } + byte[] bytes = this.fileItem.get(); + return (bytes != null ? bytes : new byte[0]); + } + + @Override + public InputStream getInputStream() throws IOException { + if (!isAvailable()) { + throw new IllegalStateException("File has been moved - cannot be read again"); + } + InputStream inputStream = this.fileItem.getInputStream(); + return (inputStream != null ? inputStream : StreamUtils.emptyInput()); + } + + @Override + public void transferTo(File dest) throws IOException, IllegalStateException { + if (!isAvailable()) { + throw new IllegalStateException("File has already been moved - cannot be transferred again"); + } + + if (dest.exists() && !dest.delete()) { + throw new IOException( + "Destination file [" + dest.getAbsolutePath() + "] already exists and could not be deleted"); + } + + try { + this.fileItem.write(dest); + LogFormatUtils.traceDebug(logger, traceOn -> { + String action = "transferred"; + if (!this.fileItem.isInMemory()) { + action = (isAvailable() ? "copied" : "moved"); + } + return "Part '" + getName() + "', filename '" + getOriginalFilename() + "'" + + (traceOn ? ", stored " + getStorageDescription() : "") + + ": " + action + " to [" + dest.getAbsolutePath() + "]"; + }); + } + catch (FileUploadException ex) { + throw new IllegalStateException(ex.getMessage(), ex); + } + catch (IllegalStateException | IOException ex) { + // Pass through IllegalStateException when coming from FileItem directly, + // or propagate an exception from I/O operations within FileItem.write + throw ex; + } + catch (Exception ex) { + throw new IOException("File transfer failed", ex); + } + } + + @Override + public void transferTo(Path dest) throws IOException, IllegalStateException { + if (!isAvailable()) { + throw new IllegalStateException("File has already been moved - cannot be transferred again"); + } + + FileCopyUtils.copy(this.fileItem.getInputStream(), Files.newOutputStream(dest)); + } + + /** + * Determine whether the multipart content is still available. + * If a temporary file has been moved, the content is no longer available. + */ + protected boolean isAvailable() { + // If in memory, it's available. + if (this.fileItem.isInMemory()) { + return true; + } + // Check actual existence of temporary file. + if (this.fileItem instanceof DiskFileItem) { + return ((DiskFileItem) this.fileItem).getStoreLocation().exists(); + } + // Check whether current file size is different than original one. + return (this.fileItem.getSize() == this.size); + } + + /** + * Return a description for the storage location of the multipart content. + * Tries to be as specific as possible: mentions the file location in case + * of a temporary file. + */ + public String getStorageDescription() { + if (this.fileItem.isInMemory()) { + return "in memory"; + } + else if (this.fileItem instanceof DiskFileItem) { + return "at [" + ((DiskFileItem) this.fileItem).getStoreLocation().getAbsolutePath() + "]"; + } + else { + return "on disk"; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsMultipartResolver.java b/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsMultipartResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..eed600aff08a1dade28290654e4df0562b6c935e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/commons/CommonsMultipartResolver.java @@ -0,0 +1,204 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.commons; + +import java.util.List; + +import javax.servlet.ServletContext; +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.fileupload.FileItem; +import org.apache.commons.fileupload.FileItemFactory; +import org.apache.commons.fileupload.FileUpload; +import org.apache.commons.fileupload.FileUploadBase; +import org.apache.commons.fileupload.FileUploadException; +import org.apache.commons.fileupload.servlet.ServletFileUpload; + +import org.springframework.util.Assert; +import org.springframework.web.context.ServletContextAware; +import org.springframework.web.multipart.MaxUploadSizeExceededException; +import org.springframework.web.multipart.MultipartException; +import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.MultipartResolver; +import org.springframework.web.multipart.support.AbstractMultipartHttpServletRequest; +import org.springframework.web.multipart.support.DefaultMultipartHttpServletRequest; +import org.springframework.web.util.WebUtils; + +/** + * Servlet-based {@link MultipartResolver} implementation for + * Apache Commons FileUpload + * 1.2 or above. + * + *

Provides "maxUploadSize", "maxInMemorySize" and "defaultEncoding" settings as + * bean properties (inherited from {@link CommonsFileUploadSupport}). See corresponding + * ServletFileUpload / DiskFileItemFactory properties ("sizeMax", "sizeThreshold", + * "headerEncoding") for details in terms of defaults and accepted values. + * + *

Saves temporary files to the servlet container's temporary directory. + * Needs to be initialized either by an application context or + * via the constructor that takes a ServletContext (for standalone usage). + * + * @author Trevor D. Cook + * @author Juergen Hoeller + * @since 29.09.2003 + * @see #CommonsMultipartResolver(ServletContext) + * @see #setResolveLazily + * @see org.apache.commons.fileupload.servlet.ServletFileUpload + * @see org.apache.commons.fileupload.disk.DiskFileItemFactory + */ +public class CommonsMultipartResolver extends CommonsFileUploadSupport + implements MultipartResolver, ServletContextAware { + + private boolean resolveLazily = false; + + + /** + * Constructor for use as bean. Determines the servlet container's + * temporary directory via the ServletContext passed in as through the + * ServletContextAware interface (typically by a WebApplicationContext). + * @see #setServletContext + * @see org.springframework.web.context.ServletContextAware + * @see org.springframework.web.context.WebApplicationContext + */ + public CommonsMultipartResolver() { + super(); + } + + /** + * Constructor for standalone usage. Determines the servlet container's + * temporary directory via the given ServletContext. + * @param servletContext the ServletContext to use + */ + public CommonsMultipartResolver(ServletContext servletContext) { + this(); + setServletContext(servletContext); + } + + + /** + * Set whether to resolve the multipart request lazily at the time of + * file or parameter access. + *

Default is "false", resolving the multipart elements immediately, throwing + * corresponding exceptions at the time of the {@link #resolveMultipart} call. + * Switch this to "true" for lazy multipart parsing, throwing parse exceptions + * once the application attempts to obtain multipart files or parameters. + */ + public void setResolveLazily(boolean resolveLazily) { + this.resolveLazily = resolveLazily; + } + + /** + * Initialize the underlying {@code org.apache.commons.fileupload.servlet.ServletFileUpload} + * instance. Can be overridden to use a custom subclass, e.g. for testing purposes. + * @param fileItemFactory the Commons FileItemFactory to use + * @return the new ServletFileUpload instance + */ + @Override + protected FileUpload newFileUpload(FileItemFactory fileItemFactory) { + return new ServletFileUpload(fileItemFactory); + } + + @Override + public void setServletContext(ServletContext servletContext) { + if (!isUploadTempDirSpecified()) { + getFileItemFactory().setRepository(WebUtils.getTempDir(servletContext)); + } + } + + + @Override + public boolean isMultipart(HttpServletRequest request) { + return ServletFileUpload.isMultipartContent(request); + } + + @Override + public MultipartHttpServletRequest resolveMultipart(final HttpServletRequest request) throws MultipartException { + Assert.notNull(request, "Request must not be null"); + if (this.resolveLazily) { + return new DefaultMultipartHttpServletRequest(request) { + @Override + protected void initializeMultipart() { + MultipartParsingResult parsingResult = parseRequest(request); + setMultipartFiles(parsingResult.getMultipartFiles()); + setMultipartParameters(parsingResult.getMultipartParameters()); + setMultipartParameterContentTypes(parsingResult.getMultipartParameterContentTypes()); + } + }; + } + else { + MultipartParsingResult parsingResult = parseRequest(request); + return new DefaultMultipartHttpServletRequest(request, parsingResult.getMultipartFiles(), + parsingResult.getMultipartParameters(), parsingResult.getMultipartParameterContentTypes()); + } + } + + /** + * Parse the given servlet request, resolving its multipart elements. + * @param request the request to parse + * @return the parsing result + * @throws MultipartException if multipart resolution failed. + */ + protected MultipartParsingResult parseRequest(HttpServletRequest request) throws MultipartException { + String encoding = determineEncoding(request); + FileUpload fileUpload = prepareFileUpload(encoding); + try { + List fileItems = ((ServletFileUpload) fileUpload).parseRequest(request); + return parseFileItems(fileItems, encoding); + } + catch (FileUploadBase.SizeLimitExceededException ex) { + throw new MaxUploadSizeExceededException(fileUpload.getSizeMax(), ex); + } + catch (FileUploadBase.FileSizeLimitExceededException ex) { + throw new MaxUploadSizeExceededException(fileUpload.getFileSizeMax(), ex); + } + catch (FileUploadException ex) { + throw new MultipartException("Failed to parse multipart servlet request", ex); + } + } + + /** + * Determine the encoding for the given request. + * Can be overridden in subclasses. + *

The default implementation checks the request encoding, + * falling back to the default encoding specified for this resolver. + * @param request current HTTP request + * @return the encoding for the request (never {@code null}) + * @see javax.servlet.ServletRequest#getCharacterEncoding + * @see #setDefaultEncoding + */ + protected String determineEncoding(HttpServletRequest request) { + String encoding = request.getCharacterEncoding(); + if (encoding == null) { + encoding = getDefaultEncoding(); + } + return encoding; + } + + @Override + public void cleanupMultipart(MultipartHttpServletRequest request) { + if (!(request instanceof AbstractMultipartHttpServletRequest) || + ((AbstractMultipartHttpServletRequest) request).isResolved()) { + try { + cleanupFileItems(request.getMultiFileMap()); + } + catch (Throwable ex) { + logger.warn("Failed to perform multipart cleanup for servlet request", ex); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/commons/package-info.java b/spring-web/src/main/java/org/springframework/web/multipart/commons/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..03b636815d45f42b59315d8992fbd9c4bbdf20a7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/commons/package-info.java @@ -0,0 +1,10 @@ +/** + * MultipartResolver implementation for + * Apache Commons FileUpload. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.multipart.commons; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/multipart/package-info.java b/spring-web/src/main/java/org/springframework/web/multipart/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..8620295ce9bcd90aae3411a26ec3eb7ccebc3484 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/package-info.java @@ -0,0 +1,12 @@ +/** + * Multipart resolution framework for handling file uploads. + * Provides a MultipartResolver strategy interface, + * and a generic extension of the HttpServletRequest interface + * for accessing multipart files in web application code. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.multipart; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/AbstractMultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/AbstractMultipartHttpServletRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..3c2a8834c563da06b03aec2215336f13c3abdbbe --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/AbstractMultipartHttpServletRequest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.util.Collections; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartHttpServletRequest; + +/** + * Abstract base implementation of the MultipartHttpServletRequest interface. + * Provides management of pre-generated MultipartFile instances. + * + * @author Juergen Hoeller + * @author Arjen Poutsma + * @since 06.10.2003 + */ +public abstract class AbstractMultipartHttpServletRequest extends HttpServletRequestWrapper + implements MultipartHttpServletRequest { + + @Nullable + private MultiValueMap multipartFiles; + + + /** + * Wrap the given HttpServletRequest in a MultipartHttpServletRequest. + * @param request the request to wrap + */ + protected AbstractMultipartHttpServletRequest(HttpServletRequest request) { + super(request); + } + + + @Override + public HttpServletRequest getRequest() { + return (HttpServletRequest) super.getRequest(); + } + + @Override + public HttpMethod getRequestMethod() { + return HttpMethod.resolve(getRequest().getMethod()); + } + + @Override + public HttpHeaders getRequestHeaders() { + HttpHeaders headers = new HttpHeaders(); + Enumeration headerNames = getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + headers.put(headerName, Collections.list(getHeaders(headerName))); + } + return headers; + } + + @Override + public Iterator getFileNames() { + return getMultipartFiles().keySet().iterator(); + } + + @Override + public MultipartFile getFile(String name) { + return getMultipartFiles().getFirst(name); + } + + @Override + public List getFiles(String name) { + List multipartFiles = getMultipartFiles().get(name); + if (multipartFiles != null) { + return multipartFiles; + } + else { + return Collections.emptyList(); + } + } + + @Override + public Map getFileMap() { + return getMultipartFiles().toSingleValueMap(); + } + + @Override + public MultiValueMap getMultiFileMap() { + return getMultipartFiles(); + } + + /** + * Determine whether the underlying multipart request has been resolved. + * @return {@code true} when eagerly initialized or lazily triggered, + * {@code false} in case of a lazy-resolution request that got aborted + * before any parameters or multipart files have been accessed + * @since 4.3.15 + * @see #getMultipartFiles() + */ + public boolean isResolved() { + return (this.multipartFiles != null); + } + + + /** + * Set a Map with parameter names as keys and list of MultipartFile objects as values. + * To be invoked by subclasses on initialization. + */ + protected final void setMultipartFiles(MultiValueMap multipartFiles) { + this.multipartFiles = + new LinkedMultiValueMap<>(Collections.unmodifiableMap(multipartFiles)); + } + + /** + * Obtain the MultipartFile Map for retrieval, + * lazily initializing it if necessary. + * @see #initializeMultipart() + */ + protected MultiValueMap getMultipartFiles() { + if (this.multipartFiles == null) { + initializeMultipart(); + } + return this.multipartFiles; + } + + /** + * Lazily initialize the multipart request, if possible. + * Only called if not already eagerly initialized. + */ + protected void initializeMultipart() { + throw new IllegalStateException("Multipart request not initialized"); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/ByteArrayMultipartFileEditor.java b/spring-web/src/main/java/org/springframework/web/multipart/support/ByteArrayMultipartFileEditor.java new file mode 100644 index 0000000000000000000000000000000000000000..b5439e55899a204a819cb47b81a5b0d1108dcfc4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/ByteArrayMultipartFileEditor.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.io.IOException; + +import org.springframework.beans.propertyeditors.ByteArrayPropertyEditor; +import org.springframework.lang.Nullable; +import org.springframework.web.multipart.MultipartFile; + +/** + * Custom {@link java.beans.PropertyEditor} for converting + * {@link MultipartFile MultipartFiles} to byte arrays. + * + * @author Juergen Hoeller + * @since 13.10.2003 + */ +public class ByteArrayMultipartFileEditor extends ByteArrayPropertyEditor { + + @Override + public void setValue(@Nullable Object value) { + if (value instanceof MultipartFile) { + MultipartFile multipartFile = (MultipartFile) value; + try { + super.setValue(multipartFile.getBytes()); + } + catch (IOException ex) { + throw new IllegalArgumentException("Cannot read contents of multipart file", ex); + } + } + else if (value instanceof byte[]) { + super.setValue(value); + } + else { + super.setValue(value != null ? value.toString().getBytes() : null); + } + } + + @Override + public String getAsText() { + byte[] value = (byte[]) getValue(); + return (value != null ? new String(value) : ""); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..d882620af27b2d1550d56cb2fe1ec942614dcffe --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequest.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartFile; + +/** + * Default implementation of the + * {@link org.springframework.web.multipart.MultipartHttpServletRequest} + * interface. Provides management of pre-generated parameter values. + * + *

Used by {@link org.springframework.web.multipart.commons.CommonsMultipartResolver}. + * + * @author Trevor D. Cook + * @author Juergen Hoeller + * @author Arjen Poutsma + * @since 29.09.2003 + * @see org.springframework.web.multipart.MultipartResolver + */ +public class DefaultMultipartHttpServletRequest extends AbstractMultipartHttpServletRequest { + + private static final String CONTENT_TYPE = "Content-Type"; + + @Nullable + private Map multipartParameters; + + @Nullable + private Map multipartParameterContentTypes; + + + /** + * Wrap the given HttpServletRequest in a MultipartHttpServletRequest. + * @param request the servlet request to wrap + * @param mpFiles a map of the multipart files + * @param mpParams a map of the parameters to expose, + * with Strings as keys and String arrays as values + */ + public DefaultMultipartHttpServletRequest(HttpServletRequest request, MultiValueMap mpFiles, + Map mpParams, Map mpParamContentTypes) { + + super(request); + setMultipartFiles(mpFiles); + setMultipartParameters(mpParams); + setMultipartParameterContentTypes(mpParamContentTypes); + } + + /** + * Wrap the given HttpServletRequest in a MultipartHttpServletRequest. + * @param request the servlet request to wrap + */ + public DefaultMultipartHttpServletRequest(HttpServletRequest request) { + super(request); + } + + + @Override + @Nullable + public String getParameter(String name) { + String[] values = getMultipartParameters().get(name); + if (values != null) { + return (values.length > 0 ? values[0] : null); + } + return super.getParameter(name); + } + + @Override + public String[] getParameterValues(String name) { + String[] parameterValues = super.getParameterValues(name); + String[] mpValues = getMultipartParameters().get(name); + if (mpValues == null) { + return parameterValues; + } + if (parameterValues == null || getQueryString() == null) { + return mpValues; + } + else { + String[] result = new String[mpValues.length + parameterValues.length]; + System.arraycopy(mpValues, 0, result, 0, mpValues.length); + System.arraycopy(parameterValues, 0, result, mpValues.length, parameterValues.length); + return result; + } + } + + @Override + public Enumeration getParameterNames() { + Map multipartParameters = getMultipartParameters(); + if (multipartParameters.isEmpty()) { + return super.getParameterNames(); + } + + Set paramNames = new LinkedHashSet<>(); + paramNames.addAll(Collections.list(super.getParameterNames())); + paramNames.addAll(multipartParameters.keySet()); + return Collections.enumeration(paramNames); + } + + @Override + public Map getParameterMap() { + Map result = new LinkedHashMap<>(); + Enumeration names = getParameterNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + result.put(name, getParameterValues(name)); + } + return result; + } + + @Override + public String getMultipartContentType(String paramOrFileName) { + MultipartFile file = getFile(paramOrFileName); + if (file != null) { + return file.getContentType(); + } + else { + return getMultipartParameterContentTypes().get(paramOrFileName); + } + } + + @Override + public HttpHeaders getMultipartHeaders(String paramOrFileName) { + String contentType = getMultipartContentType(paramOrFileName); + if (contentType != null) { + HttpHeaders headers = new HttpHeaders(); + headers.add(CONTENT_TYPE, contentType); + return headers; + } + else { + return null; + } + } + + + /** + * Set a Map with parameter names as keys and String array objects as values. + * To be invoked by subclasses on initialization. + */ + protected final void setMultipartParameters(Map multipartParameters) { + this.multipartParameters = multipartParameters; + } + + /** + * Obtain the multipart parameter Map for retrieval, + * lazily initializing it if necessary. + * @see #initializeMultipart() + */ + protected Map getMultipartParameters() { + if (this.multipartParameters == null) { + initializeMultipart(); + } + return this.multipartParameters; + } + + /** + * Set a Map with parameter names as keys and content type Strings as values. + * To be invoked by subclasses on initialization. + */ + protected final void setMultipartParameterContentTypes(Map multipartParameterContentTypes) { + this.multipartParameterContentTypes = multipartParameterContentTypes; + } + + /** + * Obtain the multipart parameter content type Map for retrieval, + * lazily initializing it if necessary. + * @see #initializeMultipart() + */ + protected Map getMultipartParameterContentTypes() { + if (this.multipartParameterContentTypes == null) { + initializeMultipart(); + } + return this.multipartParameterContentTypes; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/MissingServletRequestPartException.java b/spring-web/src/main/java/org/springframework/web/multipart/support/MissingServletRequestPartException.java new file mode 100644 index 0000000000000000000000000000000000000000..414405d8038f9a82f0a0b78eb8e94f1d5ccc4db7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/MissingServletRequestPartException.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import javax.servlet.ServletException; + +import org.springframework.web.multipart.MultipartResolver; + +/** + * Raised when the part of a "multipart/form-data" request identified by its + * name cannot be found. + * + *

This may be because the request is not a multipart/form-data request, + * because the part is not present in the request, or because the web + * application is not configured correctly for processing multipart requests, + * e.g. no {@link MultipartResolver}. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +@SuppressWarnings("serial") +public class MissingServletRequestPartException extends ServletException { + + private final String requestPartName; + + + /** + * Constructor for MissingServletRequestPartException. + * @param requestPartName the name of the missing part of the multipart request + */ + public MissingServletRequestPartException(String requestPartName) { + super("Required request part '" + requestPartName + "' is not present"); + this.requestPartName = requestPartName; + } + + + /** + * Return the name of the offending part of the multipart request. + */ + public String getRequestPartName() { + return this.requestPartName; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/MultipartFilter.java b/spring-web/src/main/java/org/springframework/web/multipart/support/MultipartFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..1cd8fb2c3cece43ae9455a73da09d2dd4cc17eab --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/MultipartFilter.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.MultipartResolver; + +/** + * Servlet Filter that resolves multipart requests via a {@link MultipartResolver}. + * in the root web application context. + * + *

Looks up the MultipartResolver in Spring's root web application context. + * Supports a "multipartResolverBeanName" filter init-param in {@code web.xml}; + * the default bean name is "filterMultipartResolver". + * + *

If no MultipartResolver bean is found, this filter falls back to a default + * MultipartResolver: {@link StandardServletMultipartResolver} for Servlet 3.0, + * based on a multipart-config section in {@code web.xml}. + * Note however that at present the Servlet specification only defines how to + * enable multipart configuration on a Servlet and as a result multipart request + * processing is likely not possible in a Filter unless the Servlet container + * provides a workaround such as Tomcat's "allowCasualMultipartParsing" property. + * + *

MultipartResolver lookup is customizable: Override this filter's + * {@code lookupMultipartResolver} method to use a custom MultipartResolver + * instance, for example if not using a Spring web application context. + * Note that the lookup method should not create a new MultipartResolver instance + * for each call but rather return a reference to a pre-built instance. + * + *

Note: This filter is an alternative to using DispatcherServlet's + * MultipartResolver support, for example for web applications with custom web views + * which do not use Spring's web MVC, or for custom filters applied before a Spring MVC + * DispatcherServlet (e.g. {@link org.springframework.web.filter.HiddenHttpMethodFilter}). + * In any case, this filter should not be combined with servlet-specific multipart resolution. + * + * @author Juergen Hoeller + * @since 08.10.2003 + * @see #setMultipartResolverBeanName + * @see #lookupMultipartResolver + * @see org.springframework.web.multipart.MultipartResolver + * @see org.springframework.web.servlet.DispatcherServlet + */ +public class MultipartFilter extends OncePerRequestFilter { + + /** + * The default name for the multipart resolver bean. + */ + public static final String DEFAULT_MULTIPART_RESOLVER_BEAN_NAME = "filterMultipartResolver"; + + private final MultipartResolver defaultMultipartResolver = new StandardServletMultipartResolver(); + + private String multipartResolverBeanName = DEFAULT_MULTIPART_RESOLVER_BEAN_NAME; + + + /** + * Set the bean name of the MultipartResolver to fetch from Spring's + * root application context. Default is "filterMultipartResolver". + */ + public void setMultipartResolverBeanName(String multipartResolverBeanName) { + this.multipartResolverBeanName = multipartResolverBeanName; + } + + /** + * Return the bean name of the MultipartResolver to fetch from Spring's + * root application context. + */ + protected String getMultipartResolverBeanName() { + return this.multipartResolverBeanName; + } + + + /** + * Check for a multipart request via this filter's MultipartResolver, + * and wrap the original request with a MultipartHttpServletRequest if appropriate. + *

All later elements in the filter chain, most importantly servlets, benefit + * from proper parameter extraction in the multipart case, and are able to cast to + * MultipartHttpServletRequest if they need to. + */ + @Override + protected void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + MultipartResolver multipartResolver = lookupMultipartResolver(request); + + HttpServletRequest processedRequest = request; + if (multipartResolver.isMultipart(processedRequest)) { + if (logger.isTraceEnabled()) { + logger.trace("Resolving multipart request"); + } + processedRequest = multipartResolver.resolveMultipart(processedRequest); + } + else { + // A regular request... + if (logger.isTraceEnabled()) { + logger.trace("Not a multipart request"); + } + } + + try { + filterChain.doFilter(processedRequest, response); + } + finally { + if (processedRequest instanceof MultipartHttpServletRequest) { + multipartResolver.cleanupMultipart((MultipartHttpServletRequest) processedRequest); + } + } + } + + /** + * Look up the MultipartResolver that this filter should use, + * taking the current HTTP request as argument. + *

The default implementation delegates to the {@code lookupMultipartResolver} + * without arguments. + * @return the MultipartResolver to use + * @see #lookupMultipartResolver() + */ + protected MultipartResolver lookupMultipartResolver(HttpServletRequest request) { + return lookupMultipartResolver(); + } + + /** + * Look for a MultipartResolver bean in the root web application context. + * Supports a "multipartResolverBeanName" filter init param; the default + * bean name is "filterMultipartResolver". + *

This can be overridden to use a custom MultipartResolver instance, + * for example if not using a Spring web application context. + * @return the MultipartResolver instance + */ + protected MultipartResolver lookupMultipartResolver() { + WebApplicationContext wac = WebApplicationContextUtils.getWebApplicationContext(getServletContext()); + String beanName = getMultipartResolverBeanName(); + if (wac != null && wac.containsBean(beanName)) { + if (logger.isDebugEnabled()) { + logger.debug("Using MultipartResolver '" + beanName + "' for MultipartFilter"); + } + return wac.getBean(beanName, MultipartResolver.class); + } + else { + return this.defaultMultipartResolver; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/MultipartResolutionDelegate.java b/spring-web/src/main/java/org/springframework/web/multipart/support/MultipartResolutionDelegate.java new file mode 100644 index 0000000000000000000000000000000000000000..1e753989affd3ef69543a7ffc08f9dd04d66b4c2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/MultipartResolutionDelegate.java @@ -0,0 +1,174 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.MultipartRequest; +import org.springframework.web.util.WebUtils; + +/** + * A common delegate for {@code HandlerMethodArgumentResolver} implementations + * which need to resolve {@link MultipartFile} and {@link Part} arguments. + * + * @author Juergen Hoeller + * @since 4.3 + */ +public abstract class MultipartResolutionDelegate { + + /** + * Indicates an unresolvable value. + */ + public static final Object UNRESOLVABLE = new Object(); + + + @Nullable + public static MultipartRequest resolveMultipartRequest(NativeWebRequest webRequest) { + MultipartRequest multipartRequest = webRequest.getNativeRequest(MultipartRequest.class); + if (multipartRequest != null) { + return multipartRequest; + } + HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); + if (servletRequest != null && isMultipartContent(servletRequest)) { + return new StandardMultipartHttpServletRequest(servletRequest); + } + return null; + } + + public static boolean isMultipartRequest(HttpServletRequest request) { + return (WebUtils.getNativeRequest(request, MultipartHttpServletRequest.class) != null || + isMultipartContent(request)); + } + + private static boolean isMultipartContent(HttpServletRequest request) { + String contentType = request.getContentType(); + return (contentType != null && contentType.toLowerCase().startsWith("multipart/")); + } + + static MultipartHttpServletRequest asMultipartHttpServletRequest(HttpServletRequest request) { + MultipartHttpServletRequest unwrapped = WebUtils.getNativeRequest(request, MultipartHttpServletRequest.class); + if (unwrapped != null) { + return unwrapped; + } + return new StandardMultipartHttpServletRequest(request); + } + + + public static boolean isMultipartArgument(MethodParameter parameter) { + Class paramType = parameter.getNestedParameterType(); + return (MultipartFile.class == paramType || + isMultipartFileCollection(parameter) || isMultipartFileArray(parameter) || + (Part.class == paramType || isPartCollection(parameter) || isPartArray(parameter))); + } + + @Nullable + public static Object resolveMultipartArgument(String name, MethodParameter parameter, HttpServletRequest request) + throws Exception { + + MultipartHttpServletRequest multipartRequest = + WebUtils.getNativeRequest(request, MultipartHttpServletRequest.class); + boolean isMultipart = (multipartRequest != null || isMultipartContent(request)); + + if (MultipartFile.class == parameter.getNestedParameterType()) { + if (multipartRequest == null && isMultipart) { + multipartRequest = new StandardMultipartHttpServletRequest(request); + } + return (multipartRequest != null ? multipartRequest.getFile(name) : null); + } + else if (isMultipartFileCollection(parameter)) { + if (multipartRequest == null && isMultipart) { + multipartRequest = new StandardMultipartHttpServletRequest(request); + } + return (multipartRequest != null ? multipartRequest.getFiles(name) : null); + } + else if (isMultipartFileArray(parameter)) { + if (multipartRequest == null && isMultipart) { + multipartRequest = new StandardMultipartHttpServletRequest(request); + } + if (multipartRequest != null) { + List multipartFiles = multipartRequest.getFiles(name); + return multipartFiles.toArray(new MultipartFile[0]); + } + else { + return null; + } + } + else if (Part.class == parameter.getNestedParameterType()) { + return (isMultipart ? request.getPart(name): null); + } + else if (isPartCollection(parameter)) { + return (isMultipart ? resolvePartList(request, name) : null); + } + else if (isPartArray(parameter)) { + return (isMultipart ? resolvePartList(request, name).toArray(new Part[0]) : null); + } + else { + return UNRESOLVABLE; + } + } + + private static boolean isMultipartFileCollection(MethodParameter methodParam) { + return (MultipartFile.class == getCollectionParameterType(methodParam)); + } + + private static boolean isMultipartFileArray(MethodParameter methodParam) { + return (MultipartFile.class == methodParam.getNestedParameterType().getComponentType()); + } + + private static boolean isPartCollection(MethodParameter methodParam) { + return (Part.class == getCollectionParameterType(methodParam)); + } + + private static boolean isPartArray(MethodParameter methodParam) { + return (Part.class == methodParam.getNestedParameterType().getComponentType()); + } + + @Nullable + private static Class getCollectionParameterType(MethodParameter methodParam) { + Class paramType = methodParam.getNestedParameterType(); + if (Collection.class == paramType || List.class.isAssignableFrom(paramType)){ + Class valueType = ResolvableType.forMethodParameter(methodParam).asCollection().resolveGeneric(); + if (valueType != null) { + return valueType; + } + } + return null; + } + + private static List resolvePartList(HttpServletRequest request, String name) throws Exception { + Collection parts = request.getParts(); + List result = new ArrayList<>(parts.size()); + for (Part part : parts) { + if (part.getName().equals(name)) { + result.add(part); + } + } + return result; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..e390f90f733fe9f5ff882b72a21ed1234d7cf645 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.web.multipart.MultipartException; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.MultipartResolver; + +/** + * {@link ServerHttpRequest} implementation that accesses one part of a multipart + * request. If using {@link MultipartResolver} configuration the part is accessed + * through a {@link MultipartFile}. Or if using Servlet 3.0 multipart processing + * the part is accessed through {@code ServletRequest.getPart}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + */ +public class RequestPartServletServerHttpRequest extends ServletServerHttpRequest { + + private final MultipartHttpServletRequest multipartRequest; + + private final String requestPartName; + + private final HttpHeaders multipartHeaders; + + + /** + * Create a new {@code RequestPartServletServerHttpRequest} instance. + * @param request the current servlet request + * @param requestPartName the name of the part to adapt to the {@link ServerHttpRequest} contract + * @throws MissingServletRequestPartException if the request part cannot be found + * @throws MultipartException if MultipartHttpServletRequest cannot be initialized + */ + public RequestPartServletServerHttpRequest(HttpServletRequest request, String requestPartName) + throws MissingServletRequestPartException { + + super(request); + + this.multipartRequest = MultipartResolutionDelegate.asMultipartHttpServletRequest(request); + this.requestPartName = requestPartName; + + HttpHeaders multipartHeaders = this.multipartRequest.getMultipartHeaders(requestPartName); + if (multipartHeaders == null) { + throw new MissingServletRequestPartException(requestPartName); + } + this.multipartHeaders = multipartHeaders; + } + + + @Override + public HttpHeaders getHeaders() { + return this.multipartHeaders; + } + + @Override + public InputStream getBody() throws IOException { + // Prefer Servlet Part resolution to cover file as well as parameter streams + boolean servletParts = (this.multipartRequest instanceof StandardMultipartHttpServletRequest); + if (servletParts) { + Part part = retrieveServletPart(); + if (part != null) { + return part.getInputStream(); + } + } + + // Spring-style distinction between MultipartFile and String parameters + MultipartFile file = this.multipartRequest.getFile(this.requestPartName); + if (file != null) { + return file.getInputStream(); + } + String paramValue = this.multipartRequest.getParameter(this.requestPartName); + if (paramValue != null) { + return new ByteArrayInputStream(paramValue.getBytes(determineCharset())); + } + + // Fallback: Servlet Part resolution even if not indicated + if (!servletParts) { + Part part = retrieveServletPart(); + if (part != null) { + return part.getInputStream(); + } + } + + throw new IllegalStateException("No body available for request part '" + this.requestPartName + "'"); + } + + @Nullable + private Part retrieveServletPart() { + try { + return this.multipartRequest.getPart(this.requestPartName); + } + catch (Exception ex) { + throw new MultipartException("Failed to retrieve request part '" + this.requestPartName + "'", ex); + } + } + + private Charset determineCharset() { + MediaType contentType = getHeaders().getContentType(); + if (contentType != null) { + Charset charset = contentType.getCharset(); + if (charset != null) { + return charset; + } + } + String encoding = this.multipartRequest.getCharacterEncoding(); + return (encoding != null ? Charset.forName(encoding) : FORM_CHARSET); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..6970c8e12ea7e6f9e7e4af16b75736fbbb11250f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequest.java @@ -0,0 +1,290 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.Serializable; +import java.io.UnsupportedEncodingException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +import javax.mail.internet.MimeUtility; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; + +import org.springframework.http.ContentDisposition; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MaxUploadSizeExceededException; +import org.springframework.web.multipart.MultipartException; +import org.springframework.web.multipart.MultipartFile; + +/** + * Spring MultipartHttpServletRequest adapter, wrapping a Servlet 3.0 HttpServletRequest + * and its Part objects. Parameters get exposed through the native request's getParameter + * methods - without any custom processing on our side. + * + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 3.1 + * @see StandardServletMultipartResolver + */ +public class StandardMultipartHttpServletRequest extends AbstractMultipartHttpServletRequest { + + @Nullable + private Set multipartParameterNames; + + + /** + * Create a new StandardMultipartHttpServletRequest wrapper for the given request, + * immediately parsing the multipart content. + * @param request the servlet request to wrap + * @throws MultipartException if parsing failed + */ + public StandardMultipartHttpServletRequest(HttpServletRequest request) throws MultipartException { + this(request, false); + } + + /** + * Create a new StandardMultipartHttpServletRequest wrapper for the given request. + * @param request the servlet request to wrap + * @param lazyParsing whether multipart parsing should be triggered lazily on + * first access of multipart files or parameters + * @throws MultipartException if an immediate parsing attempt failed + * @since 3.2.9 + */ + public StandardMultipartHttpServletRequest(HttpServletRequest request, boolean lazyParsing) + throws MultipartException { + + super(request); + if (!lazyParsing) { + parseRequest(request); + } + } + + + private void parseRequest(HttpServletRequest request) { + try { + Collection parts = request.getParts(); + this.multipartParameterNames = new LinkedHashSet<>(parts.size()); + MultiValueMap files = new LinkedMultiValueMap<>(parts.size()); + for (Part part : parts) { + String headerValue = part.getHeader(HttpHeaders.CONTENT_DISPOSITION); + ContentDisposition disposition = ContentDisposition.parse(headerValue); + String filename = disposition.getFilename(); + if (filename != null) { + if (filename.startsWith("=?") && filename.endsWith("?=")) { + filename = MimeDelegate.decode(filename); + } + files.add(part.getName(), new StandardMultipartFile(part, filename)); + } + else { + this.multipartParameterNames.add(part.getName()); + } + } + setMultipartFiles(files); + } + catch (Throwable ex) { + handleParseFailure(ex); + } + } + + protected void handleParseFailure(Throwable ex) { + String msg = ex.getMessage(); + if (msg != null && msg.contains("size") && msg.contains("exceed")) { + throw new MaxUploadSizeExceededException(-1, ex); + } + throw new MultipartException("Failed to parse multipart servlet request", ex); + } + + @Override + protected void initializeMultipart() { + parseRequest(getRequest()); + } + + @Override + public Enumeration getParameterNames() { + if (this.multipartParameterNames == null) { + initializeMultipart(); + } + if (this.multipartParameterNames.isEmpty()) { + return super.getParameterNames(); + } + + // Servlet 3.0 getParameterNames() not guaranteed to include multipart form items + // (e.g. on WebLogic 12) -> need to merge them here to be on the safe side + Set paramNames = new LinkedHashSet<>(); + Enumeration paramEnum = super.getParameterNames(); + while (paramEnum.hasMoreElements()) { + paramNames.add(paramEnum.nextElement()); + } + paramNames.addAll(this.multipartParameterNames); + return Collections.enumeration(paramNames); + } + + @Override + public Map getParameterMap() { + if (this.multipartParameterNames == null) { + initializeMultipart(); + } + if (this.multipartParameterNames.isEmpty()) { + return super.getParameterMap(); + } + + // Servlet 3.0 getParameterMap() not guaranteed to include multipart form items + // (e.g. on WebLogic 12) -> need to merge them here to be on the safe side + Map paramMap = new LinkedHashMap<>(super.getParameterMap()); + for (String paramName : this.multipartParameterNames) { + if (!paramMap.containsKey(paramName)) { + paramMap.put(paramName, getParameterValues(paramName)); + } + } + return paramMap; + } + + @Override + public String getMultipartContentType(String paramOrFileName) { + try { + Part part = getPart(paramOrFileName); + return (part != null ? part.getContentType() : null); + } + catch (Throwable ex) { + throw new MultipartException("Could not access multipart servlet request", ex); + } + } + + @Override + public HttpHeaders getMultipartHeaders(String paramOrFileName) { + try { + Part part = getPart(paramOrFileName); + if (part != null) { + HttpHeaders headers = new HttpHeaders(); + for (String headerName : part.getHeaderNames()) { + headers.put(headerName, new ArrayList<>(part.getHeaders(headerName))); + } + return headers; + } + else { + return null; + } + } + catch (Throwable ex) { + throw new MultipartException("Could not access multipart servlet request", ex); + } + } + + + /** + * Spring MultipartFile adapter, wrapping a Servlet 3.0 Part object. + */ + @SuppressWarnings("serial") + private static class StandardMultipartFile implements MultipartFile, Serializable { + + private final Part part; + + private final String filename; + + public StandardMultipartFile(Part part, String filename) { + this.part = part; + this.filename = filename; + } + + @Override + public String getName() { + return this.part.getName(); + } + + @Override + public String getOriginalFilename() { + return this.filename; + } + + @Override + public String getContentType() { + return this.part.getContentType(); + } + + @Override + public boolean isEmpty() { + return (this.part.getSize() == 0); + } + + @Override + public long getSize() { + return this.part.getSize(); + } + + @Override + public byte[] getBytes() throws IOException { + return FileCopyUtils.copyToByteArray(this.part.getInputStream()); + } + + @Override + public InputStream getInputStream() throws IOException { + return this.part.getInputStream(); + } + + @Override + public void transferTo(File dest) throws IOException, IllegalStateException { + this.part.write(dest.getPath()); + if (dest.isAbsolute() && !dest.exists()) { + // Servlet 3.0 Part.write is not guaranteed to support absolute file paths: + // may translate the given path to a relative location within a temp dir + // (e.g. on Jetty whereas Tomcat and Undertow detect absolute paths). + // At least we offloaded the file from memory storage; it'll get deleted + // from the temp dir eventually in any case. And for our user's purposes, + // we can manually copy it to the requested location as a fallback. + FileCopyUtils.copy(this.part.getInputStream(), Files.newOutputStream(dest.toPath())); + } + } + + @Override + public void transferTo(Path dest) throws IOException, IllegalStateException { + FileCopyUtils.copy(this.part.getInputStream(), Files.newOutputStream(dest)); + } + } + + + /** + * Inner class to avoid a hard dependency on the JavaMail API. + */ + private static class MimeDelegate { + + public static String decode(String value) { + try { + return MimeUtility.decodeText(value); + } + catch (UnsupportedEncodingException ex) { + throw new IllegalStateException(ex); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/StandardServletMultipartResolver.java b/spring-web/src/main/java/org/springframework/web/multipart/support/StandardServletMultipartResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..e2946ab9d53d91930a57f219544236c9d292b5db --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/StandardServletMultipartResolver.java @@ -0,0 +1,109 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; + +import org.apache.commons.logging.LogFactory; + +import org.springframework.util.StringUtils; +import org.springframework.web.multipart.MultipartException; +import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.MultipartResolver; + +/** + * Standard implementation of the {@link MultipartResolver} interface, + * based on the Servlet 3.0 {@link javax.servlet.http.Part} API. + * To be added as "multipartResolver" bean to a Spring DispatcherServlet context, + * without any extra configuration at the bean level (see below). + * + *

Note: In order to use Servlet 3.0 based multipart parsing, + * you need to mark the affected servlet with a "multipart-config" section in + * {@code web.xml}, or with a {@link javax.servlet.MultipartConfigElement} + * in programmatic servlet registration, or (in case of a custom servlet class) + * possibly with a {@link javax.servlet.annotation.MultipartConfig} annotation + * on your servlet class. Configuration settings such as maximum sizes or + * storage locations need to be applied at that servlet registration level; + * Servlet 3.0 does not allow for them to be set at the MultipartResolver level. + * + *

+ * public class AppInitializer extends AbstractAnnotationConfigDispatcherServletInitializer {
+ *	 // ...
+ *	 @Override
+ *	 protected void customizeRegistration(ServletRegistration.Dynamic registration) {
+ *     // Optionally also set maxFileSize, maxRequestSize, fileSizeThreshold
+ *     registration.setMultipartConfig(new MultipartConfigElement("/tmp"));
+ *   }
+ * }
+ * 
+ * + * @author Juergen Hoeller + * @since 3.1 + * @see #setResolveLazily + * @see HttpServletRequest#getParts() + * @see org.springframework.web.multipart.commons.CommonsMultipartResolver + */ +public class StandardServletMultipartResolver implements MultipartResolver { + + private boolean resolveLazily = false; + + + /** + * Set whether to resolve the multipart request lazily at the time of + * file or parameter access. + *

Default is "false", resolving the multipart elements immediately, throwing + * corresponding exceptions at the time of the {@link #resolveMultipart} call. + * Switch this to "true" for lazy multipart parsing, throwing parse exceptions + * once the application attempts to obtain multipart files or parameters. + * @since 3.2.9 + */ + public void setResolveLazily(boolean resolveLazily) { + this.resolveLazily = resolveLazily; + } + + + @Override + public boolean isMultipart(HttpServletRequest request) { + return StringUtils.startsWithIgnoreCase(request.getContentType(), "multipart/"); + } + + @Override + public MultipartHttpServletRequest resolveMultipart(HttpServletRequest request) throws MultipartException { + return new StandardMultipartHttpServletRequest(request, this.resolveLazily); + } + + @Override + public void cleanupMultipart(MultipartHttpServletRequest request) { + if (!(request instanceof AbstractMultipartHttpServletRequest) || + ((AbstractMultipartHttpServletRequest) request).isResolved()) { + // To be on the safe side: explicitly delete the parts, + // but only actual file parts (for Resin compatibility) + try { + for (Part part : request.getParts()) { + if (request.getFile(part.getName()) != null) { + part.delete(); + } + } + } + catch (Throwable ex) { + LogFactory.getLog(getClass()).warn("Failed to perform cleanup of multipart items", ex); + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/StringMultipartFileEditor.java b/spring-web/src/main/java/org/springframework/web/multipart/support/StringMultipartFileEditor.java new file mode 100644 index 0000000000000000000000000000000000000000..015663c274df83ae83291ff749a0cb47d8cce58c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/StringMultipartFileEditor.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.beans.PropertyEditorSupport; +import java.io.IOException; + +import org.springframework.lang.Nullable; +import org.springframework.web.multipart.MultipartFile; + +/** + * Custom {@link java.beans.PropertyEditor} for converting + * {@link MultipartFile MultipartFiles} to Strings. + * + *

Allows one to specify the charset to use. + * + * @author Juergen Hoeller + * @since 13.10.2003 + */ +public class StringMultipartFileEditor extends PropertyEditorSupport { + + @Nullable + private final String charsetName; + + + /** + * Create a new {@link StringMultipartFileEditor}, using the default charset. + */ + public StringMultipartFileEditor() { + this.charsetName = null; + } + + /** + * Create a new {@link StringMultipartFileEditor}, using the given charset. + * @param charsetName valid charset name + * @see java.lang.String#String(byte[],String) + */ + public StringMultipartFileEditor(String charsetName) { + this.charsetName = charsetName; + } + + + @Override + public void setAsText(String text) { + setValue(text); + } + + @Override + public void setValue(Object value) { + if (value instanceof MultipartFile) { + MultipartFile multipartFile = (MultipartFile) value; + try { + super.setValue(this.charsetName != null ? + new String(multipartFile.getBytes(), this.charsetName) : + new String(multipartFile.getBytes())); + } + catch (IOException ex) { + throw new IllegalArgumentException("Cannot read contents of multipart file", ex); + } + } + else { + super.setValue(value); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/package-info.java b/spring-web/src/main/java/org/springframework/web/multipart/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..776fdfcf0b0e1eb1c64a056ead3a4aaa645dab11 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/package-info.java @@ -0,0 +1,11 @@ +/** + * Support classes for the multipart resolution framework. + * Contains property editors for multipart files, and a Servlet filter + * for multipart handling without Spring's Web MVC. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.multipart.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/package-info.java b/spring-web/src/main/java/org/springframework/web/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..dd944f1f0a72c9f059ae8d958c97df20ad4046dc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/package-info.java @@ -0,0 +1,10 @@ +/** + * Common, generic interfaces that define minimal boundary points + * between Spring's web infrastructure and other framework modules. + */ +@NonNullApi +@NonNullFields +package org.springframework.web; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/server/DefaultServerWebExchangeBuilder.java b/spring-web/src/main/java/org/springframework/web/server/DefaultServerWebExchangeBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..69fd787779541434d81851292c863be3b5158fbf --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/DefaultServerWebExchangeBuilder.java @@ -0,0 +1,128 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.security.Principal; +import java.util.function.Consumer; + +import reactor.core.publisher.Mono; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Package-private implementation of {@link ServerWebExchange.Builder}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +class DefaultServerWebExchangeBuilder implements ServerWebExchange.Builder { + + private final ServerWebExchange delegate; + + @Nullable + private ServerHttpRequest request; + + @Nullable + private ServerHttpResponse response; + + @Nullable + private Mono principalMono; + + + DefaultServerWebExchangeBuilder(ServerWebExchange delegate) { + Assert.notNull(delegate, "Delegate is required"); + this.delegate = delegate; + } + + + @Override + public ServerWebExchange.Builder request(Consumer consumer) { + ServerHttpRequest.Builder builder = this.delegate.getRequest().mutate(); + consumer.accept(builder); + return request(builder.build()); + } + + @Override + public ServerWebExchange.Builder request(ServerHttpRequest request) { + this.request = request; + return this; + } + + @Override + public ServerWebExchange.Builder response(ServerHttpResponse response) { + this.response = response; + return this; + } + + @Override + public ServerWebExchange.Builder principal(Mono principalMono) { + this.principalMono = principalMono; + return this; + } + + @Override + public ServerWebExchange build() { + return new MutativeDecorator(this.delegate, this.request, this.response, this.principalMono); + } + + + /** + * An immutable wrapper of an exchange returning property overrides -- given + * to the constructor -- or original values otherwise. + */ + private static class MutativeDecorator extends ServerWebExchangeDecorator { + + @Nullable + private final ServerHttpRequest request; + + @Nullable + private final ServerHttpResponse response; + + @Nullable + private final Mono principalMono; + + public MutativeDecorator(ServerWebExchange delegate, @Nullable ServerHttpRequest request, + @Nullable ServerHttpResponse response, @Nullable Mono principalMono) { + + super(delegate); + this.request = request; + this.response = response; + this.principalMono = principalMono; + } + + @Override + public ServerHttpRequest getRequest() { + return (this.request != null ? this.request : getDelegate().getRequest()); + } + + @Override + public ServerHttpResponse getResponse() { + return (this.response != null ? this.response : getDelegate().getResponse()); + } + + @SuppressWarnings("unchecked") + @Override + public Mono getPrincipal() { + return (this.principalMono != null ? (Mono) this.principalMono : getDelegate().getPrincipal()); + } + } + +} + diff --git a/spring-web/src/main/java/org/springframework/web/server/MediaTypeNotSupportedStatusException.java b/spring-web/src/main/java/org/springframework/web/server/MediaTypeNotSupportedStatusException.java new file mode 100644 index 0000000000000000000000000000000000000000..690e7a136b0f3d363d7a415b215bc052f03e5be6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/MediaTypeNotSupportedStatusException.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.util.Collections; +import java.util.List; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; + +/** + * Exception for errors that fit response status 415 (unsupported media type). + * + * @author Rossen Stoyanchev + * @since 5.0 + * @deprecated in favor of {@link UnsupportedMediaTypeStatusException}, + * with this class never thrown by Spring code and to be removed in 5.3 + */ +@Deprecated +@SuppressWarnings("serial") +public class MediaTypeNotSupportedStatusException extends ResponseStatusException { + + private final List supportedMediaTypes; + + + /** + * Constructor for when the Content-Type is invalid. + */ + public MediaTypeNotSupportedStatusException(String reason) { + super(HttpStatus.UNSUPPORTED_MEDIA_TYPE, reason); + this.supportedMediaTypes = Collections.emptyList(); + } + + /** + * Constructor for when the Content-Type is not supported. + */ + public MediaTypeNotSupportedStatusException(List supportedMediaTypes) { + super(HttpStatus.UNSUPPORTED_MEDIA_TYPE, "Unsupported media type", null); + this.supportedMediaTypes = Collections.unmodifiableList(supportedMediaTypes); + } + + + /** + * Return the list of supported content types in cases when the Accept + * header is parsed but not supported, or an empty list otherwise. + */ + public List getSupportedMediaTypes() { + return this.supportedMediaTypes; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/MethodNotAllowedException.java b/spring-web/src/main/java/org/springframework/web/server/MethodNotAllowedException.java new file mode 100644 index 0000000000000000000000000000000000000000..40cf7d46607a53806f6867306fb6b0481d8a6ffa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/MethodNotAllowedException.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * Exception for errors that fit response status 405 (method not allowed). + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +@SuppressWarnings("serial") +public class MethodNotAllowedException extends ResponseStatusException { + + private final String method; + + private final Set httpMethods; + + + public MethodNotAllowedException(HttpMethod method, Collection supportedMethods) { + this(method.name(), supportedMethods); + } + + public MethodNotAllowedException(String method, @Nullable Collection supportedMethods) { + super(HttpStatus.METHOD_NOT_ALLOWED, "Request method '" + method + "' not supported"); + Assert.notNull(method, "'method' is required"); + if (supportedMethods == null) { + supportedMethods = Collections.emptySet(); + } + this.method = method; + this.httpMethods = Collections.unmodifiableSet(new HashSet<>(supportedMethods)); + } + + + /** + * Return a Map with an "Allow" header. + * @since 5.1.11 + */ + @SuppressWarnings("deprecation") + @Override + public Map getHeaders() { + return getResponseHeaders().toSingleValueMap(); + } + + /** + * Return HttpHeaders with an "Allow" header. + * @since 5.1.13 + */ + @Override + public HttpHeaders getResponseHeaders() { + if (CollectionUtils.isEmpty(this.httpMethods)) { + return HttpHeaders.EMPTY; + } + HttpHeaders headers = new HttpHeaders(); + headers.setAllow(this.httpMethods); + return headers; + } + + /** + * Return the HTTP method for the failed request. + */ + public String getHttpMethod() { + return this.method; + } + + /** + * Return the list of supported HTTP methods. + */ + public Set getSupportedMethods() { + return this.httpMethods; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/NotAcceptableStatusException.java b/spring-web/src/main/java/org/springframework/web/server/NotAcceptableStatusException.java new file mode 100644 index 0000000000000000000000000000000000000000..3d37d91fa5cac850cc127aea3bec14c0bc9009b2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/NotAcceptableStatusException.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.util.CollectionUtils; + +/** + * Exception for errors that fit response status 406 (not acceptable). + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +@SuppressWarnings("serial") +public class NotAcceptableStatusException extends ResponseStatusException { + + private final List supportedMediaTypes; + + + /** + * Constructor for when the requested Content-Type is invalid. + */ + public NotAcceptableStatusException(String reason) { + super(HttpStatus.NOT_ACCEPTABLE, reason); + this.supportedMediaTypes = Collections.emptyList(); + } + + /** + * Constructor for when the requested Content-Type is not supported. + */ + public NotAcceptableStatusException(List supportedMediaTypes) { + super(HttpStatus.NOT_ACCEPTABLE, "Could not find acceptable representation"); + this.supportedMediaTypes = Collections.unmodifiableList(supportedMediaTypes); + } + + + /** + * Return a Map with an "Accept" header, or an empty map. + * @since 5.1.11 + */ + @SuppressWarnings("deprecation") + @Override + public Map getHeaders() { + return getResponseHeaders().toSingleValueMap(); + } + + /** + * Return HttpHeaders with an "Accept" header, or an empty instance. + * @since 5.1.13 + */ + @Override + public HttpHeaders getResponseHeaders() { + if (CollectionUtils.isEmpty(this.supportedMediaTypes)) { + return HttpHeaders.EMPTY; + } + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(this.supportedMediaTypes); + return headers; + } + + /** + * Return the list of supported content types in cases when the Accept + * header is parsed but not supported, or an empty list otherwise. + */ + public List getSupportedMediaTypes() { + return this.supportedMediaTypes; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/ResponseStatusException.java b/spring-web/src/main/java/org/springframework/web/server/ResponseStatusException.java new file mode 100644 index 0000000000000000000000000000000000000000..67c8d78391eadf3b8f1b2d9f71d7e403838fb6d5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/ResponseStatusException.java @@ -0,0 +1,128 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.util.Collections; +import java.util.Map; + +import org.springframework.core.NestedExceptionUtils; +import org.springframework.core.NestedRuntimeException; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Base class for exceptions associated with specific HTTP response status codes. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 5.0 + */ +@SuppressWarnings("serial") +public class ResponseStatusException extends NestedRuntimeException { + + private final HttpStatus status; + + @Nullable + private final String reason; + + + /** + * Constructor with a response status. + * @param status the HTTP status (required) + */ + public ResponseStatusException(HttpStatus status) { + this(status, null, null); + } + + /** + * Constructor with a response status and a reason to add to the exception + * message as explanation. + * @param status the HTTP status (required) + * @param reason the associated reason (optional) + */ + public ResponseStatusException(HttpStatus status, @Nullable String reason) { + this(status, reason, null); + } + + /** + * Constructor with a response status and a reason to add to the exception + * message as explanation, as well as a nested exception. + * @param status the HTTP status (required) + * @param reason the associated reason (optional) + * @param cause a nested exception (optional) + */ + public ResponseStatusException(HttpStatus status, @Nullable String reason, @Nullable Throwable cause) { + super(null, cause); + Assert.notNull(status, "HttpStatus is required"); + this.status = status; + this.reason = reason; + } + + + /** + * Return the HTTP status associated with this exception. + */ + public HttpStatus getStatus() { + return this.status; + } + + /** + * Return headers associated with the exception that should be added to the + * error response, e.g. "Allow", "Accept", etc. + *

The default implementation in this class returns an empty map. + * @since 5.1.11 + * @deprecated as of 5.1.13 in favor of {@link #getResponseHeaders()} + */ + @Deprecated + public Map getHeaders() { + return Collections.emptyMap(); + } + + /** + * Return headers associated with the exception that should be added to the + * error response, e.g. "Allow", "Accept", etc. + *

The default implementation in this class returns empty headers. + * @since 5.1.13 + */ + public HttpHeaders getResponseHeaders() { + Map headers = getHeaders(); + if (headers.isEmpty()) { + return HttpHeaders.EMPTY; + } + HttpHeaders result = new HttpHeaders(); + getHeaders().forEach(result::add); + return result; + } + + /** + * The reason explaining the exception (potentially {@code null} or empty). + */ + @Nullable + public String getReason() { + return this.reason; + } + + + @Override + public String getMessage() { + String msg = this.status + (this.reason != null ? " \"" + this.reason + "\"" : ""); + return NestedExceptionUtils.buildMessage(msg, getCause()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/ServerErrorException.java b/spring-web/src/main/java/org/springframework/web/server/ServerErrorException.java new file mode 100644 index 0000000000000000000000000000000000000000..8f994f81db8784fdceb3927956d1056a5e1bff95 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/ServerErrorException.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.lang.reflect.Method; + +import org.springframework.core.MethodParameter; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; + +/** + * Exception for an {@link HttpStatus#INTERNAL_SERVER_ERROR} that exposes extra + * information about a controller method that failed, or a controller method + * argument that could not be resolved. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +@SuppressWarnings("serial") +public class ServerErrorException extends ResponseStatusException { + + @Nullable + private final Method handlerMethod; + + @Nullable + private final MethodParameter parameter; + + + /** + * Constructor for a 500 error with a reason and an optional cause. + * @since 5.0.5 + */ + public ServerErrorException(String reason, @Nullable Throwable cause) { + super(HttpStatus.INTERNAL_SERVER_ERROR, reason, cause); + this.handlerMethod = null; + this.parameter = null; + } + + /** + * Constructor for a 500 error with a handler {@link Method} and an optional cause. + * @since 5.0.5 + */ + public ServerErrorException(String reason, Method handlerMethod, @Nullable Throwable cause) { + super(HttpStatus.INTERNAL_SERVER_ERROR, reason, cause); + this.handlerMethod = handlerMethod; + this.parameter = null; + } + + /** + * Constructor for a 500 error with a {@link MethodParameter} and an optional cause. + */ + public ServerErrorException(String reason, MethodParameter parameter, @Nullable Throwable cause) { + super(HttpStatus.INTERNAL_SERVER_ERROR, reason, cause); + this.handlerMethod = parameter.getMethod(); + this.parameter = parameter; + } + + /** + * Constructor for a 500 error linked to a specific {@code MethodParameter}. + * @deprecated in favor of {@link #ServerErrorException(String, MethodParameter, Throwable)} + */ + @Deprecated + public ServerErrorException(String reason, MethodParameter parameter) { + this(reason, parameter, null); + } + + /** + * Constructor for a 500 error with a reason only. + * @deprecated in favor of {@link #ServerErrorException(String, Throwable)} + */ + @Deprecated + public ServerErrorException(String reason) { + super(HttpStatus.INTERNAL_SERVER_ERROR, reason, null); + this.handlerMethod = null; + this.parameter = null; + } + + + /** + * Return the handler method associated with the error, if any. + * @since 5.0.5 + */ + @Nullable + public Method getHandlerMethod() { + return this.handlerMethod; + } + + /** + * Return the specific method parameter associated with the error, if any. + */ + @Nullable + public MethodParameter getMethodParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/ServerWebExchange.java b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchange.java new file mode 100644 index 0000000000000000000000000000000000000000..47a63d7fe5a5f7038b5860de0cfc1f839992a539 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchange.java @@ -0,0 +1,288 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.security.Principal; +import java.time.Instant; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.i18n.LocaleContext; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * Contract for an HTTP request-response interaction. Provides access to the HTTP + * request and response and also exposes additional server-side processing + * related properties and features such as request attributes. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface ServerWebExchange { + + /** + * Name of {@link #getAttributes() attribute} whose value can be used to + * correlate log messages for this exchange. Use {@link #getLogPrefix()} to + * obtain a consistently formatted prefix based on this attribute. + * @since 5.1 + * @see #getLogPrefix() + */ + String LOG_ID_ATTRIBUTE = ServerWebExchange.class.getName() + ".LOG_ID"; + + + /** + * Return the current HTTP request. + */ + ServerHttpRequest getRequest(); + + /** + * Return the current HTTP response. + */ + ServerHttpResponse getResponse(); + + /** + * Return a mutable map of request attributes for the current exchange. + */ + Map getAttributes(); + + /** + * Return the request attribute value if present. + * @param name the attribute name + * @param the attribute type + * @return the attribute value + */ + @SuppressWarnings("unchecked") + @Nullable + default T getAttribute(String name) { + return (T) getAttributes().get(name); + } + + /** + * Return the request attribute value or if not present raise an + * {@link IllegalArgumentException}. + * @param name the attribute name + * @param the attribute type + * @return the attribute value + */ + @SuppressWarnings("unchecked") + default T getRequiredAttribute(String name) { + T value = getAttribute(name); + Assert.notNull(value, () -> "Required attribute '" + name + "' is missing"); + return value; + } + + /** + * Return the request attribute value, or a default, fallback value. + * @param name the attribute name + * @param defaultValue a default value to return instead + * @param the attribute type + * @return the attribute value + */ + @SuppressWarnings("unchecked") + default T getAttributeOrDefault(String name, T defaultValue) { + return (T) getAttributes().getOrDefault(name, defaultValue); + } + + /** + * Return the web session for the current request. Always guaranteed to + * return an instance either matching to the session id requested by the + * client, or with a new session id either because the client did not + * specify one or because the underlying session had expired. Use of this + * method does not automatically create a session. See {@link WebSession} + * for more details. + */ + Mono getSession(); + + /** + * Return the authenticated user for the request, if any. + */ + Mono getPrincipal(); + + /** + * Return the form data from the body of the request if the Content-Type is + * {@code "application/x-www-form-urlencoded"} or an empty map otherwise. + *

Note: calling this method causes the request body to + * be read and parsed in full and the resulting {@code MultiValueMap} is + * cached so that this method is safe to call more than once. + */ + Mono> getFormData(); + + /** + * Return the parts of a multipart request if the Content-Type is + * {@code "multipart/form-data"} or an empty map otherwise. + *

Note: calling this method causes the request body to + * be read and parsed in full and the resulting {@code MultiValueMap} is + * cached so that this method is safe to call more than once. + */ + Mono> getMultipartData(); + + /** + * Return the {@link LocaleContext} using the configured + * {@link org.springframework.web.server.i18n.LocaleContextResolver}. + */ + LocaleContext getLocaleContext(); + + /** + * Return the {@link ApplicationContext} associated with the web application, + * if it was initialized with one via + * {@link org.springframework.web.server.adapter.WebHttpHandlerBuilder#applicationContext(ApplicationContext)}. + * @since 5.0.3 + * @see org.springframework.web.server.adapter.WebHttpHandlerBuilder#applicationContext(ApplicationContext) + */ + @Nullable + ApplicationContext getApplicationContext(); + + /** + * Returns {@code true} if the one of the {@code checkNotModified} methods + * in this contract were used and they returned true. + */ + boolean isNotModified(); + + /** + * An overloaded variant of {@link #checkNotModified(String, Instant)} with + * a last-modified timestamp only. + * @param lastModified the last-modified time + * @return whether the request qualifies as not modified + */ + boolean checkNotModified(Instant lastModified); + + /** + * An overloaded variant of {@link #checkNotModified(String, Instant)} with + * an {@code ETag} (entity tag) value only. + * @param etag the entity tag for the underlying resource. + * @return true if the request does not require further processing. + */ + boolean checkNotModified(String etag); + + /** + * Check whether the requested resource has been modified given the supplied + * {@code ETag} (entity tag) and last-modified timestamp as determined by + * the application. Also transparently prepares the response, setting HTTP + * status, and adding "ETag" and "Last-Modified" headers when applicable. + * This method works with conditional GET/HEAD requests as well as with + * conditional POST/PUT/DELETE requests. + *

Note: The HTTP specification recommends setting both + * ETag and Last-Modified values, but you can also use + * {@code #checkNotModified(String)} or + * {@link #checkNotModified(Instant)}. + * @param etag the entity tag that the application determined for the + * underlying resource. This parameter will be padded with quotes (") + * if necessary. + * @param lastModified the last-modified timestamp that the application + * determined for the underlying resource + * @return true if the request does not require further processing. + */ + boolean checkNotModified(@Nullable String etag, Instant lastModified); + + /** + * Transform the given url according to the registered transformation function(s). + * By default, this method returns the given {@code url}, though additional + * transformation functions can by registered with {@link #addUrlTransformer} + * @param url the URL to transform + * @return the transformed URL + */ + String transformUrl(String url); + + /** + * Register an additional URL transformation function for use with {@link #transformUrl}. + * The given function can be used to insert an id for authentication, a nonce for CSRF + * protection, etc. + *

Note that the given function is applied after any previously registered functions. + * @param transformer a URL transformation function to add + */ + void addUrlTransformer(Function transformer); + + /** + * Return a log message prefix to use to correlate messages for this exchange. + * The prefix is based on the value of the attribute {@link #LOG_ID_ATTRIBUTE} + * along with some extra formatting so that the prefix can be conveniently + * prepended with no further formatting no separators required. + * @return the log message prefix or an empty String if the + * {@link #LOG_ID_ATTRIBUTE} is not set. + * @since 5.1 + */ + String getLogPrefix(); + + /** + * Return a builder to mutate properties of this exchange by wrapping it + * with {@link ServerWebExchangeDecorator} and returning either mutated + * values or delegating back to this instance. + */ + default Builder mutate() { + return new DefaultServerWebExchangeBuilder(this); + } + + + /** + * Builder for mutating an existing {@link ServerWebExchange}. + * Removes the need + */ + interface Builder { + + /** + * Configure a consumer to modify the current request using a builder. + *

Effectively this: + *

+		 * exchange.mutate().request(builder-> builder.method(HttpMethod.PUT));
+		 *
+		 * // vs...
+		 *
+		 * ServerHttpRequest request = exchange.getRequest().mutate()
+		 *     .method(HttpMethod.PUT)
+		 *     .build();
+		 *
+		 * exchange.mutate().request(request);
+		 * 
+ * @see ServerHttpRequest#mutate() + */ + Builder request(Consumer requestBuilderConsumer); + + /** + * Set the request to use especially when there is a need to override + * {@link ServerHttpRequest} methods. To simply mutate request properties + * see {@link #request(Consumer)} instead. + * @see org.springframework.http.server.reactive.ServerHttpRequestDecorator + */ + Builder request(ServerHttpRequest request); + + /** + * Set the response to use. + * @see org.springframework.http.server.reactive.ServerHttpResponseDecorator + */ + Builder response(ServerHttpResponse response); + + /** + * Set the {@code Mono} to return for this exchange. + */ + Builder principal(Mono principalMono); + + /** + * Build a {@link ServerWebExchange} decorator with the mutated properties. + */ + ServerWebExchange build(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/ServerWebExchangeDecorator.java b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchangeDecorator.java new file mode 100644 index 0000000000000000000000000000000000000000..2b753f466bec3d842a2184c231619a8caba55414 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchangeDecorator.java @@ -0,0 +1,150 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server; + +import java.security.Principal; +import java.time.Instant; +import java.util.Map; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.i18n.LocaleContext; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * A convenient base class for classes that need to wrap another + * {@link ServerWebExchange}. Pre-implements all methods by delegating to the + * wrapped instance. + * + *

Note: if the purpose for using a decorator is to override + * properties like {@link #getPrincipal()}, consider using + * {@link ServerWebExchange#mutate()} instead. + * + * @author Rossen Stoyanchev + * @since 5.0 + * + * @see ServerWebExchange#mutate() + */ +public class ServerWebExchangeDecorator implements ServerWebExchange { + + private final ServerWebExchange delegate; + + + protected ServerWebExchangeDecorator(ServerWebExchange delegate) { + Assert.notNull(delegate, "ServerWebExchange 'delegate' is required."); + this.delegate = delegate; + } + + + public ServerWebExchange getDelegate() { + return this.delegate; + } + + // ServerWebExchange delegation methods... + + @Override + public ServerHttpRequest getRequest() { + return getDelegate().getRequest(); + } + + @Override + public ServerHttpResponse getResponse() { + return getDelegate().getResponse(); + } + + @Override + public Map getAttributes() { + return getDelegate().getAttributes(); + } + + @Override + public Mono getSession() { + return getDelegate().getSession(); + } + + @Override + public Mono getPrincipal() { + return getDelegate().getPrincipal(); + } + + @Override + public LocaleContext getLocaleContext() { + return getDelegate().getLocaleContext(); + } + + @Override + public ApplicationContext getApplicationContext() { + return getDelegate().getApplicationContext(); + } + + @Override + public Mono> getFormData() { + return getDelegate().getFormData(); + } + + @Override + public Mono> getMultipartData() { + return getDelegate().getMultipartData(); + } + + @Override + public boolean isNotModified() { + return getDelegate().isNotModified(); + } + + @Override + public boolean checkNotModified(Instant lastModified) { + return getDelegate().checkNotModified(lastModified); + } + + @Override + public boolean checkNotModified(String etag) { + return getDelegate().checkNotModified(etag); + } + + @Override + public boolean checkNotModified(@Nullable String etag, Instant lastModified) { + return getDelegate().checkNotModified(etag, lastModified); + } + + @Override + public String transformUrl(String url) { + return getDelegate().transformUrl(url); + } + + @Override + public void addUrlTransformer(Function transformer) { + getDelegate().addUrlTransformer(transformer); + } + + @Override + public String getLogPrefix() { + return getDelegate().getLogPrefix(); + } + + @Override + public String toString() { + return getClass().getSimpleName() + " [delegate=" + getDelegate() + "]"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/ServerWebInputException.java b/spring-web/src/main/java/org/springframework/web/server/ServerWebInputException.java new file mode 100644 index 0000000000000000000000000000000000000000..fbb2cd9ef8bfb4f6d3af44cc14722d24913a6394 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/ServerWebInputException.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import org.springframework.core.MethodParameter; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; + +/** + * Exception for errors that fit response status 400 (bad request) for use in + * Spring Web applications. The exception provides additional fields (e.g. + * an optional {@link MethodParameter} if related to the error). + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +@SuppressWarnings("serial") +public class ServerWebInputException extends ResponseStatusException { + + @Nullable + private final MethodParameter parameter; + + + /** + * Constructor with an explanation only. + */ + public ServerWebInputException(String reason) { + this(reason, null, null); + } + + /** + * Constructor for a 400 error linked to a specific {@code MethodParameter}. + */ + public ServerWebInputException(String reason, @Nullable MethodParameter parameter) { + this(reason, parameter, null); + } + + /** + * Constructor for a 400 error with a root cause. + */ + public ServerWebInputException(String reason, @Nullable MethodParameter parameter, @Nullable Throwable cause) { + super(HttpStatus.BAD_REQUEST, reason, cause); + this.parameter = parameter; + } + + + /** + * Return the {@code MethodParameter} associated with this error, if any. + */ + @Nullable + public MethodParameter getMethodParameter() { + return this.parameter; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/UnsupportedMediaTypeStatusException.java b/spring-web/src/main/java/org/springframework/web/server/UnsupportedMediaTypeStatusException.java new file mode 100644 index 0000000000000000000000000000000000000000..1e1f7d9d05d94f3e608efee50b39ab65e80c7621 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/UnsupportedMediaTypeStatusException.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.util.Collections; +import java.util.List; + +import org.springframework.core.ResolvableType; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * Exception for errors that fit response status 415 (unsupported media type). + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +@SuppressWarnings("serial") +public class UnsupportedMediaTypeStatusException extends ResponseStatusException { + + @Nullable + private final MediaType contentType; + + private final List supportedMediaTypes; + + @Nullable + private final ResolvableType bodyType; + + + /** + * Constructor for when the specified Content-Type is invalid. + */ + public UnsupportedMediaTypeStatusException(@Nullable String reason) { + super(HttpStatus.UNSUPPORTED_MEDIA_TYPE, reason); + this.contentType = null; + this.supportedMediaTypes = Collections.emptyList(); + this.bodyType = null; + } + + /** + * Constructor for when the Content-Type can be parsed but is not supported. + */ + public UnsupportedMediaTypeStatusException(@Nullable MediaType contentType, List supportedTypes) { + this(contentType, supportedTypes, null); + } + + /** + * Constructor for when trying to encode from or decode to a specific Java type. + * @since 5.1 + */ + public UnsupportedMediaTypeStatusException(@Nullable MediaType contentType, List supportedTypes, + @Nullable ResolvableType bodyType) { + + super(HttpStatus.UNSUPPORTED_MEDIA_TYPE, initReason(contentType, bodyType)); + this.contentType = contentType; + this.supportedMediaTypes = Collections.unmodifiableList(supportedTypes); + this.bodyType = bodyType; + } + + private static String initReason(@Nullable MediaType contentType, @Nullable ResolvableType bodyType) { + return "Content type '" + (contentType != null ? contentType : "") + "' not supported" + + (bodyType != null ? " for bodyType=" + bodyType.toString() : ""); + } + + + /** + * Return the request Content-Type header if it was parsed successfully, + * or {@code null} otherwise. + */ + @Nullable + public MediaType getContentType() { + return this.contentType; + } + + /** + * Return the list of supported content types in cases when the Content-Type + * header is parsed but not supported, or an empty list otherwise. + */ + public List getSupportedMediaTypes() { + return this.supportedMediaTypes; + } + + /** + * Return the body type in the context of which this exception was generated. + *

This is applicable when the exception was raised as a result trying to + * encode from or decode to a specific Java type. + * @return the body type, or {@code null} if not available + * @since 5.1 + */ + @Nullable + public ResolvableType getBodyType() { + return this.bodyType; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/WebExceptionHandler.java b/spring-web/src/main/java/org/springframework/web/server/WebExceptionHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..3073607cb792faa55ce47a7f8d0eb16f7476bfdd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/WebExceptionHandler.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import reactor.core.publisher.Mono; + +/** + * Contract for handling exceptions during web server exchange processing. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface WebExceptionHandler { + + /** + * Handle the given exception. A completion signal through the return value + * indicates error handling is complete while an error signal indicates the + * exception is still not handled. + * @param exchange the current exchange + * @param ex the exception to handle + * @return {@code Mono} to indicate when exception handling is complete + */ + Mono handle(ServerWebExchange exchange, Throwable ex); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/WebFilter.java b/spring-web/src/main/java/org/springframework/web/server/WebFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..9a7ddeaf19b054d0d372b36d932c0651b703f905 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/WebFilter.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import reactor.core.publisher.Mono; + +/** + * Contract for interception-style, chained processing of Web requests that may + * be used to implement cross-cutting, application-agnostic requirements such + * as security, timeouts, and others. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface WebFilter { + + /** + * Process the Web request and (optionally) delegate to the next + * {@code WebFilter} through the given {@link WebFilterChain}. + * @param exchange the current server exchange + * @param chain provides a way to delegate to the next filter + * @return {@code Mono} to indicate when request processing is complete + */ + Mono filter(ServerWebExchange exchange, WebFilterChain chain); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/WebFilterChain.java b/spring-web/src/main/java/org/springframework/web/server/WebFilterChain.java new file mode 100644 index 0000000000000000000000000000000000000000..dcaf8316cf59242702557452f2a824f3dd23e3cf --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/WebFilterChain.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import reactor.core.publisher.Mono; + +/** + * Contract to allow a {@link WebFilter} to delegate to the next in the chain. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface WebFilterChain { + + /** + * Delegate to the next {@code WebFilter} in the chain. + * @param exchange the current server exchange + * @return {@code Mono} to indicate when request handling is complete + */ + Mono filter(ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/WebHandler.java b/spring-web/src/main/java/org/springframework/web/server/WebHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..3d129d47705e5be93ea9eb212cd157c3c80f8154 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/WebHandler.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.adapter.HttpWebHandlerAdapter; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +/** + * Contract to handle a web request. + * + *

Use {@link HttpWebHandlerAdapter} to adapt a {@code WebHandler} to an + * {@link org.springframework.http.server.reactive.HttpHandler HttpHandler}. + * The {@link WebHttpHandlerBuilder} provides a convenient way to do that while + * also optionally configuring one or more filters and/or exception handlers. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface WebHandler { + + /** + * Handle the web server exchange. + * @param exchange the current server exchange + * @return {@code Mono} to indicate when request handling is complete + */ + Mono handle(ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/WebSession.java b/spring-web/src/main/java/org/springframework/web/server/WebSession.java new file mode 100644 index 0000000000000000000000000000000000000000..cfa35d9667c41f162171016ad66c0b4e31c36572 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/WebSession.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server; + +import java.time.Duration; +import java.time.Instant; +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Main contract for using a server-side session that provides access to session + * attributes across HTTP requests. + * + *

The creation of a {@code WebSession} instance does not automatically start + * a session thus causing the session id to be sent to the client (typically via + * a cookie). A session starts implicitly when session attributes are added. + * A session may also be created explicitly via {@link #start()}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface WebSession { + + /** + * Return a unique session identifier. + */ + String getId(); + + /** + * Return a map that holds session attributes. + */ + Map getAttributes(); + + /** + * Return the session attribute value if present. + * @param name the attribute name + * @param the attribute type + * @return the attribute value + */ + @SuppressWarnings("unchecked") + @Nullable + default T getAttribute(String name) { + return (T) getAttributes().get(name); + } + + /** + * Return the session attribute value or if not present raise an + * {@link IllegalArgumentException}. + * @param name the attribute name + * @param the attribute type + * @return the attribute value + */ + @SuppressWarnings("unchecked") + default T getRequiredAttribute(String name) { + T value = getAttribute(name); + Assert.notNull(value, () -> "Required attribute '" + name + "' is missing."); + return value; + } + + /** + * Return the session attribute value, or a default, fallback value. + * @param name the attribute name + * @param defaultValue a default value to return instead + * @param the attribute type + * @return the attribute value + */ + @SuppressWarnings("unchecked") + default T getAttributeOrDefault(String name, T defaultValue) { + return (T) getAttributes().getOrDefault(name, defaultValue); + } + + /** + * Force the creation of a session causing the session id to be sent when + * {@link #save()} is called. + */ + void start(); + + /** + * Whether a session with the client has been started explicitly via + * {@link #start()} or implicitly by adding session attributes. + * If "false" then the session id is not sent to the client and the + * {@link #save()} method is essentially a no-op. + */ + boolean isStarted(); + + /** + * Generate a new id for the session and update the underlying session + * storage to reflect the new id. After a successful call {@link #getId()} + * reflects the new session id. + * @return completion notification (success or error) + */ + Mono changeSessionId(); + + /** + * Invalidate the current session and clear session storage. + * @return completion notification (success or error) + */ + Mono invalidate(); + + /** + * Save the session through the {@code WebSessionStore} as follows: + *

    + *
  • If the session is new (i.e. created but never persisted), it must have + * been started explicitly via {@link #start()} or implicitly by adding + * attributes, or otherwise this method should have no effect. + *
  • If the session was retrieved through the {@code WebSessionStore}, + * the implementation for this method must check whether the session was + * {@link #invalidate() invalidated} and if so return an error. + *
+ *

Note that this method is not intended for direct use by applications. + * Instead it is automatically invoked just before the response is + * committed. + * @return {@code Mono} to indicate completion with success or error + */ + Mono save(); + + /** + * Return {@code true} if the session expired after {@link #getMaxIdleTime() + * maxIdleTime} elapsed. + *

Typically expiration checks should be automatically made when a session + * is accessed, a new {@code WebSession} instance created if necessary, at + * the start of request processing so that applications don't have to worry + * about expired session by default. + */ + boolean isExpired(); + + /** + * Return the time when the session was created. + */ + Instant getCreationTime(); + + /** + * Return the last time of session access as a result of user activity such + * as an HTTP request. Together with {@link #getMaxIdleTime() + * maxIdleTimeInSeconds} this helps to determine when a session is + * {@link #isExpired() expired}. + */ + Instant getLastAccessTime(); + + /** + * Configure the max amount of time that may elapse after the + * {@link #getLastAccessTime() lastAccessTime} before a session is considered + * expired. A negative value indicates the session should not expire. + */ + void setMaxIdleTime(Duration maxIdleTime); + + /** + * Return the maximum time after the {@link #getLastAccessTime() + * lastAccessTime} before a session expires. A negative time indicates the + * session doesn't expire. + */ + Duration getMaxIdleTime(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/AbstractReactiveWebInitializer.java b/spring-web/src/main/java/org/springframework/web/server/adapter/AbstractReactiveWebInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..afafa743394f9fc4468137ca6921bd7464d15697 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/AbstractReactiveWebInitializer.java @@ -0,0 +1,162 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import javax.servlet.ServletContext; +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; +import javax.servlet.ServletException; +import javax.servlet.ServletRegistration; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ServletHttpHandlerAdapter; +import org.springframework.util.Assert; +import org.springframework.web.WebApplicationInitializer; + +/** + * Base class for a {@link org.springframework.web.WebApplicationInitializer} + * that installs a Spring Reactive Web Application on a Servlet container. + * + *

Spring configuration is loaded and given to + * {@link WebHttpHandlerBuilder#applicationContext WebHttpHandlerBuilder} + * which scans the context looking for specific beans and creates a reactive + * {@link HttpHandler}. The resulting handler is installed as a Servlet through + * the {@link ServletHttpHandlerAdapter}. + * + * @author Rossen Stoyanchev + * @since 5.0.2 + */ +public abstract class AbstractReactiveWebInitializer implements WebApplicationInitializer { + + /** + * The default servlet name to use. See {@link #getServletName}. + */ + public static final String DEFAULT_SERVLET_NAME = "http-handler-adapter"; + + + @Override + public void onStartup(ServletContext servletContext) throws ServletException { + String servletName = getServletName(); + Assert.hasLength(servletName, "getServletName() must not return null or empty"); + + ApplicationContext applicationContext = createApplicationContext(); + Assert.notNull(applicationContext, "createApplicationContext() must not return null"); + + refreshApplicationContext(applicationContext); + registerCloseListener(servletContext, applicationContext); + + HttpHandler httpHandler = WebHttpHandlerBuilder.applicationContext(applicationContext).build(); + ServletHttpHandlerAdapter servlet = new ServletHttpHandlerAdapter(httpHandler); + + ServletRegistration.Dynamic registration = servletContext.addServlet(servletName, servlet); + if (registration == null) { + throw new IllegalStateException("Failed to register servlet with name '" + servletName + "'. " + + "Check if there is another servlet registered under the same name."); + } + + registration.setLoadOnStartup(1); + registration.addMapping(getServletMapping()); + registration.setAsyncSupported(true); + } + + /** + * Return the name to use to register the {@link ServletHttpHandlerAdapter}. + *

By default this is {@link #DEFAULT_SERVLET_NAME}. + */ + protected String getServletName() { + return DEFAULT_SERVLET_NAME; + } + + /** + * Return the Spring configuration that contains application beans including + * the ones detected by {@link WebHttpHandlerBuilder#applicationContext}. + */ + protected ApplicationContext createApplicationContext() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + Class[] configClasses = getConfigClasses(); + Assert.notEmpty(configClasses, "No Spring configuration provided through getConfigClasses()"); + context.register(configClasses); + return context; + } + + /** + * Specify {@link org.springframework.context.annotation.Configuration @Configuration} + * and/or {@link org.springframework.stereotype.Component @Component} + * classes that make up the application configuration. The config classes + * are given to {@linkplain #createApplicationContext()}. + */ + protected abstract Class[] getConfigClasses(); + + /** + * Refresh the given application context, if necessary. + */ + protected void refreshApplicationContext(ApplicationContext context) { + if (context instanceof ConfigurableApplicationContext) { + ConfigurableApplicationContext cac = (ConfigurableApplicationContext) context; + if (!cac.isActive()) { + cac.refresh(); + } + } + } + + /** + * Register a {@link ServletContextListener} that closes the given + * application context when the servlet context is destroyed. + * @param servletContext the servlet context to listen to + * @param applicationContext the application context that is to be + * closed when {@code servletContext} is destroyed + */ + protected void registerCloseListener(ServletContext servletContext, ApplicationContext applicationContext) { + if (applicationContext instanceof ConfigurableApplicationContext) { + ConfigurableApplicationContext cac = (ConfigurableApplicationContext) applicationContext; + ServletContextDestroyedListener listener = new ServletContextDestroyedListener(cac); + servletContext.addListener(listener); + } + } + + /** + * Return the Servlet mapping to use. Only the default Servlet mapping '/' + * and path-based Servlet mappings such as '/api/*' are supported. + *

By default this is set to '/'. + */ + protected String getServletMapping() { + return "/"; + } + + + private static class ServletContextDestroyedListener implements ServletContextListener { + + private final ConfigurableApplicationContext applicationContext; + + public ServletContextDestroyedListener(ConfigurableApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + @Override + public void contextInitialized(ServletContextEvent sce) { + } + + @Override + public void contextDestroyed(ServletContextEvent sce) { + this.applicationContext.close(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/DefaultServerWebExchange.java b/spring-web/src/main/java/org/springframework/web/server/adapter/DefaultServerWebExchange.java new file mode 100644 index 0000000000000000000000000000000000000000..058206b5b5006f8e5c3627a263fb730827a534a7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/DefaultServerWebExchange.java @@ -0,0 +1,383 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import java.security.Principal; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.i18n.LocaleContext; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.InvalidMediaTypeException; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.i18n.LocaleContextResolver; +import org.springframework.web.server.session.WebSessionManager; + +/** + * Default implementation of {@link ServerWebExchange}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class DefaultServerWebExchange implements ServerWebExchange { + + private static final List SAFE_METHODS = Arrays.asList(HttpMethod.GET, HttpMethod.HEAD); + + private static final ResolvableType FORM_DATA_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + + private static final ResolvableType MULTIPART_DATA_TYPE = ResolvableType.forClassWithGenerics( + MultiValueMap.class, String.class, Part.class); + + private static final Mono> EMPTY_FORM_DATA = + Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))) + .cache(); + + private static final Mono> EMPTY_MULTIPART_DATA = + Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))) + .cache(); + + + private final ServerHttpRequest request; + + private final ServerHttpResponse response; + + private final Map attributes = new ConcurrentHashMap<>(); + + private final Mono sessionMono; + + private final LocaleContextResolver localeContextResolver; + + private final Mono> formDataMono; + + private final Mono> multipartDataMono; + + @Nullable + private final ApplicationContext applicationContext; + + private volatile boolean notModified; + + private Function urlTransformer = url -> url; + + @Nullable + private Object logId; + + private String logPrefix = ""; + + + public DefaultServerWebExchange(ServerHttpRequest request, ServerHttpResponse response, + WebSessionManager sessionManager, ServerCodecConfigurer codecConfigurer, + LocaleContextResolver localeContextResolver) { + + this(request, response, sessionManager, codecConfigurer, localeContextResolver, null); + } + + DefaultServerWebExchange(ServerHttpRequest request, ServerHttpResponse response, + WebSessionManager sessionManager, ServerCodecConfigurer codecConfigurer, + LocaleContextResolver localeContextResolver, @Nullable ApplicationContext applicationContext) { + + Assert.notNull(request, "'request' is required"); + Assert.notNull(response, "'response' is required"); + Assert.notNull(sessionManager, "'sessionManager' is required"); + Assert.notNull(codecConfigurer, "'codecConfigurer' is required"); + Assert.notNull(localeContextResolver, "'localeContextResolver' is required"); + + // Initialize before first call to getLogPrefix() + this.attributes.put(ServerWebExchange.LOG_ID_ATTRIBUTE, request.getId()); + + this.request = request; + this.response = response; + this.sessionMono = sessionManager.getSession(this).cache(); + this.localeContextResolver = localeContextResolver; + this.formDataMono = initFormData(request, codecConfigurer, getLogPrefix()); + this.multipartDataMono = initMultipartData(request, codecConfigurer, getLogPrefix()); + this.applicationContext = applicationContext; + } + + @SuppressWarnings("unchecked") + private static Mono> initFormData(ServerHttpRequest request, + ServerCodecConfigurer configurer, String logPrefix) { + + try { + MediaType contentType = request.getHeaders().getContentType(); + if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(contentType)) { + return ((HttpMessageReader>) configurer.getReaders().stream() + .filter(reader -> reader.canRead(FORM_DATA_TYPE, MediaType.APPLICATION_FORM_URLENCODED)) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No form data HttpMessageReader."))) + .readMono(FORM_DATA_TYPE, request, Hints.from(Hints.LOG_PREFIX_HINT, logPrefix)) + .switchIfEmpty(EMPTY_FORM_DATA) + .cache(); + } + } + catch (InvalidMediaTypeException ex) { + // Ignore + } + return EMPTY_FORM_DATA; + } + + @SuppressWarnings("unchecked") + private static Mono> initMultipartData(ServerHttpRequest request, + ServerCodecConfigurer configurer, String logPrefix) { + + try { + MediaType contentType = request.getHeaders().getContentType(); + if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) { + return ((HttpMessageReader>) configurer.getReaders().stream() + .filter(reader -> reader.canRead(MULTIPART_DATA_TYPE, MediaType.MULTIPART_FORM_DATA)) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No multipart HttpMessageReader."))) + .readMono(MULTIPART_DATA_TYPE, request, Hints.from(Hints.LOG_PREFIX_HINT, logPrefix)) + .switchIfEmpty(EMPTY_MULTIPART_DATA) + .cache(); + } + } + catch (InvalidMediaTypeException ex) { + // Ignore + } + return EMPTY_MULTIPART_DATA; + } + + + @Override + public ServerHttpRequest getRequest() { + return this.request; + } + + private HttpHeaders getRequestHeaders() { + return getRequest().getHeaders(); + } + + @Override + public ServerHttpResponse getResponse() { + return this.response; + } + + private HttpHeaders getResponseHeaders() { + return getResponse().getHeaders(); + } + + @Override + public Map getAttributes() { + return this.attributes; + } + + @Override + public Mono getSession() { + return this.sessionMono; + } + + @Override + public Mono getPrincipal() { + return Mono.empty(); + } + + @Override + public Mono> getFormData() { + return this.formDataMono; + } + + @Override + public Mono> getMultipartData() { + return this.multipartDataMono; + } + + @Override + public LocaleContext getLocaleContext() { + return this.localeContextResolver.resolveLocaleContext(this); + } + + @Override + @Nullable + public ApplicationContext getApplicationContext() { + return this.applicationContext; + } + + @Override + public boolean isNotModified() { + return this.notModified; + } + + @Override + public boolean checkNotModified(Instant lastModified) { + return checkNotModified(null, lastModified); + } + + @Override + public boolean checkNotModified(String etag) { + return checkNotModified(etag, Instant.MIN); + } + + @Override + public boolean checkNotModified(@Nullable String etag, Instant lastModified) { + HttpStatus status = getResponse().getStatusCode(); + if (this.notModified || (status != null && !HttpStatus.OK.equals(status))) { + return this.notModified; + } + + // Evaluate conditions in order of precedence. + // See https://tools.ietf.org/html/rfc7232#section-6 + + if (validateIfUnmodifiedSince(lastModified)) { + if (this.notModified) { + getResponse().setStatusCode(HttpStatus.PRECONDITION_FAILED); + } + return this.notModified; + } + + boolean validated = validateIfNoneMatch(etag); + if (!validated) { + validateIfModifiedSince(lastModified); + } + + // Update response + + boolean isHttpGetOrHead = SAFE_METHODS.contains(getRequest().getMethod()); + if (this.notModified) { + getResponse().setStatusCode(isHttpGetOrHead ? + HttpStatus.NOT_MODIFIED : HttpStatus.PRECONDITION_FAILED); + } + if (isHttpGetOrHead) { + if (lastModified.isAfter(Instant.EPOCH) && getResponseHeaders().getLastModified() == -1) { + getResponseHeaders().setLastModified(lastModified.toEpochMilli()); + } + if (StringUtils.hasLength(etag) && getResponseHeaders().getETag() == null) { + getResponseHeaders().setETag(padEtagIfNecessary(etag)); + } + } + + return this.notModified; + } + + private boolean validateIfUnmodifiedSince(Instant lastModified) { + if (lastModified.isBefore(Instant.EPOCH)) { + return false; + } + long ifUnmodifiedSince = getRequestHeaders().getIfUnmodifiedSince(); + if (ifUnmodifiedSince == -1) { + return false; + } + // We will perform this validation... + Instant sinceInstant = Instant.ofEpochMilli(ifUnmodifiedSince); + this.notModified = sinceInstant.isBefore(lastModified.truncatedTo(ChronoUnit.SECONDS)); + return true; + } + + private boolean validateIfNoneMatch(@Nullable String etag) { + if (!StringUtils.hasLength(etag)) { + return false; + } + List ifNoneMatch; + try { + ifNoneMatch = getRequestHeaders().getIfNoneMatch(); + } + catch (IllegalArgumentException ex) { + return false; + } + if (ifNoneMatch.isEmpty()) { + return false; + } + // We will perform this validation... + etag = padEtagIfNecessary(etag); + if (etag.startsWith("W/")) { + etag = etag.substring(2); + } + for (String clientEtag : ifNoneMatch) { + // Compare weak/strong ETags as per https://tools.ietf.org/html/rfc7232#section-2.3 + if (StringUtils.hasLength(clientEtag)) { + if (clientEtag.startsWith("W/")) { + clientEtag = clientEtag.substring(2); + } + if (clientEtag.equals(etag)) { + this.notModified = true; + break; + } + } + } + return true; + } + + private String padEtagIfNecessary(String etag) { + if (!StringUtils.hasLength(etag)) { + return etag; + } + if ((etag.startsWith("\"") || etag.startsWith("W/\"")) && etag.endsWith("\"")) { + return etag; + } + return "\"" + etag + "\""; + } + + private boolean validateIfModifiedSince(Instant lastModified) { + if (lastModified.isBefore(Instant.EPOCH)) { + return false; + } + long ifModifiedSince = getRequestHeaders().getIfModifiedSince(); + if (ifModifiedSince == -1) { + return false; + } + // We will perform this validation... + this.notModified = ChronoUnit.SECONDS.between(lastModified, Instant.ofEpochMilli(ifModifiedSince)) >= 0; + return true; + } + + @Override + public String transformUrl(String url) { + return this.urlTransformer.apply(url); + } + + @Override + public void addUrlTransformer(Function transformer) { + Assert.notNull(transformer, "'encoder' must not be null"); + this.urlTransformer = this.urlTransformer.andThen(transformer); + } + + @Override + public String getLogPrefix() { + Object value = getAttribute(LOG_ID_ATTRIBUTE); + if (this.logId != value) { + this.logId = value; + this.logPrefix = value != null ? "[" + value + "] " : ""; + } + return this.logPrefix; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/ForwardedHeaderTransformer.java b/spring-web/src/main/java/org/springframework/web/server/adapter/ForwardedHeaderTransformer.java new file mode 100644 index 0000000000000000000000000000000000000000..b3990166cb1906b5a336c33c163454be38bd1be1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/ForwardedHeaderTransformer.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import java.net.URI; +import java.util.Collections; +import java.util.Locale; +import java.util.Set; +import java.util.function.Function; + +import org.springframework.context.ApplicationContext; +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Extract values from "Forwarded" and "X-Forwarded-*" headers to override + * the request URI (i.e. {@link ServerHttpRequest#getURI()}) so it reflects + * the client-originated protocol and address. + * + *

Alternatively if {@link #setRemoveOnly removeOnly} is set to "true", + * then "Forwarded" and "X-Forwarded-*" headers are only removed, and not used. + * + *

An instance of this class is typically declared as a bean with the name + * "forwardedHeaderTransformer" and detected by + * {@link WebHttpHandlerBuilder#applicationContext(ApplicationContext)}, or it + * can also be registered directly via + * {@link WebHttpHandlerBuilder#forwardedHeaderTransformer(ForwardedHeaderTransformer)}. + * + * @author Rossen Stoyanchev + * @since 5.1 + * @see https://tools.ietf.org/html/rfc7239 + */ +public class ForwardedHeaderTransformer implements Function { + + static final Set FORWARDED_HEADER_NAMES = + Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH)); + + static { + FORWARDED_HEADER_NAMES.add("Forwarded"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Ssl"); + } + + + private boolean removeOnly; + + + /** + * Enable mode in which any "Forwarded" or "X-Forwarded-*" headers are + * removed only and the information in them ignored. + * @param removeOnly whether to discard and ignore forwarded headers + */ + public void setRemoveOnly(boolean removeOnly) { + this.removeOnly = removeOnly; + } + + /** + * Whether the "remove only" mode is on. + * @see #setRemoveOnly + */ + public boolean isRemoveOnly() { + return this.removeOnly; + } + + + /** + * Apply and remove, or remove Forwarded type headers. + * @param request the request + */ + @Override + public ServerHttpRequest apply(ServerHttpRequest request) { + if (hasForwardedHeaders(request)) { + ServerHttpRequest.Builder builder = request.mutate(); + if (!this.removeOnly) { + URI uri = UriComponentsBuilder.fromHttpRequest(request).build(true).toUri(); + builder.uri(uri); + String prefix = getForwardedPrefix(request); + if (prefix != null) { + builder.path(prefix + uri.getRawPath()); + builder.contextPath(prefix); + } + } + removeForwardedHeaders(builder); + request = builder.build(); + } + return request; + } + + /** + * Whether the request has any Forwarded headers. + * @param request the request + */ + protected boolean hasForwardedHeaders(ServerHttpRequest request) { + HttpHeaders headers = request.getHeaders(); + for (String headerName : FORWARDED_HEADER_NAMES) { + if (headers.containsKey(headerName)) { + return true; + } + } + return false; + } + + private void removeForwardedHeaders(ServerHttpRequest.Builder builder) { + builder.headers(map -> FORWARDED_HEADER_NAMES.forEach(map::remove)); + } + + + @Nullable + private static String getForwardedPrefix(ServerHttpRequest request) { + HttpHeaders headers = request.getHeaders(); + String prefix = headers.getFirst("X-Forwarded-Prefix"); + if (prefix != null) { + int endIndex = prefix.length(); + while (endIndex > 1 && prefix.charAt(endIndex - 1) == '/') { + endIndex--; + } + prefix = (endIndex != prefix.length() ? prefix.substring(0, endIndex) : prefix); + } + return prefix; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java b/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..bd3edea6aced57e6cf47b8e2cd263c551dcd536a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java @@ -0,0 +1,306 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.core.NestedExceptionUtils; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.codec.LoggingCodecSupport; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.handler.WebHandlerDecorator; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.i18n.LocaleContextResolver; +import org.springframework.web.server.session.DefaultWebSessionManager; +import org.springframework.web.server.session.WebSessionManager; + +/** + * Default adapter of {@link WebHandler} to the {@link HttpHandler} contract. + * + *

By default creates and configures a {@link DefaultServerWebExchange} and + * then invokes the target {@code WebHandler}. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + */ +public class HttpWebHandlerAdapter extends WebHandlerDecorator implements HttpHandler { + + /** + * Dedicated log category for disconnected client exceptions. + *

Servlet containers dn't expose a a client disconnected callback, see + * eclipse-ee4j/servlet-api#44. + *

To avoid filling logs with unnecessary stack traces, we make an + * effort to identify such network failures on a per-server basis, and then + * log under a separate log category a simple one-line message at DEBUG level + * or a full stack trace only at TRACE level. + */ + private static final String DISCONNECTED_CLIENT_LOG_CATEGORY = + "org.springframework.web.server.DisconnectedClient"; + + // Similar declaration exists in AbstractSockJsSession.. + private static final Set DISCONNECTED_CLIENT_EXCEPTIONS = new HashSet<>( + Arrays.asList("AbortedException", "ClientAbortException", "EOFException", "EofException")); + + + private static final Log logger = LogFactory.getLog(HttpWebHandlerAdapter.class); + + private static final Log lostClientLogger = LogFactory.getLog(DISCONNECTED_CLIENT_LOG_CATEGORY); + + + private WebSessionManager sessionManager = new DefaultWebSessionManager(); + + private ServerCodecConfigurer codecConfigurer = ServerCodecConfigurer.create(); + + private LocaleContextResolver localeContextResolver = new AcceptHeaderLocaleContextResolver(); + + @Nullable + private ForwardedHeaderTransformer forwardedHeaderTransformer; + + @Nullable + private ApplicationContext applicationContext; + + /** Whether to log potentially sensitive info (form data at DEBUG, headers at TRACE). */ + private boolean enableLoggingRequestDetails = false; + + + public HttpWebHandlerAdapter(WebHandler delegate) { + super(delegate); + } + + + /** + * Configure a custom {@link WebSessionManager} to use for managing web + * sessions. The provided instance is set on each created + * {@link DefaultServerWebExchange}. + *

By default this is set to {@link DefaultWebSessionManager}. + * @param sessionManager the session manager to use + */ + public void setSessionManager(WebSessionManager sessionManager) { + Assert.notNull(sessionManager, "WebSessionManager must not be null"); + this.sessionManager = sessionManager; + } + + /** + * Return the configured {@link WebSessionManager}. + */ + public WebSessionManager getSessionManager() { + return this.sessionManager; + } + + /** + * Configure a custom {@link ServerCodecConfigurer}. The provided instance is set on + * each created {@link DefaultServerWebExchange}. + *

By default this is set to {@link ServerCodecConfigurer#create()}. + * @param codecConfigurer the codec configurer to use + */ + public void setCodecConfigurer(ServerCodecConfigurer codecConfigurer) { + Assert.notNull(codecConfigurer, "ServerCodecConfigurer is required"); + this.codecConfigurer = codecConfigurer; + + this.enableLoggingRequestDetails = false; + this.codecConfigurer.getReaders().stream() + .filter(LoggingCodecSupport.class::isInstance) + .forEach(reader -> { + if (((LoggingCodecSupport) reader).isEnableLoggingRequestDetails()) { + this.enableLoggingRequestDetails = true; + } + }); + } + + /** + * Return the configured {@link ServerCodecConfigurer}. + */ + public ServerCodecConfigurer getCodecConfigurer() { + return this.codecConfigurer; + } + + /** + * Configure a custom {@link LocaleContextResolver}. The provided instance is set on + * each created {@link DefaultServerWebExchange}. + *

By default this is set to + * {@link org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver}. + * @param resolver the locale context resolver to use + */ + public void setLocaleContextResolver(LocaleContextResolver resolver) { + Assert.notNull(resolver, "LocaleContextResolver is required"); + this.localeContextResolver = resolver; + } + + /** + * Return the configured {@link LocaleContextResolver}. + */ + public LocaleContextResolver getLocaleContextResolver() { + return this.localeContextResolver; + } + + /** + * Enable processing of forwarded headers, either extracting and removing, + * or remove only. + *

By default this is not set. + * @param transformer the transformer to use + * @since 5.1 + */ + public void setForwardedHeaderTransformer(ForwardedHeaderTransformer transformer) { + Assert.notNull(transformer, "ForwardedHeaderTransformer is required"); + this.forwardedHeaderTransformer = transformer; + } + + /** + * Return the configured {@link ForwardedHeaderTransformer}. + * @since 5.1 + */ + @Nullable + public ForwardedHeaderTransformer getForwardedHeaderTransformer() { + return this.forwardedHeaderTransformer; + } + + /** + * Configure the {@code ApplicationContext} associated with the web application, + * if it was initialized with one via + * {@link org.springframework.web.server.adapter.WebHttpHandlerBuilder#applicationContext(ApplicationContext)}. + * @param applicationContext the context + * @since 5.0.3 + */ + public void setApplicationContext(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + /** + * Return the configured {@code ApplicationContext}, if any. + * @since 5.0.3 + */ + @Nullable + public ApplicationContext getApplicationContext() { + return this.applicationContext; + } + + /** + * This method must be invoked after all properties have been set to + * complete initialization. + */ + public void afterPropertiesSet() { + if (logger.isDebugEnabled()) { + String value = this.enableLoggingRequestDetails ? + "shown which may lead to unsafe logging of potentially sensitive data" : + "masked to prevent unsafe logging of potentially sensitive data"; + logger.debug("enableLoggingRequestDetails='" + this.enableLoggingRequestDetails + + "': form data and headers will be " + value); + } + } + + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + if (this.forwardedHeaderTransformer != null) { + request = this.forwardedHeaderTransformer.apply(request); + } + ServerWebExchange exchange = createExchange(request, response); + + LogFormatUtils.traceDebug(logger, traceOn -> + exchange.getLogPrefix() + formatRequest(exchange.getRequest()) + + (traceOn ? ", headers=" + formatHeaders(exchange.getRequest().getHeaders()) : "")); + + return getDelegate().handle(exchange) + .doOnSuccess(aVoid -> logResponse(exchange)) + .onErrorResume(ex -> handleUnresolvedError(exchange, ex)) + .then(Mono.defer(response::setComplete)); + } + + protected ServerWebExchange createExchange(ServerHttpRequest request, ServerHttpResponse response) { + return new DefaultServerWebExchange(request, response, this.sessionManager, + getCodecConfigurer(), getLocaleContextResolver(), this.applicationContext); + } + + private String formatRequest(ServerHttpRequest request) { + String rawQuery = request.getURI().getRawQuery(); + String query = StringUtils.hasText(rawQuery) ? "?" + rawQuery : ""; + return "HTTP " + request.getMethod() + " \"" + request.getPath() + query + "\""; + } + + private void logResponse(ServerWebExchange exchange) { + LogFormatUtils.traceDebug(logger, traceOn -> { + HttpStatus status = exchange.getResponse().getStatusCode(); + return exchange.getLogPrefix() + "Completed " + (status != null ? status : "200 OK") + + (traceOn ? ", headers=" + formatHeaders(exchange.getResponse().getHeaders()) : ""); + }); + } + + private String formatHeaders(HttpHeaders responseHeaders) { + return this.enableLoggingRequestDetails ? + responseHeaders.toString() : responseHeaders.isEmpty() ? "{}" : "{masked}"; + } + + private Mono handleUnresolvedError(ServerWebExchange exchange, Throwable ex) { + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + String logPrefix = exchange.getLogPrefix(); + + // Sometimes a remote call error can look like a disconnected client. + // Try to set the response first before the "isDisconnectedClient" check. + + if (response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR)) { + logger.error(logPrefix + "500 Server Error for " + formatRequest(request), ex); + return Mono.empty(); + } + else if (isDisconnectedClientError(ex)) { + if (lostClientLogger.isTraceEnabled()) { + lostClientLogger.trace(logPrefix + "Client went away", ex); + } + else if (lostClientLogger.isDebugEnabled()) { + lostClientLogger.debug(logPrefix + "Client went away: " + ex + + " (stacktrace at TRACE level for '" + DISCONNECTED_CLIENT_LOG_CATEGORY + "')"); + } + return Mono.empty(); + } + else { + // After the response is committed, propagate errors to the server... + logger.error(logPrefix + "Error [" + ex + "] for " + formatRequest(request) + + ", but ServerHttpResponse already committed (" + response.getStatusCode() + ")"); + return Mono.error(ex); + } + } + + private boolean isDisconnectedClientError(Throwable ex) { + String message = NestedExceptionUtils.getMostSpecificCause(ex).getMessage(); + if (message != null) { + String text = message.toLowerCase(); + if (text.contains("broken pipe") || text.contains("connection reset by peer")) { + return true; + } + } + return DISCONNECTED_CLIENT_EXCEPTIONS.contains(ex.getClass().getSimpleName()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/WebHttpHandlerBuilder.java b/spring-web/src/main/java/org/springframework/web/server/adapter/WebHttpHandlerBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..2bb6224846957ac433990b1db1fb494aff52c614 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/WebHttpHandlerBuilder.java @@ -0,0 +1,392 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.context.ApplicationContext; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.handler.ExceptionHandlingWebHandler; +import org.springframework.web.server.handler.FilteringWebHandler; +import org.springframework.web.server.i18n.LocaleContextResolver; +import org.springframework.web.server.session.DefaultWebSessionManager; +import org.springframework.web.server.session.WebSessionManager; + +/** + * This builder has two purposes: + * + *

One is to assemble a processing chain that consists of a target {@link WebHandler}, + * then decorated with a set of {@link WebFilter WebFilters}, then further decorated with + * a set of {@link WebExceptionHandler WebExceptionHandlers}. + * + *

The second purpose is to adapt the resulting processing chain to an {@link HttpHandler}: + * the lowest-level reactive HTTP handling abstraction which can then be used with any of the + * supported runtimes. The adaptation is done with the help of {@link HttpWebHandlerAdapter}. + * + *

The processing chain can be assembled manually via builder methods, or detected from + * a Spring {@link ApplicationContext} via {@link #applicationContext}, or a mix of both. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + * @see HttpWebHandlerAdapter + */ +public final class WebHttpHandlerBuilder { + + /** Well-known name for the target WebHandler in the bean factory. */ + public static final String WEB_HANDLER_BEAN_NAME = "webHandler"; + + /** Well-known name for the WebSessionManager in the bean factory. */ + public static final String WEB_SESSION_MANAGER_BEAN_NAME = "webSessionManager"; + + /** Well-known name for the ServerCodecConfigurer in the bean factory. */ + public static final String SERVER_CODEC_CONFIGURER_BEAN_NAME = "serverCodecConfigurer"; + + /** Well-known name for the LocaleContextResolver in the bean factory. */ + public static final String LOCALE_CONTEXT_RESOLVER_BEAN_NAME = "localeContextResolver"; + + /** Well-known name for the ForwardedHeaderTransformer in the bean factory. */ + public static final String FORWARDED_HEADER_TRANSFORMER_BEAN_NAME = "forwardedHeaderTransformer"; + + + private final WebHandler webHandler; + + @Nullable + private final ApplicationContext applicationContext; + + private final List filters = new ArrayList<>(); + + private final List exceptionHandlers = new ArrayList<>(); + + @Nullable + private WebSessionManager sessionManager; + + @Nullable + private ServerCodecConfigurer codecConfigurer; + + @Nullable + private LocaleContextResolver localeContextResolver; + + @Nullable + private ForwardedHeaderTransformer forwardedHeaderTransformer; + + + /** + * Private constructor to use when initialized from an ApplicationContext. + */ + private WebHttpHandlerBuilder(WebHandler webHandler, @Nullable ApplicationContext applicationContext) { + Assert.notNull(webHandler, "WebHandler must not be null"); + this.webHandler = webHandler; + this.applicationContext = applicationContext; + } + + /** + * Copy constructor. + */ + private WebHttpHandlerBuilder(WebHttpHandlerBuilder other) { + this.webHandler = other.webHandler; + this.applicationContext = other.applicationContext; + this.filters.addAll(other.filters); + this.exceptionHandlers.addAll(other.exceptionHandlers); + this.sessionManager = other.sessionManager; + this.codecConfigurer = other.codecConfigurer; + this.localeContextResolver = other.localeContextResolver; + this.forwardedHeaderTransformer = other.forwardedHeaderTransformer; + } + + + /** + * Static factory method to create a new builder instance. + * @param webHandler the target handler for the request + * @return the prepared builder + */ + public static WebHttpHandlerBuilder webHandler(WebHandler webHandler) { + return new WebHttpHandlerBuilder(webHandler, null); + } + + /** + * Static factory method to create a new builder instance by detecting beans + * in an {@link ApplicationContext}. The following are detected: + *

    + *
  • {@link WebHandler} [1] -- looked up by the name + * {@link #WEB_HANDLER_BEAN_NAME}. + *
  • {@link WebFilter} [0..N] -- detected by type and ordered, + * see {@link AnnotationAwareOrderComparator}. + *
  • {@link WebExceptionHandler} [0..N] -- detected by type and + * ordered. + *
  • {@link WebSessionManager} [0..1] -- looked up by the name + * {@link #WEB_SESSION_MANAGER_BEAN_NAME}. + *
  • {@link ServerCodecConfigurer} [0..1] -- looked up by the name + * {@link #SERVER_CODEC_CONFIGURER_BEAN_NAME}. + *
  • {@link LocaleContextResolver} [0..1] -- looked up by the name + * {@link #LOCALE_CONTEXT_RESOLVER_BEAN_NAME}. + *
+ * @param context the application context to use for the lookup + * @return the prepared builder + */ + public static WebHttpHandlerBuilder applicationContext(ApplicationContext context) { + WebHttpHandlerBuilder builder = new WebHttpHandlerBuilder( + context.getBean(WEB_HANDLER_BEAN_NAME, WebHandler.class), context); + + List webFilters = context + .getBeanProvider(WebFilter.class) + .orderedStream() + .collect(Collectors.toList()); + builder.filters(filters -> filters.addAll(webFilters)); + List exceptionHandlers = context + .getBeanProvider(WebExceptionHandler.class) + .orderedStream() + .collect(Collectors.toList()); + builder.exceptionHandlers(handlers -> handlers.addAll(exceptionHandlers)); + + try { + builder.sessionManager( + context.getBean(WEB_SESSION_MANAGER_BEAN_NAME, WebSessionManager.class)); + } + catch (NoSuchBeanDefinitionException ex) { + // Fall back on default + } + + try { + builder.codecConfigurer( + context.getBean(SERVER_CODEC_CONFIGURER_BEAN_NAME, ServerCodecConfigurer.class)); + } + catch (NoSuchBeanDefinitionException ex) { + // Fall back on default + } + + try { + builder.localeContextResolver( + context.getBean(LOCALE_CONTEXT_RESOLVER_BEAN_NAME, LocaleContextResolver.class)); + } + catch (NoSuchBeanDefinitionException ex) { + // Fall back on default + } + + try { + builder.localeContextResolver( + context.getBean(LOCALE_CONTEXT_RESOLVER_BEAN_NAME, LocaleContextResolver.class)); + } + catch (NoSuchBeanDefinitionException ex) { + // Fall back on default + } + + try { + builder.forwardedHeaderTransformer( + context.getBean(FORWARDED_HEADER_TRANSFORMER_BEAN_NAME, ForwardedHeaderTransformer.class)); + } + catch (NoSuchBeanDefinitionException ex) { + // Fall back on default + } + + return builder; + } + + + /** + * Add the given filter(s). + * @param filters the filter(s) to add that's + */ + public WebHttpHandlerBuilder filter(WebFilter... filters) { + if (!ObjectUtils.isEmpty(filters)) { + this.filters.addAll(Arrays.asList(filters)); + updateFilters(); + } + return this; + } + + /** + * Manipulate the "live" list of currently configured filters. + * @param consumer the consumer to use + */ + public WebHttpHandlerBuilder filters(Consumer> consumer) { + consumer.accept(this.filters); + updateFilters(); + return this; + } + + private void updateFilters() { + if (this.filters.isEmpty()) { + return; + } + + List filtersToUse = this.filters.stream() + .peek(filter -> { + if (filter instanceof ForwardedHeaderTransformer && this.forwardedHeaderTransformer == null) { + this.forwardedHeaderTransformer = (ForwardedHeaderTransformer) filter; + } + }) + .filter(filter -> !(filter instanceof ForwardedHeaderTransformer)) + .collect(Collectors.toList()); + + this.filters.clear(); + this.filters.addAll(filtersToUse); + } + + /** + * Add the given exception handler(s). + * @param handlers the exception handler(s) + */ + public WebHttpHandlerBuilder exceptionHandler(WebExceptionHandler... handlers) { + if (!ObjectUtils.isEmpty(handlers)) { + this.exceptionHandlers.addAll(Arrays.asList(handlers)); + } + return this; + } + + /** + * Manipulate the "live" list of currently configured exception handlers. + * @param consumer the consumer to use + */ + public WebHttpHandlerBuilder exceptionHandlers(Consumer> consumer) { + consumer.accept(this.exceptionHandlers); + return this; + } + + /** + * Configure the {@link WebSessionManager} to set on the + * {@link ServerWebExchange WebServerExchange}. + *

By default {@link DefaultWebSessionManager} is used. + * @param manager the session manager + * @see HttpWebHandlerAdapter#setSessionManager(WebSessionManager) + */ + public WebHttpHandlerBuilder sessionManager(WebSessionManager manager) { + this.sessionManager = manager; + return this; + } + + /** + * Whether a {@code WebSessionManager} is configured or not, either detected from an + * {@code ApplicationContext} or explicitly configured via {@link #sessionManager}. + * @since 5.0.9 + */ + public boolean hasSessionManager() { + return (this.sessionManager != null); + } + + /** + * Configure the {@link ServerCodecConfigurer} to set on the {@code WebServerExchange}. + * @param codecConfigurer the codec configurer + */ + public WebHttpHandlerBuilder codecConfigurer(ServerCodecConfigurer codecConfigurer) { + this.codecConfigurer = codecConfigurer; + return this; + } + + + /** + * Whether a {@code ServerCodecConfigurer} is configured or not, either detected from an + * {@code ApplicationContext} or explicitly configured via {@link #codecConfigurer}. + * @since 5.0.9 + */ + public boolean hasCodecConfigurer() { + return (this.codecConfigurer != null); + } + + /** + * Configure the {@link LocaleContextResolver} to set on the + * {@link ServerWebExchange WebServerExchange}. + * @param localeContextResolver the locale context resolver + */ + public WebHttpHandlerBuilder localeContextResolver(LocaleContextResolver localeContextResolver) { + this.localeContextResolver = localeContextResolver; + return this; + } + + /** + * Whether a {@code LocaleContextResolver} is configured or not, either detected from an + * {@code ApplicationContext} or explicitly configured via {@link #localeContextResolver}. + * @since 5.0.9 + */ + public boolean hasLocaleContextResolver() { + return (this.localeContextResolver != null); + } + + /** + * Configure the {@link ForwardedHeaderTransformer} for extracting and/or + * removing forwarded headers. + * @param transformer the transformer + * @since 5.1 + */ + public WebHttpHandlerBuilder forwardedHeaderTransformer(ForwardedHeaderTransformer transformer) { + this.forwardedHeaderTransformer = transformer; + return this; + } + + /** + * Whether a {@code ForwardedHeaderTransformer} is configured or not, either + * detected from an {@code ApplicationContext} or explicitly configured via + * {@link #forwardedHeaderTransformer(ForwardedHeaderTransformer)}. + * @since 5.1 + */ + public boolean hasForwardedHeaderTransformer() { + return (this.forwardedHeaderTransformer != null); + } + + + /** + * Build the {@link HttpHandler}. + */ + public HttpHandler build() { + WebHandler decorated = new FilteringWebHandler(this.webHandler, this.filters); + decorated = new ExceptionHandlingWebHandler(decorated, this.exceptionHandlers); + + HttpWebHandlerAdapter adapted = new HttpWebHandlerAdapter(decorated); + if (this.sessionManager != null) { + adapted.setSessionManager(this.sessionManager); + } + if (this.codecConfigurer != null) { + adapted.setCodecConfigurer(this.codecConfigurer); + } + if (this.localeContextResolver != null) { + adapted.setLocaleContextResolver(this.localeContextResolver); + } + if (this.forwardedHeaderTransformer != null) { + adapted.setForwardedHeaderTransformer(this.forwardedHeaderTransformer); + } + if (this.applicationContext != null) { + adapted.setApplicationContext(this.applicationContext); + } + adapted.afterPropertiesSet(); + + return adapted; + } + + /** + * Clone this {@link WebHttpHandlerBuilder}. + * @return the cloned builder instance + */ + @Override + public WebHttpHandlerBuilder clone() { + return new WebHttpHandlerBuilder(this); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/package-info.java b/spring-web/src/main/java/org/springframework/web/server/adapter/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..13828629a5200eb622314976502c23057238a795 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/package-info.java @@ -0,0 +1,11 @@ +/** + * Implementations to adapt to the underlying + * {@code org.springframework.http.client.reactive} reactive HTTP adapter + * and {@link org.springframework.http.server.reactive.HttpHandler}. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.server.adapter; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/DefaultWebFilterChain.java b/spring-web/src/main/java/org/springframework/web/server/handler/DefaultWebFilterChain.java new file mode 100644 index 0000000000000000000000000000000000000000..619a55d0d542b6ce938c150eb9f6ef9319b4677b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/handler/DefaultWebFilterChain.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; + +import reactor.core.publisher.Mono; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebHandler; + +/** + * Default implementation of {@link WebFilterChain}. + * + *

Each instance of this class represents one link in the chain. The public + * constructor {@link #DefaultWebFilterChain(WebHandler, List)} + * initializes the full chain and represents its first link. + * + *

This class is immutable and thread-safe. It can be created once and + * re-used to handle request concurrently. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class DefaultWebFilterChain implements WebFilterChain { + + private final List allFilters; + + private final WebHandler handler; + + @Nullable + private final WebFilter currentFilter; + + @Nullable + private final DefaultWebFilterChain next; + + + /** + * Public constructor with the list of filters and the target handler to use. + * @param handler the target handler + * @param filters the filters ahead of the handler + * @since 5.1 + */ + public DefaultWebFilterChain(WebHandler handler, List filters) { + Assert.notNull(handler, "WebHandler is required"); + this.allFilters = Collections.unmodifiableList(filters); + this.handler = handler; + DefaultWebFilterChain chain = initChain(filters, handler); + this.currentFilter = chain.currentFilter; + this.next = chain.next; + } + + private static DefaultWebFilterChain initChain(List filters, WebHandler handler) { + DefaultWebFilterChain chain = new DefaultWebFilterChain(filters, handler, null, null); + ListIterator iterator = filters.listIterator(filters.size()); + while (iterator.hasPrevious()) { + chain = new DefaultWebFilterChain(filters, handler, iterator.previous(), chain); + } + return chain; + } + + /** + * Private constructor to represent one link in the chain. + */ + private DefaultWebFilterChain(List allFilters, WebHandler handler, + @Nullable WebFilter currentFilter, @Nullable DefaultWebFilterChain next) { + + this.allFilters = allFilters; + this.currentFilter = currentFilter; + this.handler = handler; + this.next = next; + } + + /** + * Public constructor with the list of filters and the target handler to use. + * @param handler the target handler + * @param filters the filters ahead of the handler + * @deprecated as of 5.1 this constructor is deprecated in favor of + * {@link #DefaultWebFilterChain(WebHandler, List)}. + */ + @Deprecated + public DefaultWebFilterChain(WebHandler handler, WebFilter... filters) { + this(handler, Arrays.asList(filters)); + } + + + public List getFilters() { + return this.allFilters; + } + + public WebHandler getHandler() { + return this.handler; + } + + + @Override + public Mono filter(ServerWebExchange exchange) { + return Mono.defer(() -> + this.currentFilter != null && this.next != null ? + this.currentFilter.filter(exchange, this.next) : + this.handler.handle(exchange)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/ExceptionHandlingWebHandler.java b/spring-web/src/main/java/org/springframework/web/server/handler/ExceptionHandlingWebHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..9e8af92e650942c6e19e9b8b47e60c57a50f35d6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/handler/ExceptionHandlingWebHandler.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebHandler; + +/** + * WebHandler decorator that invokes one or more {@link WebExceptionHandler WebExceptionHandlers} + * after the delegate {@link WebHandler}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ExceptionHandlingWebHandler extends WebHandlerDecorator { + + + private final List exceptionHandlers; + + + public ExceptionHandlingWebHandler(WebHandler delegate, List handlers) { + super(delegate); + this.exceptionHandlers = Collections.unmodifiableList(new ArrayList<>(handlers)); + } + + + /** + * Return a read-only list of the configured exception handlers. + */ + public List getExceptionHandlers() { + return this.exceptionHandlers; + } + + + @Override + public Mono handle(ServerWebExchange exchange) { + + Mono completion; + try { + completion = super.handle(exchange); + } + catch (Throwable ex) { + completion = Mono.error(ex); + } + + for (WebExceptionHandler handler : this.exceptionHandlers) { + completion = completion.onErrorResume(ex -> handler.handle(exchange, ex)); + } + + return completion; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/FilteringWebHandler.java b/spring-web/src/main/java/org/springframework/web/server/handler/FilteringWebHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..7bb056619989d3ed15690bf1b1471864a681ea12 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/handler/FilteringWebHandler.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebHandler; + +/** + * {@link WebHandlerDecorator} that invokes a chain of {@link WebFilter WebFilters} + * before invoking the delegate {@link WebHandler}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class FilteringWebHandler extends WebHandlerDecorator { + + private final DefaultWebFilterChain chain; + + + /** + * Constructor. + * @param filters the chain of filters + */ + public FilteringWebHandler(WebHandler handler, List filters) { + super(handler); + this.chain = new DefaultWebFilterChain(handler, filters); + } + + + /** + * Return a read-only list of the configured filters. + */ + public List getFilters() { + return this.chain.getFilters(); + } + + + @Override + public Mono handle(ServerWebExchange exchange) { + return this.chain.filter(exchange); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/ResponseStatusExceptionHandler.java b/spring-web/src/main/java/org/springframework/web/server/handler/ResponseStatusExceptionHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..cda26cba2136ff05dce5d9d971dcb52dcee53298 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/handler/ResponseStatusExceptionHandler.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ResponseStatusException; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; + +/** + * Handle {@link ResponseStatusException} by setting the response status. + * + *

By default exception stack traces are not shown for successfully resolved + * exceptions. Use {@link #setWarnLogCategory(String)} to enable logging with + * stack traces. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + */ +public class ResponseStatusExceptionHandler implements WebExceptionHandler { + + private static final Log logger = LogFactory.getLog(ResponseStatusExceptionHandler.class); + + + @Nullable + private Log warnLogger; + + + /** + * Set the log category for warn logging. + *

Default is no warn logging. Specify this setting to activate warn + * logging into a specific category. + * @since 5.1 + * @see org.apache.commons.logging.LogFactory#getLog(String) + * @see java.util.logging.Logger#getLogger(String) + */ + public void setWarnLogCategory(String loggerName) { + this.warnLogger = LogFactory.getLog(loggerName); + } + + + @Override + public Mono handle(ServerWebExchange exchange, Throwable ex) { + if (!updateResponse(exchange.getResponse(), ex)) { + return Mono.error(ex); + } + + // Mirrors AbstractHandlerExceptionResolver in spring-webmvc... + String logPrefix = exchange.getLogPrefix(); + if (this.warnLogger != null && this.warnLogger.isWarnEnabled()) { + this.warnLogger.warn(logPrefix + formatError(ex, exchange.getRequest()), ex); + } + else if (logger.isDebugEnabled()) { + logger.debug(logPrefix + formatError(ex, exchange.getRequest())); + } + + return exchange.getResponse().setComplete(); + } + + + private String formatError(Throwable ex, ServerHttpRequest request) { + String reason = ex.getClass().getSimpleName() + ": " + ex.getMessage(); + String path = request.getURI().getRawPath(); + return "Resolved [" + reason + "] for HTTP " + request.getMethod() + " " + path; + } + + private boolean updateResponse(ServerHttpResponse response, Throwable ex) { + boolean result = false; + HttpStatus status = determineStatus(ex); + if (status != null) { + if (response.setStatusCode(status)) { + if (ex instanceof ResponseStatusException) { + ((ResponseStatusException) ex).getResponseHeaders() + .forEach((name, values) -> + values.forEach(value -> response.getHeaders().add(name, value))); + } + result = true; + } + } + else { + Throwable cause = ex.getCause(); + if (cause != null) { + result = updateResponse(response, cause); + } + } + return result; + } + + /** + * Determine the HTTP status implied by the given exception. + * @param ex the exception to introspect + * @return the associated HTTP status, if any + * @since 5.0.5 + */ + @Nullable + protected HttpStatus determineStatus(Throwable ex) { + if (ex instanceof ResponseStatusException) { + return ((ResponseStatusException) ex).getStatus(); + } + return null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/WebHandlerDecorator.java b/spring-web/src/main/java/org/springframework/web/server/handler/WebHandlerDecorator.java new file mode 100644 index 0000000000000000000000000000000000000000..4f917c54fe734be28107e579a82415718b556fe7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/handler/WebHandlerDecorator.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; + +/** + * {@link WebHandler} that decorates and delegates to another {@code WebHandler}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class WebHandlerDecorator implements WebHandler { + + private final WebHandler delegate; + + + public WebHandlerDecorator(WebHandler delegate) { + Assert.notNull(delegate, "'delegate' must not be null"); + this.delegate = delegate; + } + + + public WebHandler getDelegate() { + return this.delegate; + } + + + @Override + public Mono handle(ServerWebExchange exchange) { + return this.delegate.handle(exchange); + } + + @Override + public String toString() { + return getClass().getSimpleName() + " [delegate=" + this.delegate + "]"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/handler/package-info.java b/spring-web/src/main/java/org/springframework/web/server/handler/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..25dc942f1e76d770c81512e3d448ff2a3cdeb9e6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/handler/package-info.java @@ -0,0 +1,10 @@ +/** + * Provides common WebHandler implementations and a + * {@link org.springframework.web.server.handler.WebHandlerDecorator}. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.server.handler; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/server/i18n/AcceptHeaderLocaleContextResolver.java b/spring-web/src/main/java/org/springframework/web/server/i18n/AcceptHeaderLocaleContextResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..cde909cf1928af50f7474dadbc075d5121b2d28c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/i18n/AcceptHeaderLocaleContextResolver.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.i18n; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.springframework.context.i18n.LocaleContext; +import org.springframework.context.i18n.SimpleLocaleContext; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * {@link LocaleContextResolver} implementation that simply uses the primary locale + * specified in the "Accept-Language" header of the HTTP request (that is, + * the locale sent by the client browser, normally that of the client's OS). + * + *

Note: Does not support {@link #setLocaleContext}, since the accept header + * can only be changed through changing the client's locale settings. + * + * @author Sebastien Deleuze + * @author Juergen Hoeller + * @since 5.0 + * @see HttpHeaders#getAcceptLanguageAsLocales() + */ +public class AcceptHeaderLocaleContextResolver implements LocaleContextResolver { + + private final List supportedLocales = new ArrayList<>(4); + + @Nullable + private Locale defaultLocale; + + + /** + * Configure supported locales to check against the requested locales + * determined via {@link HttpHeaders#getAcceptLanguageAsLocales()}. + * @param locales the supported locales + */ + public void setSupportedLocales(List locales) { + this.supportedLocales.clear(); + this.supportedLocales.addAll(locales); + } + + /** + * Return the configured list of supported locales. + */ + public List getSupportedLocales() { + return this.supportedLocales; + } + + /** + * Configure a fixed default locale to fall back on if the request does not + * have an "Accept-Language" header (not set by default). + * @param defaultLocale the default locale to use + */ + public void setDefaultLocale(@Nullable Locale defaultLocale) { + this.defaultLocale = defaultLocale; + } + + /** + * The configured default locale, if any. + *

This method may be overridden in subclasses. + */ + @Nullable + public Locale getDefaultLocale() { + return this.defaultLocale; + } + + + @Override + public LocaleContext resolveLocaleContext(ServerWebExchange exchange) { + List requestLocales = null; + try { + requestLocales = exchange.getRequest().getHeaders().getAcceptLanguageAsLocales(); + } + catch (IllegalArgumentException ex) { + // Invalid Accept-Language header: treat as empty for matching purposes + } + return new SimpleLocaleContext(resolveSupportedLocale(requestLocales)); + } + + @Nullable + private Locale resolveSupportedLocale(@Nullable List requestLocales) { + if (CollectionUtils.isEmpty(requestLocales)) { + return getDefaultLocale(); // may be null + } + List supportedLocales = getSupportedLocales(); + if (supportedLocales.isEmpty()) { + return requestLocales.get(0); // never null + } + + Locale languageMatch = null; + for (Locale locale : requestLocales) { + if (supportedLocales.contains(locale)) { + if (languageMatch == null || languageMatch.getLanguage().equals(locale.getLanguage())) { + // Full match: language + country, possibly narrowed from earlier language-only match + return locale; + } + } + else if (languageMatch == null) { + // Let's try to find a language-only match as a fallback + for (Locale candidate : supportedLocales) { + if (!StringUtils.hasLength(candidate.getCountry()) && + candidate.getLanguage().equals(locale.getLanguage())) { + languageMatch = candidate; + break; + } + } + } + } + if (languageMatch != null) { + return languageMatch; + } + + Locale defaultLocale = getDefaultLocale(); + return (defaultLocale != null ? defaultLocale : requestLocales.get(0)); + } + + @Override + public void setLocaleContext(ServerWebExchange exchange, @Nullable LocaleContext locale) { + throw new UnsupportedOperationException( + "Cannot change HTTP accept header - use a different locale context resolution strategy"); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/i18n/FixedLocaleContextResolver.java b/spring-web/src/main/java/org/springframework/web/server/i18n/FixedLocaleContextResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..4c731a4349022515d5d6df19ba5bf6533b3fc7eb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/i18n/FixedLocaleContextResolver.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.i18n; + +import java.util.Locale; +import java.util.TimeZone; + +import org.springframework.context.i18n.LocaleContext; +import org.springframework.context.i18n.TimeZoneAwareLocaleContext; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * {@link LocaleContextResolver} implementation that always returns a fixed locale + * and optionally time zone. Default is the current JVM's default locale. + * + *

Note: Does not support {@link #setLocaleContext}, as the fixed locale and + * time zone cannot be changed. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public class FixedLocaleContextResolver implements LocaleContextResolver { + + private final Locale locale; + + @Nullable + private final TimeZone timeZone; + + + /** + * Create a default FixedLocaleResolver, exposing a configured default + * locale (or the JVM's default locale as fallback). + */ + public FixedLocaleContextResolver() { + this(Locale.getDefault()); + } + + /** + * Create a FixedLocaleResolver that exposes the given locale. + * @param locale the locale to expose + */ + public FixedLocaleContextResolver(Locale locale) { + this(locale, null); + } + + /** + * Create a FixedLocaleResolver that exposes the given locale and time zone. + * @param locale the locale to expose + * @param timeZone the time zone to expose + */ + public FixedLocaleContextResolver(Locale locale, @Nullable TimeZone timeZone) { + Assert.notNull(locale, "Locale must not be null"); + this.locale = locale; + this.timeZone = timeZone; + } + + + @Override + public LocaleContext resolveLocaleContext(ServerWebExchange exchange) { + return new TimeZoneAwareLocaleContext() { + @Override + public Locale getLocale() { + return locale; + } + @Override + @Nullable + public TimeZone getTimeZone() { + return timeZone; + } + }; + } + + @Override + public void setLocaleContext(ServerWebExchange exchange, @Nullable LocaleContext localeContext) { + throw new UnsupportedOperationException( + "Cannot change fixed locale - use a different locale context resolution strategy"); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/i18n/LocaleContextResolver.java b/spring-web/src/main/java/org/springframework/web/server/i18n/LocaleContextResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..e1d01a2a58b558bdcc766e72d32f3da1a19f431f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/i18n/LocaleContextResolver.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.i18n; + +import org.springframework.context.i18n.LocaleContext; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ServerWebExchange; + +/** + * Interface for web-based locale context resolution strategies that allows + * for both locale context resolution via the request and locale context modification + * via the HTTP exchange. + * + *

The {@link org.springframework.context.i18n.LocaleContext} object can potentially + * includes associated time zone and other locale related information. + * + * @author Sebastien Deleuze + * @since 5.0 + * @see LocaleContext + */ +public interface LocaleContextResolver { + + /** + * Resolve the current locale context via the given exchange. + *

The returned context may be a + * {@link org.springframework.context.i18n.TimeZoneAwareLocaleContext}, + * containing a locale with associated time zone information. + * Simply apply an {@code instanceof} check and downcast accordingly. + *

Custom resolver implementations may also return extra settings in + * the returned context, which again can be accessed through downcasting. + * @param exchange current server exchange + * @return the current locale context (never {@code null}) + */ + LocaleContext resolveLocaleContext(ServerWebExchange exchange); + + /** + * Set the current locale context to the given one, + * potentially including a locale with associated time zone information. + * @param exchange current server exchange + * @param localeContext the new locale context, or {@code null} to clear the locale + * @throws UnsupportedOperationException if the LocaleResolver implementation + * does not support dynamic changing of the locale or time zone + * @see org.springframework.context.i18n.SimpleLocaleContext + * @see org.springframework.context.i18n.SimpleTimeZoneAwareLocaleContext + */ + void setLocaleContext(ServerWebExchange exchange, @Nullable LocaleContext localeContext); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/i18n/package-info.java b/spring-web/src/main/java/org/springframework/web/server/i18n/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..c7724d0e0b2daf26039e5468ba7d9307b9fb49ba --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/i18n/package-info.java @@ -0,0 +1,10 @@ +/** + * Locale related support classes. + * Provides standard LocaleContextResolver implementations. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.server.i18n; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/server/package-info.java b/spring-web/src/main/java/org/springframework/web/server/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..603079df256234c811ecccec4703ee1c6bd84ea1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/package-info.java @@ -0,0 +1,12 @@ +/** + * Core interfaces and classes for Spring's generic, reactive web support. + * Builds on top of the {@code org.springframework.http.client.reactive} + * reactive HTTP adapter layer, providing additional constructs such as + * WebHandler, WebFilter, WebSession among others. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.server; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/server/session/CookieWebSessionIdResolver.java b/spring-web/src/main/java/org/springframework/web/server/session/CookieWebSessionIdResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..1f75c1e00806cddd5c4755e2a7ccb178ff2a3441 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/CookieWebSessionIdResolver.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.springframework.http.HttpCookie; +import org.springframework.http.ResponseCookie; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.web.server.ServerWebExchange; + +/** + * Cookie-based {@link WebSessionIdResolver}. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public class CookieWebSessionIdResolver implements WebSessionIdResolver { + + private String cookieName = "SESSION"; + + private Duration cookieMaxAge = Duration.ofSeconds(-1); + + @Nullable + private Consumer cookieInitializer = null; + + + /** + * Set the name of the cookie to use for the session id. + *

By default set to "SESSION". + * @param cookieName the cookie name + */ + public void setCookieName(String cookieName) { + Assert.hasText(cookieName, "'cookieName' must not be empty"); + this.cookieName = cookieName; + } + + /** + * Return the configured cookie name. + */ + public String getCookieName() { + return this.cookieName; + } + + /** + * Set the value for the "Max-Age" attribute of the cookie that holds the + * session id. For the range of values see {@link ResponseCookie#getMaxAge()}. + *

By default set to -1. + * @param maxAge the maxAge duration value + */ + public void setCookieMaxAge(Duration maxAge) { + this.cookieMaxAge = maxAge; + } + + /** + * Return the configured "Max-Age" attribute value for the session cookie. + */ + public Duration getCookieMaxAge() { + return this.cookieMaxAge; + } + + /** + * Add a {@link Consumer} for a {@code ResponseCookieBuilder} that will be invoked + * for each cookie being built, just before the call to {@code build()}. + * @param initializer consumer for a cookie builder + * @since 5.1 + */ + public void addCookieInitializer(Consumer initializer) { + this.cookieInitializer = this.cookieInitializer != null ? + this.cookieInitializer.andThen(initializer) : initializer; + } + + + @Override + public List resolveSessionIds(ServerWebExchange exchange) { + MultiValueMap cookieMap = exchange.getRequest().getCookies(); + List cookies = cookieMap.get(getCookieName()); + if (cookies == null) { + return Collections.emptyList(); + } + return cookies.stream().map(HttpCookie::getValue).collect(Collectors.toList()); + } + + @Override + public void setSessionId(ServerWebExchange exchange, String id) { + Assert.notNull(id, "'id' is required"); + ResponseCookie cookie = initSessionCookie(exchange, id, getCookieMaxAge()); + exchange.getResponse().getCookies().set(this.cookieName, cookie); + } + + @Override + public void expireSession(ServerWebExchange exchange) { + ResponseCookie cookie = initSessionCookie(exchange, "", Duration.ZERO); + exchange.getResponse().getCookies().set(this.cookieName, cookie); + } + + private ResponseCookie initSessionCookie( + ServerWebExchange exchange, String id, Duration maxAge) { + + ResponseCookie.ResponseCookieBuilder cookieBuilder = ResponseCookie.from(this.cookieName, id) + .path(exchange.getRequest().getPath().contextPath().value() + "/") + .maxAge(maxAge) + .httpOnly(true) + .secure("https".equalsIgnoreCase(exchange.getRequest().getURI().getScheme())) + .sameSite("Lax"); + + if (this.cookieInitializer != null) { + this.cookieInitializer.accept(cookieBuilder); + } + + return cookieBuilder.build(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..60e2d4352a5dfb43c13756a1c058f1ab98c8628b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import java.util.List; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; + +/** + * Default implementation of {@link WebSessionManager} delegating to a + * {@link WebSessionIdResolver} for session id resolution and to a + * {@link WebSessionStore}. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 5.0 + */ +public class DefaultWebSessionManager implements WebSessionManager { + + private WebSessionIdResolver sessionIdResolver = new CookieWebSessionIdResolver(); + + private WebSessionStore sessionStore = new InMemoryWebSessionStore(); + + + /** + * Configure the id resolution strategy. + *

By default an instance of {@link CookieWebSessionIdResolver}. + * @param sessionIdResolver the resolver to use + */ + public void setSessionIdResolver(WebSessionIdResolver sessionIdResolver) { + Assert.notNull(sessionIdResolver, "WebSessionIdResolver is required"); + this.sessionIdResolver = sessionIdResolver; + } + + /** + * Return the configured {@link WebSessionIdResolver}. + */ + public WebSessionIdResolver getSessionIdResolver() { + return this.sessionIdResolver; + } + + /** + * Configure the persistence strategy. + *

By default an instance of {@link InMemoryWebSessionStore}. + * @param sessionStore the persistence strategy to use + */ + public void setSessionStore(WebSessionStore sessionStore) { + Assert.notNull(sessionStore, "WebSessionStore is required"); + this.sessionStore = sessionStore; + } + + /** + * Return the configured {@link WebSessionStore}. + */ + public WebSessionStore getSessionStore() { + return this.sessionStore; + } + + + @Override + public Mono getSession(ServerWebExchange exchange) { + return Mono.defer(() -> retrieveSession(exchange) + .switchIfEmpty(this.sessionStore.createWebSession()) + .doOnNext(session -> exchange.getResponse().beforeCommit(() -> save(exchange, session)))); + } + + private Mono retrieveSession(ServerWebExchange exchange) { + return Flux.fromIterable(getSessionIdResolver().resolveSessionIds(exchange)) + .concatMap(this.sessionStore::retrieveSession) + .next(); + } + + private Mono save(ServerWebExchange exchange, WebSession session) { + List ids = getSessionIdResolver().resolveSessionIds(exchange); + + if (!session.isStarted() || session.isExpired()) { + if (!ids.isEmpty()) { + // Expired on retrieve or while processing request, or invalidated.. + this.sessionIdResolver.expireSession(exchange); + } + return Mono.empty(); + } + + if (ids.isEmpty() || !session.getId().equals(ids.get(0))) { + this.sessionIdResolver.setSessionId(exchange, session.getId()); + } + + return session.save(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java b/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..39e85cbb94f4f52fd7b8f2febf586ece794a25aa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import java.util.Collections; +import java.util.List; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * Request and response header-based {@link WebSessionIdResolver}. + * + * @author Greg Turnquist + * @author Rob Winch + * @since 5.0 + */ +public class HeaderWebSessionIdResolver implements WebSessionIdResolver { + + /** Default value for {@link #setHeaderName(String)}. */ + public static final String DEFAULT_HEADER_NAME = "SESSION"; + + + private String headerName = DEFAULT_HEADER_NAME; + + + /** + * Set the name of the session header to use for the session id. + * The name is used to extract the session id from the request headers as + * well to set the session id on the response headers. + *

By default set to {@code DEFAULT_HEADER_NAME} + * @param headerName the header name + */ + public void setHeaderName(String headerName) { + Assert.hasText(headerName, "'headerName' must not be empty"); + this.headerName = headerName; + } + + /** + * Return the configured header name. + * @return the configured header name + */ + public String getHeaderName() { + return this.headerName; + } + + + @Override + public List resolveSessionIds(ServerWebExchange exchange) { + HttpHeaders headers = exchange.getRequest().getHeaders(); + return headers.getOrDefault(getHeaderName(), Collections.emptyList()); + } + + @Override + public void setSessionId(ServerWebExchange exchange, String id) { + Assert.notNull(id, "'id' is required."); + exchange.getResponse().getHeaders().set(getHeaderName(), id); + } + + @Override + public void expireSession(ServerWebExchange exchange) { + this.setSessionId(exchange, ""); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java new file mode 100644 index 0000000000000000000000000000000000000000..ac86290a1ff517443b0e20ba27fffd247f6b5f05 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -0,0 +1,344 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; + +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; +import org.springframework.util.IdGenerator; +import org.springframework.util.JdkIdGenerator; +import org.springframework.web.server.WebSession; + +/** + * Simple Map-based storage for {@link WebSession} instances. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 5.0 + */ +public class InMemoryWebSessionStore implements WebSessionStore { + + private static final IdGenerator idGenerator = new JdkIdGenerator(); + + + private int maxSessions = 10000; + + private Clock clock = Clock.system(ZoneId.of("GMT")); + + private final Map sessions = new ConcurrentHashMap<>(); + + private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker(); + + + /** + * Set the maximum number of sessions that can be stored. Once the limit is + * reached, any attempt to store an additional session will result in an + * {@link IllegalStateException}. + *

By default set to 10000. + * @param maxSessions the maximum number of sessions + * @since 5.0.8 + */ + public void setMaxSessions(int maxSessions) { + this.maxSessions = maxSessions; + } + + /** + * Return the maximum number of sessions that can be stored. + * @since 5.0.8 + */ + public int getMaxSessions() { + return this.maxSessions; + } + + /** + * Configure the {@link Clock} to use to set lastAccessTime on every created + * session and to calculate if it is expired. + *

This may be useful to align to different timezone or to set the clock + * back in a test, e.g. {@code Clock.offset(clock, Duration.ofMinutes(-31))} + * in order to simulate session expiration. + *

By default this is {@code Clock.system(ZoneId.of("GMT"))}. + * @param clock the clock to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "Clock is required"); + this.clock = clock; + removeExpiredSessions(); + } + + /** + * Return the configured clock for session lastAccessTime calculations. + */ + public Clock getClock() { + return this.clock; + } + + /** + * Return the map of sessions with an {@link Collections#unmodifiableMap + * unmodifiable} wrapper. This could be used for management purposes, to + * list active sessions, invalidate expired ones, etc. + * @since 5.0.8 + */ + public Map getSessions() { + return Collections.unmodifiableMap(this.sessions); + } + + + @Override + public Mono createWebSession() { + Instant now = this.clock.instant(); + this.expiredSessionChecker.checkIfNecessary(now); + return Mono.fromSupplier(() -> new InMemoryWebSession(now)); + } + + @Override + public Mono retrieveSession(String id) { + Instant now = this.clock.instant(); + this.expiredSessionChecker.checkIfNecessary(now); + InMemoryWebSession session = this.sessions.get(id); + if (session == null) { + return Mono.empty(); + } + else if (session.isExpired(now)) { + this.sessions.remove(id); + return Mono.empty(); + } + else { + session.updateLastAccessTime(now); + return Mono.just(session); + } + } + + @Override + public Mono removeSession(String id) { + this.sessions.remove(id); + return Mono.empty(); + } + + public Mono updateLastAccessTime(WebSession session) { + return Mono.fromSupplier(() -> { + Assert.isInstanceOf(InMemoryWebSession.class, session); + ((InMemoryWebSession) session).updateLastAccessTime(this.clock.instant()); + return session; + }); + } + + /** + * Check for expired sessions and remove them. Typically such checks are + * kicked off lazily during calls to {@link #createWebSession() create} or + * {@link #retrieveSession retrieve}, no less than 60 seconds apart. + * This method can be called to force a check at a specific time. + * @since 5.0.8 + */ + public void removeExpiredSessions() { + this.expiredSessionChecker.removeExpiredSessions(this.clock.instant()); + } + + + private class InMemoryWebSession implements WebSession { + + private final AtomicReference id = new AtomicReference<>(String.valueOf(idGenerator.generateId())); + + private final Map attributes = new ConcurrentHashMap<>(); + + private final Instant creationTime; + + private volatile Instant lastAccessTime; + + private volatile Duration maxIdleTime = Duration.ofMinutes(30); + + private final AtomicReference state = new AtomicReference<>(State.NEW); + + + public InMemoryWebSession(Instant creationTime) { + this.creationTime = creationTime; + this.lastAccessTime = this.creationTime; + } + + @Override + public String getId() { + return this.id.get(); + } + + @Override + public Map getAttributes() { + return this.attributes; + } + + @Override + public Instant getCreationTime() { + return this.creationTime; + } + + @Override + public Instant getLastAccessTime() { + return this.lastAccessTime; + } + + @Override + public void setMaxIdleTime(Duration maxIdleTime) { + this.maxIdleTime = maxIdleTime; + } + + @Override + public Duration getMaxIdleTime() { + return this.maxIdleTime; + } + + @Override + public void start() { + this.state.compareAndSet(State.NEW, State.STARTED); + } + + @Override + public boolean isStarted() { + return this.state.get().equals(State.STARTED) || !getAttributes().isEmpty(); + } + + @Override + public Mono changeSessionId() { + String currentId = this.id.get(); + InMemoryWebSessionStore.this.sessions.remove(currentId); + String newId = String.valueOf(idGenerator.generateId()); + this.id.set(newId); + InMemoryWebSessionStore.this.sessions.put(this.getId(), this); + return Mono.empty(); + } + + @Override + public Mono invalidate() { + this.state.set(State.EXPIRED); + getAttributes().clear(); + InMemoryWebSessionStore.this.sessions.remove(this.id.get()); + return Mono.empty(); + } + + @Override + public Mono save() { + + checkMaxSessionsLimit(); + + // Implicitly started session.. + if (!getAttributes().isEmpty()) { + this.state.compareAndSet(State.NEW, State.STARTED); + } + + if (isStarted()) { + // Save + InMemoryWebSessionStore.this.sessions.put(this.getId(), this); + + // Unless it was invalidated + if (this.state.get().equals(State.EXPIRED)) { + InMemoryWebSessionStore.this.sessions.remove(this.getId()); + return Mono.error(new IllegalStateException("Session was invalidated")); + } + } + + return Mono.empty(); + } + + private void checkMaxSessionsLimit() { + if (sessions.size() >= maxSessions) { + expiredSessionChecker.removeExpiredSessions(clock.instant()); + if (sessions.size() >= maxSessions) { + throw new IllegalStateException("Max sessions limit reached: " + sessions.size()); + } + } + } + + @Override + public boolean isExpired() { + return isExpired(clock.instant()); + } + + private boolean isExpired(Instant now) { + if (this.state.get().equals(State.EXPIRED)) { + return true; + } + if (checkExpired(now)) { + this.state.set(State.EXPIRED); + return true; + } + return false; + } + + private boolean checkExpired(Instant currentTime) { + return isStarted() && !this.maxIdleTime.isNegative() && + currentTime.minus(this.maxIdleTime).isAfter(this.lastAccessTime); + } + + private void updateLastAccessTime(Instant currentTime) { + this.lastAccessTime = currentTime; + } + } + + + private class ExpiredSessionChecker { + + /** Max time between expiration checks. */ + private static final int CHECK_PERIOD = 60 * 1000; + + + private final ReentrantLock lock = new ReentrantLock(); + + private Instant checkTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.MILLIS); + + + public void checkIfNecessary(Instant now) { + if (this.checkTime.isBefore(now)) { + removeExpiredSessions(now); + } + } + + public void removeExpiredSessions(Instant now) { + if (sessions.isEmpty()) { + return; + } + if (this.lock.tryLock()) { + try { + Iterator iterator = sessions.values().iterator(); + while (iterator.hasNext()) { + InMemoryWebSession session = iterator.next(); + if (session.isExpired(now)) { + iterator.remove(); + session.invalidate(); + } + } + } + finally { + this.checkTime = now.plus(CHECK_PERIOD, ChronoUnit.MILLIS); + this.lock.unlock(); + } + } + } + } + + + private enum State { NEW, STARTED, EXPIRED } + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionIdResolver.java b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionIdResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..712719d0de0f2cceb7a1d1544fb0095665142dcf --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionIdResolver.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import java.util.List; + +import org.springframework.web.server.ServerWebExchange; + +/** + * Contract for session id resolution strategies. Allows for session id + * resolution through the request and for sending the session id or expiring + * the session through the response. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @see CookieWebSessionIdResolver + */ +public interface WebSessionIdResolver { + + /** + * Resolve the session id's associated with the request. + * @param exchange the current exchange + * @return the session id's or an empty list + */ + List resolveSessionIds(ServerWebExchange exchange); + + /** + * Send the given session id to the client. + * @param exchange the current exchange + * @param sessionId the session id + */ + void setSessionId(ServerWebExchange exchange, String sessionId); + + /** + * Instruct the client to end the current session. + * @param exchange the current exchange + */ + void expireSession(ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionManager.java b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..3ca79fb53e7b3aeffbef0b3f42b5e1070f169f8a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionManager.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; + +/** + * Main class for for access to the {@link WebSession} for an HTTP request. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @see WebSessionIdResolver + * @see WebSessionStore + */ +public interface WebSessionManager { + + /** + * Return the {@link WebSession} for the given exchange. Always guaranteed + * to return an instance either matching to the session id requested by the + * client, or a new session either because the client did not specify one + * or because the underlying session expired. + * @param exchange the current exchange + * @return promise for the WebSession + */ + Mono getSession(ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java new file mode 100644 index 0000000000000000000000000000000000000000..ca8984f0c05328858ec9d6edbcc7cfd9cfeb4195 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.WebSession; + +/** + * Strategy for {@link WebSession} persistence. + * + * @author Rossen Stoyanchev + * @author Rob Winch + * @since 5.0 + */ +public interface WebSessionStore { + + /** + * Create a new WebSession. + *

Note that this does nothing more than create a new instance. + * The session can later be started explicitly via {@link WebSession#start()} + * or implicitly by adding attributes -- and then persisted via + * {@link WebSession#save()}. + * @return the created session instance + */ + Mono createWebSession(); + + /** + * Return the WebSession for the given id. + *

Note: This method should perform an expiration check, + * and if it has expired remove the session and return empty. This method + * should also update the lastAccessTime of retrieved sessions. + * @param sessionId the session to load + * @return the session, or an empty {@code Mono} . + */ + Mono retrieveSession(String sessionId); + + /** + * Remove the WebSession for the specified id. + * @param sessionId the id of the session to remove + * @return a completion notification (success or error) + */ + Mono removeSession(String sessionId); + + /** + * Update the last accessed timestamp to "now". + * @param webSession the session to update + * @return the session with the updated last access time + */ + Mono updateLastAccessTime(WebSession webSession); + +} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/package-info.java b/spring-web/src/main/java/org/springframework/web/server/session/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..9df7670c70138c372b2a34d88848b7d5523704f8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/server/session/package-info.java @@ -0,0 +1,10 @@ +/** + * Auxiliary interfaces and implementation classes for + * {@link org.springframework.web.server.WebSession} support. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.server.session; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/util/AbstractUriTemplateHandler.java b/spring-web/src/main/java/org/springframework/web/util/AbstractUriTemplateHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..19066ea4ad7a6d06c1520001787bb7bd0c57194a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/AbstractUriTemplateHandler.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Abstract base class for {@link UriTemplateHandler} implementations. + * + *

Support {@link #setBaseUrl} and {@link #setDefaultUriVariables} properties + * that should be relevant regardless of the URI template expand and encode + * mechanism used in sub-classes. + * + * @author Rossen Stoyanchev + * @since 4.3 + * @deprecated as of 5.0 in favor of {@link DefaultUriBuilderFactory} + */ +@Deprecated +public abstract class AbstractUriTemplateHandler implements UriTemplateHandler { + + @Nullable + private String baseUrl; + + private final Map defaultUriVariables = new HashMap<>(); + + + /** + * Configure a base URL to prepend URI templates with. The base URL must + * have a scheme and host but may optionally contain a port and a path. + * The base URL must be fully expanded and encoded which can be done via + * {@link UriComponentsBuilder}. + * @param baseUrl the base URL. + */ + public void setBaseUrl(@Nullable String baseUrl) { + if (baseUrl != null) { + UriComponents uriComponents = UriComponentsBuilder.fromUriString(baseUrl).build(); + Assert.hasText(uriComponents.getScheme(), "'baseUrl' must have a scheme"); + Assert.hasText(uriComponents.getHost(), "'baseUrl' must have a host"); + Assert.isNull(uriComponents.getQuery(), "'baseUrl' cannot have a query"); + Assert.isNull(uriComponents.getFragment(), "'baseUrl' cannot have a fragment"); + } + this.baseUrl = baseUrl; + } + + /** + * Return the configured base URL. + */ + @Nullable + public String getBaseUrl() { + return this.baseUrl; + } + + /** + * Configure default URI variable values to use with every expanded URI + * template. These default values apply only when expanding with a Map, and + * not with an array, where the Map supplied to {@link #expand(String, Map)} + * can override the default values. + * @param defaultUriVariables the default URI variable values + * @since 4.3 + */ + public void setDefaultUriVariables(@Nullable Map defaultUriVariables) { + this.defaultUriVariables.clear(); + if (defaultUriVariables != null) { + this.defaultUriVariables.putAll(defaultUriVariables); + } + } + + /** + * Return a read-only copy of the configured default URI variables. + */ + public Map getDefaultUriVariables() { + return Collections.unmodifiableMap(this.defaultUriVariables); + } + + + @Override + public URI expand(String uriTemplate, Map uriVariables) { + if (!getDefaultUriVariables().isEmpty()) { + Map map = new HashMap<>(); + map.putAll(getDefaultUriVariables()); + map.putAll(uriVariables); + uriVariables = map; + } + URI url = expandInternal(uriTemplate, uriVariables); + return insertBaseUrl(url); + } + + @Override + public URI expand(String uriTemplate, Object... uriVariables) { + URI url = expandInternal(uriTemplate, uriVariables); + return insertBaseUrl(url); + } + + + /** + * Actually expand and encode the URI template. + */ + protected abstract URI expandInternal(String uriTemplate, Map uriVariables); + + /** + * Actually expand and encode the URI template. + */ + protected abstract URI expandInternal(String uriTemplate, Object... uriVariables); + + + /** + * Insert a base URL (if configured) unless the given URL has a host already. + */ + private URI insertBaseUrl(URI url) { + try { + String baseUrl = getBaseUrl(); + if (baseUrl != null && url.getHost() == null) { + url = new URI(baseUrl + url.toString()); + } + return url; + } + catch (URISyntaxException ex) { + throw new IllegalArgumentException("Invalid URL after inserting base URL: " + url, ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingRequestWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingRequestWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..0158ed87b97644ba02fb1a6b279f451463e6afee --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingRequestWrapper.java @@ -0,0 +1,282 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URLEncoder; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; + +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; + +/** + * {@link javax.servlet.http.HttpServletRequest} wrapper that caches all content read from + * the {@linkplain #getInputStream() input stream} and {@linkplain #getReader() reader}, + * and allows this content to be retrieved via a {@link #getContentAsByteArray() byte array}. + * + *

Used e.g. by {@link org.springframework.web.filter.AbstractRequestLoggingFilter}. + * Note: As of Spring Framework 5.0, this wrapper is built on the Servlet 3.1 API. + * + * @author Juergen Hoeller + * @author Brian Clozel + * @since 4.1.3 + * @see ContentCachingResponseWrapper + */ +public class ContentCachingRequestWrapper extends HttpServletRequestWrapper { + + private static final String FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"; + + + private final ByteArrayOutputStream cachedContent; + + @Nullable + private final Integer contentCacheLimit; + + @Nullable + private ServletInputStream inputStream; + + @Nullable + private BufferedReader reader; + + + /** + * Create a new ContentCachingRequestWrapper for the given servlet request. + * @param request the original servlet request + */ + public ContentCachingRequestWrapper(HttpServletRequest request) { + super(request); + int contentLength = request.getContentLength(); + this.cachedContent = new ByteArrayOutputStream(contentLength >= 0 ? contentLength : 1024); + this.contentCacheLimit = null; + } + + /** + * Create a new ContentCachingRequestWrapper for the given servlet request. + * @param request the original servlet request + * @param contentCacheLimit the maximum number of bytes to cache per request + * @since 4.3.6 + * @see #handleContentOverflow(int) + */ + public ContentCachingRequestWrapper(HttpServletRequest request, int contentCacheLimit) { + super(request); + this.cachedContent = new ByteArrayOutputStream(contentCacheLimit); + this.contentCacheLimit = contentCacheLimit; + } + + + @Override + public ServletInputStream getInputStream() throws IOException { + if (this.inputStream == null) { + this.inputStream = new ContentCachingInputStream(getRequest().getInputStream()); + } + return this.inputStream; + } + + @Override + public String getCharacterEncoding() { + String enc = super.getCharacterEncoding(); + return (enc != null ? enc : WebUtils.DEFAULT_CHARACTER_ENCODING); + } + + @Override + public BufferedReader getReader() throws IOException { + if (this.reader == null) { + this.reader = new BufferedReader(new InputStreamReader(getInputStream(), getCharacterEncoding())); + } + return this.reader; + } + + @Override + public String getParameter(String name) { + if (this.cachedContent.size() == 0 && isFormPost()) { + writeRequestParametersToCachedContent(); + } + return super.getParameter(name); + } + + @Override + public Map getParameterMap() { + if (this.cachedContent.size() == 0 && isFormPost()) { + writeRequestParametersToCachedContent(); + } + return super.getParameterMap(); + } + + @Override + public Enumeration getParameterNames() { + if (this.cachedContent.size() == 0 && isFormPost()) { + writeRequestParametersToCachedContent(); + } + return super.getParameterNames(); + } + + @Override + public String[] getParameterValues(String name) { + if (this.cachedContent.size() == 0 && isFormPost()) { + writeRequestParametersToCachedContent(); + } + return super.getParameterValues(name); + } + + + private boolean isFormPost() { + String contentType = getContentType(); + return (contentType != null && contentType.contains(FORM_CONTENT_TYPE) && + HttpMethod.POST.matches(getMethod())); + } + + private void writeRequestParametersToCachedContent() { + try { + if (this.cachedContent.size() == 0) { + String requestEncoding = getCharacterEncoding(); + Map form = super.getParameterMap(); + for (Iterator nameIterator = form.keySet().iterator(); nameIterator.hasNext(); ) { + String name = nameIterator.next(); + List values = Arrays.asList(form.get(name)); + for (Iterator valueIterator = values.iterator(); valueIterator.hasNext(); ) { + String value = valueIterator.next(); + this.cachedContent.write(URLEncoder.encode(name, requestEncoding).getBytes()); + if (value != null) { + this.cachedContent.write('='); + this.cachedContent.write(URLEncoder.encode(value, requestEncoding).getBytes()); + if (valueIterator.hasNext()) { + this.cachedContent.write('&'); + } + } + } + if (nameIterator.hasNext()) { + this.cachedContent.write('&'); + } + } + } + } + catch (IOException ex) { + throw new IllegalStateException("Failed to write request parameters to cached content", ex); + } + } + + /** + * Return the cached request content as a byte array. + *

The returned array will never be larger than the content cache limit. + * @see #ContentCachingRequestWrapper(HttpServletRequest, int) + */ + public byte[] getContentAsByteArray() { + return this.cachedContent.toByteArray(); + } + + /** + * Template method for handling a content overflow: specifically, a request + * body being read that exceeds the specified content cache limit. + *

The default implementation is empty. Subclasses may override this to + * throw a payload-too-large exception or the like. + * @param contentCacheLimit the maximum number of bytes to cache per request + * which has just been exceeded + * @since 4.3.6 + * @see #ContentCachingRequestWrapper(HttpServletRequest, int) + */ + protected void handleContentOverflow(int contentCacheLimit) { + } + + + private class ContentCachingInputStream extends ServletInputStream { + + private final ServletInputStream is; + + private boolean overflow = false; + + public ContentCachingInputStream(ServletInputStream is) { + this.is = is; + } + + @Override + public int read() throws IOException { + int ch = this.is.read(); + if (ch != -1 && !this.overflow) { + if (contentCacheLimit != null && cachedContent.size() == contentCacheLimit) { + this.overflow = true; + handleContentOverflow(contentCacheLimit); + } + else { + cachedContent.write(ch); + } + } + return ch; + } + + @Override + public int read(byte[] b) throws IOException { + int count = this.is.read(b); + writeToCache(b, 0, count); + return count; + } + + private void writeToCache(final byte[] b, final int off, int count) { + if (!this.overflow && count > 0) { + if (contentCacheLimit != null && + count + cachedContent.size() > contentCacheLimit) { + this.overflow = true; + cachedContent.write(b, off, contentCacheLimit - cachedContent.size()); + handleContentOverflow(contentCacheLimit); + return; + } + cachedContent.write(b, off, count); + } + } + + @Override + public int read(final byte[] b, final int off, final int len) throws IOException { + int count = this.is.read(b, off, len); + writeToCache(b, off, count); + return count; + } + + @Override + public int readLine(final byte[] b, final int off, final int len) throws IOException { + int count = this.is.readLine(b, off, len); + writeToCache(b, off, count); + return count; + } + + @Override + public boolean isFinished() { + return this.is.isFinished(); + } + + @Override + public boolean isReady() { + return this.is.isReady(); + } + + @Override + public void setReadListener(ReadListener readListener) { + this.is.setReadListener(readListener); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..950393aca0928b1b6fa947f8fe8bb197449c3dc6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -0,0 +1,296 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.io.UnsupportedEncodingException; + +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.FastByteArrayOutputStream; + +/** + * {@link javax.servlet.http.HttpServletResponse} wrapper that caches all content written to + * the {@linkplain #getOutputStream() output stream} and {@linkplain #getWriter() writer}, + * and allows this content to be retrieved via a {@link #getContentAsByteArray() byte array}. + * + *

Used e.g. by {@link org.springframework.web.filter.ShallowEtagHeaderFilter}. + * Note: As of Spring Framework 5.0, this wrapper is built on the Servlet 3.1 API. + * + * @author Juergen Hoeller + * @since 4.1.3 + * @see ContentCachingRequestWrapper + */ +public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { + + private final FastByteArrayOutputStream content = new FastByteArrayOutputStream(1024); + + @Nullable + private ServletOutputStream outputStream; + + @Nullable + private PrintWriter writer; + + private int statusCode = HttpServletResponse.SC_OK; + + @Nullable + private Integer contentLength; + + + /** + * Create a new ContentCachingResponseWrapper for the given servlet response. + * @param response the original servlet response + */ + public ContentCachingResponseWrapper(HttpServletResponse response) { + super(response); + } + + + @Override + public void setStatus(int sc) { + super.setStatus(sc); + this.statusCode = sc; + } + + @SuppressWarnings("deprecation") + @Override + public void setStatus(int sc, String sm) { + super.setStatus(sc, sm); + this.statusCode = sc; + } + + @Override + public void sendError(int sc) throws IOException { + copyBodyToResponse(false); + try { + super.sendError(sc); + } + catch (IllegalStateException ex) { + // Possibly on Tomcat when called too late: fall back to silent setStatus + super.setStatus(sc); + } + this.statusCode = sc; + } + + @Override + @SuppressWarnings("deprecation") + public void sendError(int sc, String msg) throws IOException { + copyBodyToResponse(false); + try { + super.sendError(sc, msg); + } + catch (IllegalStateException ex) { + // Possibly on Tomcat when called too late: fall back to silent setStatus + super.setStatus(sc, msg); + } + this.statusCode = sc; + } + + @Override + public void sendRedirect(String location) throws IOException { + copyBodyToResponse(false); + super.sendRedirect(location); + } + + @Override + public ServletOutputStream getOutputStream() throws IOException { + if (this.outputStream == null) { + this.outputStream = new ResponseServletOutputStream(getResponse().getOutputStream()); + } + return this.outputStream; + } + + @Override + public PrintWriter getWriter() throws IOException { + if (this.writer == null) { + String characterEncoding = getCharacterEncoding(); + this.writer = (characterEncoding != null ? new ResponsePrintWriter(characterEncoding) : + new ResponsePrintWriter(WebUtils.DEFAULT_CHARACTER_ENCODING)); + } + return this.writer; + } + + @Override + public void flushBuffer() throws IOException { + // do not flush the underlying response as the content as not been copied to it yet + } + + @Override + public void setContentLength(int len) { + if (len > this.content.size()) { + this.content.resize(len); + } + this.contentLength = len; + } + + // Overrides Servlet 3.1 setContentLengthLong(long) at runtime + public void setContentLengthLong(long len) { + if (len > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Content-Length exceeds ContentCachingResponseWrapper's maximum (" + + Integer.MAX_VALUE + "): " + len); + } + int lenInt = (int) len; + if (lenInt > this.content.size()) { + this.content.resize(lenInt); + } + this.contentLength = lenInt; + } + + @Override + public void setBufferSize(int size) { + if (size > this.content.size()) { + this.content.resize(size); + } + } + + @Override + public void resetBuffer() { + this.content.reset(); + } + + @Override + public void reset() { + super.reset(); + this.content.reset(); + } + + /** + * Return the status code as specified on the response. + */ + public int getStatusCode() { + return this.statusCode; + } + + /** + * Return the cached response content as a byte array. + */ + public byte[] getContentAsByteArray() { + return this.content.toByteArray(); + } + + /** + * Return an {@link InputStream} to the cached content. + * @since 4.2 + */ + public InputStream getContentInputStream() { + return this.content.getInputStream(); + } + + /** + * Return the current size of the cached content. + * @since 4.2 + */ + public int getContentSize() { + return this.content.size(); + } + + /** + * Copy the complete cached body content to the response. + * @since 4.2 + */ + public void copyBodyToResponse() throws IOException { + copyBodyToResponse(true); + } + + /** + * Copy the cached body content to the response. + * @param complete whether to set a corresponding content length + * for the complete cached body content + * @since 4.2 + */ + protected void copyBodyToResponse(boolean complete) throws IOException { + if (this.content.size() > 0) { + HttpServletResponse rawResponse = (HttpServletResponse) getResponse(); + if ((complete || this.contentLength != null) && !rawResponse.isCommitted()) { + if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) { + rawResponse.setContentLength(complete ? this.content.size() : this.contentLength); + } + this.contentLength = null; + } + this.content.writeTo(rawResponse.getOutputStream()); + this.content.reset(); + if (complete) { + super.flushBuffer(); + } + } + } + + + private class ResponseServletOutputStream extends ServletOutputStream { + + private final ServletOutputStream os; + + public ResponseServletOutputStream(ServletOutputStream os) { + this.os = os; + } + + @Override + public void write(int b) throws IOException { + content.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + content.write(b, off, len); + } + + @Override + public boolean isReady() { + return this.os.isReady(); + } + + @Override + public void setWriteListener(WriteListener writeListener) { + this.os.setWriteListener(writeListener); + } + } + + + private class ResponsePrintWriter extends PrintWriter { + + public ResponsePrintWriter(String characterEncoding) throws UnsupportedEncodingException { + super(new OutputStreamWriter(content, characterEncoding)); + } + + @Override + public void write(char[] buf, int off, int len) { + super.write(buf, off, len); + super.flush(); + } + + @Override + public void write(String s, int off, int len) { + super.write(s, off, len); + super.flush(); + } + + @Override + public void write(int c) { + super.write(c); + super.flush(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/CookieGenerator.java b/spring-web/src/main/java/org/springframework/web/util/CookieGenerator.java new file mode 100644 index 0000000000000000000000000000000000000000..2324274e816cdd27cceff2879291c8e4a48ddfda --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/CookieGenerator.java @@ -0,0 +1,246 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Helper class for cookie generation, carrying cookie descriptor settings + * as bean properties and being able to add and remove cookie to/from a + * given response. + * + *

Can serve as base class for components that generate specific cookies, + * such as CookieLocaleResolver and CookieThemeResolver. + * + * @author Juergen Hoeller + * @since 1.1.4 + * @see #addCookie + * @see #removeCookie + * @see org.springframework.web.servlet.i18n.CookieLocaleResolver + * @see org.springframework.web.servlet.theme.CookieThemeResolver + */ +public class CookieGenerator { + + /** + * Default path that cookies will be visible to: "/", i.e. the entire server. + */ + public static final String DEFAULT_COOKIE_PATH = "/"; + + + protected final Log logger = LogFactory.getLog(getClass()); + + @Nullable + private String cookieName; + + @Nullable + private String cookieDomain; + + private String cookiePath = DEFAULT_COOKIE_PATH; + + @Nullable + private Integer cookieMaxAge; + + private boolean cookieSecure = false; + + private boolean cookieHttpOnly = false; + + + /** + * Use the given name for cookies created by this generator. + * @see javax.servlet.http.Cookie#getName() + */ + public void setCookieName(@Nullable String cookieName) { + this.cookieName = cookieName; + } + + /** + * Return the given name for cookies created by this generator. + */ + @Nullable + public String getCookieName() { + return this.cookieName; + } + + /** + * Use the given domain for cookies created by this generator. + * The cookie is only visible to servers in this domain. + * @see javax.servlet.http.Cookie#setDomain + */ + public void setCookieDomain(@Nullable String cookieDomain) { + this.cookieDomain = cookieDomain; + } + + /** + * Return the domain for cookies created by this generator, if any. + */ + @Nullable + public String getCookieDomain() { + return this.cookieDomain; + } + + /** + * Use the given path for cookies created by this generator. + * The cookie is only visible to URLs in this path and below. + * @see javax.servlet.http.Cookie#setPath + */ + public void setCookiePath(String cookiePath) { + this.cookiePath = cookiePath; + } + + /** + * Return the path for cookies created by this generator. + */ + public String getCookiePath() { + return this.cookiePath; + } + + /** + * Use the given maximum age (in seconds) for cookies created by this generator. + * Useful special value: -1 ... not persistent, deleted when client shuts down. + *

Default is no specific maximum age at all, using the Servlet container's + * default. + * @see javax.servlet.http.Cookie#setMaxAge + */ + public void setCookieMaxAge(@Nullable Integer cookieMaxAge) { + this.cookieMaxAge = cookieMaxAge; + } + + /** + * Return the maximum age for cookies created by this generator. + */ + @Nullable + public Integer getCookieMaxAge() { + return this.cookieMaxAge; + } + + /** + * Set whether the cookie should only be sent using a secure protocol, + * such as HTTPS (SSL). This is an indication to the receiving browser, + * not processed by the HTTP server itself. + *

Default is "false". + * @see javax.servlet.http.Cookie#setSecure + */ + public void setCookieSecure(boolean cookieSecure) { + this.cookieSecure = cookieSecure; + } + + /** + * Return whether the cookie should only be sent using a secure protocol, + * such as HTTPS (SSL). + */ + public boolean isCookieSecure() { + return this.cookieSecure; + } + + /** + * Set whether the cookie is supposed to be marked with the "HttpOnly" attribute. + *

Default is "false". + * @see javax.servlet.http.Cookie#setHttpOnly + */ + public void setCookieHttpOnly(boolean cookieHttpOnly) { + this.cookieHttpOnly = cookieHttpOnly; + } + + /** + * Return whether the cookie is supposed to be marked with the "HttpOnly" attribute. + */ + public boolean isCookieHttpOnly() { + return this.cookieHttpOnly; + } + + + /** + * Add a cookie with the given value to the response, + * using the cookie descriptor settings of this generator. + *

Delegates to {@link #createCookie} for cookie creation. + * @param response the HTTP response to add the cookie to + * @param cookieValue the value of the cookie to add + * @see #setCookieName + * @see #setCookieDomain + * @see #setCookiePath + * @see #setCookieMaxAge + */ + public void addCookie(HttpServletResponse response, String cookieValue) { + Assert.notNull(response, "HttpServletResponse must not be null"); + Cookie cookie = createCookie(cookieValue); + Integer maxAge = getCookieMaxAge(); + if (maxAge != null) { + cookie.setMaxAge(maxAge); + } + if (isCookieSecure()) { + cookie.setSecure(true); + } + if (isCookieHttpOnly()) { + cookie.setHttpOnly(true); + } + response.addCookie(cookie); + if (logger.isTraceEnabled()) { + logger.trace("Added cookie [" + getCookieName() + "=" + cookieValue + "]"); + } + } + + /** + * Remove the cookie that this generator describes from the response. + * Will generate a cookie with empty value and max age 0. + *

Delegates to {@link #createCookie} for cookie creation. + * @param response the HTTP response to remove the cookie from + * @see #setCookieName + * @see #setCookieDomain + * @see #setCookiePath + */ + public void removeCookie(HttpServletResponse response) { + Assert.notNull(response, "HttpServletResponse must not be null"); + Cookie cookie = createCookie(""); + cookie.setMaxAge(0); + if (isCookieSecure()) { + cookie.setSecure(true); + } + if (isCookieHttpOnly()) { + cookie.setHttpOnly(true); + } + response.addCookie(cookie); + if (logger.isTraceEnabled()) { + logger.trace("Removed cookie '" + getCookieName() + "'"); + } + } + + /** + * Create a cookie with the given value, using the cookie descriptor + * settings of this generator (except for "cookieMaxAge"). + * @param cookieValue the value of the cookie to crate + * @return the cookie + * @see #setCookieName + * @see #setCookieDomain + * @see #setCookiePath + */ + protected Cookie createCookie(String cookieValue) { + Cookie cookie = new Cookie(getCookieName(), cookieValue); + if (getCookieDomain() != null) { + cookie.setDomain(getCookieDomain()); + } + cookie.setPath(getCookiePath()); + return cookie; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/DefaultUriBuilderFactory.java b/spring-web/src/main/java/org/springframework/web/util/DefaultUriBuilderFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..8f2822bc651d487857c7c15e14e209d9b13aaff1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/DefaultUriBuilderFactory.java @@ -0,0 +1,398 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * {@code UriBuilderFactory} that relies on {@link UriComponentsBuilder} for + * the actual building of the URI. + * + *

Provides options to create {@link UriBuilder} instances with a common + * base URI, alternative encoding mode strategies, among others. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @see UriComponentsBuilder + */ +public class DefaultUriBuilderFactory implements UriBuilderFactory { + + @Nullable + private final UriComponentsBuilder baseUri; + + private EncodingMode encodingMode = EncodingMode.TEMPLATE_AND_VALUES; + + private final Map defaultUriVariables = new HashMap<>(); + + private boolean parsePath = true; + + + /** + * Default constructor without a base URI. + *

The target address must be specified on each UriBuilder. + */ + public DefaultUriBuilderFactory() { + this.baseUri = null; + } + + /** + * Constructor with a base URI. + *

The given URI template is parsed via + * {@link UriComponentsBuilder#fromUriString} and then applied as a base URI + * to every UriBuilder via {@link UriComponentsBuilder#uriComponents} unless + * the UriBuilder itself was created with a URI template that already has a + * target address. + * @param baseUriTemplate the URI template to use a base URL + */ + public DefaultUriBuilderFactory(String baseUriTemplate) { + this.baseUri = UriComponentsBuilder.fromUriString(baseUriTemplate); + } + + /** + * Variant of {@link #DefaultUriBuilderFactory(String)} with a + * {@code UriComponentsBuilder}. + */ + public DefaultUriBuilderFactory(UriComponentsBuilder baseUri) { + this.baseUri = baseUri; + } + + + /** + * Set the encoding mode to use. + *

By default this is set to {@link EncodingMode#TEMPLATE_AND_VALUES + * EncodingMode.TEMPLATE_AND_VALUES}. + *

Note: In 5.1 the default was changed from + * {@link EncodingMode#URI_COMPONENT EncodingMode.URI_COMPONENT}. + * Consequently the {@code WebClient}, which relies on the built-in default + * has also been switched to the new default. The {@code RestTemplate} + * however sets this explicitly to {@link EncodingMode#URI_COMPONENT + * EncodingMode.URI_COMPONENT} explicitly for historic and backwards + * compatibility reasons. + * @param encodingMode the encoding mode to use + */ + public void setEncodingMode(EncodingMode encodingMode) { + this.encodingMode = encodingMode; + } + + /** + * Return the configured encoding mode. + */ + public EncodingMode getEncodingMode() { + return this.encodingMode; + } + + /** + * Provide default URI variable values to use when expanding URI templates + * with a Map of variables. + * @param defaultUriVariables default URI variable values + */ + public void setDefaultUriVariables(@Nullable Map defaultUriVariables) { + this.defaultUriVariables.clear(); + if (defaultUriVariables != null) { + this.defaultUriVariables.putAll(defaultUriVariables); + } + } + + /** + * Return the configured default URI variable values. + */ + public Map getDefaultUriVariables() { + return Collections.unmodifiableMap(this.defaultUriVariables); + } + + /** + * Whether to parse the input path into path segments if the encoding mode + * is set to {@link EncodingMode#URI_COMPONENT EncodingMode.URI_COMPONENT}, + * which ensures that URI variables in the path are encoded according to + * path segment rules and for example a '/' is encoded. + *

By default this is set to {@code true}. + * @param parsePath whether to parse the path into path segments + */ + public void setParsePath(boolean parsePath) { + this.parsePath = parsePath; + } + + /** + * Whether to parse the path into path segments if the encoding mode is set + * to {@link EncodingMode#URI_COMPONENT EncodingMode.URI_COMPONENT}. + */ + public boolean shouldParsePath() { + return this.parsePath; + } + + + // UriTemplateHandler + + @Override + public URI expand(String uriTemplate, Map uriVars) { + return uriString(uriTemplate).build(uriVars); + } + + @Override + public URI expand(String uriTemplate, Object... uriVars) { + return uriString(uriTemplate).build(uriVars); + } + + // UriBuilderFactory + + @Override + public UriBuilder uriString(String uriTemplate) { + return new DefaultUriBuilder(uriTemplate); + } + + @Override + public UriBuilder builder() { + return new DefaultUriBuilder(""); + } + + + /** + * Enum to represent multiple URI encoding strategies. The following are + * available: + *

    + *
  • {@link #TEMPLATE_AND_VALUES} + *
  • {@link #VALUES_ONLY} + *
  • {@link #URI_COMPONENT} + *
  • {@link #NONE} + *
+ * @see #setEncodingMode + */ + public enum EncodingMode { + + /** + * Pre-encode the URI template first, then strictly encode URI variables + * when expanded, with the following rules: + *
    + *
  • For the URI template replace only non-ASCII and illegal + * (within a given URI component type) characters with escaped octets. + *
  • For URI variables do the same and also replace characters with + * reserved meaning. + *
+ *

For most cases, this mode is most likely to give the expected + * result because in treats URI variables as opaque data to be fully + * encoded, while {@link #URI_COMPONENT} by comparison is useful only + * if intentionally expanding URI variables with reserved characters. + * @since 5.0.8 + * @see UriComponentsBuilder#encode() + */ + TEMPLATE_AND_VALUES, + + /** + * Does not encode the URI template and instead applies strict encoding + * to URI variables via {@link UriUtils#encodeUriVariables} prior to + * expanding them into the template. + * @see UriUtils#encodeUriVariables(Object...) + * @see UriUtils#encodeUriVariables(Map) + */ + VALUES_ONLY, + + /** + * Expand URI variables first, and then encode the resulting URI + * component values, replacing only non-ASCII and illegal + * (within a given URI component type) characters, but not characters + * with reserved meaning. + * @see UriComponents#encode() + */ + URI_COMPONENT, + + /** + * No encoding should be applied. + */ + NONE + } + + + /** + * {@link DefaultUriBuilderFactory} specific implementation of UriBuilder. + */ + private class DefaultUriBuilder implements UriBuilder { + + private final UriComponentsBuilder uriComponentsBuilder; + + public DefaultUriBuilder(String uriTemplate) { + this.uriComponentsBuilder = initUriComponentsBuilder(uriTemplate); + } + + private UriComponentsBuilder initUriComponentsBuilder(String uriTemplate) { + UriComponentsBuilder result; + if (!StringUtils.hasLength(uriTemplate)) { + result = (baseUri != null ? baseUri.cloneBuilder() : UriComponentsBuilder.newInstance()); + } + else if (baseUri != null) { + UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(uriTemplate); + UriComponents uri = builder.build(); + result = (uri.getHost() == null ? baseUri.cloneBuilder().uriComponents(uri) : builder); + } + else { + result = UriComponentsBuilder.fromUriString(uriTemplate); + } + if (encodingMode.equals(EncodingMode.TEMPLATE_AND_VALUES)) { + result.encode(); + } + parsePathIfNecessary(result); + return result; + } + + private void parsePathIfNecessary(UriComponentsBuilder result) { + if (parsePath && encodingMode.equals(EncodingMode.URI_COMPONENT)) { + UriComponents uric = result.build(); + String path = uric.getPath(); + result.replacePath(null); + for (String segment : uric.getPathSegments()) { + result.pathSegment(segment); + } + if (path != null && path.endsWith("/")) { + result.path("/"); + } + } + } + + + @Override + public DefaultUriBuilder scheme(@Nullable String scheme) { + this.uriComponentsBuilder.scheme(scheme); + return this; + } + + @Override + public DefaultUriBuilder userInfo(@Nullable String userInfo) { + this.uriComponentsBuilder.userInfo(userInfo); + return this; + } + + @Override + public DefaultUriBuilder host(@Nullable String host) { + this.uriComponentsBuilder.host(host); + return this; + } + + @Override + public DefaultUriBuilder port(int port) { + this.uriComponentsBuilder.port(port); + return this; + } + + @Override + public DefaultUriBuilder port(@Nullable String port) { + this.uriComponentsBuilder.port(port); + return this; + } + + @Override + public DefaultUriBuilder path(String path) { + this.uriComponentsBuilder.path(path); + return this; + } + + @Override + public DefaultUriBuilder replacePath(@Nullable String path) { + this.uriComponentsBuilder.replacePath(path); + return this; + } + + @Override + public DefaultUriBuilder pathSegment(String... pathSegments) { + this.uriComponentsBuilder.pathSegment(pathSegments); + return this; + } + + @Override + public DefaultUriBuilder query(String query) { + this.uriComponentsBuilder.query(query); + return this; + } + + @Override + public DefaultUriBuilder replaceQuery(@Nullable String query) { + this.uriComponentsBuilder.replaceQuery(query); + return this; + } + + @Override + public DefaultUriBuilder queryParam(String name, Object... values) { + this.uriComponentsBuilder.queryParam(name, values); + return this; + } + + @Override + public DefaultUriBuilder replaceQueryParam(String name, Object... values) { + this.uriComponentsBuilder.replaceQueryParam(name, values); + return this; + } + + @Override + public DefaultUriBuilder queryParams(MultiValueMap params) { + this.uriComponentsBuilder.queryParams(params); + return this; + } + + @Override + public DefaultUriBuilder replaceQueryParams(MultiValueMap params) { + this.uriComponentsBuilder.replaceQueryParams(params); + return this; + } + + @Override + public DefaultUriBuilder fragment(@Nullable String fragment) { + this.uriComponentsBuilder.fragment(fragment); + return this; + } + + @Override + public URI build(Map uriVars) { + if (!defaultUriVariables.isEmpty()) { + Map map = new HashMap<>(); + map.putAll(defaultUriVariables); + map.putAll(uriVars); + uriVars = map; + } + if (encodingMode.equals(EncodingMode.VALUES_ONLY)) { + uriVars = UriUtils.encodeUriVariables(uriVars); + } + UriComponents uric = this.uriComponentsBuilder.build().expand(uriVars); + return createUri(uric); + } + + @Override + public URI build(Object... uriVars) { + if (ObjectUtils.isEmpty(uriVars) && !defaultUriVariables.isEmpty()) { + return build(Collections.emptyMap()); + } + if (encodingMode.equals(EncodingMode.VALUES_ONLY)) { + uriVars = UriUtils.encodeUriVariables(uriVars); + } + UriComponents uric = this.uriComponentsBuilder.build().expand(uriVars); + return createUri(uric); + } + + private URI createUri(UriComponents uric) { + if (encodingMode.equals(EncodingMode.URI_COMPONENT)) { + uric = uric.encode(); + } + return URI.create(uric.toString()); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/DefaultUriTemplateHandler.java b/spring-web/src/main/java/org/springframework/web/util/DefaultUriTemplateHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..385608dc84c682a4556b825f870b127da1357d04 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/DefaultUriTemplateHandler.java @@ -0,0 +1,158 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; + +/** + * Default implementation of {@link UriTemplateHandler} based on the use of + * {@link UriComponentsBuilder} for expanding and encoding variables. + * + *

There are also several properties to customize how URI template handling + * is performed, including a {@link #setBaseUrl baseUrl} to be used as a prefix + * for all URI templates and a couple of encoding related options — + * {@link #setParsePath parsePath} and {@link #setStrictEncoding strictEncoding} + * respectively. + * + * @author Rossen Stoyanchev + * @since 4.2 + * @deprecated as of 5.0 in favor of {@link DefaultUriBuilderFactory}. + *

Note: {@link DefaultUriBuilderFactory} has a different + * default for the {@link #setParsePath(boolean) parsePath} property (from + * false to true). + */ +@Deprecated +public class DefaultUriTemplateHandler extends AbstractUriTemplateHandler { + + private boolean parsePath; + + private boolean strictEncoding; + + + /** + * Whether to parse the path of a URI template string into path segments. + *

If set to {@code true} the URI template path is immediately decomposed + * into path segments any URI variables expanded into it are then subject to + * path segment encoding rules. In effect URI variables in the path have any + * "/" characters percent encoded. + *

By default this is set to {@code false} in which case the path is kept + * as a full path and expanded URI variables will preserve "/" characters. + * @param parsePath whether to parse the path into path segments + */ + public void setParsePath(boolean parsePath) { + this.parsePath = parsePath; + } + + /** + * Whether the handler is configured to parse the path into path segments. + */ + public boolean shouldParsePath() { + return this.parsePath; + } + + /** + * Whether to encode characters outside the unreserved set as defined in + * RFC 3986 Section 2. + * This ensures a URI variable value will not contain any characters with a + * reserved purpose. + *

By default this is set to {@code false} in which case only characters + * illegal for the given URI component are encoded. For example when expanding + * a URI variable into a path segment the "/" character is illegal and + * encoded. The ";" character however is legal and not encoded even though + * it has a reserved purpose. + *

Note: this property supersedes the need to also set + * the {@link #setParsePath parsePath} property. + * @param strictEncoding whether to perform strict encoding + * @since 4.3 + */ + public void setStrictEncoding(boolean strictEncoding) { + this.strictEncoding = strictEncoding; + } + + /** + * Whether to strictly encode any character outside the unreserved set. + */ + public boolean isStrictEncoding() { + return this.strictEncoding; + } + + + @Override + protected URI expandInternal(String uriTemplate, Map uriVariables) { + UriComponentsBuilder uriComponentsBuilder = initUriComponentsBuilder(uriTemplate); + UriComponents uriComponents = expandAndEncode(uriComponentsBuilder, uriVariables); + return createUri(uriComponents); + } + + @Override + protected URI expandInternal(String uriTemplate, Object... uriVariables) { + UriComponentsBuilder uriComponentsBuilder = initUriComponentsBuilder(uriTemplate); + UriComponents uriComponents = expandAndEncode(uriComponentsBuilder, uriVariables); + return createUri(uriComponents); + } + + /** + * Create a {@code UriComponentsBuilder} from the URI template string. + * This implementation also breaks up the path into path segments depending + * on whether {@link #setParsePath parsePath} is enabled. + */ + protected UriComponentsBuilder initUriComponentsBuilder(String uriTemplate) { + UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(uriTemplate); + if (shouldParsePath() && !isStrictEncoding()) { + List pathSegments = builder.build().getPathSegments(); + builder.replacePath(null); + for (String pathSegment : pathSegments) { + builder.pathSegment(pathSegment); + } + } + return builder; + } + + protected UriComponents expandAndEncode(UriComponentsBuilder builder, Map uriVariables) { + if (!isStrictEncoding()) { + return builder.buildAndExpand(uriVariables).encode(); + } + else { + Map encodedUriVars = UriUtils.encodeUriVariables(uriVariables); + return builder.buildAndExpand(encodedUriVars); + } + } + + protected UriComponents expandAndEncode(UriComponentsBuilder builder, Object[] uriVariables) { + if (!isStrictEncoding()) { + return builder.buildAndExpand(uriVariables).encode(); + } + else { + Object[] encodedUriVars = UriUtils.encodeUriVariables(uriVariables); + return builder.buildAndExpand(encodedUriVars); + } + } + + private URI createUri(UriComponents uriComponents) { + try { + // Avoid further encoding (in the case of strictEncoding=true) + return new URI(uriComponents.toUriString()); + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Could not create URI object: " + ex.getMessage(), ex); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/HierarchicalUriComponents.java b/spring-web/src/main/java/org/springframework/web/util/HierarchicalUriComponents.java new file mode 100644 index 0000000000000000000000000000000000000000..22f13c5e9ce63221afa7cb21e223da18bf2c3cce --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/HierarchicalUriComponents.java @@ -0,0 +1,1074 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.ByteArrayOutputStream; +import java.io.Serializable; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.UnaryOperator; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * Extension of {@link UriComponents} for hierarchical URIs. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @author Phillip Webb + * @since 3.1.3 + * @see Hierarchical URIs + */ +@SuppressWarnings("serial") +final class HierarchicalUriComponents extends UriComponents { + + private static final char PATH_DELIMITER = '/'; + + private static final String PATH_DELIMITER_STRING = "/"; + + private static final MultiValueMap EMPTY_QUERY_PARAMS = + CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap<>()); + + + /** + * Represents an empty path. + */ + static final PathComponent NULL_PATH_COMPONENT = new PathComponent() { + @Override + public String getPath() { + return ""; + } + @Override + public List getPathSegments() { + return Collections.emptyList(); + } + @Override + public PathComponent encode(BiFunction encoder) { + return this; + } + @Override + public void verify() { + } + @Override + public PathComponent expand(UriTemplateVariables uriVariables, @Nullable UnaryOperator encoder) { + return this; + } + @Override + public void copyToUriComponentsBuilder(UriComponentsBuilder builder) { + } + @Override + public boolean equals(Object other) { + return (this == other); + } + @Override + public int hashCode() { + return getClass().hashCode(); + } + }; + + + @Nullable + private final String userInfo; + + @Nullable + private final String host; + + @Nullable + private final String port; + + private final PathComponent path; + + private final MultiValueMap queryParams; + + private final EncodeState encodeState; + + @Nullable + private UnaryOperator variableEncoder; + + + /** + * Package-private constructor. All arguments are optional, and can be {@code null}. + * @param scheme the scheme + * @param userInfo the user info + * @param host the host + * @param port the port + * @param path the path + * @param query the query parameters + * @param fragment the fragment + * @param encoded whether the components are already encoded + */ + HierarchicalUriComponents(@Nullable String scheme, @Nullable String fragment, @Nullable String userInfo, + @Nullable String host, @Nullable String port, @Nullable PathComponent path, + @Nullable MultiValueMap query, boolean encoded) { + + super(scheme, fragment); + + this.userInfo = userInfo; + this.host = host; + this.port = port; + this.path = path != null ? path : NULL_PATH_COMPONENT; + this.queryParams = query != null ? CollectionUtils.unmodifiableMultiValueMap(query) : EMPTY_QUERY_PARAMS; + this.encodeState = encoded ? EncodeState.FULLY_ENCODED : EncodeState.RAW; + + // Check for illegal characters.. + if (encoded) { + verify(); + } + } + + private HierarchicalUriComponents(@Nullable String scheme, @Nullable String fragment, + @Nullable String userInfo, @Nullable String host, @Nullable String port, + PathComponent path, MultiValueMap queryParams, + EncodeState encodeState, @Nullable UnaryOperator variableEncoder) { + + super(scheme, fragment); + + this.userInfo = userInfo; + this.host = host; + this.port = port; + this.path = path; + this.queryParams = queryParams; + this.encodeState = encodeState; + this.variableEncoder = variableEncoder; + } + + + // Component getters + + @Override + @Nullable + public String getSchemeSpecificPart() { + return null; + } + + @Override + @Nullable + public String getUserInfo() { + return this.userInfo; + } + + @Override + @Nullable + public String getHost() { + return this.host; + } + + @Override + public int getPort() { + if (this.port == null) { + return -1; + } + else if (this.port.contains("{")) { + throw new IllegalStateException( + "The port contains a URI variable but has not been expanded yet: " + this.port); + } + return Integer.parseInt(this.port); + } + + @Override + @NonNull + public String getPath() { + return this.path.getPath(); + } + + @Override + public List getPathSegments() { + return this.path.getPathSegments(); + } + + @Override + @Nullable + public String getQuery() { + if (!this.queryParams.isEmpty()) { + StringBuilder queryBuilder = new StringBuilder(); + this.queryParams.forEach((name, values) -> { + if (CollectionUtils.isEmpty(values)) { + if (queryBuilder.length() != 0) { + queryBuilder.append('&'); + } + queryBuilder.append(name); + } + else { + for (Object value : values) { + if (queryBuilder.length() != 0) { + queryBuilder.append('&'); + } + queryBuilder.append(name); + if (value != null) { + queryBuilder.append('=').append(value.toString()); + } + } + } + }); + return queryBuilder.toString(); + } + else { + return null; + } + } + + /** + * Return the map of query parameters. Empty if no query has been set. + */ + @Override + public MultiValueMap getQueryParams() { + return this.queryParams; + } + + + // Encoding + + /** + * Identical to {@link #encode()} but skipping over URI variable placeholders. + * Also {@link #variableEncoder} is initialized with the given charset for + * use later when URI variables are expanded. + */ + HierarchicalUriComponents encodeTemplate(Charset charset) { + if (this.encodeState.isEncoded()) { + return this; + } + + // Remember the charset to encode URI variables later.. + this.variableEncoder = value -> encodeUriComponent(value, charset, Type.URI); + + UriTemplateEncoder encoder = new UriTemplateEncoder(charset); + String schemeTo = (getScheme() != null ? encoder.apply(getScheme(), Type.SCHEME) : null); + String fragmentTo = (getFragment() != null ? encoder.apply(getFragment(), Type.FRAGMENT) : null); + String userInfoTo = (getUserInfo() != null ? encoder.apply(getUserInfo(), Type.USER_INFO) : null); + String hostTo = (getHost() != null ? encoder.apply(getHost(), getHostType()) : null); + PathComponent pathTo = this.path.encode(encoder); + MultiValueMap queryParamsTo = encodeQueryParams(encoder); + + return new HierarchicalUriComponents(schemeTo, fragmentTo, userInfoTo, + hostTo, this.port, pathTo, queryParamsTo, EncodeState.TEMPLATE_ENCODED, this.variableEncoder); + } + + @Override + public HierarchicalUriComponents encode(Charset charset) { + if (this.encodeState.isEncoded()) { + return this; + } + String scheme = getScheme(); + String fragment = getFragment(); + String schemeTo = (scheme != null ? encodeUriComponent(scheme, charset, Type.SCHEME) : null); + String fragmentTo = (fragment != null ? encodeUriComponent(fragment, charset, Type.FRAGMENT) : null); + String userInfoTo = (this.userInfo != null ? encodeUriComponent(this.userInfo, charset, Type.USER_INFO) : null); + String hostTo = (this.host != null ? encodeUriComponent(this.host, charset, getHostType()) : null); + BiFunction encoder = (s, type) -> encodeUriComponent(s, charset, type); + PathComponent pathTo = this.path.encode(encoder); + MultiValueMap queryParamsTo = encodeQueryParams(encoder); + + return new HierarchicalUriComponents(schemeTo, fragmentTo, userInfoTo, + hostTo, this.port, pathTo, queryParamsTo, EncodeState.FULLY_ENCODED, null); + } + + private MultiValueMap encodeQueryParams(BiFunction encoder) { + int size = this.queryParams.size(); + MultiValueMap result = new LinkedMultiValueMap<>(size); + this.queryParams.forEach((key, values) -> { + String name = encoder.apply(key, Type.QUERY_PARAM); + List encodedValues = new ArrayList<>(values.size()); + for (String value : values) { + encodedValues.add(value != null ? encoder.apply(value, Type.QUERY_PARAM) : null); + } + result.put(name, encodedValues); + }); + return CollectionUtils.unmodifiableMultiValueMap(result); + } + + /** + * Encode the given source into an encoded String using the rules specified + * by the given component and with the given options. + * @param source the source String + * @param encoding the encoding of the source String + * @param type the URI component for the source + * @return the encoded URI + * @throws IllegalArgumentException when the given value is not a valid URI component + */ + static String encodeUriComponent(String source, String encoding, Type type) { + return encodeUriComponent(source, Charset.forName(encoding), type); + } + + /** + * Encode the given source into an encoded String using the rules specified + * by the given component and with the given options. + * @param source the source String + * @param charset the encoding of the source String + * @param type the URI component for the source + * @return the encoded URI + * @throws IllegalArgumentException when the given value is not a valid URI component + */ + static String encodeUriComponent(String source, Charset charset, Type type) { + if (!StringUtils.hasLength(source)) { + return source; + } + Assert.notNull(charset, "Charset must not be null"); + Assert.notNull(type, "Type must not be null"); + + byte[] bytes = source.getBytes(charset); + boolean original = true; + for (byte b : bytes) { + if (b < 0) { + b += 256; + } + if (!type.isAllowed(b)) { + original = false; + break; + } + } + if (original) { + return source; + } + + ByteArrayOutputStream bos = new ByteArrayOutputStream(bytes.length); + for (byte b : bytes) { + if (b < 0) { + b += 256; + } + if (type.isAllowed(b)) { + bos.write(b); + } + else { + bos.write('%'); + char hex1 = Character.toUpperCase(Character.forDigit((b >> 4) & 0xF, 16)); + char hex2 = Character.toUpperCase(Character.forDigit(b & 0xF, 16)); + bos.write(hex1); + bos.write(hex2); + } + } + return new String(bos.toByteArray(), charset); + } + + private Type getHostType() { + return (this.host != null && this.host.startsWith("[") ? Type.HOST_IPV6 : Type.HOST_IPV4); + } + + // Verifying + + /** + * Check if any of the URI components contain any illegal characters. + * @throws IllegalArgumentException if any component has illegal characters + */ + private void verify() { + verifyUriComponent(getScheme(), Type.SCHEME); + verifyUriComponent(this.userInfo, Type.USER_INFO); + verifyUriComponent(this.host, getHostType()); + this.path.verify(); + this.queryParams.forEach((key, values) -> { + verifyUriComponent(key, Type.QUERY_PARAM); + for (String value : values) { + verifyUriComponent(value, Type.QUERY_PARAM); + } + }); + verifyUriComponent(getFragment(), Type.FRAGMENT); + } + + private static void verifyUriComponent(@Nullable String source, Type type) { + if (source == null) { + return; + } + int length = source.length(); + for (int i = 0; i < length; i++) { + char ch = source.charAt(i); + if (ch == '%') { + if ((i + 2) < length) { + char hex1 = source.charAt(i + 1); + char hex2 = source.charAt(i + 2); + int u = Character.digit(hex1, 16); + int l = Character.digit(hex2, 16); + if (u == -1 || l == -1) { + throw new IllegalArgumentException("Invalid encoded sequence \"" + + source.substring(i) + "\""); + } + i += 2; + } + else { + throw new IllegalArgumentException("Invalid encoded sequence \"" + + source.substring(i) + "\""); + } + } + else if (!type.isAllowed(ch)) { + throw new IllegalArgumentException("Invalid character '" + ch + "' for " + + type.name() + " in \"" + source + "\""); + } + } + } + + + // Expanding + + @Override + protected HierarchicalUriComponents expandInternal(UriTemplateVariables uriVariables) { + Assert.state(!this.encodeState.equals(EncodeState.FULLY_ENCODED), + "URI components already encoded, and could not possibly contain '{' or '}'."); + + String schemeTo = expandUriComponent(getScheme(), uriVariables, this.variableEncoder); + String fragmentTo = expandUriComponent(getFragment(), uriVariables, this.variableEncoder); + String userInfoTo = expandUriComponent(this.userInfo, uriVariables, this.variableEncoder); + String hostTo = expandUriComponent(this.host, uriVariables, this.variableEncoder); + String portTo = expandUriComponent(this.port, uriVariables, this.variableEncoder); + PathComponent pathTo = this.path.expand(uriVariables, this.variableEncoder); + MultiValueMap queryParamsTo = expandQueryParams(uriVariables); + + return new HierarchicalUriComponents(schemeTo, fragmentTo, userInfoTo, + hostTo, portTo, pathTo, queryParamsTo, this.encodeState, this.variableEncoder); + } + + private MultiValueMap expandQueryParams(UriTemplateVariables variables) { + int size = this.queryParams.size(); + MultiValueMap result = new LinkedMultiValueMap<>(size); + UriTemplateVariables queryVariables = new QueryUriTemplateVariables(variables); + this.queryParams.forEach((key, values) -> { + String name = expandUriComponent(key, queryVariables, this.variableEncoder); + List expandedValues = new ArrayList<>(values.size()); + for (String value : values) { + expandedValues.add(expandUriComponent(value, queryVariables, this.variableEncoder)); + } + result.put(name, expandedValues); + }); + return CollectionUtils.unmodifiableMultiValueMap(result); + } + + @Override + public UriComponents normalize() { + String normalizedPath = StringUtils.cleanPath(getPath()); + FullPathComponent path = new FullPathComponent(normalizedPath); + return new HierarchicalUriComponents(getScheme(), getFragment(), this.userInfo, this.host, this.port, + path, this.queryParams, this.encodeState, this.variableEncoder); + } + + + // Other functionality + + @Override + public String toUriString() { + StringBuilder uriBuilder = new StringBuilder(); + if (getScheme() != null) { + uriBuilder.append(getScheme()).append(':'); + } + if (this.userInfo != null || this.host != null) { + uriBuilder.append("//"); + if (this.userInfo != null) { + uriBuilder.append(this.userInfo).append('@'); + } + if (this.host != null) { + uriBuilder.append(this.host); + } + if (getPort() != -1) { + uriBuilder.append(':').append(this.port); + } + } + String path = getPath(); + if (StringUtils.hasLength(path)) { + if (uriBuilder.length() != 0 && path.charAt(0) != PATH_DELIMITER) { + uriBuilder.append(PATH_DELIMITER); + } + uriBuilder.append(path); + } + String query = getQuery(); + if (query != null) { + uriBuilder.append('?').append(query); + } + if (getFragment() != null) { + uriBuilder.append('#').append(getFragment()); + } + return uriBuilder.toString(); + } + + @Override + public URI toUri() { + try { + if (this.encodeState.isEncoded()) { + return new URI(toUriString()); + } + else { + String path = getPath(); + if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) { + // Only prefix the path delimiter if something exists before it + if (getScheme() != null || getUserInfo() != null || getHost() != null || getPort() != -1) { + path = PATH_DELIMITER + path; + } + } + return new URI(getScheme(), getUserInfo(), getHost(), getPort(), path, getQuery(), getFragment()); + } + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Could not create URI object: " + ex.getMessage(), ex); + } + } + + @Override + protected void copyToUriComponentsBuilder(UriComponentsBuilder builder) { + if (getScheme() != null) { + builder.scheme(getScheme()); + } + if (getUserInfo() != null) { + builder.userInfo(getUserInfo()); + } + if (getHost() != null) { + builder.host(getHost()); + } + // Avoid parsing the port, may have URI variable.. + if (this.port != null) { + builder.port(this.port); + } + this.path.copyToUriComponentsBuilder(builder); + if (!getQueryParams().isEmpty()) { + builder.queryParams(getQueryParams()); + } + if (getFragment() != null) { + builder.fragment(getFragment()); + } + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof HierarchicalUriComponents)) { + return false; + } + HierarchicalUriComponents otherComp = (HierarchicalUriComponents) other; + return (ObjectUtils.nullSafeEquals(getScheme(), otherComp.getScheme()) && + ObjectUtils.nullSafeEquals(getUserInfo(), otherComp.getUserInfo()) && + ObjectUtils.nullSafeEquals(getHost(), otherComp.getHost()) && + getPort() == otherComp.getPort() && + this.path.equals(otherComp.path) && + this.queryParams.equals(otherComp.queryParams) && + ObjectUtils.nullSafeEquals(getFragment(), otherComp.getFragment())); + } + + @Override + public int hashCode() { + int result = ObjectUtils.nullSafeHashCode(getScheme()); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.userInfo); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.host); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.port); + result = 31 * result + this.path.hashCode(); + result = 31 * result + this.queryParams.hashCode(); + result = 31 * result + ObjectUtils.nullSafeHashCode(getFragment()); + return result; + } + + + // Nested types + + /** + * Enumeration used to identify the allowed characters per URI component. + *

Contains methods to indicate whether a given character is valid in a specific URI component. + * @see RFC 3986 + */ + enum Type { + + SCHEME { + @Override + public boolean isAllowed(int c) { + return isAlpha(c) || isDigit(c) || '+' == c || '-' == c || '.' == c; + } + }, + AUTHORITY { + @Override + public boolean isAllowed(int c) { + return isUnreserved(c) || isSubDelimiter(c) || ':' == c || '@' == c; + } + }, + USER_INFO { + @Override + public boolean isAllowed(int c) { + return isUnreserved(c) || isSubDelimiter(c) || ':' == c; + } + }, + HOST_IPV4 { + @Override + public boolean isAllowed(int c) { + return isUnreserved(c) || isSubDelimiter(c); + } + }, + HOST_IPV6 { + @Override + public boolean isAllowed(int c) { + return isUnreserved(c) || isSubDelimiter(c) || '[' == c || ']' == c || ':' == c; + } + }, + PORT { + @Override + public boolean isAllowed(int c) { + return isDigit(c); + } + }, + PATH { + @Override + public boolean isAllowed(int c) { + return isPchar(c) || '/' == c; + } + }, + PATH_SEGMENT { + @Override + public boolean isAllowed(int c) { + return isPchar(c); + } + }, + QUERY { + @Override + public boolean isAllowed(int c) { + return isPchar(c) || '/' == c || '?' == c; + } + }, + QUERY_PARAM { + @Override + public boolean isAllowed(int c) { + if ('=' == c || '&' == c) { + return false; + } + else { + return isPchar(c) || '/' == c || '?' == c; + } + } + }, + FRAGMENT { + @Override + public boolean isAllowed(int c) { + return isPchar(c) || '/' == c || '?' == c; + } + }, + URI { + @Override + public boolean isAllowed(int c) { + return isUnreserved(c); + } + }; + + /** + * Indicates whether the given character is allowed in this URI component. + * @return {@code true} if the character is allowed; {@code false} otherwise + */ + public abstract boolean isAllowed(int c); + + /** + * Indicates whether the given character is in the {@code ALPHA} set. + * @see RFC 3986, appendix A + */ + protected boolean isAlpha(int c) { + return (c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z'); + } + + /** + * Indicates whether the given character is in the {@code DIGIT} set. + * @see RFC 3986, appendix A + */ + protected boolean isDigit(int c) { + return (c >= '0' && c <= '9'); + } + + /** + * Indicates whether the given character is in the {@code gen-delims} set. + * @see RFC 3986, appendix A + */ + protected boolean isGenericDelimiter(int c) { + return (':' == c || '/' == c || '?' == c || '#' == c || '[' == c || ']' == c || '@' == c); + } + + /** + * Indicates whether the given character is in the {@code sub-delims} set. + * @see RFC 3986, appendix A + */ + protected boolean isSubDelimiter(int c) { + return ('!' == c || '$' == c || '&' == c || '\'' == c || '(' == c || ')' == c || '*' == c || '+' == c || + ',' == c || ';' == c || '=' == c); + } + + /** + * Indicates whether the given character is in the {@code reserved} set. + * @see RFC 3986, appendix A + */ + protected boolean isReserved(int c) { + return (isGenericDelimiter(c) || isSubDelimiter(c)); + } + + /** + * Indicates whether the given character is in the {@code unreserved} set. + * @see RFC 3986, appendix A + */ + protected boolean isUnreserved(int c) { + return (isAlpha(c) || isDigit(c) || '-' == c || '.' == c || '_' == c || '~' == c); + } + + /** + * Indicates whether the given character is in the {@code pchar} set. + * @see RFC 3986, appendix A + */ + protected boolean isPchar(int c) { + return (isUnreserved(c) || isSubDelimiter(c) || ':' == c || '@' == c); + } + } + + + private enum EncodeState { + + /** + * Not encoded. + */ + RAW, + + /** + * URI vars expanded first and then each URI component encoded by + * quoting only illegal characters within that URI component. + */ + FULLY_ENCODED, + + /** + * URI template encoded first by quoting illegal characters only, and + * then URI vars encoded more strictly when expanded, by quoting both + * illegal chars and chars with reserved meaning. + */ + TEMPLATE_ENCODED; + + + public boolean isEncoded() { + return this.equals(FULLY_ENCODED) || this.equals(TEMPLATE_ENCODED); + } + } + + + private static class UriTemplateEncoder implements BiFunction { + + private final Charset charset; + + private final StringBuilder currentLiteral = new StringBuilder(); + + private final StringBuilder currentVariable = new StringBuilder(); + + private final StringBuilder output = new StringBuilder(); + + + public UriTemplateEncoder(Charset charset) { + this.charset = charset; + } + + + @Override + public String apply(String source, Type type) { + + // Only URI variable (nothing to encode).. + if (source.length() > 1 && source.charAt(0) == '{' && source.charAt(source.length() -1) == '}') { + return source; + } + + // Only literal (encode full source).. + if (source.indexOf('{') == -1) { + return encodeUriComponent(source, this.charset, type); + } + + // Mixed literal parts and URI variables, maybe (encode literal parts only).. + int level = 0; + clear(this.currentLiteral); + clear(this.currentVariable); + clear(this.output); + for (char c : source.toCharArray()) { + if (c == '{') { + level++; + if (level == 1) { + encodeAndAppendCurrentLiteral(type); + } + } + if (c == '}' && level > 0) { + level--; + this.currentVariable.append('}'); + if (level == 0) { + this.output.append(this.currentVariable); + clear(this.currentVariable); + } + } + else if (level > 0) { + this.currentVariable.append(c); + } + else { + this.currentLiteral.append(c); + } + } + if (level > 0) { + this.currentLiteral.append(this.currentVariable); + } + encodeAndAppendCurrentLiteral(type); + return this.output.toString(); + } + + private void encodeAndAppendCurrentLiteral(Type type) { + this.output.append(encodeUriComponent(this.currentLiteral.toString(), this.charset, type)); + clear(this.currentLiteral); + } + + private void clear(StringBuilder sb) { + sb.delete(0, sb.length()); + } + } + + + /** + * Defines the contract for path (segments). + */ + interface PathComponent extends Serializable { + + String getPath(); + + List getPathSegments(); + + PathComponent encode(BiFunction encoder); + + void verify(); + + PathComponent expand(UriTemplateVariables uriVariables, @Nullable UnaryOperator encoder); + + void copyToUriComponentsBuilder(UriComponentsBuilder builder); + } + + + /** + * Represents a path backed by a String. + */ + static final class FullPathComponent implements PathComponent { + + private final String path; + + public FullPathComponent(@Nullable String path) { + this.path = (path != null ? path : ""); + } + + @Override + public String getPath() { + return this.path; + } + + @Override + public List getPathSegments() { + String[] segments = StringUtils.tokenizeToStringArray(getPath(), PATH_DELIMITER_STRING); + return Collections.unmodifiableList(Arrays.asList(segments)); + } + + @Override + public PathComponent encode(BiFunction encoder) { + String encodedPath = encoder.apply(getPath(), Type.PATH); + return new FullPathComponent(encodedPath); + } + + @Override + public void verify() { + verifyUriComponent(getPath(), Type.PATH); + } + + @Override + public PathComponent expand(UriTemplateVariables uriVariables, @Nullable UnaryOperator encoder) { + String expandedPath = expandUriComponent(getPath(), uriVariables, encoder); + return new FullPathComponent(expandedPath); + } + + @Override + public void copyToUriComponentsBuilder(UriComponentsBuilder builder) { + builder.path(getPath()); + } + + @Override + public boolean equals(Object other) { + return (this == other || (other instanceof FullPathComponent && + getPath().equals(((FullPathComponent) other).getPath()))); + } + + @Override + public int hashCode() { + return getPath().hashCode(); + } + } + + + /** + * Represents a path backed by a String list (i.e. path segments). + */ + static final class PathSegmentComponent implements PathComponent { + + private final List pathSegments; + + public PathSegmentComponent(List pathSegments) { + Assert.notNull(pathSegments, "List must not be null"); + this.pathSegments = Collections.unmodifiableList(new ArrayList<>(pathSegments)); + } + + @Override + public String getPath() { + StringBuilder pathBuilder = new StringBuilder(); + pathBuilder.append(PATH_DELIMITER); + for (Iterator iterator = this.pathSegments.iterator(); iterator.hasNext(); ) { + String pathSegment = iterator.next(); + pathBuilder.append(pathSegment); + if (iterator.hasNext()) { + pathBuilder.append(PATH_DELIMITER); + } + } + return pathBuilder.toString(); + } + + @Override + public List getPathSegments() { + return this.pathSegments; + } + + @Override + public PathComponent encode(BiFunction encoder) { + List pathSegments = getPathSegments(); + List encodedPathSegments = new ArrayList<>(pathSegments.size()); + for (String pathSegment : pathSegments) { + String encodedPathSegment = encoder.apply(pathSegment, Type.PATH_SEGMENT); + encodedPathSegments.add(encodedPathSegment); + } + return new PathSegmentComponent(encodedPathSegments); + } + + @Override + public void verify() { + for (String pathSegment : getPathSegments()) { + verifyUriComponent(pathSegment, Type.PATH_SEGMENT); + } + } + + @Override + public PathComponent expand(UriTemplateVariables uriVariables, @Nullable UnaryOperator encoder) { + List pathSegments = getPathSegments(); + List expandedPathSegments = new ArrayList<>(pathSegments.size()); + for (String pathSegment : pathSegments) { + String expandedPathSegment = expandUriComponent(pathSegment, uriVariables, encoder); + expandedPathSegments.add(expandedPathSegment); + } + return new PathSegmentComponent(expandedPathSegments); + } + + @Override + public void copyToUriComponentsBuilder(UriComponentsBuilder builder) { + builder.pathSegment(StringUtils.toStringArray(getPathSegments())); + } + + @Override + public boolean equals(Object other) { + return (this == other || (other instanceof PathSegmentComponent && + getPathSegments().equals(((PathSegmentComponent) other).getPathSegments()))); + } + + @Override + public int hashCode() { + return getPathSegments().hashCode(); + } + } + + + /** + * Represents a collection of PathComponents. + */ + static final class PathComponentComposite implements PathComponent { + + private final List pathComponents; + + public PathComponentComposite(List pathComponents) { + Assert.notNull(pathComponents, "PathComponent List must not be null"); + this.pathComponents = pathComponents; + } + + @Override + public String getPath() { + StringBuilder pathBuilder = new StringBuilder(); + for (PathComponent pathComponent : this.pathComponents) { + pathBuilder.append(pathComponent.getPath()); + } + return pathBuilder.toString(); + } + + @Override + public List getPathSegments() { + List result = new ArrayList<>(); + for (PathComponent pathComponent : this.pathComponents) { + result.addAll(pathComponent.getPathSegments()); + } + return result; + } + + @Override + public PathComponent encode(BiFunction encoder) { + List encodedComponents = new ArrayList<>(this.pathComponents.size()); + for (PathComponent pathComponent : this.pathComponents) { + encodedComponents.add(pathComponent.encode(encoder)); + } + return new PathComponentComposite(encodedComponents); + } + + @Override + public void verify() { + for (PathComponent pathComponent : this.pathComponents) { + pathComponent.verify(); + } + } + + @Override + public PathComponent expand(UriTemplateVariables uriVariables, @Nullable UnaryOperator encoder) { + List expandedComponents = new ArrayList<>(this.pathComponents.size()); + for (PathComponent pathComponent : this.pathComponents) { + expandedComponents.add(pathComponent.expand(uriVariables, encoder)); + } + return new PathComponentComposite(expandedComponents); + } + + @Override + public void copyToUriComponentsBuilder(UriComponentsBuilder builder) { + for (PathComponent pathComponent : this.pathComponents) { + pathComponent.copyToUriComponentsBuilder(builder); + } + } + } + + + private static class QueryUriTemplateVariables implements UriTemplateVariables { + + private final UriTemplateVariables delegate; + + public QueryUriTemplateVariables(UriTemplateVariables delegate) { + this.delegate = delegate; + } + + @Override + public Object getValue(@Nullable String name) { + Object value = this.delegate.getValue(name); + if (ObjectUtils.isArray(value)) { + value = StringUtils.arrayToCommaDelimitedString(ObjectUtils.toObjectArray(value)); + } + return value; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/HtmlCharacterEntityDecoder.java b/spring-web/src/main/java/org/springframework/web/util/HtmlCharacterEntityDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..2277edfff3adfa9b2afe3812782dcc8af1644f89 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/HtmlCharacterEntityDecoder.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +/** + * Helper for decoding HTML Strings by replacing character + * entity references with the referred character. + * + * @author Juergen Hoeller + * @author Martin Kersten + * @since 1.2.1 + */ +class HtmlCharacterEntityDecoder { + + private static final int MAX_REFERENCE_SIZE = 10; + + + private final HtmlCharacterEntityReferences characterEntityReferences; + + private final String originalMessage; + + private final StringBuilder decodedMessage; + + private int currentPosition = 0; + + private int nextPotentialReferencePosition = -1; + + private int nextSemicolonPosition = -2; + + + public HtmlCharacterEntityDecoder(HtmlCharacterEntityReferences characterEntityReferences, String original) { + this.characterEntityReferences = characterEntityReferences; + this.originalMessage = original; + this.decodedMessage = new StringBuilder(original.length()); + } + + + public String decode() { + while (this.currentPosition < this.originalMessage.length()) { + findNextPotentialReference(this.currentPosition); + copyCharactersTillPotentialReference(); + processPossibleReference(); + } + return this.decodedMessage.toString(); + } + + private void findNextPotentialReference(int startPosition) { + this.nextPotentialReferencePosition = Math.max(startPosition, this.nextSemicolonPosition - MAX_REFERENCE_SIZE); + + do { + this.nextPotentialReferencePosition = + this.originalMessage.indexOf('&', this.nextPotentialReferencePosition); + + if (this.nextSemicolonPosition != -1 && + this.nextSemicolonPosition < this.nextPotentialReferencePosition) { + this.nextSemicolonPosition = this.originalMessage.indexOf(';', this.nextPotentialReferencePosition + 1); + } + + boolean isPotentialReference = (this.nextPotentialReferencePosition != -1 && + this.nextSemicolonPosition != -1 && + this.nextPotentialReferencePosition - this.nextSemicolonPosition < MAX_REFERENCE_SIZE); + + if (isPotentialReference) { + break; + } + if (this.nextPotentialReferencePosition == -1) { + break; + } + if (this.nextSemicolonPosition == -1) { + this.nextPotentialReferencePosition = -1; + break; + } + + this.nextPotentialReferencePosition = this.nextPotentialReferencePosition + 1; + } + while (this.nextPotentialReferencePosition != -1); + } + + private void copyCharactersTillPotentialReference() { + if (this.nextPotentialReferencePosition != this.currentPosition) { + int skipUntilIndex = (this.nextPotentialReferencePosition != -1 ? + this.nextPotentialReferencePosition : this.originalMessage.length()); + if (skipUntilIndex - this.currentPosition > 3) { + this.decodedMessage.append(this.originalMessage, this.currentPosition, skipUntilIndex); + this.currentPosition = skipUntilIndex; + } + else { + while (this.currentPosition < skipUntilIndex) { + this.decodedMessage.append(this.originalMessage.charAt(this.currentPosition++)); + } + } + } + } + + private void processPossibleReference() { + if (this.nextPotentialReferencePosition != -1) { + boolean isNumberedReference = (this.originalMessage.charAt(this.currentPosition + 1) == '#'); + boolean wasProcessable = isNumberedReference ? processNumberedReference() : processNamedReference(); + if (wasProcessable) { + this.currentPosition = this.nextSemicolonPosition + 1; + } + else { + char currentChar = this.originalMessage.charAt(this.currentPosition); + this.decodedMessage.append(currentChar); + this.currentPosition++; + } + } + } + + private boolean processNumberedReference() { + char referenceChar = this.originalMessage.charAt(this.nextPotentialReferencePosition + 2); + boolean isHexNumberedReference = (referenceChar == 'x' || referenceChar == 'X'); + try { + int value = (!isHexNumberedReference ? + Integer.parseInt(getReferenceSubstring(2)) : + Integer.parseInt(getReferenceSubstring(3), 16)); + this.decodedMessage.append((char) value); + return true; + } + catch (NumberFormatException ex) { + return false; + } + } + + private boolean processNamedReference() { + String referenceName = getReferenceSubstring(1); + char mappedCharacter = this.characterEntityReferences.convertToCharacter(referenceName); + if (mappedCharacter != HtmlCharacterEntityReferences.CHAR_NULL) { + this.decodedMessage.append(mappedCharacter); + return true; + } + return false; + } + + private String getReferenceSubstring(int referenceOffset) { + return this.originalMessage.substring( + this.nextPotentialReferencePosition + referenceOffset, this.nextSemicolonPosition); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/HtmlCharacterEntityReferences.java b/spring-web/src/main/java/org/springframework/web/util/HtmlCharacterEntityReferences.java new file mode 100644 index 0000000000000000000000000000000000000000..5ffeafcf58c53770af16a678df3cbe688e5af24d --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/HtmlCharacterEntityReferences.java @@ -0,0 +1,171 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Represents a set of character entity references defined by the + * HTML 4.0 standard. + * + *

A complete description of the HTML 4.0 character set can be found + * at http://www.w3.org/TR/html4/charset.html. + * + * @author Juergen Hoeller + * @author Martin Kersten + * @author Craig Andrews + * @since 1.2.1 + */ +class HtmlCharacterEntityReferences { + + private static final String PROPERTIES_FILE = "HtmlCharacterEntityReferences.properties"; + + static final char REFERENCE_START = '&'; + + static final String DECIMAL_REFERENCE_START = "&#"; + + static final String HEX_REFERENCE_START = "&#x"; + + static final char REFERENCE_END = ';'; + + static final char CHAR_NULL = (char) -1; + + + private final String[] characterToEntityReferenceMap = new String[3000]; + + private final Map entityReferenceToCharacterMap = new HashMap<>(512); + + + /** + * Returns a new set of character entity references reflecting the HTML 4.0 character set. + */ + public HtmlCharacterEntityReferences() { + Properties entityReferences = new Properties(); + + // Load reference definition file + InputStream is = HtmlCharacterEntityReferences.class.getResourceAsStream(PROPERTIES_FILE); + if (is == null) { + throw new IllegalStateException( + "Cannot find reference definition file [HtmlCharacterEntityReferences.properties] as class path resource"); + } + try { + try { + entityReferences.load(is); + } + finally { + is.close(); + } + } + catch (IOException ex) { + throw new IllegalStateException( + "Failed to parse reference definition file [HtmlCharacterEntityReferences.properties]: " + ex.getMessage()); + } + + // Parse reference definition properties + Enumeration keys = entityReferences.propertyNames(); + while (keys.hasMoreElements()) { + String key = (String) keys.nextElement(); + int referredChar = Integer.parseInt(key); + Assert.isTrue((referredChar < 1000 || (referredChar >= 8000 && referredChar < 10000)), + () -> "Invalid reference to special HTML entity: " + referredChar); + int index = (referredChar < 1000 ? referredChar : referredChar - 7000); + String reference = entityReferences.getProperty(key); + this.characterToEntityReferenceMap[index] = REFERENCE_START + reference + REFERENCE_END; + this.entityReferenceToCharacterMap.put(reference, (char) referredChar); + } + } + + + /** + * Return the number of supported entity references. + */ + public int getSupportedReferenceCount() { + return this.entityReferenceToCharacterMap.size(); + } + + /** + * Return true if the given character is mapped to a supported entity reference. + */ + public boolean isMappedToReference(char character) { + return isMappedToReference(character, WebUtils.DEFAULT_CHARACTER_ENCODING); + } + + /** + * Return true if the given character is mapped to a supported entity reference. + */ + public boolean isMappedToReference(char character, String encoding) { + return (convertToReference(character, encoding) != null); + } + + /** + * Return the reference mapped to the given character, or {@code null} if none found. + */ + @Nullable + public String convertToReference(char character) { + return convertToReference(character, WebUtils.DEFAULT_CHARACTER_ENCODING); + } + + /** + * Return the reference mapped to the given character, or {@code null} if none found. + * @since 4.1.2 + */ + @Nullable + public String convertToReference(char character, String encoding) { + if (encoding.startsWith("UTF-")){ + switch (character){ + case '<': + return "<"; + case '>': + return ">"; + case '"': + return """; + case '&': + return "&"; + case '\'': + return "'"; + } + } + else if (character < 1000 || (character >= 8000 && character < 10000)) { + int index = (character < 1000 ? character : character - 7000); + String entityReference = this.characterToEntityReferenceMap[index]; + if (entityReference != null) { + return entityReference; + } + } + return null; + } + + /** + * Return the char mapped to the given entityReference or -1. + */ + public char convertToCharacter(String entityReference) { + Character referredCharacter = this.entityReferenceToCharacterMap.get(entityReference); + if (referredCharacter != null) { + return referredCharacter; + } + return CHAR_NULL; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/HtmlUtils.java b/spring-web/src/main/java/org/springframework/web/util/HtmlUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..0eab5c4e33af360f29c35313d049e6b3342ebc8f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/HtmlUtils.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import org.springframework.util.Assert; + +/** + * Utility class for HTML escaping. + * + *

Escapes and unescapes based on the W3C HTML 4.01 recommendation, handling + * character entity references. + * + *

Reference: + * http://www.w3.org/TR/html4/charset.html + * + *

For a comprehensive set of String escaping utilities, consider + * Apache Commons Text + * and its {@code StringEscapeUtils} class. We do not use that class here in order + * to avoid a runtime dependency on Commons Text just for HTML escaping. Furthermore, + * Spring's HTML escaping is more flexible and 100% HTML 4.0 compliant. + * + * @author Juergen Hoeller + * @author Martin Kersten + * @author Craig Andrews + * @since 01.03.2003 + */ +public abstract class HtmlUtils { + + /** + * Shared instance of pre-parsed HTML character entity references. + */ + private static final HtmlCharacterEntityReferences characterEntityReferences = + new HtmlCharacterEntityReferences(); + + + /** + * Turn special characters into HTML character references. + *

Handles complete character set defined in HTML 4.01 recommendation. + *

Escapes all special characters to their corresponding + * entity reference (e.g. {@code <}). + *

Reference: + * + * http://www.w3.org/TR/html4/sgml/entities.html + * + * @param input the (unescaped) input string + * @return the escaped string + */ + public static String htmlEscape(String input) { + return htmlEscape(input, WebUtils.DEFAULT_CHARACTER_ENCODING); + } + + /** + * Turn special characters into HTML character references. + *

Handles complete character set defined in HTML 4.01 recommendation. + *

Escapes all special characters to their corresponding + * entity reference (e.g. {@code <}) at least as required by the + * specified encoding. In other words, if a special character does + * not have to be escaped for the given encoding, it may not be. + *

Reference: + * + * http://www.w3.org/TR/html4/sgml/entities.html + * + * @param input the (unescaped) input string + * @param encoding the name of a supported {@link java.nio.charset.Charset charset} + * @return the escaped string + * @since 4.1.2 + */ + public static String htmlEscape(String input, String encoding) { + Assert.notNull(input, "Input is required"); + Assert.notNull(encoding, "Encoding is required"); + StringBuilder escaped = new StringBuilder(input.length() * 2); + for (int i = 0; i < input.length(); i++) { + char character = input.charAt(i); + String reference = characterEntityReferences.convertToReference(character, encoding); + if (reference != null) { + escaped.append(reference); + } + else { + escaped.append(character); + } + } + return escaped.toString(); + } + + /** + * Turn special characters into HTML character references. + *

Handles complete character set defined in HTML 4.01 recommendation. + *

Escapes all special characters to their corresponding numeric + * reference in decimal format (&#Decimal;). + *

Reference: + * + * http://www.w3.org/TR/html4/sgml/entities.html + * + * @param input the (unescaped) input string + * @return the escaped string + */ + public static String htmlEscapeDecimal(String input) { + return htmlEscapeDecimal(input, WebUtils.DEFAULT_CHARACTER_ENCODING); + } + + /** + * Turn special characters into HTML character references. + *

Handles complete character set defined in HTML 4.01 recommendation. + *

Escapes all special characters to their corresponding numeric + * reference in decimal format (&#Decimal;) at least as required by the + * specified encoding. In other words, if a special character does + * not have to be escaped for the given encoding, it may not be. + *

Reference: + * + * http://www.w3.org/TR/html4/sgml/entities.html + * + * @param input the (unescaped) input string + * @param encoding the name of a supported {@link java.nio.charset.Charset charset} + * @return the escaped string + * @since 4.1.2 + */ + public static String htmlEscapeDecimal(String input, String encoding) { + Assert.notNull(input, "Input is required"); + Assert.notNull(encoding, "Encoding is required"); + StringBuilder escaped = new StringBuilder(input.length() * 2); + for (int i = 0; i < input.length(); i++) { + char character = input.charAt(i); + if (characterEntityReferences.isMappedToReference(character, encoding)) { + escaped.append(HtmlCharacterEntityReferences.DECIMAL_REFERENCE_START); + escaped.append((int) character); + escaped.append(HtmlCharacterEntityReferences.REFERENCE_END); + } + else { + escaped.append(character); + } + } + return escaped.toString(); + } + + /** + * Turn special characters into HTML character references. + *

Handles complete character set defined in HTML 4.01 recommendation. + *

Escapes all special characters to their corresponding numeric + * reference in hex format (&#xHex;). + *

Reference: + * + * http://www.w3.org/TR/html4/sgml/entities.html + * + * @param input the (unescaped) input string + * @return the escaped string + */ + public static String htmlEscapeHex(String input) { + return htmlEscapeHex(input, WebUtils.DEFAULT_CHARACTER_ENCODING); + } + + /** + * Turn special characters into HTML character references. + *

Handles complete character set defined in HTML 4.01 recommendation. + *

Escapes all special characters to their corresponding numeric + * reference in hex format (&#xHex;) at least as required by the + * specified encoding. In other words, if a special character does + * not have to be escaped for the given encoding, it may not be. + *

Reference: + * + * http://www.w3.org/TR/html4/sgml/entities.html + * + * @param input the (unescaped) input string + * @param encoding the name of a supported {@link java.nio.charset.Charset charset} + * @return the escaped string + * @since 4.1.2 + */ + public static String htmlEscapeHex(String input, String encoding) { + Assert.notNull(input, "Input is required"); + Assert.notNull(encoding, "Encoding is required"); + StringBuilder escaped = new StringBuilder(input.length() * 2); + for (int i = 0; i < input.length(); i++) { + char character = input.charAt(i); + if (characterEntityReferences.isMappedToReference(character, encoding)) { + escaped.append(HtmlCharacterEntityReferences.HEX_REFERENCE_START); + escaped.append(Integer.toString(character, 16)); + escaped.append(HtmlCharacterEntityReferences.REFERENCE_END); + } + else { + escaped.append(character); + } + } + return escaped.toString(); + } + + /** + * Turn HTML character references into their plain text UNICODE equivalent. + *

Handles complete character set defined in HTML 4.01 recommendation + * and all reference types (decimal, hex, and entity). + *

Correctly converts the following formats: + *

+ * &#Entity; - (Example: &amp;) case sensitive + * &#Decimal; - (Example: &#68;)
+ * &#xHex; - (Example: &#xE5;) case insensitive
+ *
+ *

Gracefully handles malformed character references by copying original + * characters as is when encountered. + *

Reference: + * + * http://www.w3.org/TR/html4/sgml/entities.html + * + * @param input the (escaped) input string + * @return the unescaped string + */ + public static String htmlUnescape(String input) { + return new HtmlCharacterEntityDecoder(characterEntityReferences, input).decode(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/HttpSessionMutexListener.java b/spring-web/src/main/java/org/springframework/web/util/HttpSessionMutexListener.java new file mode 100644 index 0000000000000000000000000000000000000000..369904116fe2df93c62dc1217f05491022020dba --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/HttpSessionMutexListener.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.Serializable; + +import javax.servlet.http.HttpSessionEvent; +import javax.servlet.http.HttpSessionListener; + +/** + * Servlet HttpSessionListener that automatically exposes the session mutex + * when an HttpSession gets created. To be registered as a listener in + * {@code web.xml}. + * + *

The session mutex is guaranteed to be the same object during + * the entire lifetime of the session, available under the key defined + * by the {@code SESSION_MUTEX_ATTRIBUTE} constant. It serves as a + * safe reference to synchronize on for locking on the current session. + * + *

In many cases, the HttpSession reference itself is a safe mutex + * as well, since it will always be the same object reference for the + * same active logical session. However, this is not guaranteed across + * different servlet containers; the only 100% safe way is a session mutex. + * + * @author Juergen Hoeller + * @since 1.2.7 + * @see WebUtils#SESSION_MUTEX_ATTRIBUTE + * @see WebUtils#getSessionMutex(javax.servlet.http.HttpSession) + * @see org.springframework.web.servlet.mvc.AbstractController#setSynchronizeOnSession + */ +public class HttpSessionMutexListener implements HttpSessionListener { + + @Override + public void sessionCreated(HttpSessionEvent event) { + event.getSession().setAttribute(WebUtils.SESSION_MUTEX_ATTRIBUTE, new Mutex()); + } + + @Override + public void sessionDestroyed(HttpSessionEvent event) { + event.getSession().removeAttribute(WebUtils.SESSION_MUTEX_ATTRIBUTE); + } + + + /** + * The mutex to be registered. + * Doesn't need to be anything but a plain Object to synchronize on. + * Should be serializable to allow for HttpSession persistence. + */ + @SuppressWarnings("serial") + private static class Mutex implements Serializable { + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/IntrospectorCleanupListener.java b/spring-web/src/main/java/org/springframework/web/util/IntrospectorCleanupListener.java new file mode 100644 index 0000000000000000000000000000000000000000..d3afef75f980df93dc261fdd2f2088a89d9826f3 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/IntrospectorCleanupListener.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.beans.Introspector; + +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; + +import org.springframework.beans.CachedIntrospectionResults; + +/** + * Listener that flushes the JDK's {@link java.beans.Introspector JavaBeans Introspector} + * cache on web app shutdown. Register this listener in your {@code web.xml} to + * guarantee proper release of the web application class loader and its loaded classes. + * + *

If the JavaBeans Introspector has been used to analyze application classes, + * the system-level Introspector cache will hold a hard reference to those classes. + * Consequently, those classes and the web application class loader will not be + * garbage-collected on web app shutdown! This listener performs proper cleanup, + * to allow for garbage collection to take effect. + * + *

Unfortunately, the only way to clean up the Introspector is to flush + * the entire cache, as there is no way to specifically determine the + * application's classes referenced there. This will remove cached + * introspection results for all other applications in the server too. + * + *

Note that this listener is not necessary when using Spring's beans + * infrastructure within the application, as Spring's own introspection results + * cache will immediately flush an analyzed class from the JavaBeans Introspector + * cache and only hold a cache within the application's own ClassLoader. + * + * Although Spring itself does not create JDK Introspector leaks, note that this + * listener should nevertheless be used in scenarios where the Spring framework classes + * themselves reside in a 'common' ClassLoader (such as the system ClassLoader). + * In such a scenario, this listener will properly clean up Spring's introspection cache. + * + *

Application classes hardly ever need to use the JavaBeans Introspector + * directly, so are normally not the cause of Introspector resource leaks. + * Rather, many libraries and frameworks do not clean up the Introspector: + * e.g. Struts and Quartz. + * + *

Note that a single such Introspector leak will cause the entire web + * app class loader to not get garbage collected! This has the consequence that + * you will see all the application's static class resources (like singletons) + * around after web app shutdown, which is not the fault of those classes! + * + *

This listener should be registered as the first one in {@code web.xml}, + * before any application listeners such as Spring's ContextLoaderListener. + * This allows the listener to take full effect at the right time of the lifecycle. + * + * @author Juergen Hoeller + * @since 1.1 + * @see java.beans.Introspector#flushCaches() + * @see org.springframework.beans.CachedIntrospectionResults#acceptClassLoader + * @see org.springframework.beans.CachedIntrospectionResults#clearClassLoader + */ +public class IntrospectorCleanupListener implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent event) { + CachedIntrospectionResults.acceptClassLoader(Thread.currentThread().getContextClassLoader()); + } + + @Override + public void contextDestroyed(ServletContextEvent event) { + CachedIntrospectionResults.clearClassLoader(Thread.currentThread().getContextClassLoader()); + Introspector.flushCaches(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/JavaScriptUtils.java b/spring-web/src/main/java/org/springframework/web/util/JavaScriptUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..e505d9797078973196a0dca6f0319ddfd2cafbb6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/JavaScriptUtils.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +/** + * Utility class for JavaScript escaping. + * Escapes based on the JavaScript 1.5 recommendation. + * + *

Reference: + * + * JavaScript Guide on Mozilla Developer Network. + * + * @author Juergen Hoeller + * @author Rob Harrop + * @author Rossen Stoyanchev + * @since 1.1.1 + */ +public abstract class JavaScriptUtils { + + /** + * Turn JavaScript special characters into escaped characters. + * @param input the input string + * @return the string with escaped characters + */ + public static String javaScriptEscape(String input) { + StringBuilder filtered = new StringBuilder(input.length()); + char prevChar = '\u0000'; + char c; + for (int i = 0; i < input.length(); i++) { + c = input.charAt(i); + if (c == '"') { + filtered.append("\\\""); + } + else if (c == '\'') { + filtered.append("\\'"); + } + else if (c == '\\') { + filtered.append("\\\\"); + } + else if (c == '/') { + filtered.append("\\/"); + } + else if (c == '\t') { + filtered.append("\\t"); + } + else if (c == '\n') { + if (prevChar != '\r') { + filtered.append("\\n"); + } + } + else if (c == '\r') { + filtered.append("\\n"); + } + else if (c == '\f') { + filtered.append("\\f"); + } + else if (c == '\b') { + filtered.append("\\b"); + } + // No '\v' in Java, use octal value for VT ascii char + else if (c == '\013') { + filtered.append("\\v"); + } + else if (c == '<') { + filtered.append("\\u003C"); + } + else if (c == '>') { + filtered.append("\\u003E"); + } + // Unicode for PS (line terminator in ECMA-262) + else if (c == '\u2028') { + filtered.append("\\u2028"); + } + // Unicode for LS (line terminator in ECMA-262) + else if (c == '\u2029') { + filtered.append("\\u2029"); + } + else { + filtered.append(c); + } + prevChar = c; + + } + return filtered.toString(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/NestedServletException.java b/spring-web/src/main/java/org/springframework/web/util/NestedServletException.java new file mode 100644 index 0000000000000000000000000000000000000000..1f8db8376beb6fdd98fb15649a608bd4a4857763 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/NestedServletException.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import javax.servlet.ServletException; + +import org.springframework.core.NestedExceptionUtils; +import org.springframework.lang.Nullable; + +/** + * Subclass of {@link ServletException} that properly handles a root cause in terms + * of message and stacktrace, just like NestedChecked/RuntimeException does. + * + *

Note that the plain ServletException doesn't expose its root cause at all, + * neither in the exception message nor in printed stack traces! While this might + * be fixed in later Servlet API variants (which even differ per vendor for the + * same API version), it is not reliably available on Servlet 2.4 (the minimum + * version required by Spring 3.x), which is why we need to do it ourselves. + * + *

The similarity between this class and the NestedChecked/RuntimeException + * class is unavoidable, as this class needs to derive from ServletException. + * + * @author Juergen Hoeller + * @since 1.2.5 + * @see #getMessage + * @see #printStackTrace + * @see org.springframework.core.NestedCheckedException + * @see org.springframework.core.NestedRuntimeException + */ +public class NestedServletException extends ServletException { + + /** Use serialVersionUID from Spring 1.2 for interoperability. */ + private static final long serialVersionUID = -5292377985529381145L; + + static { + // Eagerly load the NestedExceptionUtils class to avoid classloader deadlock + // issues on OSGi when calling getMessage(). Reported by Don Brown; SPR-5607. + NestedExceptionUtils.class.getName(); + } + + + /** + * Construct a {@code NestedServletException} with the specified detail message. + * @param msg the detail message + */ + public NestedServletException(String msg) { + super(msg); + } + + /** + * Construct a {@code NestedServletException} with the specified detail message + * and nested exception. + * @param msg the detail message + * @param cause the nested exception + */ + public NestedServletException(@Nullable String msg, @Nullable Throwable cause) { + super(msg, cause); + } + + + /** + * Return the detail message, including the message from the nested exception + * if there is one. + */ + @Override + @Nullable + public String getMessage() { + return NestedExceptionUtils.buildMessage(super.getMessage(), getCause()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/OpaqueUriComponents.java b/spring-web/src/main/java/org/springframework/web/util/OpaqueUriComponents.java new file mode 100644 index 0000000000000000000000000000000000000000..fd076077b2713386b6f1b47fecdae6ca7ab67cb1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/OpaqueUriComponents.java @@ -0,0 +1,181 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.Charset; +import java.util.Collections; +import java.util.List; + +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; + +/** + * Extension of {@link UriComponents} for opaque URIs. + * + * @author Arjen Poutsma + * @author Phillip Webb + * @since 3.2 + * @see Hierarchical vs Opaque URIs + */ +@SuppressWarnings("serial") +final class OpaqueUriComponents extends UriComponents { + + private static final MultiValueMap QUERY_PARAMS_NONE = new LinkedMultiValueMap<>(); + + @Nullable + private final String ssp; + + + OpaqueUriComponents(@Nullable String scheme, @Nullable String schemeSpecificPart, @Nullable String fragment) { + super(scheme, fragment); + this.ssp = schemeSpecificPart; + } + + + @Override + @Nullable + public String getSchemeSpecificPart() { + return this.ssp; + } + + @Override + @Nullable + public String getUserInfo() { + return null; + } + + @Override + @Nullable + public String getHost() { + return null; + } + + @Override + public int getPort() { + return -1; + } + + @Override + @Nullable + public String getPath() { + return null; + } + + @Override + public List getPathSegments() { + return Collections.emptyList(); + } + + @Override + @Nullable + public String getQuery() { + return null; + } + + @Override + public MultiValueMap getQueryParams() { + return QUERY_PARAMS_NONE; + } + + @Override + public UriComponents encode(Charset charset) { + return this; + } + + @Override + protected UriComponents expandInternal(UriTemplateVariables uriVariables) { + String expandedScheme = expandUriComponent(getScheme(), uriVariables); + String expandedSsp = expandUriComponent(getSchemeSpecificPart(), uriVariables); + String expandedFragment = expandUriComponent(getFragment(), uriVariables); + return new OpaqueUriComponents(expandedScheme, expandedSsp, expandedFragment); + } + + @Override + public UriComponents normalize() { + return this; + } + + @Override + public String toUriString() { + StringBuilder uriBuilder = new StringBuilder(); + + if (getScheme() != null) { + uriBuilder.append(getScheme()); + uriBuilder.append(':'); + } + if (this.ssp != null) { + uriBuilder.append(this.ssp); + } + if (getFragment() != null) { + uriBuilder.append('#'); + uriBuilder.append(getFragment()); + } + + return uriBuilder.toString(); + } + + @Override + public URI toUri() { + try { + return new URI(getScheme(), this.ssp, getFragment()); + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Could not create URI object: " + ex.getMessage(), ex); + } + } + + @Override + protected void copyToUriComponentsBuilder(UriComponentsBuilder builder) { + if (getScheme() != null) { + builder.scheme(getScheme()); + } + if (getSchemeSpecificPart() != null) { + builder.schemeSpecificPart(getSchemeSpecificPart()); + } + if (getFragment() != null) { + builder.fragment(getFragment()); + } + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof OpaqueUriComponents)) { + return false; + } + OpaqueUriComponents otherComp = (OpaqueUriComponents) other; + return (ObjectUtils.nullSafeEquals(getScheme(), otherComp.getScheme()) && + ObjectUtils.nullSafeEquals(this.ssp, otherComp.ssp) && + ObjectUtils.nullSafeEquals(getFragment(), otherComp.getFragment())); + } + + @Override + public int hashCode() { + int result = ObjectUtils.nullSafeHashCode(getScheme()); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.ssp); + result = 31 * result + ObjectUtils.nullSafeHashCode(getFragment()); + return result; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/ServletContextPropertyUtils.java b/spring-web/src/main/java/org/springframework/web/util/ServletContextPropertyUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..ffaefa75b4be0425cd20ee36366969535b2c0ddb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/ServletContextPropertyUtils.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import javax.servlet.ServletContext; + +import org.springframework.lang.Nullable; +import org.springframework.util.PropertyPlaceholderHelper; +import org.springframework.util.SystemPropertyUtils; + +/** + * Helper class for resolving placeholders in texts. Usually applied to file paths. + * + *

A text may contain {@code ${...}} placeholders, to be resolved as servlet context + * init parameters or system properties: e.g. {@code ${user.dir}}. Default values can + * be supplied using the ":" separator between key and value. + * + * @author Juergen Hoeller + * @author Marten Deinum + * @since 3.2.2 + * @see SystemPropertyUtils + * @see ServletContext#getInitParameter(String) + */ +public abstract class ServletContextPropertyUtils { + + private static final PropertyPlaceholderHelper strictHelper = + new PropertyPlaceholderHelper(SystemPropertyUtils.PLACEHOLDER_PREFIX, + SystemPropertyUtils.PLACEHOLDER_SUFFIX, SystemPropertyUtils.VALUE_SEPARATOR, false); + + private static final PropertyPlaceholderHelper nonStrictHelper = + new PropertyPlaceholderHelper(SystemPropertyUtils.PLACEHOLDER_PREFIX, + SystemPropertyUtils.PLACEHOLDER_SUFFIX, SystemPropertyUtils.VALUE_SEPARATOR, true); + + + /** + * Resolve ${...} placeholders in the given text, replacing them with corresponding + * servlet context init parameter or system property values. + * @param text the String to resolve + * @param servletContext the servletContext to use for lookups. + * @return the resolved String + * @throws IllegalArgumentException if there is an unresolvable placeholder + * @see SystemPropertyUtils#PLACEHOLDER_PREFIX + * @see SystemPropertyUtils#PLACEHOLDER_SUFFIX + * @see SystemPropertyUtils#resolvePlaceholders(String, boolean) + */ + public static String resolvePlaceholders(String text, ServletContext servletContext) { + return resolvePlaceholders(text, servletContext, false); + } + + /** + * Resolve ${...} placeholders in the given text, replacing them with corresponding + * servlet context init parameter or system property values. Unresolvable placeholders + * with no default value are ignored and passed through unchanged if the flag is set to true. + * @param text the String to resolve + * @param servletContext the servletContext to use for lookups. + * @param ignoreUnresolvablePlaceholders flag to determine is unresolved placeholders are ignored + * @return the resolved String + * @throws IllegalArgumentException if there is an unresolvable placeholder and the flag is false + * @see SystemPropertyUtils#PLACEHOLDER_PREFIX + * @see SystemPropertyUtils#PLACEHOLDER_SUFFIX + * @see SystemPropertyUtils#resolvePlaceholders(String, boolean) + */ + public static String resolvePlaceholders( + String text, ServletContext servletContext, boolean ignoreUnresolvablePlaceholders) { + + if (text.isEmpty()) { + return text; + } + PropertyPlaceholderHelper helper = (ignoreUnresolvablePlaceholders ? nonStrictHelper : strictHelper); + return helper.replacePlaceholders(text, new ServletContextPlaceholderResolver(text, servletContext)); + } + + + private static class ServletContextPlaceholderResolver implements PropertyPlaceholderHelper.PlaceholderResolver { + + private final String text; + + private final ServletContext servletContext; + + public ServletContextPlaceholderResolver(String text, ServletContext servletContext) { + this.text = text; + this.servletContext = servletContext; + } + + @Override + @Nullable + public String resolvePlaceholder(String placeholderName) { + try { + String propVal = this.servletContext.getInitParameter(placeholderName); + if (propVal == null) { + // Fall back to system properties. + propVal = System.getProperty(placeholderName); + if (propVal == null) { + // Fall back to searching the system environment. + propVal = System.getenv(placeholderName); + } + } + return propVal; + } + catch (Throwable ex) { + System.err.println("Could not resolve placeholder '" + placeholderName + "' in [" + + this.text + "] as ServletContext init-parameter or system property: " + ex); + return null; + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/TagUtils.java b/spring-web/src/main/java/org/springframework/web/util/TagUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..f1e206547b2adb478ecc0728480efce0aead0846 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/TagUtils.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import javax.servlet.jsp.PageContext; +import javax.servlet.jsp.tagext.Tag; + +import org.springframework.util.Assert; + +/** + * Utility class for tag library related code, exposing functionality + * such as translating {@link String Strings} to web scopes. + * + *

+ *

    + *
  • {@code page} will be transformed to + * {@link javax.servlet.jsp.PageContext#PAGE_SCOPE PageContext.PAGE_SCOPE} + *
  • {@code request} will be transformed to + * {@link javax.servlet.jsp.PageContext#REQUEST_SCOPE PageContext.REQUEST_SCOPE} + *
  • {@code session} will be transformed to + * {@link javax.servlet.jsp.PageContext#SESSION_SCOPE PageContext.SESSION_SCOPE} + *
  • {@code application} will be transformed to + * {@link javax.servlet.jsp.PageContext#APPLICATION_SCOPE PageContext.APPLICATION_SCOPE} + *
+ * + * @author Alef Arendsen + * @author Rob Harrop + * @author Juergen Hoeller + * @author Rick Evans + */ +public abstract class TagUtils { + + /** Constant identifying the page scope. */ + public static final String SCOPE_PAGE = "page"; + + /** Constant identifying the request scope. */ + public static final String SCOPE_REQUEST = "request"; + + /** Constant identifying the session scope. */ + public static final String SCOPE_SESSION = "session"; + + /** Constant identifying the application scope. */ + public static final String SCOPE_APPLICATION = "application"; + + + /** + * Determines the scope for a given input {@code String}. + *

If the {@code String} does not match 'request', 'session', + * 'page' or 'application', the method will return {@link PageContext#PAGE_SCOPE}. + * @param scope the {@code String} to inspect + * @return the scope found, or {@link PageContext#PAGE_SCOPE} if no scope matched + * @throws IllegalArgumentException if the supplied {@code scope} is {@code null} + */ + public static int getScope(String scope) { + Assert.notNull(scope, "Scope to search for cannot be null"); + if (scope.equals(SCOPE_REQUEST)) { + return PageContext.REQUEST_SCOPE; + } + else if (scope.equals(SCOPE_SESSION)) { + return PageContext.SESSION_SCOPE; + } + else if (scope.equals(SCOPE_APPLICATION)) { + return PageContext.APPLICATION_SCOPE; + } + else { + return PageContext.PAGE_SCOPE; + } + } + + /** + * Determine whether the supplied {@link Tag} has any ancestor tag + * of the supplied type. + * @param tag the tag whose ancestors are to be checked + * @param ancestorTagClass the ancestor {@link Class} being searched for + * @return {@code true} if the supplied {@link Tag} has any ancestor tag + * of the supplied type + * @throws IllegalArgumentException if either of the supplied arguments is {@code null}; + * or if the supplied {@code ancestorTagClass} is not type-assignable to + * the {@link Tag} class + */ + public static boolean hasAncestorOfType(Tag tag, Class ancestorTagClass) { + Assert.notNull(tag, "Tag cannot be null"); + Assert.notNull(ancestorTagClass, "Ancestor tag class cannot be null"); + if (!Tag.class.isAssignableFrom(ancestorTagClass)) { + throw new IllegalArgumentException( + "Class '" + ancestorTagClass.getName() + "' is not a valid Tag type"); + } + Tag ancestor = tag.getParent(); + while (ancestor != null) { + if (ancestorTagClass.isAssignableFrom(ancestor.getClass())) { + return true; + } + ancestor = ancestor.getParent(); + } + return false; + } + + /** + * Determine whether the supplied {@link Tag} has any ancestor tag + * of the supplied type, throwing an {@link IllegalStateException} + * if not. + * @param tag the tag whose ancestors are to be checked + * @param ancestorTagClass the ancestor {@link Class} being searched for + * @param tagName the name of the {@code tag}; for example '{@code option}' + * @param ancestorTagName the name of the ancestor {@code tag}; for example '{@code select}' + * @throws IllegalStateException if the supplied {@code tag} does not + * have a tag of the supplied {@code parentTagClass} as an ancestor + * @throws IllegalArgumentException if any of the supplied arguments is {@code null}, + * or in the case of the {@link String}-typed arguments, is composed wholly + * of whitespace; or if the supplied {@code ancestorTagClass} is not + * type-assignable to the {@link Tag} class + * @see #hasAncestorOfType(javax.servlet.jsp.tagext.Tag, Class) + */ + public static void assertHasAncestorOfType(Tag tag, Class ancestorTagClass, String tagName, + String ancestorTagName) { + + Assert.hasText(tagName, "'tagName' must not be empty"); + Assert.hasText(ancestorTagName, "'ancestorTagName' must not be empty"); + if (!TagUtils.hasAncestorOfType(tag, ancestorTagClass)) { + throw new IllegalStateException("The '" + tagName + + "' tag can only be used inside a valid '" + ancestorTagName + "' tag."); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..f0b8f1dcc77363c4e53bcb06f9da27ec35d3f8dd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UriBuilder.java @@ -0,0 +1,173 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * Builder-style methods to prepare and expand a URI template with variables. + * + *

Effectively a generalization of {@link UriComponentsBuilder} but with + * shortcuts to expand directly into {@link URI} rather than + * {@link UriComponents} and also leaving common concerns such as encoding + * preferences, a base URI, and others as implementation concerns. + * + *

Typically obtained via {@link UriBuilderFactory} which serves as a central + * component configured once and used to create many URLs. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @see UriBuilderFactory + * @see UriComponentsBuilder + */ +public interface UriBuilder { + + /** + * Set the URI scheme which may contain URI template variables, + * and may also be {@code null} to clear the scheme of this builder. + * @param scheme the URI scheme + */ + UriBuilder scheme(@Nullable String scheme); + + /** + * Set the URI user info which may contain URI template variables, and + * may also be {@code null} to clear the user info of this builder. + * @param userInfo the URI user info + */ + UriBuilder userInfo(@Nullable String userInfo); + + /** + * Set the URI host which may contain URI template variables, and may also + * be {@code null} to clear the host of this builder. + * @param host the URI host + */ + UriBuilder host(@Nullable String host); + + /** + * Set the URI port. Passing {@code -1} will clear the port of this builder. + * @param port the URI port + */ + UriBuilder port(int port); + + /** + * Set the URI port . Use this method only when the port needs to be + * parameterized with a URI variable. Otherwise use {@link #port(int)}. + * Passing {@code null} will clear the port of this builder. + * @param port the URI port + */ + UriBuilder port(@Nullable String port); + + /** + * Append the given path to the existing path of this builder. + * The given path may contain URI template variables. + * @param path the URI path + */ + UriBuilder path(String path); + + /** + * Set the path of this builder overriding the existing path values. + * @param path the URI path, or {@code null} for an empty path + */ + UriBuilder replacePath(@Nullable String path); + + /** + * Append path segments to the existing path. Each path segment may contain + * URI template variables and should not contain any slashes. + * Use {@code path("/")} subsequently to ensure a trailing slash. + * @param pathSegments the URI path segments + */ + UriBuilder pathSegment(String... pathSegments) throws IllegalArgumentException; + + /** + * Append the given query to the existing query of this builder. + * The given query may contain URI template variables. + *

Note: The presence of reserved characters can prevent + * correct parsing of the URI string. For example if a query parameter + * contains {@code '='} or {@code '&'} characters, the query string cannot + * be parsed unambiguously. Such values should be substituted for URI + * variables to enable correct parsing: + *

+	 * builder.query("filter={value}").uriString("hot&cold");
+	 * 
+ * @param query the query string + */ + UriBuilder query(String query); + + /** + * Set the query of this builder overriding all existing query parameters. + * @param query the query string, or {@code null} to remove all query params + */ + UriBuilder replaceQuery(@Nullable String query); + + /** + * Append the given query parameter to the existing query parameters. The + * given name or any of the values may contain URI template variables. If no + * values are given, the resulting URI will contain the query parameter name + * only (i.e. {@code ?foo} instead of {@code ?foo=bar}. + * @param name the query parameter name + * @param values the query parameter values + */ + UriBuilder queryParam(String name, Object... values); + + /** + * Add the given query parameters. + * @param params the params + */ + UriBuilder queryParams(MultiValueMap params); + + /** + * Set the query parameter values overriding all existing query values for + * the same parameter. If no values are given, the query parameter is removed. + * @param name the query parameter name + * @param values the query parameter values + */ + UriBuilder replaceQueryParam(String name, Object... values); + + /** + * Set the query parameter values overriding all existing query values. + * @param params the query parameter name + */ + UriBuilder replaceQueryParams(MultiValueMap params); + + /** + * Set the URI fragment. The given fragment may contain URI template variables, + * and may also be {@code null} to clear the fragment of this builder. + * @param fragment the URI fragment + */ + UriBuilder fragment(@Nullable String fragment); + + /** + * Build a {@link URI} instance and replaces URI template variables + * with the values from an array. + * @param uriVariables the map of URI variables + * @return the URI + */ + URI build(Object... uriVariables); + + /** + * Build a {@link URI} instance and replaces URI template variables + * with the values from a map. + * @param uriVariables the map of URI variables + * @return the URI + */ + URI build(Map uriVariables); + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriBuilderFactory.java b/spring-web/src/main/java/org/springframework/web/util/UriBuilderFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..ddd5f5dd8335ab743aa569d3abd1af1c63172ba4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UriBuilderFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +/** + * Factory to create {@link UriBuilder} instances with shared configuration + * such as a base URI, an encoding mode strategy, and others across all URI + * builder instances created through a factory. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @see DefaultUriBuilderFactory + */ +public interface UriBuilderFactory extends UriTemplateHandler { + + /** + * Initialize a builder with the given URI template. + * @param uriTemplate the URI template to use + * @return the builder instance + */ + UriBuilder uriString(String uriTemplate); + + /** + * Create a URI builder with default settings. + * @return the builder instance + */ + UriBuilder builder(); + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponents.java b/spring-web/src/main/java/org/springframework/web/util/UriComponents.java new file mode 100644 index 0000000000000000000000000000000000000000..c29b09f584eaaec380e4718eaea6a13ad52730bf --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponents.java @@ -0,0 +1,373 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.Serializable; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.UnaryOperator; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; + +/** + * Represents an immutable collection of URI components, mapping component type to + * String values. Contains convenience getters for all components. Effectively similar + * to {@link java.net.URI}, but with more powerful encoding options and support for + * URI template variables. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 3.1 + * @see UriComponentsBuilder + */ +@SuppressWarnings("serial") +public abstract class UriComponents implements Serializable { + + /** Captures URI template variable names. */ + private static final Pattern NAMES_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); + + + @Nullable + private final String scheme; + + @Nullable + private final String fragment; + + + protected UriComponents(@Nullable String scheme, @Nullable String fragment) { + this.scheme = scheme; + this.fragment = fragment; + } + + + // Component getters + + /** + * Return the scheme. Can be {@code null}. + */ + @Nullable + public final String getScheme() { + return this.scheme; + } + + /** + * Return the fragment. Can be {@code null}. + */ + @Nullable + public final String getFragment() { + return this.fragment; + } + + /** + * Return the scheme specific part. Can be {@code null}. + */ + @Nullable + public abstract String getSchemeSpecificPart(); + + /** + * Return the user info. Can be {@code null}. + */ + @Nullable + public abstract String getUserInfo(); + + /** + * Return the host. Can be {@code null}. + */ + @Nullable + public abstract String getHost(); + + /** + * Return the port. {@code -1} if no port has been set. + */ + public abstract int getPort(); + + /** + * Return the path. Can be {@code null}. + */ + @Nullable + public abstract String getPath(); + + /** + * Return the list of path segments. Empty if no path has been set. + */ + public abstract List getPathSegments(); + + /** + * Return the query. Can be {@code null}. + */ + @Nullable + public abstract String getQuery(); + + /** + * Return the map of query parameters. Empty if no query has been set. + */ + public abstract MultiValueMap getQueryParams(); + + + /** + * Invoke this after expanding URI variables to encode the + * resulting URI component values. + *

In comparison to {@link UriComponentsBuilder#encode()}, this method + * only replaces non-ASCII and illegal (within a given URI + * component type) characters, but not characters with reserved meaning. + * For most cases, {@link UriComponentsBuilder#encode()} is more likely + * to give the expected result. + * @see UriComponentsBuilder#encode() + */ + public final UriComponents encode() { + return encode(StandardCharsets.UTF_8); + } + + /** + * A variant of {@link #encode()} with a charset other than "UTF-8". + * @param charset the charset to use for encoding + * @see UriComponentsBuilder#encode(Charset) + */ + public abstract UriComponents encode(Charset charset); + + /** + * Replace all URI template variables with the values from a given map. + *

The given map keys represent variable names; the corresponding values + * represent variable values. The order of variables is not significant. + * @param uriVariables the map of URI variables + * @return the expanded URI components + */ + public final UriComponents expand(Map uriVariables) { + Assert.notNull(uriVariables, "'uriVariables' must not be null"); + return expandInternal(new MapTemplateVariables(uriVariables)); + } + + /** + * Replace all URI template variables with the values from a given array. + *

The given array represents variable values. The order of variables is significant. + * @param uriVariableValues the URI variable values + * @return the expanded URI components + */ + public final UriComponents expand(Object... uriVariableValues) { + Assert.notNull(uriVariableValues, "'uriVariableValues' must not be null"); + return expandInternal(new VarArgsTemplateVariables(uriVariableValues)); + } + + /** + * Replace all URI template variables with the values from the given + * {@link UriTemplateVariables}. + * @param uriVariables the URI template values + * @return the expanded URI components + */ + public final UriComponents expand(UriTemplateVariables uriVariables) { + Assert.notNull(uriVariables, "'uriVariables' must not be null"); + return expandInternal(uriVariables); + } + + /** + * Replace all URI template variables with the values from the given {@link + * UriTemplateVariables}. + * @param uriVariables the URI template values + * @return the expanded URI components + */ + abstract UriComponents expandInternal(UriTemplateVariables uriVariables); + + /** + * Normalize the path removing sequences like "path/..". Note that + * normalization is applied to the full path, and not to individual path + * segments. + * @see org.springframework.util.StringUtils#cleanPath(String) + */ + public abstract UriComponents normalize(); + + /** + * Concatenate all URI components to return the fully formed URI String. + *

This method does nothing more than a simple concatenation based on + * current values. That means it could produce different results if invoked + * before vs after methods that can change individual values such as + * {@code encode}, {@code expand}, or {@code normalize}. + */ + public abstract String toUriString(); + + /** + * Create a {@link URI} from this instance as follows: + *

If the current instance is {@link #encode() encoded}, form the full + * URI String via {@link #toUriString()}, and then pass it to the single + * argument {@link URI} constructor which preserves percent encoding. + *

If not yet encoded, pass individual URI component values to the + * multi-argument {@link URI} constructor which quotes illegal characters + * that cannot appear in their respective URI component. + */ + public abstract URI toUri(); + + /** + * A simple pass-through to {@link #toUriString()}. + */ + @Override + public final String toString() { + return toUriString(); + } + + /** + * Set all components of the given UriComponentsBuilder. + * @since 4.2 + */ + protected abstract void copyToUriComponentsBuilder(UriComponentsBuilder builder); + + + // Static expansion helpers + + @Nullable + static String expandUriComponent(@Nullable String source, UriTemplateVariables uriVariables) { + return expandUriComponent(source, uriVariables, null); + } + + @Nullable + static String expandUriComponent(@Nullable String source, UriTemplateVariables uriVariables, + @Nullable UnaryOperator encoder) { + + if (source == null) { + return null; + } + if (source.indexOf('{') == -1) { + return source; + } + if (source.indexOf(':') != -1) { + source = sanitizeSource(source); + } + Matcher matcher = NAMES_PATTERN.matcher(source); + StringBuffer sb = new StringBuffer(); + while (matcher.find()) { + String match = matcher.group(1); + String varName = getVariableName(match); + Object varValue = uriVariables.getValue(varName); + if (UriTemplateVariables.SKIP_VALUE.equals(varValue)) { + continue; + } + String formatted = getVariableValueAsString(varValue); + formatted = encoder != null ? encoder.apply(formatted) : Matcher.quoteReplacement(formatted); + matcher.appendReplacement(sb, formatted); + } + matcher.appendTail(sb); + return sb.toString(); + } + + /** + * Remove nested "{}" such as in URI vars with regular expressions. + */ + private static String sanitizeSource(String source) { + int level = 0; + StringBuilder sb = new StringBuilder(); + for (char c : source.toCharArray()) { + if (c == '{') { + level++; + } + if (c == '}') { + level--; + } + if (level > 1 || (level == 1 && c == '}')) { + continue; + } + sb.append(c); + } + return sb.toString(); + } + + private static String getVariableName(String match) { + int colonIdx = match.indexOf(':'); + return (colonIdx != -1 ? match.substring(0, colonIdx) : match); + } + + private static String getVariableValueAsString(@Nullable Object variableValue) { + return (variableValue != null ? variableValue.toString() : ""); + } + + + /** + * Defines the contract for URI Template variables. + * @see HierarchicalUriComponents#expand + */ + public interface UriTemplateVariables { + + /** + * Constant for a value that indicates a URI variable name should be + * ignored and left as is. This is useful for partial expanding of some + * but not all URI variables. + */ + Object SKIP_VALUE = UriTemplateVariables.class; + + /** + * Get the value for the given URI variable name. + * If the value is {@code null}, an empty String is expanded. + * If the value is {@link #SKIP_VALUE}, the URI variable is not expanded. + * @param name the variable name + * @return the variable value, possibly {@code null} or {@link #SKIP_VALUE} + */ + @Nullable + Object getValue(@Nullable String name); + } + + + /** + * URI template variables backed by a map. + */ + private static class MapTemplateVariables implements UriTemplateVariables { + + private final Map uriVariables; + + public MapTemplateVariables(Map uriVariables) { + this.uriVariables = uriVariables; + } + + @Override + @Nullable + public Object getValue(@Nullable String name) { + if (!this.uriVariables.containsKey(name)) { + throw new IllegalArgumentException("Map has no value for '" + name + "'"); + } + return this.uriVariables.get(name); + } + } + + + /** + * URI template variables backed by a variable argument array. + */ + private static class VarArgsTemplateVariables implements UriTemplateVariables { + + private final Iterator valueIterator; + + public VarArgsTemplateVariables(Object... uriVariableValues) { + this.valueIterator = Arrays.asList(uriVariableValues).iterator(); + } + + @Override + @Nullable + public Object getValue(@Nullable String name) { + if (!this.valueIterator.hasNext()) { + throw new IllegalArgumentException("Not enough variable values available to expand '" + name + "'"); + } + return this.valueIterator.next(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..9b20bfd15fe3494ca1704503b76d2b998b63162f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java @@ -0,0 +1,1076 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.HierarchicalUriComponents.PathComponent; +import org.springframework.web.util.UriComponents.UriTemplateVariables; + +/** + * Builder for {@link UriComponents}. + * + *

Typical usage involves: + *

    + *
  1. Create a {@code UriComponentsBuilder} with one of the static factory methods + * (such as {@link #fromPath(String)} or {@link #fromUri(URI)})
  2. + *
  3. Set the various URI components through the respective methods ({@link #scheme(String)}, + * {@link #userInfo(String)}, {@link #host(String)}, {@link #port(int)}, {@link #path(String)}, + * {@link #pathSegment(String...)}, {@link #queryParam(String, Object...)}, and + * {@link #fragment(String)}.
  4. + *
  5. Build the {@link UriComponents} instance with the {@link #build()} method.
  6. + *
+ * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Phillip Webb + * @author Oliver Gierke + * @author Brian Clozel + * @author Sebastien Deleuze + * @since 3.1 + * @see #newInstance() + * @see #fromPath(String) + * @see #fromUri(URI) + */ +public class UriComponentsBuilder implements UriBuilder, Cloneable { + + private static final Pattern QUERY_PARAM_PATTERN = Pattern.compile("([^&=]+)(=?)([^&]+)?"); + + private static final String SCHEME_PATTERN = "([^:/?#]+):"; + + private static final String HTTP_PATTERN = "(?i)(http|https):"; + + private static final String USERINFO_PATTERN = "([^@\\[/?#]*)"; + + private static final String HOST_IPV4_PATTERN = "[^\\[/?#:]*"; + + private static final String HOST_IPV6_PATTERN = "\\[[\\p{XDigit}\\:\\.]*[%\\p{Alnum}]*\\]"; + + private static final String HOST_PATTERN = "(" + HOST_IPV6_PATTERN + "|" + HOST_IPV4_PATTERN + ")"; + + private static final String PORT_PATTERN = "(\\d*(?:\\{[^/]+?\\})?)"; + + private static final String PATH_PATTERN = "([^?#]*)"; + + private static final String QUERY_PATTERN = "([^#]*)"; + + private static final String LAST_PATTERN = "(.*)"; + + // Regex patterns that matches URIs. See RFC 3986, appendix B + private static final Pattern URI_PATTERN = Pattern.compile( + "^(" + SCHEME_PATTERN + ")?" + "(//(" + USERINFO_PATTERN + "@)?" + HOST_PATTERN + "(:" + PORT_PATTERN + + ")?" + ")?" + PATH_PATTERN + "(\\?" + QUERY_PATTERN + ")?" + "(#" + LAST_PATTERN + ")?"); + + private static final Pattern HTTP_URL_PATTERN = Pattern.compile( + "^" + HTTP_PATTERN + "(//(" + USERINFO_PATTERN + "@)?" + HOST_PATTERN + "(:" + PORT_PATTERN + ")?" + ")?" + + PATH_PATTERN + "(\\?" + LAST_PATTERN + ")?"); + + private static final Pattern FORWARDED_HOST_PATTERN = Pattern.compile("host=\"?([^;,\"]+)\"?"); + + private static final Pattern FORWARDED_PROTO_PATTERN = Pattern.compile("proto=\"?([^;,\"]+)\"?"); + + + @Nullable + private String scheme; + + @Nullable + private String ssp; + + @Nullable + private String userInfo; + + @Nullable + private String host; + + @Nullable + private String port; + + private CompositePathComponentBuilder pathBuilder; + + private final MultiValueMap queryParams = new LinkedMultiValueMap<>(); + + @Nullable + private String fragment; + + private final Map uriVariables = new HashMap<>(4); + + private boolean encodeTemplate; + + private Charset charset = StandardCharsets.UTF_8; + + + /** + * Default constructor. Protected to prevent direct instantiation. + * @see #newInstance() + * @see #fromPath(String) + * @see #fromUri(URI) + */ + protected UriComponentsBuilder() { + this.pathBuilder = new CompositePathComponentBuilder(); + } + + /** + * Create a deep copy of the given UriComponentsBuilder. + * @param other the other builder to copy from + * @since 4.1.3 + */ + protected UriComponentsBuilder(UriComponentsBuilder other) { + this.scheme = other.scheme; + this.ssp = other.ssp; + this.userInfo = other.userInfo; + this.host = other.host; + this.port = other.port; + this.pathBuilder = other.pathBuilder.cloneBuilder(); + this.uriVariables.putAll(other.uriVariables); + this.queryParams.addAll(other.queryParams); + this.fragment = other.fragment; + this.encodeTemplate = other.encodeTemplate; + this.charset = other.charset; + } + + + // Factory methods + + /** + * Create a new, empty builder. + * @return the new {@code UriComponentsBuilder} + */ + public static UriComponentsBuilder newInstance() { + return new UriComponentsBuilder(); + } + + /** + * Create a builder that is initialized with the given path. + * @param path the path to initialize with + * @return the new {@code UriComponentsBuilder} + */ + public static UriComponentsBuilder fromPath(String path) { + UriComponentsBuilder builder = new UriComponentsBuilder(); + builder.path(path); + return builder; + } + + /** + * Create a builder that is initialized from the given {@code URI}. + *

Note: the components in the resulting builder will be + * in fully encoded (raw) form and further changes must also supply values + * that are fully encoded, for example via methods in {@link UriUtils}. + * In addition please use {@link #build(boolean)} with a value of "true" to + * build the {@link UriComponents} instance in order to indicate that the + * components are encoded. + * @param uri the URI to initialize with + * @return the new {@code UriComponentsBuilder} + */ + public static UriComponentsBuilder fromUri(URI uri) { + UriComponentsBuilder builder = new UriComponentsBuilder(); + builder.uri(uri); + return builder; + } + + /** + * Create a builder that is initialized with the given URI string. + *

Note: The presence of reserved characters can prevent + * correct parsing of the URI string. For example if a query parameter + * contains {@code '='} or {@code '&'} characters, the query string cannot + * be parsed unambiguously. Such values should be substituted for URI + * variables to enable correct parsing: + *

+	 * String uriString = "/hotels/42?filter={value}";
+	 * UriComponentsBuilder.fromUriString(uriString).buildAndExpand("hot&cold");
+	 * 
+ * @param uri the URI string to initialize with + * @return the new {@code UriComponentsBuilder} + */ + public static UriComponentsBuilder fromUriString(String uri) { + Assert.notNull(uri, "URI must not be null"); + Matcher matcher = URI_PATTERN.matcher(uri); + if (matcher.matches()) { + UriComponentsBuilder builder = new UriComponentsBuilder(); + String scheme = matcher.group(2); + String userInfo = matcher.group(5); + String host = matcher.group(6); + String port = matcher.group(8); + String path = matcher.group(9); + String query = matcher.group(11); + String fragment = matcher.group(13); + boolean opaque = false; + if (StringUtils.hasLength(scheme)) { + String rest = uri.substring(scheme.length()); + if (!rest.startsWith(":/")) { + opaque = true; + } + } + builder.scheme(scheme); + if (opaque) { + String ssp = uri.substring(scheme.length() + 1); + if (StringUtils.hasLength(fragment)) { + ssp = ssp.substring(0, ssp.length() - (fragment.length() + 1)); + } + builder.schemeSpecificPart(ssp); + } + else { + builder.userInfo(userInfo); + builder.host(host); + if (StringUtils.hasLength(port)) { + builder.port(port); + } + builder.path(path); + builder.query(query); + } + if (StringUtils.hasText(fragment)) { + builder.fragment(fragment); + } + return builder; + } + else { + throw new IllegalArgumentException("[" + uri + "] is not a valid URI"); + } + } + + /** + * Create a URI components builder from the given HTTP URL String. + *

Note: The presence of reserved characters can prevent + * correct parsing of the URI string. For example if a query parameter + * contains {@code '='} or {@code '&'} characters, the query string cannot + * be parsed unambiguously. Such values should be substituted for URI + * variables to enable correct parsing: + *

+	 * String urlString = "https://example.com/hotels/42?filter={value}";
+	 * UriComponentsBuilder.fromHttpUrl(urlString).buildAndExpand("hot&cold");
+	 * 
+ * @param httpUrl the source URI + * @return the URI components of the URI + */ + public static UriComponentsBuilder fromHttpUrl(String httpUrl) { + Assert.notNull(httpUrl, "HTTP URL must not be null"); + Matcher matcher = HTTP_URL_PATTERN.matcher(httpUrl); + if (matcher.matches()) { + UriComponentsBuilder builder = new UriComponentsBuilder(); + String scheme = matcher.group(1); + builder.scheme(scheme != null ? scheme.toLowerCase() : null); + builder.userInfo(matcher.group(4)); + String host = matcher.group(5); + if (StringUtils.hasLength(scheme) && !StringUtils.hasLength(host)) { + throw new IllegalArgumentException("[" + httpUrl + "] is not a valid HTTP URL"); + } + builder.host(host); + String port = matcher.group(7); + if (StringUtils.hasLength(port)) { + builder.port(port); + } + builder.path(matcher.group(8)); + builder.query(matcher.group(10)); + return builder; + } + else { + throw new IllegalArgumentException("[" + httpUrl + "] is not a valid HTTP URL"); + } + } + + /** + * Create a new {@code UriComponents} object from the URI associated with + * the given HttpRequest while also overlaying with values from the headers + * "Forwarded" (RFC 7239), + * or "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" if + * "Forwarded" is not found. + * @param request the source request + * @return the URI components of the URI + * @since 4.1.5 + */ + public static UriComponentsBuilder fromHttpRequest(HttpRequest request) { + return fromUri(request.getURI()).adaptFromForwardedHeaders(request.getHeaders()); + } + + /** + * Create an instance by parsing the "Origin" header of an HTTP request. + * @see RFC 6454 + */ + public static UriComponentsBuilder fromOriginHeader(String origin) { + Matcher matcher = URI_PATTERN.matcher(origin); + if (matcher.matches()) { + UriComponentsBuilder builder = new UriComponentsBuilder(); + String scheme = matcher.group(2); + String host = matcher.group(6); + String port = matcher.group(8); + if (StringUtils.hasLength(scheme)) { + builder.scheme(scheme); + } + builder.host(host); + if (StringUtils.hasLength(port)) { + builder.port(port); + } + return builder; + } + else { + throw new IllegalArgumentException("[" + origin + "] is not a valid \"Origin\" header value"); + } + } + + + // Encode methods + + /** + * Request to have the URI template pre-encoded at build time, and + * URI variables encoded separately when expanded. + *

In comparison to {@link UriComponents#encode()}, this method has the + * same effect on the URI template, i.e. each URI component is encoded by + * replacing non-ASCII and illegal (within the URI component type) characters + * with escaped octets. However URI variables are encoded more strictly, by + * also escaping characters with reserved meaning. + *

For most cases, this method is more likely to give the expected result + * because in treats URI variables as opaque data to be fully encoded, while + * {@link UriComponents#encode()} is useful only if intentionally expanding + * URI variables that contain reserved characters. + *

For example ';' is legal in a path but has reserved meaning. This + * method replaces ";" with "%3B" in URI variables but not in the URI + * template. By contrast, {@link UriComponents#encode()} never replaces ";" + * since it is a legal character in a path. + * @since 5.0.8 + */ + public final UriComponentsBuilder encode() { + return encode(StandardCharsets.UTF_8); + } + + /** + * A variant of {@link #encode()} with a charset other than "UTF-8". + * @param charset the charset to use for encoding + * @since 5.0.8 + */ + public UriComponentsBuilder encode(Charset charset) { + this.encodeTemplate = true; + this.charset = charset; + return this; + } + + + // Build methods + + /** + * Build a {@code UriComponents} instance from the various components contained in this builder. + * @return the URI components + */ + public UriComponents build() { + return build(false); + } + + /** + * Variant of {@link #build()} to create a {@link UriComponents} instance + * when components are already fully encoded. This is useful for example if + * the builder was created via {@link UriComponentsBuilder#fromUri(URI)}. + * @param encoded whether the components in this builder are already encoded + * @return the URI components + * @throws IllegalArgumentException if any of the components contain illegal + * characters that should have been encoded. + */ + public UriComponents build(boolean encoded) { + return buildInternal(encoded ? EncodingHint.FULLY_ENCODED : + (this.encodeTemplate ? EncodingHint.ENCODE_TEMPLATE : EncodingHint.NONE)); + } + + private UriComponents buildInternal(EncodingHint hint) { + UriComponents result; + if (this.ssp != null) { + result = new OpaqueUriComponents(this.scheme, this.ssp, this.fragment); + } + else { + HierarchicalUriComponents uric = new HierarchicalUriComponents(this.scheme, this.fragment, + this.userInfo, this.host, this.port, this.pathBuilder.build(), this.queryParams, + hint == EncodingHint.FULLY_ENCODED); + result = (hint == EncodingHint.ENCODE_TEMPLATE ? uric.encodeTemplate(this.charset) : uric); + } + if (!this.uriVariables.isEmpty()) { + result = result.expand(name -> this.uriVariables.getOrDefault(name, UriTemplateVariables.SKIP_VALUE)); + } + return result; + } + + /** + * Build a {@code UriComponents} instance and replaces URI template variables + * with the values from a map. This is a shortcut method which combines + * calls to {@link #build()} and then {@link UriComponents#expand(Map)}. + * @param uriVariables the map of URI variables + * @return the URI components with expanded values + */ + public UriComponents buildAndExpand(Map uriVariables) { + return build().expand(uriVariables); + } + + /** + * Build a {@code UriComponents} instance and replaces URI template variables + * with the values from an array. This is a shortcut method which combines + * calls to {@link #build()} and then {@link UriComponents#expand(Object...)}. + * @param uriVariableValues the URI variable values + * @return the URI components with expanded values + */ + public UriComponents buildAndExpand(Object... uriVariableValues) { + return build().expand(uriVariableValues); + } + + @Override + public URI build(Object... uriVariables) { + return buildInternal(EncodingHint.ENCODE_TEMPLATE).expand(uriVariables).toUri(); + } + + @Override + public URI build(Map uriVariables) { + return buildInternal(EncodingHint.ENCODE_TEMPLATE).expand(uriVariables).toUri(); + } + + /** + * Build a URI String. + *

Effectively, a shortcut for building, encoding, and returning the + * String representation: + *

+	 * String uri = builder.build().encode().toUriString()
+	 * 
+ *

However if {@link #uriVariables(Map) URI variables} have been provided + * then the URI template is pre-encoded separately from URI variables (see + * {@link #encode()} for details), i.e. equivalent to: + *

+	 * String uri = builder.encode().build().toUriString()
+	 * 
+ * @since 4.1 + * @see UriComponents#toUriString() + */ + public String toUriString() { + return (this.uriVariables.isEmpty() ? build().encode().toUriString() : + buildInternal(EncodingHint.ENCODE_TEMPLATE).toUriString()); + } + + + // Instance methods + + /** + * Initialize components of this builder from components of the given URI. + * @param uri the URI + * @return this UriComponentsBuilder + */ + public UriComponentsBuilder uri(URI uri) { + Assert.notNull(uri, "URI must not be null"); + this.scheme = uri.getScheme(); + if (uri.isOpaque()) { + this.ssp = uri.getRawSchemeSpecificPart(); + resetHierarchicalComponents(); + } + else { + if (uri.getRawUserInfo() != null) { + this.userInfo = uri.getRawUserInfo(); + } + if (uri.getHost() != null) { + this.host = uri.getHost(); + } + if (uri.getPort() != -1) { + this.port = String.valueOf(uri.getPort()); + } + if (StringUtils.hasLength(uri.getRawPath())) { + this.pathBuilder = new CompositePathComponentBuilder(); + this.pathBuilder.addPath(uri.getRawPath()); + } + if (StringUtils.hasLength(uri.getRawQuery())) { + this.queryParams.clear(); + query(uri.getRawQuery()); + } + resetSchemeSpecificPart(); + } + if (uri.getRawFragment() != null) { + this.fragment = uri.getRawFragment(); + } + return this; + } + + /** + * Set or append individual URI components of this builder from the values + * of the given {@link UriComponents} instance. + *

For the semantics of each component (i.e. set vs append) check the + * builder methods on this class. For example {@link #host(String)} sets + * while {@link #path(String)} appends. + * @param uriComponents the UriComponents to copy from + * @return this UriComponentsBuilder + */ + public UriComponentsBuilder uriComponents(UriComponents uriComponents) { + Assert.notNull(uriComponents, "UriComponents must not be null"); + uriComponents.copyToUriComponentsBuilder(this); + return this; + } + + /** + * Set the URI scheme. The given scheme may contain URI template variables, + * and may also be {@code null} to clear the scheme of this builder. + * @param scheme the URI scheme + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder scheme(@Nullable String scheme) { + this.scheme = scheme; + return this; + } + + /** + * Set the URI scheme-specific-part. When invoked, this method overwrites + * {@linkplain #userInfo(String) user-info}, {@linkplain #host(String) host}, + * {@linkplain #port(int) port}, {@linkplain #path(String) path}, and + * {@link #query(String) query}. + * @param ssp the URI scheme-specific-part, may contain URI template parameters + * @return this UriComponentsBuilder + */ + public UriComponentsBuilder schemeSpecificPart(String ssp) { + this.ssp = ssp; + resetHierarchicalComponents(); + return this; + } + + /** + * Set the URI user info. The given user info may contain URI template variables, + * and may also be {@code null} to clear the user info of this builder. + * @param userInfo the URI user info + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder userInfo(@Nullable String userInfo) { + this.userInfo = userInfo; + resetSchemeSpecificPart(); + return this; + } + + /** + * Set the URI host. The given host may contain URI template variables, + * and may also be {@code null} to clear the host of this builder. + * @param host the URI host + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder host(@Nullable String host) { + this.host = host; + resetSchemeSpecificPart(); + return this; + } + + /** + * Set the URI port. Passing {@code -1} will clear the port of this builder. + * @param port the URI port + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder port(int port) { + Assert.isTrue(port >= -1, "Port must be >= -1"); + this.port = String.valueOf(port); + resetSchemeSpecificPart(); + return this; + } + + /** + * Set the URI port. Use this method only when the port needs to be + * parameterized with a URI variable. Otherwise use {@link #port(int)}. + * Passing {@code null} will clear the port of this builder. + * @param port the URI port + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder port(@Nullable String port) { + this.port = port; + resetSchemeSpecificPart(); + return this; + } + + /** + * Append the given path to the existing path of this builder. + * The given path may contain URI template variables. + * @param path the URI path + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder path(String path) { + this.pathBuilder.addPath(path); + resetSchemeSpecificPart(); + return this; + } + + /** + * Append path segments to the existing path. Each path segment may contain + * URI template variables and should not contain any slashes. + * Use {@code path("/")} subsequently to ensure a trailing slash. + * @param pathSegments the URI path segments + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder pathSegment(String... pathSegments) throws IllegalArgumentException { + this.pathBuilder.addPathSegments(pathSegments); + resetSchemeSpecificPart(); + return this; + } + + /** + * Set the path of this builder overriding all existing path and path segment values. + * @param path the URI path (a {@code null} value results in an empty path) + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder replacePath(@Nullable String path) { + this.pathBuilder = new CompositePathComponentBuilder(); + if (path != null) { + this.pathBuilder.addPath(path); + } + resetSchemeSpecificPart(); + return this; + } + + /** + * Append the given query to the existing query of this builder. + * The given query may contain URI template variables. + *

Note: The presence of reserved characters can prevent + * correct parsing of the URI string. For example if a query parameter + * contains {@code '='} or {@code '&'} characters, the query string cannot + * be parsed unambiguously. Such values should be substituted for URI + * variables to enable correct parsing: + *

+	 * UriComponentsBuilder.fromUriString("/hotels/42")
+	 * 	.query("filter={value}")
+	 * 	.buildAndExpand("hot&cold");
+	 * 
+ * @param query the query string + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder query(@Nullable String query) { + if (query != null) { + Matcher matcher = QUERY_PARAM_PATTERN.matcher(query); + while (matcher.find()) { + String name = matcher.group(1); + String eq = matcher.group(2); + String value = matcher.group(3); + queryParam(name, (value != null ? value : (StringUtils.hasLength(eq) ? "" : null))); + } + } + else { + this.queryParams.clear(); + } + resetSchemeSpecificPart(); + return this; + } + + /** + * Set the query of this builder overriding all existing query parameters. + * @param query the query string; a {@code null} value removes all query parameters. + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder replaceQuery(@Nullable String query) { + this.queryParams.clear(); + if (query != null) { + query(query); + } + resetSchemeSpecificPart(); + return this; + } + + /** + * Append the given query parameter to the existing query parameters. The + * given name or any of the values may contain URI template variables. If no + * values are given, the resulting URI will contain the query parameter name + * only (i.e. {@code ?foo} instead of {@code ?foo=bar}). + * @param name the query parameter name + * @param values the query parameter values + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder queryParam(String name, Object... values) { + Assert.notNull(name, "Name must not be null"); + if (!ObjectUtils.isEmpty(values)) { + for (Object value : values) { + String valueAsString = (value != null ? value.toString() : null); + this.queryParams.add(name, valueAsString); + } + } + else { + this.queryParams.add(name, null); + } + resetSchemeSpecificPart(); + return this; + } + + /** + * Add the given query parameters. + * @param params the params + * @return this UriComponentsBuilder + * @since 4.0 + */ + @Override + public UriComponentsBuilder queryParams(@Nullable MultiValueMap params) { + if (params != null) { + this.queryParams.addAll(params); + } + return this; + } + + /** + * Set the query parameter values overriding all existing query values for + * the same parameter. If no values are given, the query parameter is removed. + * @param name the query parameter name + * @param values the query parameter values + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder replaceQueryParam(String name, Object... values) { + Assert.notNull(name, "Name must not be null"); + this.queryParams.remove(name); + if (!ObjectUtils.isEmpty(values)) { + queryParam(name, values); + } + resetSchemeSpecificPart(); + return this; + } + + /** + * Set the query parameter values overriding all existing query values. + * @param params the query parameter name + * @return this UriComponentsBuilder + * @since 4.2 + */ + @Override + public UriComponentsBuilder replaceQueryParams(@Nullable MultiValueMap params) { + this.queryParams.clear(); + if (params != null) { + this.queryParams.putAll(params); + } + return this; + } + + /** + * Set the URI fragment. The given fragment may contain URI template variables, + * and may also be {@code null} to clear the fragment of this builder. + * @param fragment the URI fragment + * @return this UriComponentsBuilder + */ + @Override + public UriComponentsBuilder fragment(@Nullable String fragment) { + if (fragment != null) { + Assert.hasLength(fragment, "Fragment must not be empty"); + this.fragment = fragment; + } + else { + this.fragment = null; + } + return this; + } + + /** + * Configure URI variables to be expanded at build time. + *

The provided variables may be a subset of all required ones. At build + * time, the available ones are expanded, while unresolved URI placeholders + * are left in place and can still be expanded later. + *

In contrast to {@link UriComponents#expand(Map)} or + * {@link #buildAndExpand(Map)}, this method is useful when you need to + * supply URI variables without building the {@link UriComponents} instance + * just yet, or perhaps pre-expand some shared default values such as host + * and port. + * @param uriVariables the URI variables to use + * @return this UriComponentsBuilder + * @since 5.0.8 + */ + public UriComponentsBuilder uriVariables(Map uriVariables) { + this.uriVariables.putAll(uriVariables); + return this; + } + + /** + * Adapt this builder's scheme+host+port from the given headers, specifically + * "Forwarded" (RFC 7239, + * or "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" if + * "Forwarded" is not found. + *

Note: this method uses values from forwarded headers, + * if present, in order to reflect the client-originated protocol and address. + * Consider using the {@code ForwardedHeaderFilter} in order to choose from a + * central place whether to extract and use, or to discard such headers. + * See the Spring Framework reference for more on this filter. + * @param headers the HTTP headers to consider + * @return this UriComponentsBuilder + * @since 4.2.7 + */ + UriComponentsBuilder adaptFromForwardedHeaders(HttpHeaders headers) { + try { + String forwardedHeader = headers.getFirst("Forwarded"); + if (StringUtils.hasText(forwardedHeader)) { + String forwardedToUse = StringUtils.tokenizeToStringArray(forwardedHeader, ",")[0]; + Matcher matcher = FORWARDED_PROTO_PATTERN.matcher(forwardedToUse); + if (matcher.find()) { + scheme(matcher.group(1).trim()); + port(null); + } + else if (isForwardedSslOn(headers)) { + scheme("https"); + port(null); + } + matcher = FORWARDED_HOST_PATTERN.matcher(forwardedToUse); + if (matcher.find()) { + adaptForwardedHost(matcher.group(1).trim()); + } + } + else { + String protocolHeader = headers.getFirst("X-Forwarded-Proto"); + if (StringUtils.hasText(protocolHeader)) { + scheme(StringUtils.tokenizeToStringArray(protocolHeader, ",")[0]); + port(null); + } + else if (isForwardedSslOn(headers)) { + scheme("https"); + port(null); + } + String hostHeader = headers.getFirst("X-Forwarded-Host"); + if (StringUtils.hasText(hostHeader)) { + adaptForwardedHost(StringUtils.tokenizeToStringArray(hostHeader, ",")[0]); + } + String portHeader = headers.getFirst("X-Forwarded-Port"); + if (StringUtils.hasText(portHeader)) { + port(Integer.parseInt(StringUtils.tokenizeToStringArray(portHeader, ",")[0])); + } + } + } + catch (NumberFormatException ex) { + throw new IllegalArgumentException("Failed to parse a port from \"forwarded\"-type headers. " + + "If not behind a trusted proxy, consider using ForwardedHeaderFilter " + + "with the removeOnly=true. Request headers: " + headers); + } + + if (this.scheme != null && ((this.scheme.equals("http") && "80".equals(this.port)) || + (this.scheme.equals("https") && "443".equals(this.port)))) { + port(null); + } + + return this; + } + + private boolean isForwardedSslOn(HttpHeaders headers) { + String forwardedSsl = headers.getFirst("X-Forwarded-Ssl"); + return StringUtils.hasText(forwardedSsl) && forwardedSsl.equalsIgnoreCase("on"); + } + + private void adaptForwardedHost(String hostToUse) { + int portSeparatorIdx = hostToUse.lastIndexOf(':'); + if (portSeparatorIdx > hostToUse.lastIndexOf(']')) { + host(hostToUse.substring(0, portSeparatorIdx)); + port(Integer.parseInt(hostToUse.substring(portSeparatorIdx + 1))); + } + else { + host(hostToUse); + port(null); + } + } + + private void resetHierarchicalComponents() { + this.userInfo = null; + this.host = null; + this.port = null; + this.pathBuilder = new CompositePathComponentBuilder(); + this.queryParams.clear(); + } + + private void resetSchemeSpecificPart() { + this.ssp = null; + } + + + /** + * Public declaration of Object's {@code clone()} method. + * Delegates to {@link #cloneBuilder()}. + */ + @Override + public Object clone() { + return cloneBuilder(); + } + + /** + * Clone this {@code UriComponentsBuilder}. + * @return the cloned {@code UriComponentsBuilder} object + * @since 4.2.7 + */ + public UriComponentsBuilder cloneBuilder() { + return new UriComponentsBuilder(this); + } + + + private interface PathComponentBuilder { + + @Nullable + PathComponent build(); + + PathComponentBuilder cloneBuilder(); + } + + + private static class CompositePathComponentBuilder implements PathComponentBuilder { + + private final LinkedList builders = new LinkedList<>(); + + public void addPathSegments(String... pathSegments) { + if (!ObjectUtils.isEmpty(pathSegments)) { + PathSegmentComponentBuilder psBuilder = getLastBuilder(PathSegmentComponentBuilder.class); + FullPathComponentBuilder fpBuilder = getLastBuilder(FullPathComponentBuilder.class); + if (psBuilder == null) { + psBuilder = new PathSegmentComponentBuilder(); + this.builders.add(psBuilder); + if (fpBuilder != null) { + fpBuilder.removeTrailingSlash(); + } + } + psBuilder.append(pathSegments); + } + } + + public void addPath(String path) { + if (StringUtils.hasText(path)) { + PathSegmentComponentBuilder psBuilder = getLastBuilder(PathSegmentComponentBuilder.class); + FullPathComponentBuilder fpBuilder = getLastBuilder(FullPathComponentBuilder.class); + if (psBuilder != null) { + path = (path.startsWith("/") ? path : "/" + path); + } + if (fpBuilder == null) { + fpBuilder = new FullPathComponentBuilder(); + this.builders.add(fpBuilder); + } + fpBuilder.append(path); + } + } + + @SuppressWarnings("unchecked") + @Nullable + private T getLastBuilder(Class builderClass) { + if (!this.builders.isEmpty()) { + PathComponentBuilder last = this.builders.getLast(); + if (builderClass.isInstance(last)) { + return (T) last; + } + } + return null; + } + + @Override + public PathComponent build() { + int size = this.builders.size(); + List components = new ArrayList<>(size); + for (PathComponentBuilder componentBuilder : this.builders) { + PathComponent pathComponent = componentBuilder.build(); + if (pathComponent != null) { + components.add(pathComponent); + } + } + if (components.isEmpty()) { + return HierarchicalUriComponents.NULL_PATH_COMPONENT; + } + if (components.size() == 1) { + return components.get(0); + } + return new HierarchicalUriComponents.PathComponentComposite(components); + } + + @Override + public CompositePathComponentBuilder cloneBuilder() { + CompositePathComponentBuilder compositeBuilder = new CompositePathComponentBuilder(); + for (PathComponentBuilder builder : this.builders) { + compositeBuilder.builders.add(builder.cloneBuilder()); + } + return compositeBuilder; + } + } + + + private static class FullPathComponentBuilder implements PathComponentBuilder { + + private final StringBuilder path = new StringBuilder(); + + public void append(String path) { + this.path.append(path); + } + + @Override + public PathComponent build() { + if (this.path.length() == 0) { + return null; + } + String path = this.path.toString(); + while (true) { + int index = path.indexOf("//"); + if (index == -1) { + break; + } + path = path.substring(0, index) + path.substring(index + 1); + } + return new HierarchicalUriComponents.FullPathComponent(path); + } + + public void removeTrailingSlash() { + int index = this.path.length() - 1; + if (this.path.charAt(index) == '/') { + this.path.deleteCharAt(index); + } + } + + @Override + public FullPathComponentBuilder cloneBuilder() { + FullPathComponentBuilder builder = new FullPathComponentBuilder(); + builder.append(this.path.toString()); + return builder; + } + } + + + private static class PathSegmentComponentBuilder implements PathComponentBuilder { + + private final List pathSegments = new LinkedList<>(); + + public void append(String... pathSegments) { + for (String pathSegment : pathSegments) { + if (StringUtils.hasText(pathSegment)) { + this.pathSegments.add(pathSegment); + } + } + } + + @Override + public PathComponent build() { + return (this.pathSegments.isEmpty() ? null : + new HierarchicalUriComponents.PathSegmentComponent(this.pathSegments)); + } + + @Override + public PathSegmentComponentBuilder cloneBuilder() { + PathSegmentComponentBuilder builder = new PathSegmentComponentBuilder(); + builder.pathSegments.addAll(this.pathSegments); + return builder; + } + } + + + private enum EncodingHint { ENCODE_TEMPLATE, FULLY_ENCODED, NONE } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriTemplate.java b/spring-web/src/main/java/org/springframework/web/util/UriTemplate.java new file mode 100644 index 0000000000000000000000000000000000000000..ffe01470a7d05bea83906c38449dbf3d969caa0c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UriTemplate.java @@ -0,0 +1,251 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.Serializable; +import java.net.URI; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Representation of a URI template that can be expanded with URI variables via + * {@link #expand(Map)}, {@link #expand(Object[])}, or matched to a URL via + * {@link #match(String)}. This class is designed to be thread-safe and + * reusable, and allows any number of expand or match calls. + * + *

Note: this class uses {@link UriComponentsBuilder} + * internally to expand URI templates, and is merely a shortcut for already + * prepared URI templates. For more dynamic preparation and extra flexibility, + * e.g. around URI encoding, consider using {@code UriComponentsBuilder} or the + * higher level {@link DefaultUriBuilderFactory} which adds several encoding + * modes on top of {@code UriComponentsBuilder}. See the + * reference docs + * for further details. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 3.0 + */ +@SuppressWarnings("serial") +public class UriTemplate implements Serializable { + + private final String uriTemplate; + + private final UriComponents uriComponents; + + private final List variableNames; + + private final Pattern matchPattern; + + + /** + * Construct a new {@code UriTemplate} with the given URI String. + * @param uriTemplate the URI template string + */ + public UriTemplate(String uriTemplate) { + Assert.hasText(uriTemplate, "'uriTemplate' must not be null"); + this.uriTemplate = uriTemplate; + this.uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).build(); + + TemplateInfo info = TemplateInfo.parse(uriTemplate); + this.variableNames = Collections.unmodifiableList(info.getVariableNames()); + this.matchPattern = info.getMatchPattern(); + } + + + /** + * Return the names of the variables in the template, in order. + * @return the template variable names + */ + public List getVariableNames() { + return this.variableNames; + } + + /** + * Given the Map of variables, expands this template into a URI. The Map keys represent variable names, + * the Map values variable values. The order of variables is not significant. + *

Example: + *

+	 * UriTemplate template = new UriTemplate("https://example.com/hotels/{hotel}/bookings/{booking}");
+	 * Map<String, String> uriVariables = new HashMap<String, String>();
+	 * uriVariables.put("booking", "42");
+	 * uriVariables.put("hotel", "Rest & Relax");
+	 * System.out.println(template.expand(uriVariables));
+	 * 
+ * will print:
{@code https://example.com/hotels/Rest%20%26%20Relax/bookings/42}
+ * @param uriVariables the map of URI variables + * @return the expanded URI + * @throws IllegalArgumentException if {@code uriVariables} is {@code null}; + * or if it does not contain values for all the variable names + */ + public URI expand(Map uriVariables) { + UriComponents expandedComponents = this.uriComponents.expand(uriVariables); + UriComponents encodedComponents = expandedComponents.encode(); + return encodedComponents.toUri(); + } + + /** + * Given an array of variables, expand this template into a full URI. The array represent variable values. + * The order of variables is significant. + *

Example: + *

+	 * UriTemplate template = new UriTemplate("https://example.com/hotels/{hotel}/bookings/{booking}");
+	 * System.out.println(template.expand("Rest & Relax", 42));
+	 * 
+ * will print:
{@code https://example.com/hotels/Rest%20%26%20Relax/bookings/42}
+ * @param uriVariableValues the array of URI variables + * @return the expanded URI + * @throws IllegalArgumentException if {@code uriVariables} is {@code null} + * or if it does not contain sufficient variables + */ + public URI expand(Object... uriVariableValues) { + UriComponents expandedComponents = this.uriComponents.expand(uriVariableValues); + UriComponents encodedComponents = expandedComponents.encode(); + return encodedComponents.toUri(); + } + + /** + * Indicate whether the given URI matches this template. + * @param uri the URI to match to + * @return {@code true} if it matches; {@code false} otherwise + */ + public boolean matches(@Nullable String uri) { + if (uri == null) { + return false; + } + Matcher matcher = this.matchPattern.matcher(uri); + return matcher.matches(); + } + + /** + * Match the given URI to a map of variable values. Keys in the returned map are variable names, + * values are variable values, as occurred in the given URI. + *

Example: + *

+	 * UriTemplate template = new UriTemplate("https://example.com/hotels/{hotel}/bookings/{booking}");
+	 * System.out.println(template.match("https://example.com/hotels/1/bookings/42"));
+	 * 
+ * will print:
{@code {hotel=1, booking=42}}
+ * @param uri the URI to match to + * @return a map of variable values + */ + public Map match(String uri) { + Assert.notNull(uri, "'uri' must not be null"); + Map result = new LinkedHashMap<>(this.variableNames.size()); + Matcher matcher = this.matchPattern.matcher(uri); + if (matcher.find()) { + for (int i = 1; i <= matcher.groupCount(); i++) { + String name = this.variableNames.get(i - 1); + String value = matcher.group(i); + result.put(name, value); + } + } + return result; + } + + @Override + public String toString() { + return this.uriTemplate; + } + + + /** + * Helper to extract variable names and regex for matching to actual URLs. + */ + private static final class TemplateInfo { + + private final List variableNames; + + private final Pattern pattern; + + private TemplateInfo(List vars, Pattern pattern) { + this.variableNames = vars; + this.pattern = pattern; + } + + public List getVariableNames() { + return this.variableNames; + } + + public Pattern getMatchPattern() { + return this.pattern; + } + + public static TemplateInfo parse(String uriTemplate) { + int level = 0; + List variableNames = new ArrayList<>(); + StringBuilder pattern = new StringBuilder(); + StringBuilder builder = new StringBuilder(); + for (int i = 0 ; i < uriTemplate.length(); i++) { + char c = uriTemplate.charAt(i); + if (c == '{') { + level++; + if (level == 1) { + // start of URI variable + pattern.append(quote(builder)); + builder = new StringBuilder(); + continue; + } + } + else if (c == '}') { + level--; + if (level == 0) { + // end of URI variable + String variable = builder.toString(); + int idx = variable.indexOf(':'); + if (idx == -1) { + pattern.append("([^/]*)"); + variableNames.add(variable); + } + else { + if (idx + 1 == variable.length()) { + throw new IllegalArgumentException( + "No custom regular expression specified after ':' in \"" + variable + "\""); + } + String regex = variable.substring(idx + 1); + pattern.append('('); + pattern.append(regex); + pattern.append(')'); + variableNames.add(variable.substring(0, idx)); + } + builder = new StringBuilder(); + continue; + } + } + builder.append(c); + } + if (builder.length() > 0) { + pattern.append(quote(builder)); + } + return new TemplateInfo(variableNames, Pattern.compile(pattern.toString())); + } + + private static String quote(StringBuilder builder) { + return (builder.length() > 0 ? Pattern.quote(builder.toString()) : ""); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriTemplateHandler.java b/spring-web/src/main/java/org/springframework/web/util/UriTemplateHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..0b7d06504b262323eb502867c62492d5c7cb5dfb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UriTemplateHandler.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.util.Map; + +/** + * Defines methods for expanding a URI template with variables. + * + * @author Rossen Stoyanchev + * @since 4.2 + * @see org.springframework.web.client.RestTemplate#setUriTemplateHandler(UriTemplateHandler) + */ +public interface UriTemplateHandler { + + /** + * Expand the given URI template with a map of URI variables. + * @param uriTemplate the URI template + * @param uriVariables variable values + * @return the created URI instance + */ + URI expand(String uriTemplate, Map uriVariables); + + /** + * Expand the given URI template with an array of URI variables. + * @param uriTemplate the URI template + * @param uriVariables variable values + * @return the created URI instance + */ + URI expand(String uriTemplate, Object... uriVariables); + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriUtils.java b/spring-web/src/main/java/org/springframework/web/util/UriUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..c2d5d29ba8437022edc0f84d8109e86f11847f76 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UriUtils.java @@ -0,0 +1,386 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * Utility methods for URI encoding and decoding based on RFC 3986. + * + *

There are two types of encode methods: + *

    + *
  • {@code "encodeXyz"} -- these encode a specific URI component (e.g. path, + * query) by percent encoding illegal characters, which includes non-US-ASCII + * characters, and also characters that are otherwise illegal within the given + * URI component type, as defined in RFC 3986. The effect of this method, with + * regards to encoding, is comparable to using the multi-argument constructor + * of {@link URI}. + *
  • {@code "encode"} and {@code "encodeUriVariables"} -- these can be used + * to encode URI variable values by percent encoding all characters that are + * either illegal, or have any reserved meaning, anywhere within a URI. + *
+ * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 3.0 + * @see RFC 3986 + */ +public abstract class UriUtils { + + /** + * Encode the given URI scheme with the given encoding. + * @param scheme the scheme to be encoded + * @param encoding the character encoding to encode to + * @return the encoded scheme + */ + public static String encodeScheme(String scheme, String encoding) { + return encode(scheme, encoding, HierarchicalUriComponents.Type.SCHEME); + } + + /** + * Encode the given URI scheme with the given encoding. + * @param scheme the scheme to be encoded + * @param charset the character encoding to encode to + * @return the encoded scheme + * @since 5.0 + */ + public static String encodeScheme(String scheme, Charset charset) { + return encode(scheme, charset, HierarchicalUriComponents.Type.SCHEME); + } + + /** + * Encode the given URI authority with the given encoding. + * @param authority the authority to be encoded + * @param encoding the character encoding to encode to + * @return the encoded authority + */ + public static String encodeAuthority(String authority, String encoding) { + return encode(authority, encoding, HierarchicalUriComponents.Type.AUTHORITY); + } + + /** + * Encode the given URI authority with the given encoding. + * @param authority the authority to be encoded + * @param charset the character encoding to encode to + * @return the encoded authority + * @since 5.0 + */ + public static String encodeAuthority(String authority, Charset charset) { + return encode(authority, charset, HierarchicalUriComponents.Type.AUTHORITY); + } + + /** + * Encode the given URI user info with the given encoding. + * @param userInfo the user info to be encoded + * @param encoding the character encoding to encode to + * @return the encoded user info + */ + public static String encodeUserInfo(String userInfo, String encoding) { + return encode(userInfo, encoding, HierarchicalUriComponents.Type.USER_INFO); + } + + /** + * Encode the given URI user info with the given encoding. + * @param userInfo the user info to be encoded + * @param charset the character encoding to encode to + * @return the encoded user info + * @since 5.0 + */ + public static String encodeUserInfo(String userInfo, Charset charset) { + return encode(userInfo, charset, HierarchicalUriComponents.Type.USER_INFO); + } + + /** + * Encode the given URI host with the given encoding. + * @param host the host to be encoded + * @param encoding the character encoding to encode to + * @return the encoded host + */ + public static String encodeHost(String host, String encoding) { + return encode(host, encoding, HierarchicalUriComponents.Type.HOST_IPV4); + } + + /** + * Encode the given URI host with the given encoding. + * @param host the host to be encoded + * @param charset the character encoding to encode to + * @return the encoded host + * @since 5.0 + */ + public static String encodeHost(String host, Charset charset) { + return encode(host, charset, HierarchicalUriComponents.Type.HOST_IPV4); + } + + /** + * Encode the given URI port with the given encoding. + * @param port the port to be encoded + * @param encoding the character encoding to encode to + * @return the encoded port + */ + public static String encodePort(String port, String encoding) { + return encode(port, encoding, HierarchicalUriComponents.Type.PORT); + } + + /** + * Encode the given URI port with the given encoding. + * @param port the port to be encoded + * @param charset the character encoding to encode to + * @return the encoded port + * @since 5.0 + */ + public static String encodePort(String port, Charset charset) { + return encode(port, charset, HierarchicalUriComponents.Type.PORT); + } + + /** + * Encode the given URI path with the given encoding. + * @param path the path to be encoded + * @param encoding the character encoding to encode to + * @return the encoded path + */ + public static String encodePath(String path, String encoding) { + return encode(path, encoding, HierarchicalUriComponents.Type.PATH); + } + + /** + * Encode the given URI path with the given encoding. + * @param path the path to be encoded + * @param charset the character encoding to encode to + * @return the encoded path + * @since 5.0 + */ + public static String encodePath(String path, Charset charset) { + return encode(path, charset, HierarchicalUriComponents.Type.PATH); + } + + /** + * Encode the given URI path segment with the given encoding. + * @param segment the segment to be encoded + * @param encoding the character encoding to encode to + * @return the encoded segment + */ + public static String encodePathSegment(String segment, String encoding) { + return encode(segment, encoding, HierarchicalUriComponents.Type.PATH_SEGMENT); + } + + /** + * Encode the given URI path segment with the given encoding. + * @param segment the segment to be encoded + * @param charset the character encoding to encode to + * @return the encoded segment + * @since 5.0 + */ + public static String encodePathSegment(String segment, Charset charset) { + return encode(segment, charset, HierarchicalUriComponents.Type.PATH_SEGMENT); + } + + /** + * Encode the given URI query with the given encoding. + * @param query the query to be encoded + * @param encoding the character encoding to encode to + * @return the encoded query + */ + public static String encodeQuery(String query, String encoding) { + return encode(query, encoding, HierarchicalUriComponents.Type.QUERY); + } + + /** + * Encode the given URI query with the given encoding. + * @param query the query to be encoded + * @param charset the character encoding to encode to + * @return the encoded query + * @since 5.0 + */ + public static String encodeQuery(String query, Charset charset) { + return encode(query, charset, HierarchicalUriComponents.Type.QUERY); + } + + /** + * Encode the given URI query parameter with the given encoding. + * @param queryParam the query parameter to be encoded + * @param encoding the character encoding to encode to + * @return the encoded query parameter + */ + public static String encodeQueryParam(String queryParam, String encoding) { + + return encode(queryParam, encoding, HierarchicalUriComponents.Type.QUERY_PARAM); + } + + /** + * Encode the given URI query parameter with the given encoding. + * @param queryParam the query parameter to be encoded + * @param charset the character encoding to encode to + * @return the encoded query parameter + * @since 5.0 + */ + public static String encodeQueryParam(String queryParam, Charset charset) { + return encode(queryParam, charset, HierarchicalUriComponents.Type.QUERY_PARAM); + } + + /** + * Encode the given URI fragment with the given encoding. + * @param fragment the fragment to be encoded + * @param encoding the character encoding to encode to + * @return the encoded fragment + */ + public static String encodeFragment(String fragment, String encoding) { + return encode(fragment, encoding, HierarchicalUriComponents.Type.FRAGMENT); + } + + /** + * Encode the given URI fragment with the given encoding. + * @param fragment the fragment to be encoded + * @param charset the character encoding to encode to + * @return the encoded fragment + * @since 5.0 + */ + public static String encodeFragment(String fragment, Charset charset) { + return encode(fragment, charset, HierarchicalUriComponents.Type.FRAGMENT); + } + + + /** + * Variant of {@link #encode(String, Charset)} with a String charset. + * @param source the String to be encoded + * @param encoding the character encoding to encode to + * @return the encoded String + */ + public static String encode(String source, String encoding) { + return encode(source, encoding, HierarchicalUriComponents.Type.URI); + } + + /** + * Encode all characters that are either illegal, or have any reserved + * meaning, anywhere within a URI, as defined in + * RFC 3986. + * This is useful to ensure that the given String will be preserved as-is + * and will not have any o impact on the structure or meaning of the URI. + * @param source the String to be encoded + * @param charset the character encoding to encode to + * @return the encoded String + * @since 5.0 + */ + public static String encode(String source, Charset charset) { + return encode(source, charset, HierarchicalUriComponents.Type.URI); + } + + /** + * Convenience method to apply {@link #encode(String, Charset)} to all + * given URI variable values. + * @param uriVariables the URI variable values to be encoded + * @return the encoded String + * @since 5.0 + */ + public static Map encodeUriVariables(Map uriVariables) { + Map result = new LinkedHashMap<>(uriVariables.size()); + uriVariables.forEach((key, value) -> { + String stringValue = (value != null ? value.toString() : ""); + result.put(key, encode(stringValue, StandardCharsets.UTF_8)); + }); + return result; + } + + /** + * Convenience method to apply {@link #encode(String, Charset)} to all + * given URI variable values. + * @param uriVariables the URI variable values to be encoded + * @return the encoded String + * @since 5.0 + */ + public static Object[] encodeUriVariables(Object... uriVariables) { + return Arrays.stream(uriVariables) + .map(value -> { + String stringValue = (value != null ? value.toString() : ""); + return encode(stringValue, StandardCharsets.UTF_8); + }) + .toArray(); + } + + private static String encode(String scheme, String encoding, HierarchicalUriComponents.Type type) { + return HierarchicalUriComponents.encodeUriComponent(scheme, encoding, type); + } + + private static String encode(String scheme, Charset charset, HierarchicalUriComponents.Type type) { + return HierarchicalUriComponents.encodeUriComponent(scheme, charset, type); + } + + + /** + * Decode the given encoded URI component. + *

See {@link StringUtils#uriDecode(String, Charset)} for the decoding rules. + * @param source the encoded String + * @param encoding the character encoding to use + * @return the decoded value + * @throws IllegalArgumentException when the given source contains invalid encoded sequences + * @see StringUtils#uriDecode(String, Charset) + * @see java.net.URLDecoder#decode(String, String) + */ + public static String decode(String source, String encoding) { + return StringUtils.uriDecode(source, Charset.forName(encoding)); + } + + /** + * Decode the given encoded URI component. + *

See {@link StringUtils#uriDecode(String, Charset)} for the decoding rules. + * @param source the encoded String + * @param charset the character encoding to use + * @return the decoded value + * @throws IllegalArgumentException when the given source contains invalid encoded sequences + * @since 5.0 + * @see StringUtils#uriDecode(String, Charset) + * @see java.net.URLDecoder#decode(String, String) + */ + public static String decode(String source, Charset charset) { + return StringUtils.uriDecode(source, charset); + } + + /** + * Extract the file extension from the given URI path. + * @param path the URI path (e.g. "/products/index.html") + * @return the extracted file extension (e.g. "html") + * @since 4.3.2 + */ + @Nullable + public static String extractFileExtension(String path) { + int end = path.indexOf('?'); + int fragmentIndex = path.indexOf('#'); + if (fragmentIndex != -1 && (end == -1 || fragmentIndex < end)) { + end = fragmentIndex; + } + if (end == -1) { + end = path.length(); + } + int begin = path.lastIndexOf('/', end) + 1; + int paramIndex = path.indexOf(';', begin); + end = (paramIndex != -1 && paramIndex < end ? paramIndex : end); + int extIndex = path.lastIndexOf('.', end); + if (extIndex != -1 && extIndex > begin) { + return path.substring(extIndex + 1, end); + } + return null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UrlPathHelper.java b/spring-web/src/main/java/org/springframework/web/util/UrlPathHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..2c5676a70914b602aa005d95ac31dc6da3af0686 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/UrlPathHelper.java @@ -0,0 +1,675 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URLDecoder; +import java.nio.charset.UnsupportedCharsetException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Properties; + +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Helper class for URL path matching. Provides support for URL paths in + * {@code RequestDispatcher} includes and support for consistent URL decoding. + * + *

Used by {@link org.springframework.web.servlet.handler.AbstractUrlHandlerMapping} + * and {@link org.springframework.web.servlet.support.RequestContext} for path matching + * and/or URI determination. + * + * @author Juergen Hoeller + * @author Rob Harrop + * @author Rossen Stoyanchev + * @since 14.01.2004 + * @see #getLookupPathForRequest + * @see javax.servlet.RequestDispatcher + */ +public class UrlPathHelper { + + /** + * Special WebSphere request attribute, indicating the original request URI. + * Preferable over the standard Servlet 2.4 forward attribute on WebSphere, + * simply because we need the very first URI in the request forwarding chain. + */ + private static final String WEBSPHERE_URI_ATTRIBUTE = "com.ibm.websphere.servlet.uri_non_decoded"; + + private static final Log logger = LogFactory.getLog(UrlPathHelper.class); + + @Nullable + static volatile Boolean websphereComplianceFlag; + + + private boolean alwaysUseFullPath = false; + + private boolean urlDecode = true; + + private boolean removeSemicolonContent = true; + + private String defaultEncoding = WebUtils.DEFAULT_CHARACTER_ENCODING; + + private boolean readOnly = false; + + + /** + * Whether URL lookups should always use the full path within the current + * web application context, i.e. within + * {@link javax.servlet.ServletContext#getContextPath()}. + *

If set to {@literal false} the path within the current servlet mapping + * is used instead if applicable (i.e. in the case of a prefix based Servlet + * mapping such as "/myServlet/*"). + *

By default this is set to "false". + */ + public void setAlwaysUseFullPath(boolean alwaysUseFullPath) { + checkReadOnly(); + this.alwaysUseFullPath = alwaysUseFullPath; + } + + /** + * Whether the context path and request URI should be decoded -- both of + * which are returned undecoded by the Servlet API, in contrast to + * the servlet path. + *

Either the request encoding or the default Servlet spec encoding + * (ISO-8859-1) is used when set to "true". + *

By default this is set to {@literal true}. + *

Note: Be aware the servlet path will not match when + * compared to encoded paths. Therefore use of {@code urlDecode=false} is + * not compatible with a prefix-based Servlet mapping and likewise implies + * also setting {@code alwaysUseFullPath=true}. + * @see #getServletPath + * @see #getContextPath + * @see #getRequestUri + * @see WebUtils#DEFAULT_CHARACTER_ENCODING + * @see javax.servlet.ServletRequest#getCharacterEncoding() + * @see java.net.URLDecoder#decode(String, String) + */ + public void setUrlDecode(boolean urlDecode) { + checkReadOnly(); + this.urlDecode = urlDecode; + } + + /** + * Whether to decode the request URI when determining the lookup path. + * @since 4.3.13 + */ + public boolean isUrlDecode() { + return this.urlDecode; + } + + /** + * Set if ";" (semicolon) content should be stripped from the request URI. + *

Default is "true". + */ + public void setRemoveSemicolonContent(boolean removeSemicolonContent) { + checkReadOnly(); + this.removeSemicolonContent = removeSemicolonContent; + } + + /** + * Whether configured to remove ";" (semicolon) content from the request URI. + */ + public boolean shouldRemoveSemicolonContent() { + checkReadOnly(); + return this.removeSemicolonContent; + } + + /** + * Set the default character encoding to use for URL decoding. + * Default is ISO-8859-1, according to the Servlet spec. + *

If the request specifies a character encoding itself, the request + * encoding will override this setting. This also allows for generically + * overriding the character encoding in a filter that invokes the + * {@code ServletRequest.setCharacterEncoding} method. + * @param defaultEncoding the character encoding to use + * @see #determineEncoding + * @see javax.servlet.ServletRequest#getCharacterEncoding() + * @see javax.servlet.ServletRequest#setCharacterEncoding(String) + * @see WebUtils#DEFAULT_CHARACTER_ENCODING + */ + public void setDefaultEncoding(String defaultEncoding) { + checkReadOnly(); + this.defaultEncoding = defaultEncoding; + } + + /** + * Return the default character encoding to use for URL decoding. + */ + protected String getDefaultEncoding() { + return this.defaultEncoding; + } + + /** + * Switch to read-only mode where further configuration changes are not allowed. + */ + private void setReadOnly() { + this.readOnly = true; + } + + private void checkReadOnly() { + Assert.isTrue(!this.readOnly, "This instance cannot be modified"); + } + + + /** + * Return the mapping lookup path for the given request, within the current + * servlet mapping if applicable, else within the web application. + *

Detects include request URL if called within a RequestDispatcher include. + * @param request current HTTP request + * @return the lookup path + * @see #getPathWithinServletMapping + * @see #getPathWithinApplication + */ + public String getLookupPathForRequest(HttpServletRequest request) { + // Always use full path within current servlet context? + if (this.alwaysUseFullPath) { + return getPathWithinApplication(request); + } + // Else, use path within current servlet mapping if applicable + String rest = getPathWithinServletMapping(request); + if (!"".equals(rest)) { + return rest; + } + else { + return getPathWithinApplication(request); + } + } + + /** + * Return the path within the servlet mapping for the given request, + * i.e. the part of the request's URL beyond the part that called the servlet, + * or "" if the whole URL has been used to identify the servlet. + *

Detects include request URL if called within a RequestDispatcher include. + *

E.g.: servlet mapping = "/*"; request URI = "/test/a" -> "/test/a". + *

E.g.: servlet mapping = "/"; request URI = "/test/a" -> "/test/a". + *

E.g.: servlet mapping = "/test/*"; request URI = "/test/a" -> "/a". + *

E.g.: servlet mapping = "/test"; request URI = "/test" -> "". + *

E.g.: servlet mapping = "/*.test"; request URI = "/a.test" -> "". + * @param request current HTTP request + * @return the path within the servlet mapping, or "" + * @see #getLookupPathForRequest + */ + public String getPathWithinServletMapping(HttpServletRequest request) { + String pathWithinApp = getPathWithinApplication(request); + String servletPath = getServletPath(request); + String sanitizedPathWithinApp = getSanitizedPath(pathWithinApp); + String path; + + // If the app container sanitized the servletPath, check against the sanitized version + if (servletPath.contains(sanitizedPathWithinApp)) { + path = getRemainingPath(sanitizedPathWithinApp, servletPath, false); + } + else { + path = getRemainingPath(pathWithinApp, servletPath, false); + } + + if (path != null) { + // Normal case: URI contains servlet path. + return path; + } + else { + // Special case: URI is different from servlet path. + String pathInfo = request.getPathInfo(); + if (pathInfo != null) { + // Use path info if available. Indicates index page within a servlet mapping? + // e.g. with index page: URI="/", servletPath="/index.html" + return pathInfo; + } + if (!this.urlDecode) { + // No path info... (not mapped by prefix, nor by extension, nor "/*") + // For the default servlet mapping (i.e. "/"), urlDecode=false can + // cause issues since getServletPath() returns a decoded path. + // If decoding pathWithinApp yields a match just use pathWithinApp. + path = getRemainingPath(decodeInternal(request, pathWithinApp), servletPath, false); + if (path != null) { + return pathWithinApp; + } + } + // Otherwise, use the full servlet path. + return servletPath; + } + } + + /** + * Return the path within the web application for the given request. + *

Detects include request URL if called within a RequestDispatcher include. + * @param request current HTTP request + * @return the path within the web application + * @see #getLookupPathForRequest + */ + public String getPathWithinApplication(HttpServletRequest request) { + String contextPath = getContextPath(request); + String requestUri = getRequestUri(request); + String path = getRemainingPath(requestUri, contextPath, true); + if (path != null) { + // Normal case: URI contains context path. + return (StringUtils.hasText(path) ? path : "/"); + } + else { + return requestUri; + } + } + + /** + * Match the given "mapping" to the start of the "requestUri" and if there + * is a match return the extra part. This method is needed because the + * context path and the servlet path returned by the HttpServletRequest are + * stripped of semicolon content unlike the requestUri. + */ + @Nullable + private String getRemainingPath(String requestUri, String mapping, boolean ignoreCase) { + int index1 = 0; + int index2 = 0; + for (; (index1 < requestUri.length()) && (index2 < mapping.length()); index1++, index2++) { + char c1 = requestUri.charAt(index1); + char c2 = mapping.charAt(index2); + if (c1 == ';') { + index1 = requestUri.indexOf('/', index1); + if (index1 == -1) { + return null; + } + c1 = requestUri.charAt(index1); + } + if (c1 == c2 || (ignoreCase && (Character.toLowerCase(c1) == Character.toLowerCase(c2)))) { + continue; + } + return null; + } + if (index2 != mapping.length()) { + return null; + } + else if (index1 == requestUri.length()) { + return ""; + } + else if (requestUri.charAt(index1) == ';') { + index1 = requestUri.indexOf('/', index1); + } + return (index1 != -1 ? requestUri.substring(index1) : ""); + } + + /** + * Sanitize the given path. Uses the following rules: + *

    + *
  • replace all "//" by "/"
  • + *
+ */ + private String getSanitizedPath(final String path) { + String sanitized = path; + while (true) { + int index = sanitized.indexOf("//"); + if (index < 0) { + break; + } + else { + sanitized = sanitized.substring(0, index) + sanitized.substring(index + 1); + } + } + return sanitized; + } + + /** + * Return the request URI for the given request, detecting an include request + * URL if called within a RequestDispatcher include. + *

As the value returned by {@code request.getRequestURI()} is not + * decoded by the servlet container, this method will decode it. + *

The URI that the web container resolves should be correct, but some + * containers like JBoss/Jetty incorrectly include ";" strings like ";jsessionid" + * in the URI. This method cuts off such incorrect appendices. + * @param request current HTTP request + * @return the request URI + */ + public String getRequestUri(HttpServletRequest request) { + String uri = (String) request.getAttribute(WebUtils.INCLUDE_REQUEST_URI_ATTRIBUTE); + if (uri == null) { + uri = request.getRequestURI(); + } + return decodeAndCleanUriString(request, uri); + } + + /** + * Return the context path for the given request, detecting an include request + * URL if called within a RequestDispatcher include. + *

As the value returned by {@code request.getContextPath()} is not + * decoded by the servlet container, this method will decode it. + * @param request current HTTP request + * @return the context path + */ + public String getContextPath(HttpServletRequest request) { + String contextPath = (String) request.getAttribute(WebUtils.INCLUDE_CONTEXT_PATH_ATTRIBUTE); + if (contextPath == null) { + contextPath = request.getContextPath(); + } + if ("/".equals(contextPath)) { + // Invalid case, but happens for includes on Jetty: silently adapt it. + contextPath = ""; + } + return decodeRequestString(request, contextPath); + } + + /** + * Return the servlet path for the given request, regarding an include request + * URL if called within a RequestDispatcher include. + *

As the value returned by {@code request.getServletPath()} is already + * decoded by the servlet container, this method will not attempt to decode it. + * @param request current HTTP request + * @return the servlet path + */ + public String getServletPath(HttpServletRequest request) { + String servletPath = (String) request.getAttribute(WebUtils.INCLUDE_SERVLET_PATH_ATTRIBUTE); + if (servletPath == null) { + servletPath = request.getServletPath(); + } + if (servletPath.length() > 1 && servletPath.endsWith("/") && shouldRemoveTrailingServletPathSlash(request)) { + // On WebSphere, in non-compliant mode, for a "/foo/" case that would be "/foo" + // on all other servlet containers: removing trailing slash, proceeding with + // that remaining slash as final lookup path... + servletPath = servletPath.substring(0, servletPath.length() - 1); + } + return servletPath; + } + + + /** + * Return the request URI for the given request. If this is a forwarded request, + * correctly resolves to the request URI of the original request. + */ + public String getOriginatingRequestUri(HttpServletRequest request) { + String uri = (String) request.getAttribute(WEBSPHERE_URI_ATTRIBUTE); + if (uri == null) { + uri = (String) request.getAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE); + if (uri == null) { + uri = request.getRequestURI(); + } + } + return decodeAndCleanUriString(request, uri); + } + + /** + * Return the context path for the given request, detecting an include request + * URL if called within a RequestDispatcher include. + *

As the value returned by {@code request.getContextPath()} is not + * decoded by the servlet container, this method will decode it. + * @param request current HTTP request + * @return the context path + */ + public String getOriginatingContextPath(HttpServletRequest request) { + String contextPath = (String) request.getAttribute(WebUtils.FORWARD_CONTEXT_PATH_ATTRIBUTE); + if (contextPath == null) { + contextPath = request.getContextPath(); + } + return decodeRequestString(request, contextPath); + } + + /** + * Return the servlet path for the given request, detecting an include request + * URL if called within a RequestDispatcher include. + * @param request current HTTP request + * @return the servlet path + */ + public String getOriginatingServletPath(HttpServletRequest request) { + String servletPath = (String) request.getAttribute(WebUtils.FORWARD_SERVLET_PATH_ATTRIBUTE); + if (servletPath == null) { + servletPath = request.getServletPath(); + } + return servletPath; + } + + /** + * Return the query string part of the given request's URL. If this is a forwarded request, + * correctly resolves to the query string of the original request. + * @param request current HTTP request + * @return the query string + */ + public String getOriginatingQueryString(HttpServletRequest request) { + if ((request.getAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE) != null) || + (request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE) != null)) { + return (String) request.getAttribute(WebUtils.FORWARD_QUERY_STRING_ATTRIBUTE); + } + else { + return request.getQueryString(); + } + } + + /** + * Decode the supplied URI string and strips any extraneous portion after a ';'. + */ + private String decodeAndCleanUriString(HttpServletRequest request, String uri) { + uri = removeSemicolonContent(uri); + uri = decodeRequestString(request, uri); + uri = getSanitizedPath(uri); + return uri; + } + + /** + * Decode the given source string with a URLDecoder. The encoding will be taken + * from the request, falling back to the default "ISO-8859-1". + *

The default implementation uses {@code URLDecoder.decode(input, enc)}. + * @param request current HTTP request + * @param source the String to decode + * @return the decoded String + * @see WebUtils#DEFAULT_CHARACTER_ENCODING + * @see javax.servlet.ServletRequest#getCharacterEncoding + * @see java.net.URLDecoder#decode(String, String) + * @see java.net.URLDecoder#decode(String) + */ + public String decodeRequestString(HttpServletRequest request, String source) { + if (this.urlDecode) { + return decodeInternal(request, source); + } + return source; + } + + @SuppressWarnings("deprecation") + private String decodeInternal(HttpServletRequest request, String source) { + String enc = determineEncoding(request); + try { + return UriUtils.decode(source, enc); + } + catch (UnsupportedCharsetException ex) { + if (logger.isWarnEnabled()) { + logger.warn("Could not decode request string [" + source + "] with encoding '" + enc + + "': falling back to platform default encoding; exception message: " + ex.getMessage()); + } + return URLDecoder.decode(source); + } + } + + /** + * Determine the encoding for the given request. + * Can be overridden in subclasses. + *

The default implementation checks the request encoding, + * falling back to the default encoding specified for this resolver. + * @param request current HTTP request + * @return the encoding for the request (never {@code null}) + * @see javax.servlet.ServletRequest#getCharacterEncoding() + * @see #setDefaultEncoding + */ + protected String determineEncoding(HttpServletRequest request) { + String enc = request.getCharacterEncoding(); + if (enc == null) { + enc = getDefaultEncoding(); + } + return enc; + } + + /** + * Remove ";" (semicolon) content from the given request URI if the + * {@linkplain #setRemoveSemicolonContent removeSemicolonContent} + * property is set to "true". Note that "jsessionid" is always removed. + * @param requestUri the request URI string to remove ";" content from + * @return the updated URI string + */ + public String removeSemicolonContent(String requestUri) { + return (this.removeSemicolonContent ? + removeSemicolonContentInternal(requestUri) : removeJsessionid(requestUri)); + } + + private String removeSemicolonContentInternal(String requestUri) { + int semicolonIndex = requestUri.indexOf(';'); + while (semicolonIndex != -1) { + int slashIndex = requestUri.indexOf('/', semicolonIndex); + String start = requestUri.substring(0, semicolonIndex); + requestUri = (slashIndex != -1) ? start + requestUri.substring(slashIndex) : start; + semicolonIndex = requestUri.indexOf(';', semicolonIndex); + } + return requestUri; + } + + private String removeJsessionid(String requestUri) { + String key = ";jsessionid="; + int index = requestUri.toLowerCase().indexOf(key); + if (index == -1) { + return requestUri; + } + String start = requestUri.substring(0, index); + for (int i = index + key.length(); i < requestUri.length(); i++) { + char c = requestUri.charAt(i); + if (c == ';' || c == '/') { + return start + requestUri.substring(i); + } + } + return start; + } + + /** + * Decode the given URI path variables via {@link #decodeRequestString} unless + * {@link #setUrlDecode} is set to {@code true} in which case it is assumed + * the URL path from which the variables were extracted is already decoded + * through a call to {@link #getLookupPathForRequest(HttpServletRequest)}. + * @param request current HTTP request + * @param vars the URI variables extracted from the URL path + * @return the same Map or a new Map instance + */ + public Map decodePathVariables(HttpServletRequest request, Map vars) { + if (this.urlDecode) { + return vars; + } + else { + Map decodedVars = new LinkedHashMap<>(vars.size()); + vars.forEach((key, value) -> decodedVars.put(key, decodeInternal(request, value))); + return decodedVars; + } + } + + /** + * Decode the given matrix variables via {@link #decodeRequestString} unless + * {@link #setUrlDecode} is set to {@code true} in which case it is assumed + * the URL path from which the variables were extracted is already decoded + * through a call to {@link #getLookupPathForRequest(HttpServletRequest)}. + * @param request current HTTP request + * @param vars the URI variables extracted from the URL path + * @return the same Map or a new Map instance + */ + public MultiValueMap decodeMatrixVariables( + HttpServletRequest request, MultiValueMap vars) { + + if (this.urlDecode) { + return vars; + } + else { + MultiValueMap decodedVars = new LinkedMultiValueMap<>(vars.size()); + vars.forEach((key, values) -> { + for (String value : values) { + decodedVars.add(key, decodeInternal(request, value)); + } + }); + return decodedVars; + } + } + + private boolean shouldRemoveTrailingServletPathSlash(HttpServletRequest request) { + if (request.getAttribute(WEBSPHERE_URI_ATTRIBUTE) == null) { + // Regular servlet container: behaves as expected in any case, + // so the trailing slash is the result of a "/" url-pattern mapping. + // Don't remove that slash. + return false; + } + Boolean flagToUse = websphereComplianceFlag; + if (flagToUse == null) { + ClassLoader classLoader = UrlPathHelper.class.getClassLoader(); + String className = "com.ibm.ws.webcontainer.WebContainer"; + String methodName = "getWebContainerProperties"; + String propName = "com.ibm.ws.webcontainer.removetrailingservletpathslash"; + boolean flag = false; + try { + Class cl = classLoader.loadClass(className); + Properties prop = (Properties) cl.getMethod(methodName).invoke(null); + flag = Boolean.parseBoolean(prop.getProperty(propName)); + } + catch (Throwable ex) { + if (logger.isDebugEnabled()) { + logger.debug("Could not introspect WebSphere web container properties: " + ex); + } + } + flagToUse = flag; + websphereComplianceFlag = flag; + } + // Don't bother if WebSphere is configured to be fully Servlet compliant. + // However, if it is not compliant, do remove the improper trailing slash! + return !flagToUse; + } + + + /** + * Shared, read-only instance with defaults. The following apply: + *

    + *
  • {@code alwaysUseFullPath=false} + *
  • {@code urlDecode=true} + *
  • {@code removeSemicolon=true} + *
  • {@code defaultEncoding=}{@link WebUtils#DEFAULT_CHARACTER_ENCODING} + *
+ */ + public static final UrlPathHelper defaultInstance = new UrlPathHelper(); + + static { + defaultInstance.setReadOnly(); + } + + + /** + * Shared, read-only instance for the full, encoded path. The following apply: + *
    + *
  • {@code alwaysUseFullPath=true} + *
  • {@code urlDecode=false} + *
  • {@code removeSemicolon=false} + *
  • {@code defaultEncoding=}{@link WebUtils#DEFAULT_CHARACTER_ENCODING} + *
+ */ + public static final UrlPathHelper rawPathInstance = new UrlPathHelper() { + + @Override + public String removeSemicolonContent(String requestUri) { + return requestUri; + } + }; + + static { + rawPathInstance.setAlwaysUseFullPath(true); + rawPathInstance.setUrlDecode(false); + rawPathInstance.setRemoveSemicolonContent(false); + rawPathInstance.setReadOnly(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/WebAppRootListener.java b/spring-web/src/main/java/org/springframework/web/util/WebAppRootListener.java new file mode 100644 index 0000000000000000000000000000000000000000..bd340eea48ba269f62cfc0b3ca0e1e364a7119e8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/WebAppRootListener.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; + +/** + * Listener that sets a system property to the web application root directory. + * The key of the system property can be defined with the "webAppRootKey" init + * parameter at the servlet context level (i.e. context-param in web.xml), + * the default key is "webapp.root". + * + *

Can be used for toolkits that support substitution with system properties + * (i.e. System.getProperty values), like log4j's "${key}" syntax within log + * file locations. + * + *

Note: This listener should be placed before ContextLoaderListener in {@code web.xml}, + * at least when used for log4j. Log4jConfigListener sets the system property + * implicitly, so there's no need for this listener in addition to it. + * + *

WARNING: Some containers, e.g. Tomcat, do NOT keep system properties separate + * per web app. You have to use unique "webAppRootKey" context-params per web app + * then, to avoid clashes. Other containers like Resin do isolate each web app's + * system properties: Here you can use the default key (i.e. no "webAppRootKey" + * context-param at all) without worrying. + * + *

WARNING: The WAR file containing the web application needs to be expanded + * to allow for setting the web app root system property. This is by default not + * the case when a WAR file gets deployed to WebLogic, for example. Do not use + * this listener in such an environment! + * + * @author Juergen Hoeller + * @since 18.04.2003 + * @see WebUtils#setWebAppRootSystemProperty + * @see System#getProperty + */ +public class WebAppRootListener implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent event) { + WebUtils.setWebAppRootSystemProperty(event.getServletContext()); + } + + @Override + public void contextDestroyed(ServletContextEvent event) { + WebUtils.removeWebAppRootSystemProperty(event.getServletContext()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..253f4935f0afd005c3f5986eab63b3d2b4a49a00 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -0,0 +1,837 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.File; +import java.io.FileNotFoundException; +import java.net.URI; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Map; +import java.util.StringTokenizer; +import java.util.TreeMap; + +import javax.servlet.ServletContext; +import javax.servlet.ServletRequest; +import javax.servlet.ServletRequestWrapper; +import javax.servlet.ServletResponse; +import javax.servlet.ServletResponseWrapper; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * Miscellaneous utilities for web applications. + * Used by various framework classes. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Sebastien Deleuze + */ +public abstract class WebUtils { + + /** + * Standard Servlet 2.3+ spec request attribute for include request URI. + *

If included via a {@code RequestDispatcher}, the current resource will see the + * originating request. Its own request URI is exposed as a request attribute. + */ + public static final String INCLUDE_REQUEST_URI_ATTRIBUTE = "javax.servlet.include.request_uri"; + + /** + * Standard Servlet 2.3+ spec request attribute for include context path. + *

If included via a {@code RequestDispatcher}, the current resource will see the + * originating context path. Its own context path is exposed as a request attribute. + */ + public static final String INCLUDE_CONTEXT_PATH_ATTRIBUTE = "javax.servlet.include.context_path"; + + /** + * Standard Servlet 2.3+ spec request attribute for include servlet path. + *

If included via a {@code RequestDispatcher}, the current resource will see the + * originating servlet path. Its own servlet path is exposed as a request attribute. + */ + public static final String INCLUDE_SERVLET_PATH_ATTRIBUTE = "javax.servlet.include.servlet_path"; + + /** + * Standard Servlet 2.3+ spec request attribute for include path info. + *

If included via a {@code RequestDispatcher}, the current resource will see the + * originating path info. Its own path info is exposed as a request attribute. + */ + public static final String INCLUDE_PATH_INFO_ATTRIBUTE = "javax.servlet.include.path_info"; + + /** + * Standard Servlet 2.3+ spec request attribute for include query string. + *

If included via a {@code RequestDispatcher}, the current resource will see the + * originating query string. Its own query string is exposed as a request attribute. + */ + public static final String INCLUDE_QUERY_STRING_ATTRIBUTE = "javax.servlet.include.query_string"; + + /** + * Standard Servlet 2.4+ spec request attribute for forward request URI. + *

If forwarded to via a RequestDispatcher, the current resource will see its + * own request URI. The originating request URI is exposed as a request attribute. + */ + public static final String FORWARD_REQUEST_URI_ATTRIBUTE = "javax.servlet.forward.request_uri"; + + /** + * Standard Servlet 2.4+ spec request attribute for forward context path. + *

If forwarded to via a RequestDispatcher, the current resource will see its + * own context path. The originating context path is exposed as a request attribute. + */ + public static final String FORWARD_CONTEXT_PATH_ATTRIBUTE = "javax.servlet.forward.context_path"; + + /** + * Standard Servlet 2.4+ spec request attribute for forward servlet path. + *

If forwarded to via a RequestDispatcher, the current resource will see its + * own servlet path. The originating servlet path is exposed as a request attribute. + */ + public static final String FORWARD_SERVLET_PATH_ATTRIBUTE = "javax.servlet.forward.servlet_path"; + + /** + * Standard Servlet 2.4+ spec request attribute for forward path info. + *

If forwarded to via a RequestDispatcher, the current resource will see its + * own path ingo. The originating path info is exposed as a request attribute. + */ + public static final String FORWARD_PATH_INFO_ATTRIBUTE = "javax.servlet.forward.path_info"; + + /** + * Standard Servlet 2.4+ spec request attribute for forward query string. + *

If forwarded to via a RequestDispatcher, the current resource will see its + * own query string. The originating query string is exposed as a request attribute. + */ + public static final String FORWARD_QUERY_STRING_ATTRIBUTE = "javax.servlet.forward.query_string"; + + /** + * Standard Servlet 2.3+ spec request attribute for error page status code. + *

To be exposed to JSPs that are marked as error pages, when forwarding + * to them directly rather than through the servlet container's error page + * resolution mechanism. + */ + public static final String ERROR_STATUS_CODE_ATTRIBUTE = "javax.servlet.error.status_code"; + + /** + * Standard Servlet 2.3+ spec request attribute for error page exception type. + *

To be exposed to JSPs that are marked as error pages, when forwarding + * to them directly rather than through the servlet container's error page + * resolution mechanism. + */ + public static final String ERROR_EXCEPTION_TYPE_ATTRIBUTE = "javax.servlet.error.exception_type"; + + /** + * Standard Servlet 2.3+ spec request attribute for error page message. + *

To be exposed to JSPs that are marked as error pages, when forwarding + * to them directly rather than through the servlet container's error page + * resolution mechanism. + */ + public static final String ERROR_MESSAGE_ATTRIBUTE = "javax.servlet.error.message"; + + /** + * Standard Servlet 2.3+ spec request attribute for error page exception. + *

To be exposed to JSPs that are marked as error pages, when forwarding + * to them directly rather than through the servlet container's error page + * resolution mechanism. + */ + public static final String ERROR_EXCEPTION_ATTRIBUTE = "javax.servlet.error.exception"; + + /** + * Standard Servlet 2.3+ spec request attribute for error page request URI. + *

To be exposed to JSPs that are marked as error pages, when forwarding + * to them directly rather than through the servlet container's error page + * resolution mechanism. + */ + public static final String ERROR_REQUEST_URI_ATTRIBUTE = "javax.servlet.error.request_uri"; + + /** + * Standard Servlet 2.3+ spec request attribute for error page servlet name. + *

To be exposed to JSPs that are marked as error pages, when forwarding + * to them directly rather than through the servlet container's error page + * resolution mechanism. + */ + public static final String ERROR_SERVLET_NAME_ATTRIBUTE = "javax.servlet.error.servlet_name"; + + /** + * Prefix of the charset clause in a content type String: ";charset=". + */ + public static final String CONTENT_TYPE_CHARSET_PREFIX = ";charset="; + + /** + * Default character encoding to use when {@code request.getCharacterEncoding} + * returns {@code null}, according to the Servlet spec. + * @see ServletRequest#getCharacterEncoding + */ + public static final String DEFAULT_CHARACTER_ENCODING = "ISO-8859-1"; + + /** + * Standard Servlet spec context attribute that specifies a temporary + * directory for the current web application, of type {@code java.io.File}. + */ + public static final String TEMP_DIR_CONTEXT_ATTRIBUTE = "javax.servlet.context.tempdir"; + + /** + * HTML escape parameter at the servlet context level + * (i.e. a context-param in {@code web.xml}): "defaultHtmlEscape". + */ + public static final String HTML_ESCAPE_CONTEXT_PARAM = "defaultHtmlEscape"; + + /** + * Use of response encoding for HTML escaping parameter at the servlet context level + * (i.e. a context-param in {@code web.xml}): "responseEncodedHtmlEscape". + * @since 4.1.2 + */ + public static final String RESPONSE_ENCODED_HTML_ESCAPE_CONTEXT_PARAM = "responseEncodedHtmlEscape"; + + /** + * Web app root key parameter at the servlet context level + * (i.e. a context-param in {@code web.xml}): "webAppRootKey". + */ + public static final String WEB_APP_ROOT_KEY_PARAM = "webAppRootKey"; + + /** Default web app root key: "webapp.root". */ + public static final String DEFAULT_WEB_APP_ROOT_KEY = "webapp.root"; + + /** Name suffixes in case of image buttons. */ + public static final String[] SUBMIT_IMAGE_SUFFIXES = {".x", ".y"}; + + /** Key for the mutex session attribute. */ + public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX"; + + + /** + * Set a system property to the web application root directory. + * The key of the system property can be defined with the "webAppRootKey" + * context-param in {@code web.xml}. Default is "webapp.root". + *

Can be used for tools that support substitution with {@code System.getProperty} + * values, like log4j's "${key}" syntax within log file locations. + * @param servletContext the servlet context of the web application + * @throws IllegalStateException if the system property is already set, + * or if the WAR file is not expanded + * @see #WEB_APP_ROOT_KEY_PARAM + * @see #DEFAULT_WEB_APP_ROOT_KEY + * @see WebAppRootListener + */ + public static void setWebAppRootSystemProperty(ServletContext servletContext) throws IllegalStateException { + Assert.notNull(servletContext, "ServletContext must not be null"); + String root = servletContext.getRealPath("/"); + if (root == null) { + throw new IllegalStateException( + "Cannot set web app root system property when WAR file is not expanded"); + } + String param = servletContext.getInitParameter(WEB_APP_ROOT_KEY_PARAM); + String key = (param != null ? param : DEFAULT_WEB_APP_ROOT_KEY); + String oldValue = System.getProperty(key); + if (oldValue != null && !StringUtils.pathEquals(oldValue, root)) { + throw new IllegalStateException("Web app root system property already set to different value: '" + + key + "' = [" + oldValue + "] instead of [" + root + "] - " + + "Choose unique values for the 'webAppRootKey' context-param in your web.xml files!"); + } + System.setProperty(key, root); + servletContext.log("Set web app root system property: '" + key + "' = [" + root + "]"); + } + + /** + * Remove the system property that points to the web app root directory. + * To be called on shutdown of the web application. + * @param servletContext the servlet context of the web application + * @see #setWebAppRootSystemProperty + */ + public static void removeWebAppRootSystemProperty(ServletContext servletContext) { + Assert.notNull(servletContext, "ServletContext must not be null"); + String param = servletContext.getInitParameter(WEB_APP_ROOT_KEY_PARAM); + String key = (param != null ? param : DEFAULT_WEB_APP_ROOT_KEY); + System.getProperties().remove(key); + } + + /** + * Return whether default HTML escaping is enabled for the web application, + * i.e. the value of the "defaultHtmlEscape" context-param in {@code web.xml} + * (if any). + *

This method differentiates between no param specified at all and + * an actual boolean value specified, allowing to have a context-specific + * default in case of no setting at the global level. + * @param servletContext the servlet context of the web application + * @return whether default HTML escaping is enabled for the given application + * ({@code null} = no explicit default) + */ + @Nullable + public static Boolean getDefaultHtmlEscape(@Nullable ServletContext servletContext) { + if (servletContext == null) { + return null; + } + String param = servletContext.getInitParameter(HTML_ESCAPE_CONTEXT_PARAM); + return (StringUtils.hasText(param) ? Boolean.valueOf(param) : null); + } + + /** + * Return whether response encoding should be used when HTML escaping characters, + * thus only escaping XML markup significant characters with UTF-* encodings. + * This option is enabled for the web application with a ServletContext param, + * i.e. the value of the "responseEncodedHtmlEscape" context-param in {@code web.xml} + * (if any). + *

This method differentiates between no param specified at all and + * an actual boolean value specified, allowing to have a context-specific + * default in case of no setting at the global level. + * @param servletContext the servlet context of the web application + * @return whether response encoding is to be used for HTML escaping + * ({@code null} = no explicit default) + * @since 4.1.2 + */ + @Nullable + public static Boolean getResponseEncodedHtmlEscape(@Nullable ServletContext servletContext) { + if (servletContext == null) { + return null; + } + String param = servletContext.getInitParameter(RESPONSE_ENCODED_HTML_ESCAPE_CONTEXT_PARAM); + return (StringUtils.hasText(param) ? Boolean.valueOf(param) : null); + } + + /** + * Return the temporary directory for the current web application, + * as provided by the servlet container. + * @param servletContext the servlet context of the web application + * @return the File representing the temporary directory + */ + public static File getTempDir(ServletContext servletContext) { + Assert.notNull(servletContext, "ServletContext must not be null"); + return (File) servletContext.getAttribute(TEMP_DIR_CONTEXT_ATTRIBUTE); + } + + /** + * Return the real path of the given path within the web application, + * as provided by the servlet container. + *

Prepends a slash if the path does not already start with a slash, + * and throws a FileNotFoundException if the path cannot be resolved to + * a resource (in contrast to ServletContext's {@code getRealPath}, + * which returns null). + * @param servletContext the servlet context of the web application + * @param path the path within the web application + * @return the corresponding real path + * @throws FileNotFoundException if the path cannot be resolved to a resource + * @see javax.servlet.ServletContext#getRealPath + */ + public static String getRealPath(ServletContext servletContext, String path) throws FileNotFoundException { + Assert.notNull(servletContext, "ServletContext must not be null"); + // Interpret location as relative to the web application root directory. + if (!path.startsWith("/")) { + path = "/" + path; + } + String realPath = servletContext.getRealPath(path); + if (realPath == null) { + throw new FileNotFoundException( + "ServletContext resource [" + path + "] cannot be resolved to absolute file path - " + + "web application archive not expanded?"); + } + return realPath; + } + + /** + * Determine the session id of the given request, if any. + * @param request current HTTP request + * @return the session id, or {@code null} if none + */ + @Nullable + public static String getSessionId(HttpServletRequest request) { + Assert.notNull(request, "Request must not be null"); + HttpSession session = request.getSession(false); + return (session != null ? session.getId() : null); + } + + /** + * Check the given request for a session attribute of the given name. + * Returns null if there is no session or if the session has no such attribute. + * Does not create a new session if none has existed before! + * @param request current HTTP request + * @param name the name of the session attribute + * @return the value of the session attribute, or {@code null} if not found + */ + @Nullable + public static Object getSessionAttribute(HttpServletRequest request, String name) { + Assert.notNull(request, "Request must not be null"); + HttpSession session = request.getSession(false); + return (session != null ? session.getAttribute(name) : null); + } + + /** + * Check the given request for a session attribute of the given name. + * Throws an exception if there is no session or if the session has no such + * attribute. Does not create a new session if none has existed before! + * @param request current HTTP request + * @param name the name of the session attribute + * @return the value of the session attribute, or {@code null} if not found + * @throws IllegalStateException if the session attribute could not be found + */ + public static Object getRequiredSessionAttribute(HttpServletRequest request, String name) + throws IllegalStateException { + + Object attr = getSessionAttribute(request, name); + if (attr == null) { + throw new IllegalStateException("No session attribute '" + name + "' found"); + } + return attr; + } + + /** + * Set the session attribute with the given name to the given value. + * Removes the session attribute if value is null, if a session existed at all. + * Does not create a new session if not necessary! + * @param request current HTTP request + * @param name the name of the session attribute + * @param value the value of the session attribute + */ + public static void setSessionAttribute(HttpServletRequest request, String name, @Nullable Object value) { + Assert.notNull(request, "Request must not be null"); + if (value != null) { + request.getSession().setAttribute(name, value); + } + else { + HttpSession session = request.getSession(false); + if (session != null) { + session.removeAttribute(name); + } + } + } + + /** + * Return the best available mutex for the given session: + * that is, an object to synchronize on for the given session. + *

Returns the session mutex attribute if available; usually, + * this means that the HttpSessionMutexListener needs to be defined + * in {@code web.xml}. Falls back to the HttpSession itself + * if no mutex attribute found. + *

The session mutex is guaranteed to be the same object during + * the entire lifetime of the session, available under the key defined + * by the {@code SESSION_MUTEX_ATTRIBUTE} constant. It serves as a + * safe reference to synchronize on for locking on the current session. + *

In many cases, the HttpSession reference itself is a safe mutex + * as well, since it will always be the same object reference for the + * same active logical session. However, this is not guaranteed across + * different servlet containers; the only 100% safe way is a session mutex. + * @param session the HttpSession to find a mutex for + * @return the mutex object (never {@code null}) + * @see #SESSION_MUTEX_ATTRIBUTE + * @see HttpSessionMutexListener + */ + public static Object getSessionMutex(HttpSession session) { + Assert.notNull(session, "Session must not be null"); + Object mutex = session.getAttribute(SESSION_MUTEX_ATTRIBUTE); + if (mutex == null) { + mutex = session; + } + return mutex; + } + + + /** + * Return an appropriate request object of the specified type, if available, + * unwrapping the given request as far as necessary. + * @param request the servlet request to introspect + * @param requiredType the desired type of request object + * @return the matching request object, or {@code null} if none + * of that type is available + */ + @SuppressWarnings("unchecked") + @Nullable + public static T getNativeRequest(ServletRequest request, @Nullable Class requiredType) { + if (requiredType != null) { + if (requiredType.isInstance(request)) { + return (T) request; + } + else if (request instanceof ServletRequestWrapper) { + return getNativeRequest(((ServletRequestWrapper) request).getRequest(), requiredType); + } + } + return null; + } + + /** + * Return an appropriate response object of the specified type, if available, + * unwrapping the given response as far as necessary. + * @param response the servlet response to introspect + * @param requiredType the desired type of response object + * @return the matching response object, or {@code null} if none + * of that type is available + */ + @SuppressWarnings("unchecked") + @Nullable + public static T getNativeResponse(ServletResponse response, @Nullable Class requiredType) { + if (requiredType != null) { + if (requiredType.isInstance(response)) { + return (T) response; + } + else if (response instanceof ServletResponseWrapper) { + return getNativeResponse(((ServletResponseWrapper) response).getResponse(), requiredType); + } + } + return null; + } + + /** + * Determine whether the given request is an include request, + * that is, not a top-level HTTP request coming in from the outside. + *

Checks the presence of the "javax.servlet.include.request_uri" + * request attribute. Could check any request attribute that is only + * present in an include request. + * @param request current servlet request + * @return whether the given request is an include request + */ + public static boolean isIncludeRequest(ServletRequest request) { + return (request.getAttribute(INCLUDE_REQUEST_URI_ATTRIBUTE) != null); + } + + /** + * Expose the Servlet spec's error attributes as {@link javax.servlet.http.HttpServletRequest} + * attributes under the keys defined in the Servlet 2.3 specification, for error pages that + * are rendered directly rather than through the Servlet container's error page resolution: + * {@code javax.servlet.error.status_code}, + * {@code javax.servlet.error.exception_type}, + * {@code javax.servlet.error.message}, + * {@code javax.servlet.error.exception}, + * {@code javax.servlet.error.request_uri}, + * {@code javax.servlet.error.servlet_name}. + *

Does not override values if already present, to respect attribute values + * that have been exposed explicitly before. + *

Exposes status code 200 by default. Set the "javax.servlet.error.status_code" + * attribute explicitly (before or after) in order to expose a different status code. + * @param request current servlet request + * @param ex the exception encountered + * @param servletName the name of the offending servlet + */ + public static void exposeErrorRequestAttributes(HttpServletRequest request, Throwable ex, + @Nullable String servletName) { + + exposeRequestAttributeIfNotPresent(request, ERROR_STATUS_CODE_ATTRIBUTE, HttpServletResponse.SC_OK); + exposeRequestAttributeIfNotPresent(request, ERROR_EXCEPTION_TYPE_ATTRIBUTE, ex.getClass()); + exposeRequestAttributeIfNotPresent(request, ERROR_MESSAGE_ATTRIBUTE, ex.getMessage()); + exposeRequestAttributeIfNotPresent(request, ERROR_EXCEPTION_ATTRIBUTE, ex); + exposeRequestAttributeIfNotPresent(request, ERROR_REQUEST_URI_ATTRIBUTE, request.getRequestURI()); + if (servletName != null) { + exposeRequestAttributeIfNotPresent(request, ERROR_SERVLET_NAME_ATTRIBUTE, servletName); + } + } + + /** + * Expose the specified request attribute if not already present. + * @param request current servlet request + * @param name the name of the attribute + * @param value the suggested value of the attribute + */ + private static void exposeRequestAttributeIfNotPresent(ServletRequest request, String name, Object value) { + if (request.getAttribute(name) == null) { + request.setAttribute(name, value); + } + } + + /** + * Clear the Servlet spec's error attributes as {@link javax.servlet.http.HttpServletRequest} + * attributes under the keys defined in the Servlet 2.3 specification: + * {@code javax.servlet.error.status_code}, + * {@code javax.servlet.error.exception_type}, + * {@code javax.servlet.error.message}, + * {@code javax.servlet.error.exception}, + * {@code javax.servlet.error.request_uri}, + * {@code javax.servlet.error.servlet_name}. + * @param request current servlet request + */ + public static void clearErrorRequestAttributes(HttpServletRequest request) { + request.removeAttribute(ERROR_STATUS_CODE_ATTRIBUTE); + request.removeAttribute(ERROR_EXCEPTION_TYPE_ATTRIBUTE); + request.removeAttribute(ERROR_MESSAGE_ATTRIBUTE); + request.removeAttribute(ERROR_EXCEPTION_ATTRIBUTE); + request.removeAttribute(ERROR_REQUEST_URI_ATTRIBUTE); + request.removeAttribute(ERROR_SERVLET_NAME_ATTRIBUTE); + } + + /** + * Retrieve the first cookie with the given name. Note that multiple + * cookies can have the same name but different paths or domains. + * @param request current servlet request + * @param name cookie name + * @return the first cookie with the given name, or {@code null} if none is found + */ + @Nullable + public static Cookie getCookie(HttpServletRequest request, String name) { + Assert.notNull(request, "Request must not be null"); + Cookie[] cookies = request.getCookies(); + if (cookies != null) { + for (Cookie cookie : cookies) { + if (name.equals(cookie.getName())) { + return cookie; + } + } + } + return null; + } + + /** + * Check if a specific input type="submit" parameter was sent in the request, + * either via a button (directly with name) or via an image (name + ".x" or + * name + ".y"). + * @param request current HTTP request + * @param name name of the parameter + * @return if the parameter was sent + * @see #SUBMIT_IMAGE_SUFFIXES + */ + public static boolean hasSubmitParameter(ServletRequest request, String name) { + Assert.notNull(request, "Request must not be null"); + if (request.getParameter(name) != null) { + return true; + } + for (String suffix : SUBMIT_IMAGE_SUFFIXES) { + if (request.getParameter(name + suffix) != null) { + return true; + } + } + return false; + } + + /** + * Obtain a named parameter from the given request parameters. + *

See {@link #findParameterValue(java.util.Map, String)} + * for a description of the lookup algorithm. + * @param request current HTTP request + * @param name the logical name of the request parameter + * @return the value of the parameter, or {@code null} + * if the parameter does not exist in given request + */ + @Nullable + public static String findParameterValue(ServletRequest request, String name) { + return findParameterValue(request.getParameterMap(), name); + } + + /** + * Obtain a named parameter from the given request parameters. + *

This method will try to obtain a parameter value using the + * following algorithm: + *

    + *
  1. Try to get the parameter value using just the given logical name. + * This handles parameters of the form logicalName = value. For normal + * parameters, e.g. submitted using a hidden HTML form field, this will return + * the requested value.
  2. + *
  3. Try to obtain the parameter value from the parameter name, where the + * parameter name in the request is of the form logicalName_value = xyz + * with "_" being the configured delimiter. This deals with parameter values + * submitted using an HTML form submit button.
  4. + *
  5. If the value obtained in the previous step has a ".x" or ".y" suffix, + * remove that. This handles cases where the value was submitted using an + * HTML form image button. In this case the parameter in the request would + * actually be of the form logicalName_value.x = 123.
  6. + *
+ * @param parameters the available parameter map + * @param name the logical name of the request parameter + * @return the value of the parameter, or {@code null} + * if the parameter does not exist in given request + */ + @Nullable + public static String findParameterValue(Map parameters, String name) { + // First try to get it as a normal name=value parameter + Object value = parameters.get(name); + if (value instanceof String[]) { + String[] values = (String[]) value; + return (values.length > 0 ? values[0] : null); + } + else if (value != null) { + return value.toString(); + } + // If no value yet, try to get it as a name_value=xyz parameter + String prefix = name + "_"; + for (String paramName : parameters.keySet()) { + if (paramName.startsWith(prefix)) { + // Support images buttons, which would submit parameters as name_value.x=123 + for (String suffix : SUBMIT_IMAGE_SUFFIXES) { + if (paramName.endsWith(suffix)) { + return paramName.substring(prefix.length(), paramName.length() - suffix.length()); + } + } + return paramName.substring(prefix.length()); + } + } + // We couldn't find the parameter value... + return null; + } + + /** + * Return a map containing all parameters with the given prefix. + * Maps single values to String and multiple values to String array. + *

For example, with a prefix of "spring_", "spring_param1" and + * "spring_param2" result in a Map with "param1" and "param2" as keys. + * @param request the HTTP request in which to look for parameters + * @param prefix the beginning of parameter names + * (if this is null or the empty string, all parameters will match) + * @return map containing request parameters without the prefix, + * containing either a String or a String array as values + * @see javax.servlet.ServletRequest#getParameterNames + * @see javax.servlet.ServletRequest#getParameterValues + * @see javax.servlet.ServletRequest#getParameterMap + */ + public static Map getParametersStartingWith(ServletRequest request, @Nullable String prefix) { + Assert.notNull(request, "Request must not be null"); + Enumeration paramNames = request.getParameterNames(); + Map params = new TreeMap<>(); + if (prefix == null) { + prefix = ""; + } + while (paramNames != null && paramNames.hasMoreElements()) { + String paramName = paramNames.nextElement(); + if ("".equals(prefix) || paramName.startsWith(prefix)) { + String unprefixed = paramName.substring(prefix.length()); + String[] values = request.getParameterValues(paramName); + if (values == null || values.length == 0) { + // Do nothing, no values found at all. + } + else if (values.length > 1) { + params.put(unprefixed, values); + } + else { + params.put(unprefixed, values[0]); + } + } + } + return params; + } + + /** + * Parse the given string with matrix variables. An example string would look + * like this {@code "q1=a;q1=b;q2=a,b,c"}. The resulting map would contain + * keys {@code "q1"} and {@code "q2"} with values {@code ["a","b"]} and + * {@code ["a","b","c"]} respectively. + * @param matrixVariables the unparsed matrix variables string + * @return a map with matrix variable names and values (never {@code null}) + * @since 3.2 + */ + public static MultiValueMap parseMatrixVariables(String matrixVariables) { + MultiValueMap result = new LinkedMultiValueMap<>(); + if (!StringUtils.hasText(matrixVariables)) { + return result; + } + StringTokenizer pairs = new StringTokenizer(matrixVariables, ";"); + while (pairs.hasMoreTokens()) { + String pair = pairs.nextToken(); + int index = pair.indexOf('='); + if (index != -1) { + String name = pair.substring(0, index); + if (name.equalsIgnoreCase("jsessionid")) { + continue; + } + String rawValue = pair.substring(index + 1); + for (String value : StringUtils.commaDelimitedListToStringArray(rawValue)) { + result.add(name, value); + } + } + else { + result.add(pair, ""); + } + } + return result; + } + + /** + * Check the given request origin against a list of allowed origins. + * A list containing "*" means that all origins are allowed. + * An empty list means only same origin is allowed. + * + *

Note: as of 5.1 this method ignores + * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the + * client-originated address. Consider using the {@code ForwardedHeaderFilter} + * to extract and use, or to discard such headers. + * + * @return {@code true} if the request origin is valid, {@code false} otherwise + * @since 4.1.5 + * @see RFC 6454: The Web Origin Concept + */ + public static boolean isValidOrigin(HttpRequest request, Collection allowedOrigins) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(allowedOrigins, "Allowed origins must not be null"); + + String origin = request.getHeaders().getOrigin(); + if (origin == null || allowedOrigins.contains("*")) { + return true; + } + else if (CollectionUtils.isEmpty(allowedOrigins)) { + return isSameOrigin(request); + } + else { + return allowedOrigins.contains(origin); + } + } + + /** + * Check if the request is a same-origin one, based on {@code Origin}, {@code Host}, + * {@code Forwarded}, {@code X-Forwarded-Proto}, {@code X-Forwarded-Host} and + * {@code X-Forwarded-Port} headers. + * + *

Note: as of 5.1 this method ignores + * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the + * client-originated address. Consider using the {@code ForwardedHeaderFilter} + * to extract and use, or to discard such headers. + + * @return {@code true} if the request is a same-origin one, {@code false} in case + * of cross-origin request + * @since 4.2 + */ + public static boolean isSameOrigin(HttpRequest request) { + HttpHeaders headers = request.getHeaders(); + String origin = headers.getOrigin(); + if (origin == null) { + return true; + } + + String scheme; + String host; + int port; + if (request instanceof ServletServerHttpRequest) { + // Build more efficiently if we can: we only need scheme, host, port for origin comparison + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + scheme = servletRequest.getScheme(); + host = servletRequest.getServerName(); + port = servletRequest.getServerPort(); + } + else { + URI uri = request.getURI(); + scheme = uri.getScheme(); + host = uri.getHost(); + port = uri.getPort(); + } + + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); + return (ObjectUtils.nullSafeEquals(scheme, originUrl.getScheme()) && + ObjectUtils.nullSafeEquals(host, originUrl.getHost()) && + getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort())); + } + + private static int getPort(@Nullable String scheme, int port) { + if (port == -1) { + if ("http".equals(scheme) || "ws".equals(scheme)) { + port = 80; + } + else if ("https".equals(scheme) || "wss".equals(scheme)) { + port = 443; + } + } + return port; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/package-info.java b/spring-web/src/main/java/org/springframework/web/util/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..1f9c37a7a52dd05e87dba4829d6d8c8af34b1536 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/package-info.java @@ -0,0 +1,10 @@ +/** + * Miscellaneous web utility classes, such as HTML escaping, + * Log4j initialization, and cookie handling. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.util; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/CaptureTheRestPathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/CaptureTheRestPathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..b1a9ad8e488e12c717e78625a7a676c5be67aac5 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/CaptureTheRestPathElement.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.util.List; + +import org.springframework.http.server.PathContainer.Element; +import org.springframework.http.server.PathContainer.PathSegment; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.util.pattern.PathPattern.MatchingContext; + +/** + * A path element representing capturing the rest of a path. In the pattern + * '/foo/{*foobar}' the /{*foobar} is represented as a {@link CaptureTheRestPathElement}. + * + * @author Andy Clement + * @since 5.0 + */ +class CaptureTheRestPathElement extends PathElement { + + private final String variableName; + + + /** + * Create a new {@link CaptureTheRestPathElement} instance. + * @param pos position of the path element within the path pattern text + * @param captureDescriptor a character array containing contents like '{' '*' 'a' 'b' '}' + * @param separator the separator used in the path pattern + */ + CaptureTheRestPathElement(int pos, char[] captureDescriptor, char separator) { + super(pos, separator); + this.variableName = new String(captureDescriptor, 2, captureDescriptor.length - 3); + } + + + @Override + public boolean matches(int pathIndex, MatchingContext matchingContext) { + // No need to handle 'match start' checking as this captures everything + // anyway and cannot be followed by anything else + // assert next == null + + // If there is more data, it must start with the separator + if (pathIndex < matchingContext.pathLength && !matchingContext.isSeparator(pathIndex)) { + return false; + } + if (matchingContext.determineRemainingPath) { + matchingContext.remainingPathIndex = matchingContext.pathLength; + } + if (matchingContext.extractingVariables) { + // Collect the parameters from all the remaining segments + MultiValueMap parametersCollector = null; + for (int i = pathIndex; i < matchingContext.pathLength; i++) { + Element element = matchingContext.pathElements.get(i); + if (element instanceof PathSegment) { + MultiValueMap parameters = ((PathSegment) element).parameters(); + if (!parameters.isEmpty()) { + if (parametersCollector == null) { + parametersCollector = new LinkedMultiValueMap<>(); + } + parametersCollector.addAll(parameters); + } + } + } + matchingContext.set(this.variableName, pathToString(pathIndex, matchingContext.pathElements), + parametersCollector == null?NO_PARAMETERS:parametersCollector); + } + return true; + } + + private String pathToString(int fromSegment, List pathElements) { + StringBuilder buf = new StringBuilder(); + for (int i = fromSegment, max = pathElements.size(); i < max; i++) { + Element element = pathElements.get(i); + if (element instanceof PathSegment) { + buf.append(((PathSegment)element).valueToMatch()); + } + else { + buf.append(element.value()); + } + } + return buf.toString(); + } + + @Override + public int getNormalizedLength() { + return 1; + } + + @Override + public int getWildcardCount() { + return 0; + } + + @Override + public int getCaptureCount() { + return 1; + } + + + public String toString() { + return "CaptureTheRest(/{*" + this.variableName + "})"; + } + + @Override + public char[] getChars() { + return ("/{*"+this.variableName+"}").toCharArray(); + } +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/CaptureVariablePathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/CaptureVariablePathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..88ad319fb726110d1c468aabe99a2172f430f127 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/CaptureVariablePathElement.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.springframework.http.server.PathContainer.PathSegment; +import org.springframework.lang.Nullable; + +/** + * A path element representing capturing a piece of the path as a variable. In the pattern + * '/foo/{bar}/goo' the {bar} is represented as a {@link CaptureVariablePathElement}. There + * must be at least one character to bind to the variable. + * + * @author Andy Clement + * @since 5.0 + */ +class CaptureVariablePathElement extends PathElement { + + private final String variableName; + + @Nullable + private Pattern constraintPattern; + + + /** + * Create a new {@link CaptureVariablePathElement} instance. + * @param pos the position in the pattern of this capture element + * @param captureDescriptor is of the form {AAAAA[:pattern]} + */ + CaptureVariablePathElement(int pos, char[] captureDescriptor, boolean caseSensitive, char separator) { + super(pos, separator); + int colon = -1; + for (int i = 0; i < captureDescriptor.length; i++) { + if (captureDescriptor[i] == ':') { + colon = i; + break; + } + } + if (colon == -1) { + // no constraint + this.variableName = new String(captureDescriptor, 1, captureDescriptor.length - 2); + } + else { + this.variableName = new String(captureDescriptor, 1, colon - 1); + if (caseSensitive) { + this.constraintPattern = Pattern.compile( + new String(captureDescriptor, colon + 1, captureDescriptor.length - colon - 2)); + } + else { + this.constraintPattern = Pattern.compile( + new String(captureDescriptor, colon + 1, captureDescriptor.length - colon - 2), + Pattern.CASE_INSENSITIVE); + } + } + } + + + @Override + public boolean matches(int pathIndex, PathPattern.MatchingContext matchingContext) { + if (pathIndex >= matchingContext.pathLength) { + // no more path left to match this element + return false; + } + String candidateCapture = matchingContext.pathElementValue(pathIndex); + if (candidateCapture.length() == 0) { + return false; + } + + if (this.constraintPattern != null) { + // TODO possible optimization - only regex match if rest of pattern matches? + // Benefit likely to vary pattern to pattern + Matcher matcher = this.constraintPattern.matcher(candidateCapture); + if (matcher.groupCount() != 0) { + throw new IllegalArgumentException( + "No capture groups allowed in the constraint regex: " + this.constraintPattern.pattern()); + } + if (!matcher.matches()) { + return false; + } + } + + boolean match = false; + pathIndex++; + if (isNoMorePattern()) { + if (matchingContext.determineRemainingPath) { + matchingContext.remainingPathIndex = pathIndex; + match = true; + } + else { + // Needs to be at least one character #SPR15264 + match = (pathIndex == matchingContext.pathLength); + if (!match && matchingContext.isMatchOptionalTrailingSeparator()) { + match = //(nextPos > candidateIndex) && + (pathIndex + 1) == matchingContext.pathLength && + matchingContext.isSeparator(pathIndex); + } + } + } + else { + if (this.next != null) { + match = this.next.matches(pathIndex, matchingContext); + } + } + + if (match && matchingContext.extractingVariables) { + matchingContext.set(this.variableName, candidateCapture, + ((PathSegment)matchingContext.pathElements.get(pathIndex-1)).parameters()); + } + return match; + } + + public String getVariableName() { + return this.variableName; + } + + @Override + public int getNormalizedLength() { + return 1; + } + + @Override + public int getWildcardCount() { + return 0; + } + + @Override + public int getCaptureCount() { + return 1; + } + + @Override + public int getScore() { + return CAPTURE_VARIABLE_WEIGHT; + } + + + public String toString() { + return "CaptureVariable({" + this.variableName + + (this.constraintPattern != null ? ":" + this.constraintPattern.pattern() : "") + "})"; + } + + public char[] getChars() { + StringBuilder b = new StringBuilder(); + b.append("{"); + b.append(this.variableName); + if (this.constraintPattern != null) { + b.append(":").append(this.constraintPattern.pattern()); + } + b.append("}"); + return b.toString().toCharArray(); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/InternalPathPatternParser.java b/spring-web/src/main/java/org/springframework/web/util/pattern/InternalPathPatternParser.java new file mode 100644 index 0000000000000000000000000000000000000000..607d1fc42fd815729a7a7a2097cf18677f381c37 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/InternalPathPatternParser.java @@ -0,0 +1,419 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.PatternSyntaxException; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.util.pattern.PatternParseException.PatternMessage; + +/** + * Parser for URI template patterns. It breaks the path pattern into a number of + * {@link PathElement PathElements} in a linked list. Instances are reusable but are not thread-safe. + * + * @author Andy Clement + * @since 5.0 + */ +class InternalPathPatternParser { + + private final PathPatternParser parser; + + // The input data for parsing + private char[] pathPatternData = new char[0]; + + // The length of the input data + private int pathPatternLength; + + // Current parsing position + int pos; + + // How many ? characters in a particular path element + private int singleCharWildcardCount; + + // Is the path pattern using * characters in a particular path element + private boolean wildcard = false; + + // Is the construct {*...} being used in a particular path element + private boolean isCaptureTheRestVariable = false; + + // Has the parser entered a {...} variable capture block in a particular + // path element + private boolean insideVariableCapture = false; + + // How many variable captures are occurring in a particular path element + private int variableCaptureCount = 0; + + // Start of the most recent path element in a particular path element + private int pathElementStart; + + // Start of the most recent variable capture in a particular path element + private int variableCaptureStart; + + // Variables captures in this path pattern + @Nullable + private List capturedVariableNames; + + // The head of the path element chain currently being built + @Nullable + private PathElement headPE; + + // The most recently constructed path element in the chain + @Nullable + private PathElement currentPE; + + + /** + * Package private constructor for use in {@link PathPatternParser#parse}. + * @param parentParser reference back to the stateless, public parser + */ + InternalPathPatternParser(PathPatternParser parentParser) { + this.parser = parentParser; + } + + + /** + * Package private delegate for {@link PathPatternParser#parse(String)}. + */ + public PathPattern parse(String pathPattern) throws PatternParseException { + Assert.notNull(pathPattern, "Path pattern must not be null"); + + this.pathPatternData = pathPattern.toCharArray(); + this.pathPatternLength = this.pathPatternData.length; + this.headPE = null; + this.currentPE = null; + this.capturedVariableNames = null; + this.pathElementStart = -1; + this.pos = 0; + resetPathElementState(); + + while (this.pos < this.pathPatternLength) { + char ch = this.pathPatternData[this.pos]; + if (ch == this.parser.getSeparator()) { + if (this.pathElementStart != -1) { + pushPathElement(createPathElement()); + } + if (peekDoubleWildcard()) { + pushPathElement(new WildcardTheRestPathElement(this.pos, this.parser.getSeparator())); + this.pos += 2; + } + else { + pushPathElement(new SeparatorPathElement(this.pos, this.parser.getSeparator())); + } + } + else { + if (this.pathElementStart == -1) { + this.pathElementStart = this.pos; + } + if (ch == '?') { + this.singleCharWildcardCount++; + } + else if (ch == '{') { + if (this.insideVariableCapture) { + throw new PatternParseException(this.pos, this.pathPatternData, + PatternMessage.ILLEGAL_NESTED_CAPTURE); + } + // If we enforced that adjacent captures weren't allowed, + // this would do it (this would be an error: /foo/{bar}{boo}/) + // } else if (pos > 0 && pathPatternData[pos - 1] == '}') { + // throw new PatternParseException(pos, pathPatternData, + // PatternMessage.CANNOT_HAVE_ADJACENT_CAPTURES); + this.insideVariableCapture = true; + this.variableCaptureStart = this.pos; + } + else if (ch == '}') { + if (!this.insideVariableCapture) { + throw new PatternParseException(this.pos, this.pathPatternData, + PatternMessage.MISSING_OPEN_CAPTURE); + } + this.insideVariableCapture = false; + if (this.isCaptureTheRestVariable && (this.pos + 1) < this.pathPatternLength) { + throw new PatternParseException(this.pos + 1, this.pathPatternData, + PatternMessage.NO_MORE_DATA_EXPECTED_AFTER_CAPTURE_THE_REST); + } + this.variableCaptureCount++; + } + else if (ch == ':') { + if (this.insideVariableCapture && !this.isCaptureTheRestVariable) { + skipCaptureRegex(); + this.insideVariableCapture = false; + this.variableCaptureCount++; + } + } + else if (ch == '*') { + if (this.insideVariableCapture && this.variableCaptureStart == this.pos - 1) { + this.isCaptureTheRestVariable = true; + } + this.wildcard = true; + } + // Check that the characters used for captured variable names are like java identifiers + if (this.insideVariableCapture) { + if ((this.variableCaptureStart + 1 + (this.isCaptureTheRestVariable ? 1 : 0)) == this.pos && + !Character.isJavaIdentifierStart(ch)) { + throw new PatternParseException(this.pos, this.pathPatternData, + PatternMessage.ILLEGAL_CHARACTER_AT_START_OF_CAPTURE_DESCRIPTOR, + Character.toString(ch)); + + } + else if ((this.pos > (this.variableCaptureStart + 1 + (this.isCaptureTheRestVariable ? 1 : 0)) && + !Character.isJavaIdentifierPart(ch) && ch != '-')) { + throw new PatternParseException(this.pos, this.pathPatternData, + PatternMessage.ILLEGAL_CHARACTER_IN_CAPTURE_DESCRIPTOR, + Character.toString(ch)); + } + } + } + this.pos++; + } + if (this.pathElementStart != -1) { + pushPathElement(createPathElement()); + } + return new PathPattern(pathPattern, this.parser, this.headPE); + } + + /** + * Just hit a ':' and want to jump over the regex specification for this + * variable. pos will be pointing at the ':', we want to skip until the }. + *

+ * Nested {...} pairs don't have to be escaped: /abc/{var:x{1,2}}/def + *

An escaped } will not be treated as the end of the regex: /abc/{var:x\\{y:}/def + *

A separator that should not indicate the end of the regex can be escaped: + */ + private void skipCaptureRegex() { + this.pos++; + int regexStart = this.pos; + int curlyBracketDepth = 0; // how deep in nested {...} pairs + boolean previousBackslash = false; + + while (this.pos < this.pathPatternLength) { + char ch = this.pathPatternData[this.pos]; + if (ch == '\\' && !previousBackslash) { + this.pos++; + previousBackslash = true; + continue; + } + if (ch == '{' && !previousBackslash) { + curlyBracketDepth++; + } + else if (ch == '}' && !previousBackslash) { + if (curlyBracketDepth == 0) { + if (regexStart == this.pos) { + throw new PatternParseException(regexStart, this.pathPatternData, + PatternMessage.MISSING_REGEX_CONSTRAINT); + } + return; + } + curlyBracketDepth--; + } + if (ch == this.parser.getSeparator() && !previousBackslash) { + throw new PatternParseException(this.pos, this.pathPatternData, + PatternMessage.MISSING_CLOSE_CAPTURE); + } + this.pos++; + previousBackslash = false; + } + + throw new PatternParseException(this.pos - 1, this.pathPatternData, + PatternMessage.MISSING_CLOSE_CAPTURE); + } + + /** + * After processing a separator, a quick peek whether it is followed by + * (and only before the end of the pattern or the next separator). + */ + private boolean peekDoubleWildcard() { + if ((this.pos + 2) >= this.pathPatternLength) { + return false; + } + if (this.pathPatternData[this.pos + 1] != '*' || this.pathPatternData[this.pos + 2] != '*') { + return false; + } + return (this.pos + 3 == this.pathPatternLength); + } + + /** + * Push a path element to the chain being build. + * @param newPathElement the new path element to add + */ + private void pushPathElement(PathElement newPathElement) { + if (newPathElement instanceof CaptureTheRestPathElement) { + // There must be a separator ahead of this thing + // currentPE SHOULD be a SeparatorPathElement + if (this.currentPE == null) { + this.headPE = newPathElement; + this.currentPE = newPathElement; + } + else if (this.currentPE instanceof SeparatorPathElement) { + PathElement peBeforeSeparator = this.currentPE.prev; + if (peBeforeSeparator == null) { + // /{*foobar} is at the start + this.headPE = newPathElement; + newPathElement.prev = null; + } + else { + peBeforeSeparator.next = newPathElement; + newPathElement.prev = peBeforeSeparator; + } + this.currentPE = newPathElement; + } + else { + throw new IllegalStateException("Expected SeparatorPathElement but was " + this.currentPE); + } + } + else { + if (this.headPE == null) { + this.headPE = newPathElement; + this.currentPE = newPathElement; + } + else if (this.currentPE != null) { + this.currentPE.next = newPathElement; + newPathElement.prev = this.currentPE; + this.currentPE = newPathElement; + } + } + + resetPathElementState(); + } + + private char[] getPathElementText() { + char[] pathElementText = new char[this.pos - this.pathElementStart]; + System.arraycopy(this.pathPatternData, this.pathElementStart, pathElementText, 0, + this.pos - this.pathElementStart); + return pathElementText; + } + + /** + * Used the knowledge built up whilst processing since the last path element to determine what kind of path + * element to create. + * @return the new path element + */ + private PathElement createPathElement() { + if (this.insideVariableCapture) { + throw new PatternParseException(this.pos, this.pathPatternData, PatternMessage.MISSING_CLOSE_CAPTURE); + } + + PathElement newPE = null; + + if (this.variableCaptureCount > 0) { + if (this.variableCaptureCount == 1 && this.pathElementStart == this.variableCaptureStart && + this.pathPatternData[this.pos - 1] == '}') { + if (this.isCaptureTheRestVariable) { + // It is {*....} + newPE = new CaptureTheRestPathElement( + this.pathElementStart, getPathElementText(), this.parser.getSeparator()); + } + else { + // It is a full capture of this element (possibly with constraint), for example: /foo/{abc}/ + try { + newPE = new CaptureVariablePathElement(this.pathElementStart, getPathElementText(), + this.parser.isCaseSensitive(), this.parser.getSeparator()); + } + catch (PatternSyntaxException pse) { + throw new PatternParseException(pse, + findRegexStart(this.pathPatternData, this.pathElementStart) + pse.getIndex(), + this.pathPatternData, PatternMessage.REGEX_PATTERN_SYNTAX_EXCEPTION); + } + recordCapturedVariable(this.pathElementStart, + ((CaptureVariablePathElement) newPE).getVariableName()); + } + } + else { + if (this.isCaptureTheRestVariable) { + throw new PatternParseException(this.pathElementStart, this.pathPatternData, + PatternMessage.CAPTURE_ALL_IS_STANDALONE_CONSTRUCT); + } + RegexPathElement newRegexSection = new RegexPathElement(this.pathElementStart, + getPathElementText(), this.parser.isCaseSensitive(), + this.pathPatternData, this.parser.getSeparator()); + for (String variableName : newRegexSection.getVariableNames()) { + recordCapturedVariable(this.pathElementStart, variableName); + } + newPE = newRegexSection; + } + } + else { + if (this.wildcard) { + if (this.pos - 1 == this.pathElementStart) { + newPE = new WildcardPathElement(this.pathElementStart, this.parser.getSeparator()); + } + else { + newPE = new RegexPathElement(this.pathElementStart, getPathElementText(), + this.parser.isCaseSensitive(), this.pathPatternData, this.parser.getSeparator()); + } + } + else if (this.singleCharWildcardCount != 0) { + newPE = new SingleCharWildcardedPathElement(this.pathElementStart, getPathElementText(), + this.singleCharWildcardCount, this.parser.isCaseSensitive(), this.parser.getSeparator()); + } + else { + newPE = new LiteralPathElement(this.pathElementStart, getPathElementText(), + this.parser.isCaseSensitive(), this.parser.getSeparator()); + } + } + + return newPE; + } + + /** + * For a path element representing a captured variable, locate the constraint pattern. + * Assumes there is a constraint pattern. + * @param data a complete path expression, e.g. /aaa/bbb/{ccc:...} + * @param offset the start of the capture pattern of interest + * @return the index of the character after the ':' within + * the pattern expression relative to the start of the whole expression + */ + private int findRegexStart(char[] data, int offset) { + int pos = offset; + while (pos < data.length) { + if (data[pos] == ':') { + return pos + 1; + } + pos++; + } + return -1; + } + + /** + * Reset all the flags and position markers computed during path element processing. + */ + private void resetPathElementState() { + this.pathElementStart = -1; + this.singleCharWildcardCount = 0; + this.insideVariableCapture = false; + this.variableCaptureCount = 0; + this.wildcard = false; + this.isCaptureTheRestVariable = false; + this.variableCaptureStart = -1; + } + + /** + * Record a new captured variable. If it clashes with an existing one then report an error. + */ + private void recordCapturedVariable(int pos, String variableName) { + if (this.capturedVariableNames == null) { + this.capturedVariableNames = new ArrayList<>(); + } + if (this.capturedVariableNames.contains(variableName)) { + throw new PatternParseException(pos, this.pathPatternData, + PatternMessage.ILLEGAL_DOUBLE_CAPTURE, variableName); + } + this.capturedVariableNames.add(variableName); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/LiteralPathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/LiteralPathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..ca2c239a4dbbe24cd16e0067d63d2702f9bfbc02 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/LiteralPathElement.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.PathContainer.Element; +import org.springframework.http.server.PathContainer.PathSegment; +import org.springframework.web.util.pattern.PathPattern.MatchingContext; + +/** + * A literal path element. In the pattern '/foo/bar/goo' there are three + * literal path elements 'foo', 'bar' and 'goo'. + * + * @author Andy Clement + */ +class LiteralPathElement extends PathElement { + + private char[] text; + + private int len; + + private boolean caseSensitive; + + + public LiteralPathElement(int pos, char[] literalText, boolean caseSensitive, char separator) { + super(pos, separator); + this.len = literalText.length; + this.caseSensitive = caseSensitive; + if (caseSensitive) { + this.text = literalText; + } + else { + // Force all the text lower case to make matching faster + this.text = new char[literalText.length]; + for (int i = 0; i < this.len; i++) { + this.text[i] = Character.toLowerCase(literalText[i]); + } + } + } + + + @Override + public boolean matches(int pathIndex, MatchingContext matchingContext) { + if (pathIndex >= matchingContext.pathLength) { + // no more path left to match this element + return false; + } + Element element = matchingContext.pathElements.get(pathIndex); + if (!(element instanceof PathContainer.PathSegment)) { + return false; + } + String value = ((PathSegment)element).valueToMatch(); + if (value.length() != this.len) { + // Not enough data to match this path element + return false; + } + + char[] data = ((PathContainer.PathSegment)element).valueToMatchAsChars(); + if (this.caseSensitive) { + for (int i = 0; i < this.len; i++) { + if (data[i] != this.text[i]) { + return false; + } + } + } + else { + for (int i = 0; i < this.len; i++) { + // TODO revisit performance if doing a lot of case insensitive matching + if (Character.toLowerCase(data[i]) != this.text[i]) { + return false; + } + } + } + + pathIndex++; + if (isNoMorePattern()) { + if (matchingContext.determineRemainingPath) { + matchingContext.remainingPathIndex = pathIndex; + return true; + } + else { + if (pathIndex == matchingContext.pathLength) { + return true; + } + else { + return (matchingContext.isMatchOptionalTrailingSeparator() && + (pathIndex + 1) == matchingContext.pathLength && + matchingContext.isSeparator(pathIndex)); + } + } + } + else { + return (this.next != null && this.next.matches(pathIndex, matchingContext)); + } + } + + @Override + public int getNormalizedLength() { + return this.len; + } + + public char[] getChars() { + return this.text; + } + + + public String toString() { + return "Literal(" + String.valueOf(this.text) + ")"; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/PathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/PathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..df4330172645148025758995493907aa90f5bc84 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/PathElement.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.util.pattern.PathPattern.MatchingContext; + +/** + * Common supertype for the Ast nodes created to represent a path pattern. + * + * @author Andy Clement + * @since 5.0 + */ +abstract class PathElement { + + // Score related + protected static final int WILDCARD_WEIGHT = 100; + + protected static final int CAPTURE_VARIABLE_WEIGHT = 1; + + protected static final MultiValueMap NO_PARAMETERS = new LinkedMultiValueMap<>(); + + // Position in the pattern where this path element starts + protected final int pos; + + // The separator used in this path pattern + protected final char separator; + + // The next path element in the chain + @Nullable + protected PathElement next; + + // The previous path element in the chain + @Nullable + protected PathElement prev; + + + /** + * Create a new path element. + * @param pos the position where this path element starts in the pattern data + * @param separator the separator in use in the path pattern + */ + PathElement(int pos, char separator) { + this.pos = pos; + this.separator = separator; + } + + + /** + * Attempt to match this path element. + * @param candidatePos the current position within the candidate path + * @param matchingContext encapsulates context for the match including the candidate + * @return {@code true} if it matches, otherwise {@code false} + */ + public abstract boolean matches(int candidatePos, MatchingContext matchingContext); + + /** + * Return the length of the path element where captures are considered to be one character long. + * @return the normalized length + */ + public abstract int getNormalizedLength(); + + public abstract char[] getChars(); + + /** + * Return the number of variables captured by the path element. + */ + public int getCaptureCount() { + return 0; + } + + /** + * Return the number of wildcard elements (*, ?) in the path element. + */ + public int getWildcardCount() { + return 0; + } + + /** + * Return the score for this PathElement, combined score is used to compare parsed patterns. + */ + public int getScore() { + return 0; + } + + /** + * Return if the there are no more PathElements in the pattern. + * @return {@code true} if the there are no more elements + */ + protected final boolean isNoMorePattern() { + return this.next == null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/PathPattern.java b/spring-web/src/main/java/org/springframework/web/util/pattern/PathPattern.java new file mode 100644 index 0000000000000000000000000000000000000000..acdab2eba2b907061ddb316703f885a8a27e9b39 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/PathPattern.java @@ -0,0 +1,705 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.PathContainer.Element; +import org.springframework.http.server.PathContainer.Separator; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Representation of a parsed path pattern. Includes a chain of path elements + * for fast matching and accumulates computed state for quick comparison of + * patterns. + * + *

{@code PathPattern} matches URL paths using the following rules:
+ *

    + *
  • {@code ?} matches one character
  • + *
  • {@code *} matches zero or more characters within a path segment
  • + *
  • {@code **} matches zero or more path segments until the end of the path
  • + *
  • {spring} matches a path segment and captures it as a variable named "spring"
  • + *
  • {spring:[a-z]+} matches the regexp {@code [a-z]+} as a path variable named "spring"
  • + *
  • {*spring} matches zero or more path segments until the end of the path + * and captures it as a variable named "spring"
  • + *
+ * + *

Examples

+ *
    + *
  • {@code /pages/t?st.html} — matches {@code /pages/test.html} as well as + * {@code /pages/tXst.html} but not {@code /pages/toast.html}
  • + *
  • {@code /resources/*.png} — matches all {@code .png} files in the + * {@code resources} directory
  • + *
  • /resources/** — matches all files + * underneath the {@code /resources/} path, including {@code /resources/image.png} + * and {@code /resources/css/spring.css}
  • + *
  • /resources/{*path} — matches all files + * underneath the {@code /resources/} path and captures their relative path in + * a variable named "path"; {@code /resources/image.png} will match with + * "spring" → "/image.png", and {@code /resources/css/spring.css} will match + * with "spring" → "/css/spring.css"
  • + *
  • /resources/{filename:\\w+}.dat will match {@code /resources/spring.dat} + * and assign the value {@code "spring"} to the {@code filename} variable
  • + *
+ * + * @author Andy Clement + * @author Rossen Stoyanchev + * @since 5.0 + * @see PathContainer + */ +public class PathPattern implements Comparable { + + private static final PathContainer EMPTY_PATH = PathContainer.parsePath(""); + + /** + * Comparator that sorts patterns by specificity as follows: + *
    + *
  1. Null instances are last. + *
  2. Catch-all patterns are last. + *
  3. If both patterns are catch-all, consider the length (longer wins). + *
  4. Compare wildcard and captured variable count (lower wins). + *
  5. Consider length (longer wins) + *
+ */ + public static final Comparator SPECIFICITY_COMPARATOR = + Comparator.nullsLast( + Comparator. + comparingInt(p -> p.isCatchAll() ? 1 : 0) + .thenComparingInt(p -> p.isCatchAll() ? scoreByNormalizedLength(p) : 0) + .thenComparingInt(PathPattern::getScore) + .thenComparingInt(PathPattern::scoreByNormalizedLength) + ); + + + /** The text of the parsed pattern. */ + private final String patternString; + + /** The parser used to construct this pattern. */ + private final PathPatternParser parser; + + /** The separator used when parsing the pattern. */ + private final char separator; + + /** If this pattern has no trailing slash, allow candidates to include one and still match successfully. */ + private final boolean matchOptionalTrailingSeparator; + + /** Will this match candidates in a case sensitive way? (case sensitivity at parse time). */ + private final boolean caseSensitive; + + /** First path element in the parsed chain of path elements for this pattern. */ + @Nullable + private final PathElement head; + + /** How many variables are captured in this pattern. */ + private int capturedVariableCount; + + /** + * The normalized length is trying to measure the 'active' part of the pattern. It is computed + * by assuming all captured variables have a normalized length of 1. Effectively this means changing + * your variable name lengths isn't going to change the length of the active part of the pattern. + * Useful when comparing two patterns. + */ + private int normalizedLength; + + /** + * Does the pattern end with '<separator>'. + */ + private boolean endsWithSeparatorWildcard = false; + + /** + * Score is used to quickly compare patterns. Different pattern components are given different + * weights. A 'lower score' is more specific. Current weights: + *
    + *
  • Captured variables are worth 1 + *
  • Wildcard is worth 100 + *
+ */ + private int score; + + /** Does the pattern end with {*...}. */ + private boolean catchAll = false; + + + PathPattern(String patternText, PathPatternParser parser, @Nullable PathElement head) { + this.patternString = patternText; + this.parser = parser; + this.separator = parser.getSeparator(); + this.matchOptionalTrailingSeparator = parser.isMatchOptionalTrailingSeparator(); + this.caseSensitive = parser.isCaseSensitive(); + this.head = head; + + // Compute fields for fast comparison + PathElement elem = head; + while (elem != null) { + this.capturedVariableCount += elem.getCaptureCount(); + this.normalizedLength += elem.getNormalizedLength(); + this.score += elem.getScore(); + if (elem instanceof CaptureTheRestPathElement || elem instanceof WildcardTheRestPathElement) { + this.catchAll = true; + } + if (elem instanceof SeparatorPathElement && elem.next != null && + elem.next instanceof WildcardPathElement && elem.next.next == null) { + this.endsWithSeparatorWildcard = true; + } + elem = elem.next; + } + } + + + /** + * Return the original String that was parsed to create this PathPattern. + */ + public String getPatternString() { + return this.patternString; + } + + /** + * Whether this pattern matches the given path. + * @param pathContainer the candidate path to attempt to match against + * @return {@code true} if the path matches this pattern + */ + public boolean matches(PathContainer pathContainer) { + if (this.head == null) { + return !hasLength(pathContainer) || + (this.matchOptionalTrailingSeparator && pathContainerIsJustSeparator(pathContainer)); + } + else if (!hasLength(pathContainer)) { + if (this.head instanceof WildcardTheRestPathElement || this.head instanceof CaptureTheRestPathElement) { + pathContainer = EMPTY_PATH; // Will allow CaptureTheRest to bind the variable to empty + } + else { + return false; + } + } + MatchingContext matchingContext = new MatchingContext(pathContainer, false); + return this.head.matches(0, matchingContext); + } + + /** + * Match this pattern to the given URI path and return extracted URI template + * variables as well as path parameters (matrix variables). + * @param pathContainer the candidate path to attempt to match against + * @return info object with the extracted variables, or {@code null} for no match + */ + @Nullable + public PathMatchInfo matchAndExtract(PathContainer pathContainer) { + if (this.head == null) { + return hasLength(pathContainer) && + !(this.matchOptionalTrailingSeparator && pathContainerIsJustSeparator(pathContainer)) + ? null : PathMatchInfo.EMPTY; + } + else if (!hasLength(pathContainer)) { + if (this.head instanceof WildcardTheRestPathElement || this.head instanceof CaptureTheRestPathElement) { + pathContainer = EMPTY_PATH; // Will allow CaptureTheRest to bind the variable to empty + } + else { + return null; + } + } + MatchingContext matchingContext = new MatchingContext(pathContainer, true); + return this.head.matches(0, matchingContext) ? matchingContext.getPathMatchResult() : null; + } + + /** + * Match the beginning of the given path and return the remaining portion + * not covered by this pattern. This is useful for matching nested routes + * where the path is matched incrementally at each level. + * @param pathContainer the candidate path to attempt to match against + * @return info object with the match result or {@code null} for no match + */ + @Nullable + public PathRemainingMatchInfo matchStartOfPath(PathContainer pathContainer) { + if (this.head == null) { + return new PathRemainingMatchInfo(pathContainer); + } + else if (!hasLength(pathContainer)) { + return null; + } + + MatchingContext matchingContext = new MatchingContext(pathContainer, true); + matchingContext.setMatchAllowExtraPath(); + boolean matches = this.head.matches(0, matchingContext); + if (!matches) { + return null; + } + else { + PathRemainingMatchInfo info; + if (matchingContext.remainingPathIndex == pathContainer.elements().size()) { + info = new PathRemainingMatchInfo(EMPTY_PATH, matchingContext.getPathMatchResult()); + } + else { + info = new PathRemainingMatchInfo(pathContainer.subPath(matchingContext.remainingPathIndex), + matchingContext.getPathMatchResult()); + } + return info; + } + } + + /** + * Determine the pattern-mapped part for the given path. + *

For example:

    + *
  • '{@code /docs/cvs/commit.html}' and '{@code /docs/cvs/commit.html} → ''
  • + *
  • '{@code /docs/*}' and '{@code /docs/cvs/commit}' → '{@code cvs/commit}'
  • + *
  • '{@code /docs/cvs/*.html}' and '{@code /docs/cvs/commit.html} → '{@code commit.html}'
  • + *
  • '{@code /docs/**}' and '{@code /docs/cvs/commit} → '{@code cvs/commit}'
  • + *
+ *

Notes: + *

    + *
  • Assumes that {@link #matches} returns {@code true} for + * the same path but does not enforce this. + *
  • Duplicate occurrences of separators within the returned result are removed + *
  • Leading and trailing separators are removed from the returned result + *
+ * @param path a path that matches this pattern + * @return the subset of the path that is matched by pattern or "" if none + * of it is matched by pattern elements + */ + public PathContainer extractPathWithinPattern(PathContainer path) { + List pathElements = path.elements(); + int pathElementsCount = pathElements.size(); + + int startIndex = 0; + // Find first path element that is not a separator or a literal (i.e. the first pattern based element) + PathElement elem = this.head; + while (elem != null) { + if (elem.getWildcardCount() != 0 || elem.getCaptureCount() != 0) { + break; + } + elem = elem.next; + startIndex++; + } + if (elem == null) { + // There is no pattern piece + return PathContainer.parsePath(""); + } + + // Skip leading separators that would be in the result + while (startIndex < pathElementsCount && (pathElements.get(startIndex) instanceof Separator)) { + startIndex++; + } + + int endIndex = pathElements.size(); + // Skip trailing separators that would be in the result + while (endIndex > 0 && (pathElements.get(endIndex - 1) instanceof Separator)) { + endIndex--; + } + + boolean multipleAdjacentSeparators = false; + for (int i = startIndex; i < (endIndex - 1); i++) { + if ((pathElements.get(i) instanceof Separator) && (pathElements.get(i+1) instanceof Separator)) { + multipleAdjacentSeparators=true; + break; + } + } + + PathContainer resultPath = null; + if (multipleAdjacentSeparators) { + // Need to rebuild the path without the duplicate adjacent separators + StringBuilder buf = new StringBuilder(); + int i = startIndex; + while (i < endIndex) { + Element e = pathElements.get(i++); + buf.append(e.value()); + if (e instanceof Separator) { + while (i < endIndex && (pathElements.get(i) instanceof Separator)) { + i++; + } + } + } + resultPath = PathContainer.parsePath(buf.toString()); + } + else if (startIndex >= endIndex) { + resultPath = PathContainer.parsePath(""); + } + else { + resultPath = path.subPath(startIndex, endIndex); + } + return resultPath; + } + + /** + * Compare this pattern with a supplied pattern: return -1,0,+1 if this pattern + * is more specific, the same or less specific than the supplied pattern. + * The aim is to sort more specific patterns first. + */ + @Override + public int compareTo(@Nullable PathPattern otherPattern) { + int result = SPECIFICITY_COMPARATOR.compare(this, otherPattern); + return (result == 0 && otherPattern != null ? + this.patternString.compareTo(otherPattern.patternString) : result); + } + + /** + * Combine this pattern with another. Currently does not produce a new PathPattern, just produces a new string. + */ + public PathPattern combine(PathPattern pattern2string) { + // If one of them is empty the result is the other. If both empty the result is "" + if (!StringUtils.hasLength(this.patternString)) { + if (!StringUtils.hasLength(pattern2string.patternString)) { + return this.parser.parse(""); + } + else { + return pattern2string; + } + } + else if (!StringUtils.hasLength(pattern2string.patternString)) { + return this; + } + + // /* + /hotel => /hotel + // /*.* + /*.html => /*.html + // However: + // /usr + /user => /usr/user + // /{foo} + /bar => /{foo}/bar + if (!this.patternString.equals(pattern2string.patternString) && this.capturedVariableCount == 0 && + matches(PathContainer.parsePath(pattern2string.patternString))) { + return pattern2string; + } + + // /hotels/* + /booking => /hotels/booking + // /hotels/* + booking => /hotels/booking + if (this.endsWithSeparatorWildcard) { + return this.parser.parse(concat( + this.patternString.substring(0, this.patternString.length() - 2), + pattern2string.patternString)); + } + + // /hotels + /booking => /hotels/booking + // /hotels + booking => /hotels/booking + int starDotPos1 = this.patternString.indexOf("*."); // Are there any file prefix/suffix things to consider? + if (this.capturedVariableCount != 0 || starDotPos1 == -1 || this.separator == '.') { + return this.parser.parse(concat(this.patternString, pattern2string.patternString)); + } + + // /*.html + /hotel => /hotel.html + // /*.html + /hotel.* => /hotel.html + String firstExtension = this.patternString.substring(starDotPos1 + 1); // looking for the first extension + String p2string = pattern2string.patternString; + int dotPos2 = p2string.indexOf('.'); + String file2 = (dotPos2 == -1 ? p2string : p2string.substring(0, dotPos2)); + String secondExtension = (dotPos2 == -1 ? "" : p2string.substring(dotPos2)); + boolean firstExtensionWild = (firstExtension.equals(".*") || firstExtension.equals("")); + boolean secondExtensionWild = (secondExtension.equals(".*") || secondExtension.equals("")); + if (!firstExtensionWild && !secondExtensionWild) { + throw new IllegalArgumentException( + "Cannot combine patterns: " + this.patternString + " and " + pattern2string); + } + return this.parser.parse(file2 + (firstExtensionWild ? secondExtension : firstExtension)); + } + + public boolean equals(Object other) { + if (!(other instanceof PathPattern)) { + return false; + } + PathPattern otherPattern = (PathPattern) other; + return (this.patternString.equals(otherPattern.getPatternString()) && + this.separator == otherPattern.getSeparator() && + this.caseSensitive == otherPattern.caseSensitive); + } + + public int hashCode() { + return (this.patternString.hashCode() + this.separator) * 17 + (this.caseSensitive ? 1 : 0); + } + + public String toString() { + return this.patternString; + } + + int getScore() { + return this.score; + } + + boolean isCatchAll() { + return this.catchAll; + } + + /** + * The normalized length is trying to measure the 'active' part of the pattern. It is computed + * by assuming all capture variables have a normalized length of 1. Effectively this means changing + * your variable name lengths isn't going to change the length of the active part of the pattern. + * Useful when comparing two patterns. + */ + int getNormalizedLength() { + return this.normalizedLength; + } + + char getSeparator() { + return this.separator; + } + + int getCapturedVariableCount() { + return this.capturedVariableCount; + } + + String toChainString() { + StringBuilder buf = new StringBuilder(); + PathElement pe = this.head; + while (pe != null) { + buf.append(pe.toString()).append(" "); + pe = pe.next; + } + return buf.toString().trim(); + } + + /** + * Return the string form of the pattern built from walking the path element chain. + * @return the string form of the pattern + */ + String computePatternString() { + StringBuilder buf = new StringBuilder(); + PathElement pe = this.head; + while (pe != null) { + buf.append(pe.getChars()); + pe = pe.next; + } + return buf.toString(); + } + + @Nullable + PathElement getHeadSection() { + return this.head; + } + + /** + * Join two paths together including a separator if necessary. + * Extraneous separators are removed (if the first path + * ends with one and the second path starts with one). + * @param path1 first path + * @param path2 second path + * @return joined path that may include separator if necessary + */ + private String concat(String path1, String path2) { + boolean path1EndsWithSeparator = (path1.charAt(path1.length() - 1) == this.separator); + boolean path2StartsWithSeparator = (path2.charAt(0) == this.separator); + if (path1EndsWithSeparator && path2StartsWithSeparator) { + return path1 + path2.substring(1); + } + else if (path1EndsWithSeparator || path2StartsWithSeparator) { + return path1 + path2; + } + else { + return path1 + this.separator + path2; + } + } + + /** + * Return if the container is not null and has more than zero elements. + * @param container a path container + * @return {@code true} has more than zero elements + */ + private boolean hasLength(@Nullable PathContainer container) { + return container != null && container.elements().size() > 0; + } + + private static int scoreByNormalizedLength(PathPattern pattern) { + return -pattern.getNormalizedLength(); + } + + private boolean pathContainerIsJustSeparator(PathContainer pathContainer) { + return pathContainer.value().length() == 1 && + pathContainer.value().charAt(0) == this.separator; + } + + /** + * Holder for URI variables and path parameters (matrix variables) extracted + * based on the pattern for a given matched path. + */ + public static class PathMatchInfo { + + private static final PathMatchInfo EMPTY = + new PathMatchInfo(Collections.emptyMap(), Collections.emptyMap()); + + + private final Map uriVariables; + + private final Map> matrixVariables; + + + PathMatchInfo(Map uriVars, + @Nullable Map> matrixVars) { + + this.uriVariables = Collections.unmodifiableMap(uriVars); + this.matrixVariables = matrixVars != null ? + Collections.unmodifiableMap(matrixVars) : Collections.emptyMap(); + } + + + /** + * Return the extracted URI variables. + */ + public Map getUriVariables() { + return this.uriVariables; + } + + /** + * Return maps of matrix variables per path segment, keyed off by URI + * variable name. + */ + public Map> getMatrixVariables() { + return this.matrixVariables; + } + + @Override + public String toString() { + return "PathMatchInfo[uriVariables=" + this.uriVariables + ", " + + "matrixVariables=" + this.matrixVariables + "]"; + } + } + + + /** + * Holder for the result of a match on the start of a pattern. + * Provides access to the remaining path not matched to the pattern as well + * as any variables bound in that first part that was matched. + */ + public static class PathRemainingMatchInfo { + + private final PathContainer pathRemaining; + + private final PathMatchInfo pathMatchInfo; + + + PathRemainingMatchInfo(PathContainer pathRemaining) { + this(pathRemaining, PathMatchInfo.EMPTY); + } + + PathRemainingMatchInfo(PathContainer pathRemaining, PathMatchInfo pathMatchInfo) { + this.pathRemaining = pathRemaining; + this.pathMatchInfo = pathMatchInfo; + } + + /** + * Return the part of a path that was not matched by a pattern. + */ + public PathContainer getPathRemaining() { + return this.pathRemaining; + } + + /** + * Return variables that were bound in the part of the path that was + * successfully matched or an empty map. + */ + public Map getUriVariables() { + return this.pathMatchInfo.getUriVariables(); + } + + /** + * Return the path parameters for each bound variable. + */ + public Map> getMatrixVariables() { + return this.pathMatchInfo.getMatrixVariables(); + } + } + + + /** + * Encapsulates context when attempting a match. Includes some fixed state like the + * candidate currently being considered for a match but also some accumulators for + * extracted variables. + */ + class MatchingContext { + + final PathContainer candidate; + + final List pathElements; + + final int pathLength; + + @Nullable + private Map extractedUriVariables; + + @Nullable + private Map> extractedMatrixVariables; + + boolean extractingVariables; + + boolean determineRemainingPath = false; + + // if determineRemaining is true, this is set to the position in + // the candidate where the pattern finished matching - i.e. it + // points to the remaining path that wasn't consumed + int remainingPathIndex; + + public MatchingContext(PathContainer pathContainer, boolean extractVariables) { + this.candidate = pathContainer; + this.pathElements = pathContainer.elements(); + this.pathLength = this.pathElements.size(); + this.extractingVariables = extractVariables; + } + + public void setMatchAllowExtraPath() { + this.determineRemainingPath = true; + } + + public boolean isMatchOptionalTrailingSeparator() { + return matchOptionalTrailingSeparator; + } + + public void set(String key, String value, MultiValueMap parameters) { + if (this.extractedUriVariables == null) { + this.extractedUriVariables = new HashMap<>(); + } + this.extractedUriVariables.put(key, value); + + if (!parameters.isEmpty()) { + if (this.extractedMatrixVariables == null) { + this.extractedMatrixVariables = new HashMap<>(); + } + this.extractedMatrixVariables.put(key, CollectionUtils.unmodifiableMultiValueMap(parameters)); + } + } + + public PathMatchInfo getPathMatchResult() { + if (this.extractedUriVariables == null) { + return PathMatchInfo.EMPTY; + } + else { + return new PathMatchInfo(this.extractedUriVariables, this.extractedMatrixVariables); + } + } + + /** + * Return if element at specified index is a separator. + * @param pathIndex possible index of a separator + * @return {@code true} if element is a separator + */ + boolean isSeparator(int pathIndex) { + return this.pathElements.get(pathIndex) instanceof Separator; + } + + /** + * Return the decoded value of the specified element. + * @param pathIndex path element index + * @return the decoded value + */ + String pathElementValue(int pathIndex) { + Element element = (pathIndex < this.pathLength) ? this.pathElements.get(pathIndex) : null; + if (element instanceof PathContainer.PathSegment) { + return ((PathContainer.PathSegment)element).valueToMatch(); + } + return ""; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/PathPatternParser.java b/spring-web/src/main/java/org/springframework/web/util/pattern/PathPatternParser.java new file mode 100644 index 0000000000000000000000000000000000000000..78625c653cdd8d4aef0cbe5afc4180927f8cdcae --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/PathPatternParser.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +/** + * Parser for URI path patterns producing {@link PathPattern} instances that can + * then be matched to requests. + * + *

The {@link PathPatternParser} and {@link PathPattern} are specifically + * designed for use with HTTP URL paths in web applications where a large number + * of URI path patterns, continuously matched against incoming requests, + * motivates the need for efficient matching. + * + *

For details of the path pattern syntax see {@link PathPattern}. + * + * @author Andy Clement + * @since 5.0 + */ +public class PathPatternParser { + + private boolean matchOptionalTrailingSeparator = true; + + private boolean caseSensitive = true; + + + /** + * Whether a {@link PathPattern} produced by this parser should should + * automatically match request paths with a trailing slash. + * + *

If set to {@code true} a {@code PathPattern} without a trailing slash + * will also match request paths with a trailing slash. If set to + * {@code false} a {@code PathPattern} will only match request paths with + * a trailing slash. + * + *

The default is {@code true}. + */ + public void setMatchOptionalTrailingSeparator(boolean matchOptionalTrailingSeparator) { + this.matchOptionalTrailingSeparator = matchOptionalTrailingSeparator; + } + + /** + * Whether optional trailing slashing match is enabled. + */ + public boolean isMatchOptionalTrailingSeparator() { + return this.matchOptionalTrailingSeparator; + } + + /** + * Whether path pattern matching should be case-sensitive. + *

The default is {@code true}. + */ + public void setCaseSensitive(boolean caseSensitive) { + this.caseSensitive = caseSensitive; + } + + /** + * Whether case-sensitive pattern matching is enabled. + */ + public boolean isCaseSensitive() { + return this.caseSensitive; + } + + /** + * Accessor used for the separator to use. + *

Currently not exposed for configuration with URI path patterns and + * mainly for use in InternalPathPatternParser and PathPattern. If required + * in the future, a similar option would also need to be exposed in + * {@link org.springframework.http.server.PathContainer PathContainer}. + */ + char getSeparator() { + return '/'; + } + + + /** + * Process the path pattern content, a character at a time, breaking it into + * path elements around separator boundaries and verifying the structure at each + * stage. Produces a PathPattern object that can be used for fast matching + * against paths. Each invocation of this method delegates to a new instance of + * the {@link InternalPathPatternParser} because that class is not thread-safe. + * @param pathPattern the input path pattern, e.g. /foo/{bar} + * @return a PathPattern for quickly matching paths against request paths + * @throws PatternParseException in case of parse errors + */ + public PathPattern parse(String pathPattern) throws PatternParseException { + return new InternalPathPatternParser(this).parse(pathPattern); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/PatternParseException.java b/spring-web/src/main/java/org/springframework/web/util/pattern/PatternParseException.java new file mode 100644 index 0000000000000000000000000000000000000000..cfd62eccf45672962bc991503e95dc6b5c990d77 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/PatternParseException.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.text.MessageFormat; + +/** + * Exception that is thrown when there is a problem with the pattern being parsed. + * + * @author Andy Clement + * @since 5.0 + */ +@SuppressWarnings("serial") +public class PatternParseException extends IllegalArgumentException { + + private final int position; + + private final char[] pattern; + + private final PatternMessage messageType; + + private final Object[] inserts; + + + PatternParseException(int pos, char[] pattern, PatternMessage messageType, Object... inserts) { + super(messageType.formatMessage(inserts)); + this.position = pos; + this.pattern = pattern; + this.messageType = messageType; + this.inserts = inserts; + } + + PatternParseException(Throwable cause, int pos, char[] pattern, PatternMessage messageType, Object... inserts) { + super(messageType.formatMessage(inserts), cause); + this.position = pos; + this.pattern = pattern; + this.messageType = messageType; + this.inserts = inserts; + } + + + /** + * Return a formatted message with inserts applied. + */ + @Override + public String getMessage() { + return this.messageType.formatMessage(this.inserts); + } + + /** + * Return a detailed message that includes the original pattern text + * with a pointer to the error position, as well as the error message. + */ + public String toDetailedString() { + StringBuilder buf = new StringBuilder(); + buf.append(this.pattern).append('\n'); + for (int i = 0; i < this.position; i++) { + buf.append(' '); + } + buf.append("^\n"); + buf.append(getMessage()); + return buf.toString(); + } + + public int getPosition() { + return this.position; + } + + public PatternMessage getMessageType() { + return this.messageType; + } + + public Object[] getInserts() { + return this.inserts; + } + + + /** + * The messages that can be included in a {@link PatternParseException} when there is a parse failure. + */ + public enum PatternMessage { + + MISSING_CLOSE_CAPTURE("Expected close capture character after variable name '}'"), + MISSING_OPEN_CAPTURE("Missing preceding open capture character before variable name'{'"), + ILLEGAL_NESTED_CAPTURE("Not allowed to nest variable captures"), + CANNOT_HAVE_ADJACENT_CAPTURES("Adjacent captures are not allowed"), + ILLEGAL_CHARACTER_AT_START_OF_CAPTURE_DESCRIPTOR("Char ''{0}'' not allowed at start of captured variable name"), + ILLEGAL_CHARACTER_IN_CAPTURE_DESCRIPTOR("Char ''{0}'' is not allowed in a captured variable name"), + NO_MORE_DATA_EXPECTED_AFTER_CAPTURE_THE_REST("No more pattern data allowed after '{*...}' pattern element"), + BADLY_FORMED_CAPTURE_THE_REST("Expected form when capturing the rest of the path is simply '{*...}'"), + MISSING_REGEX_CONSTRAINT("Missing regex constraint on capture"), + ILLEGAL_DOUBLE_CAPTURE("Not allowed to capture ''{0}'' twice in the same pattern"), + REGEX_PATTERN_SYNTAX_EXCEPTION("Exception occurred in regex pattern compilation"), + CAPTURE_ALL_IS_STANDALONE_CONSTRUCT("'{*...}' can only be preceded by a path separator"); + + private final String message; + + PatternMessage(String message) { + this.message = message; + } + + public String formatMessage(Object... inserts) { + return MessageFormat.format(this.message, inserts); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/RegexPathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/RegexPathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..e821f615e8d2c5d38716755c53c16807c625f2ac --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/RegexPathElement.java @@ -0,0 +1,215 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.util.LinkedList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.springframework.http.server.PathContainer.PathSegment; +import org.springframework.web.util.pattern.PathPattern.MatchingContext; + +/** + * A regex path element. Used to represent any complicated element of the path. + * For example in '/foo/*_*/*_{foobar}' both *_* and *_{foobar} + * are {@link RegexPathElement} path elements. Derived from the general + * {@link org.springframework.util.AntPathMatcher} approach. + * + * @author Andy Clement + * @since 5.0 + */ +class RegexPathElement extends PathElement { + + private static final Pattern GLOB_PATTERN = Pattern.compile("\\?|\\*|\\{((?:\\{[^/]+?\\}|[^/{}]|\\\\[{}])+?)\\}"); + + private static final String DEFAULT_VARIABLE_PATTERN = "(.*)"; + + + private final char[] regex; + + private final boolean caseSensitive; + + private final Pattern pattern; + + private int wildcardCount; + + private final List variableNames = new LinkedList<>(); + + + RegexPathElement(int pos, char[] regex, boolean caseSensitive, char[] completePattern, char separator) { + super(pos, separator); + this.regex = regex; + this.caseSensitive = caseSensitive; + this.pattern = buildPattern(regex, completePattern); + } + + + public Pattern buildPattern(char[] regex, char[] completePattern) { + StringBuilder patternBuilder = new StringBuilder(); + String text = new String(regex); + Matcher matcher = GLOB_PATTERN.matcher(text); + int end = 0; + + while (matcher.find()) { + patternBuilder.append(quote(text, end, matcher.start())); + String match = matcher.group(); + if ("?".equals(match)) { + patternBuilder.append('.'); + } + else if ("*".equals(match)) { + patternBuilder.append(".*"); + int pos = matcher.start(); + if (pos < 1 || text.charAt(pos-1) != '.') { + // To be compatible with the AntPathMatcher comparator, + // '.*' is not considered a wildcard usage + this.wildcardCount++; + } + } + else if (match.startsWith("{") && match.endsWith("}")) { + int colonIdx = match.indexOf(':'); + if (colonIdx == -1) { + patternBuilder.append(DEFAULT_VARIABLE_PATTERN); + String variableName = matcher.group(1); + if (this.variableNames.contains(variableName)) { + throw new PatternParseException(this.pos, completePattern, + PatternParseException.PatternMessage.ILLEGAL_DOUBLE_CAPTURE, variableName); + } + this.variableNames.add(variableName); + } + else { + String variablePattern = match.substring(colonIdx + 1, match.length() - 1); + patternBuilder.append('('); + patternBuilder.append(variablePattern); + patternBuilder.append(')'); + String variableName = match.substring(1, colonIdx); + if (this.variableNames.contains(variableName)) { + throw new PatternParseException(this.pos, completePattern, + PatternParseException.PatternMessage.ILLEGAL_DOUBLE_CAPTURE, variableName); + } + this.variableNames.add(variableName); + } + } + end = matcher.end(); + } + + patternBuilder.append(quote(text, end, text.length())); + if (this.caseSensitive) { + return Pattern.compile(patternBuilder.toString()); + } + else { + return Pattern.compile(patternBuilder.toString(), Pattern.CASE_INSENSITIVE); + } + } + + public List getVariableNames() { + return this.variableNames; + } + + private String quote(String s, int start, int end) { + if (start == end) { + return ""; + } + return Pattern.quote(s.substring(start, end)); + } + + @Override + public boolean matches(int pathIndex, MatchingContext matchingContext) { + String textToMatch = matchingContext.pathElementValue(pathIndex); + Matcher matcher = this.pattern.matcher(textToMatch); + boolean matches = matcher.matches(); + + if (matches) { + if (isNoMorePattern()) { + if (matchingContext.determineRemainingPath && + (this.variableNames.isEmpty() || textToMatch.length() > 0)) { + matchingContext.remainingPathIndex = pathIndex + 1; + matches = true; + } + else { + // No more pattern, is there more data? + // If pattern is capturing variables there must be some actual data to bind to them + matches = (pathIndex + 1) >= matchingContext.pathLength + && (this.variableNames.isEmpty() || textToMatch.length() > 0); + if (!matches && matchingContext.isMatchOptionalTrailingSeparator()) { + matches = (this.variableNames.isEmpty() + || textToMatch.length() > 0) + && (pathIndex + 2) >= matchingContext.pathLength + && matchingContext.isSeparator(pathIndex + 1); + } + } + } + else { + matches = (this.next != null && this.next.matches(pathIndex + 1, matchingContext)); + } + } + + if (matches && matchingContext.extractingVariables) { + // Process captures + if (this.variableNames.size() != matcher.groupCount()) { // SPR-8455 + throw new IllegalArgumentException("The number of capturing groups in the pattern segment " + + this.pattern + " does not match the number of URI template variables it defines, " + + "which can occur if capturing groups are used in a URI template regex. " + + "Use non-capturing groups instead."); + } + for (int i = 1; i <= matcher.groupCount(); i++) { + String name = this.variableNames.get(i - 1); + String value = matcher.group(i); + matchingContext.set(name, value, + (i == this.variableNames.size())? + ((PathSegment)matchingContext.pathElements.get(pathIndex)).parameters(): + NO_PARAMETERS); + } + } + return matches; + } + + @Override + public int getNormalizedLength() { + int varsLength = 0; + for (String variableName : this.variableNames) { + varsLength += variableName.length(); + } + return (this.regex.length - varsLength - this.variableNames.size()); + } + + @Override + public int getCaptureCount() { + return this.variableNames.size(); + } + + @Override + public int getWildcardCount() { + return this.wildcardCount; + } + + @Override + public int getScore() { + return (getCaptureCount() * CAPTURE_VARIABLE_WEIGHT + getWildcardCount() * WILDCARD_WEIGHT); + } + + + @Override + public String toString() { + return "Regex(" + String.valueOf(this.regex) + ")"; + } + + @Override + public char[] getChars() { + return this.regex; + } +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/SeparatorPathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/SeparatorPathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..43968a1b4c7dafefdb8ed992e8e0f8412130b5aa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/SeparatorPathElement.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import org.springframework.web.util.pattern.PathPattern.MatchingContext; + +/** + * A separator path element. In the pattern '/foo/bar' the two occurrences + * of '/' will be represented by a SeparatorPathElement (if the default + * separator of '/' is being used). + * + * @author Andy Clement + * @since 5.0 + */ +class SeparatorPathElement extends PathElement { + + SeparatorPathElement(int pos, char separator) { + super(pos, separator); + } + + + /** + * Matching a separator is easy, basically the character at candidateIndex + * must be the separator. + */ + @Override + public boolean matches(int pathIndex, MatchingContext matchingContext) { + if (pathIndex < matchingContext.pathLength && matchingContext.isSeparator(pathIndex)) { + if (isNoMorePattern()) { + if (matchingContext.determineRemainingPath) { + matchingContext.remainingPathIndex = pathIndex + 1; + return true; + } + else { + return (pathIndex + 1 == matchingContext.pathLength); + } + } + else { + pathIndex++; + return (this.next != null && this.next.matches(pathIndex, matchingContext)); + } + } + return false; + } + + @Override + public int getNormalizedLength() { + return 1; + } + + public String toString() { + return "Separator(" + this.separator + ")"; + } + + public char[] getChars() { + return new char[] {this.separator}; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/SingleCharWildcardedPathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/SingleCharWildcardedPathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..b6d8327d58a241ce342d43ea6e6bf5ce761c9d07 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/SingleCharWildcardedPathElement.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import org.springframework.http.server.PathContainer.Element; +import org.springframework.http.server.PathContainer.PathSegment; +import org.springframework.web.util.pattern.PathPattern.MatchingContext; + +/** + * A literal path element that does includes the single character wildcard '?' one + * or more times (to basically many any character at that position). + * + * @author Andy Clement + * @since 5.0 + */ +class SingleCharWildcardedPathElement extends PathElement { + + private final char[] text; + + private final int len; + + private final int questionMarkCount; + + private final boolean caseSensitive; + + + public SingleCharWildcardedPathElement( + int pos, char[] literalText, int questionMarkCount, boolean caseSensitive, char separator) { + + super(pos, separator); + this.len = literalText.length; + this.questionMarkCount = questionMarkCount; + this.caseSensitive = caseSensitive; + if (caseSensitive) { + this.text = literalText; + } + else { + this.text = new char[literalText.length]; + for (int i = 0; i < this.len; i++) { + this.text[i] = Character.toLowerCase(literalText[i]); + } + } + } + + + @Override + public boolean matches(int pathIndex, MatchingContext matchingContext) { + if (pathIndex >= matchingContext.pathLength) { + // no more path left to match this element + return false; + } + + Element element = matchingContext.pathElements.get(pathIndex); + if (!(element instanceof PathSegment)) { + return false; + } + String value = ((PathSegment)element).valueToMatch(); + if (value.length() != this.len) { + // Not enough data to match this path element + return false; + } + + char[] data = ((PathSegment)element).valueToMatchAsChars(); + if (this.caseSensitive) { + for (int i = 0; i < this.len; i++) { + char ch = this.text[i]; + if ((ch != '?') && (ch != data[i])) { + return false; + } + } + } + else { + for (int i = 0; i < this.len; i++) { + char ch = this.text[i]; + // TODO revisit performance if doing a lot of case insensitive matching + if ((ch != '?') && (ch != Character.toLowerCase(data[i]))) { + return false; + } + } + } + + pathIndex++; + if (isNoMorePattern()) { + if (matchingContext.determineRemainingPath) { + matchingContext.remainingPathIndex = pathIndex; + return true; + } + else { + if (pathIndex == matchingContext.pathLength) { + return true; + } + else { + return (matchingContext.isMatchOptionalTrailingSeparator() && + (pathIndex + 1) == matchingContext.pathLength && + matchingContext.isSeparator(pathIndex)); + } + } + } + else { + return (this.next != null && this.next.matches(pathIndex, matchingContext)); + } + } + + @Override + public int getWildcardCount() { + return this.questionMarkCount; + } + + @Override + public int getNormalizedLength() { + return this.len; + } + + + public String toString() { + return "SingleCharWildcarded(" + String.valueOf(this.text) + ")"; + } + + @Override + public char[] getChars() { + return this.text; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/WildcardPathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/WildcardPathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..9b1101a47d0dca41d84e7f151e0a527a984739aa --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/WildcardPathElement.java @@ -0,0 +1,108 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.PathContainer.Element; +import org.springframework.web.util.pattern.PathPattern.MatchingContext; + +/** + * A wildcard path element. In the pattern '/foo/*/goo' the * is + * represented by a WildcardPathElement. Within a path it matches at least + * one character but at the end of a path it can match zero characters. + * + * @author Andy Clement + * @since 5.0 + */ +class WildcardPathElement extends PathElement { + + public WildcardPathElement(int pos, char separator) { + super(pos, separator); + } + + + /** + * Matching on a WildcardPathElement is quite straight forward. Scan the + * candidate from the candidateIndex onwards for the next separator or the end of the + * candidate. + */ + @Override + public boolean matches(int pathIndex, MatchingContext matchingContext) { + String segmentData = null; + // Assert if it exists it is a segment + if (pathIndex < matchingContext.pathLength) { + Element element = matchingContext.pathElements.get(pathIndex); + if (!(element instanceof PathContainer.PathSegment)) { + // Should not match a separator + return false; + } + segmentData = ((PathContainer.PathSegment)element).valueToMatch(); + pathIndex++; + } + + if (isNoMorePattern()) { + if (matchingContext.determineRemainingPath) { + matchingContext.remainingPathIndex = pathIndex; + return true; + } + else { + if (pathIndex == matchingContext.pathLength) { + // and the path data has run out too + return true; + } + else { + return (matchingContext.isMatchOptionalTrailingSeparator() && // if optional slash is on... + segmentData != null && segmentData.length() > 0 && // and there is at least one character to match the *... + (pathIndex + 1) == matchingContext.pathLength && // and the next path element is the end of the candidate... + matchingContext.isSeparator(pathIndex)); // and the final element is a separator + } + } + } + else { + // Within a path (e.g. /aa/*/bb) there must be at least one character to match the wildcard + if (segmentData == null || segmentData.length() == 0) { + return false; + } + return (this.next != null && this.next.matches(pathIndex, matchingContext)); + } + } + + @Override + public int getNormalizedLength() { + return 1; + } + + @Override + public int getWildcardCount() { + return 1; + } + + @Override + public int getScore() { + return WILDCARD_WEIGHT; + } + + + public String toString() { + return "Wildcard(*)"; + } + + @Override + public char[] getChars() { + return new char[] {'*'}; + } +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/WildcardTheRestPathElement.java b/spring-web/src/main/java/org/springframework/web/util/pattern/WildcardTheRestPathElement.java new file mode 100644 index 0000000000000000000000000000000000000000..b494a72af75e8db4697ab21cef229dc4723ffde6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/WildcardTheRestPathElement.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +/** + * A path element representing wildcarding the rest of a path. In the pattern + * '/foo/**' the /** is represented as a {@link WildcardTheRestPathElement}. + * + * @author Andy Clement + * @since 5.0 + */ +class WildcardTheRestPathElement extends PathElement { + + WildcardTheRestPathElement(int pos, char separator) { + super(pos, separator); + } + + + @Override + public boolean matches(int pathIndex, PathPattern.MatchingContext matchingContext) { + // If there is more data, it must start with the separator + if (pathIndex < matchingContext.pathLength && !matchingContext.isSeparator(pathIndex)) { + return false; + } + if (matchingContext.determineRemainingPath) { + matchingContext.remainingPathIndex = matchingContext.pathLength; + } + return true; + } + + @Override + public int getNormalizedLength() { + return 1; + } + + @Override + public int getWildcardCount() { + return 1; + } + + + public String toString() { + return "WildcardTheRest(" + this.separator + "**)"; + } + + @Override + public char[] getChars() { + return (this.separator+"**").toCharArray(); + } +} diff --git a/spring-web/src/main/java/org/springframework/web/util/pattern/package-info.java b/spring-web/src/main/java/org/springframework/web/util/pattern/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..f6a5f3fce8b11454d3cbb3243df443c2f676c23c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/util/pattern/package-info.java @@ -0,0 +1,14 @@ +/** + * Dedicated support for matching HTTP request paths. + * + *

{@link org.springframework.web.util.pattern.PathPatternParser} is used to + * parse URI path patterns into + * {@link org.springframework.web.util.pattern.PathPattern org.springframework.web.util.pattern.PathPatterns} that can then be + * used for matching purposes at request time. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.util.pattern; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/main/java/overview.html b/spring-web/src/main/java/overview.html new file mode 100644 index 0000000000000000000000000000000000000000..cef93c70025f426755b5c8bba2c1672a5d677126 --- /dev/null +++ b/spring-web/src/main/java/overview.html @@ -0,0 +1,7 @@ + + +

+Spring's core web support packages, for any kind of web environment. +

+ + \ No newline at end of file diff --git a/spring-web/src/main/kotlin/org/springframework/web/client/RestOperationsExtensions.kt b/spring-web/src/main/kotlin/org/springframework/web/client/RestOperationsExtensions.kt new file mode 100644 index 0000000000000000000000000000000000000000..e4d6278e74b7b1ae6bbc12f9c341104d484481ef --- /dev/null +++ b/spring-web/src/main/kotlin/org/springframework/web/client/RestOperationsExtensions.kt @@ -0,0 +1,290 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client + +import org.springframework.core.ParameterizedTypeReference +import org.springframework.http.HttpEntity +import org.springframework.http.HttpMethod +import org.springframework.http.RequestEntity +import org.springframework.http.ResponseEntity +import java.net.URI + +/** + * Extension for [RestOperations.getForObject] providing a `getForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.getForObject(url: String, vararg uriVariables: Any): T? = + getForObject(url, T::class.java, *uriVariables) + +/** + * Extension for [RestOperations.getForObject] providing a `getForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.getForObject(url: String, uriVariables: Map): T? = + getForObject(url, T::class.java, uriVariables) + +/** + * Extension for [RestOperations.getForObject] providing a `getForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.getForObject(url: URI): T? = + getForObject(url, T::class.java) + +/** + * Extension for [RestOperations.getForEntity] providing a `getForEntity(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Sebastien Deleuze + * @since 5.0.2 + */ +@Throws(RestClientException::class) +inline fun RestOperations.getForEntity(url: URI): ResponseEntity = + getForEntity(url, T::class.java) + +/** + * Extension for [RestOperations.getForEntity] providing a `getForEntity(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.getForEntity(url: String, vararg uriVariables: Any): ResponseEntity = + getForEntity(url, T::class.java, *uriVariables) + +/** + * Extension for [RestOperations.getForEntity] providing a `getForEntity(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Sebastien Deleuze + * @since 5.0.2 + */ +@Throws(RestClientException::class) +inline fun RestOperations.getForEntity(url: String, uriVariables: Map): ResponseEntity = + getForEntity(url, T::class.java, uriVariables) + +/** + * Extension for [RestOperations.patchForObject] providing a `patchForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Sebastien Deleuze + * @since 5.0.2 + */ +@Throws(RestClientException::class) +inline fun RestOperations.patchForObject(url: String, request: Any? = null, + vararg uriVariables: Any): T? = + patchForObject(url, request, T::class.java, *uriVariables) + +/** + * Extension for [RestOperations.patchForObject] providing a `patchForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Sebastien Deleuze + * @since 5.0.2 + */ +@Throws(RestClientException::class) +inline fun RestOperations.patchForObject(url: String, request: Any? = null, + uriVariables: Map): T? = + patchForObject(url, request, T::class.java, uriVariables) + +/** + * Extension for [RestOperations.patchForObject] providing a `patchForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Sebastien Deleuze + * @since 5.0.2 + */ +@Throws(RestClientException::class) +inline fun RestOperations.patchForObject(url: URI, request: Any? = null): T? = + patchForObject(url, request, T::class.java) + +/** + * Extension for [RestOperations.postForObject] providing a `postForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.postForObject(url: String, request: Any? = null, + vararg uriVariables: Any): T? = + postForObject(url, request, T::class.java, *uriVariables) + +/** + * Extension for [RestOperations.postForObject] providing a `postForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.postForObject(url: String, request: Any? = null, + uriVariables: Map): T? = + postForObject(url, request, T::class.java, uriVariables) + +/** + * Extension for [RestOperations.postForObject] providing a `postForObject(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.postForObject(url: URI, request: Any? = null): T? = + postForObject(url, request, T::class.java) + +/** + * Extension for [RestOperations.postForEntity] providing a `postForEntity(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.postForEntity(url: String, request: Any? = null, + vararg uriVariables: Any): ResponseEntity = + postForEntity(url, request, T::class.java, *uriVariables) + +/** + * Extension for [RestOperations.postForEntity] providing a `postForEntity(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.postForEntity(url: String, request: Any? = null, + uriVariables: Map): ResponseEntity = + postForEntity(url, request, T::class.java, uriVariables) + +/** + * Extension for [RestOperations.postForEntity] providing a `postForEntity(...)` + * variant leveraging Kotlin reified type parameters. Like the original Java method, this + * extension is subject to type erasure. Use [exchange] if you need to retain actual + * generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.postForEntity(url: URI, request: Any? = null): ResponseEntity = + postForEntity(url, request, T::class.java) + +/** + * Extension for [RestOperations.exchange] providing an `exchange(...)` + * variant leveraging Kotlin reified type parameters. This extension is not subject to + * type erasure and retains actual generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.exchange(url: String, method: HttpMethod, + requestEntity: HttpEntity<*>? = null, vararg uriVariables: Any): ResponseEntity = + exchange(url, method, requestEntity, object : ParameterizedTypeReference() {}, *uriVariables) + +/** + * Extension for [RestOperations.exchange] providing an `exchange(...)` + * variant leveraging Kotlin reified type parameters. This extension is not subject to + * type erasure and retains actual generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.exchange(url: String, method: HttpMethod, + requestEntity: HttpEntity<*>? = null, uriVariables: Map): ResponseEntity = + exchange(url, method, requestEntity, object : ParameterizedTypeReference() {}, uriVariables) + +/** + * Extension for [RestOperations.exchange] providing an `exchange(...)` + * variant leveraging Kotlin reified type parameters. This extension is not subject to + * type erasure and retains actual generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.exchange(url: URI, method: HttpMethod, + requestEntity: HttpEntity<*>? = null): ResponseEntity = + exchange(url, method, requestEntity, object : ParameterizedTypeReference() {}) + +/** + * Extension for [RestOperations.exchange] providing an `exchange(...)` + * variant leveraging Kotlin reified type parameters. This extension is not subject to + * type erasure and retains actual generic type arguments. + * + * @author Jon Schneider + * @author Sebastien Deleuze + * @since 5.0 + */ +@Throws(RestClientException::class) +inline fun RestOperations.exchange(requestEntity: RequestEntity<*>): ResponseEntity = + exchange(requestEntity, object : ParameterizedTypeReference() {}) diff --git a/spring-web/src/main/resources/META-INF/services/javax.servlet.ServletContainerInitializer b/spring-web/src/main/resources/META-INF/services/javax.servlet.ServletContainerInitializer new file mode 100644 index 0000000000000000000000000000000000000000..b0ab3c3f0068640bffc7b4a3dc46b8bae367743e --- /dev/null +++ b/spring-web/src/main/resources/META-INF/services/javax.servlet.ServletContainerInitializer @@ -0,0 +1 @@ +org.springframework.web.SpringServletContainerInitializer \ No newline at end of file diff --git a/spring-web/src/main/resources/META-INF/web-fragment.xml b/spring-web/src/main/resources/META-INF/web-fragment.xml new file mode 100644 index 0000000000000000000000000000000000000000..ef85d0866ae466f0a9e3579a77d0159d2c17ae29 --- /dev/null +++ b/spring-web/src/main/resources/META-INF/web-fragment.xml @@ -0,0 +1,10 @@ + + + + spring_web + + + diff --git a/spring-web/src/main/resources/org/springframework/http/codec/CodecConfigurer.properties b/spring-web/src/main/resources/org/springframework/http/codec/CodecConfigurer.properties new file mode 100644 index 0000000000000000000000000000000000000000..dedf3138c1d149dc1efa59418b26c7362aa8b8d0 --- /dev/null +++ b/spring-web/src/main/resources/org/springframework/http/codec/CodecConfigurer.properties @@ -0,0 +1,5 @@ +# Default CodecConfigurer implementation classes for static Client/ServerCodecConfigurer.create() calls. +# Not meant to be customized by application developers; simply instantiate custom impl classes instead. + +org.springframework.http.codec.ClientCodecConfigurer=org.springframework.http.codec.support.DefaultClientCodecConfigurer +org.springframework.http.codec.ServerCodecConfigurer=org.springframework.http.codec.support.DefaultServerCodecConfigurer diff --git a/spring-web/src/main/resources/org/springframework/http/mime.types b/spring-web/src/main/resources/org/springframework/http/mime.types new file mode 100644 index 0000000000000000000000000000000000000000..597425c1184600deb352e79c149ae8a0efc78637 --- /dev/null +++ b/spring-web/src/main/resources/org/springframework/http/mime.types @@ -0,0 +1,1855 @@ +# This file maps Internet media types to unique file extension(s). +# Although created for httpd, this file is used by many software systems +# and has been placed in the public domain for unlimited redistribution. +# +# The table below contains both registered and (common) unregistered types. +# A type that has no unique extension can be ignored -- they are listed +# here to guide configurations toward known types and to make it easier to +# identify "new" types. File extensions are also commonly used to indicate +# content languages and encodings, so choose them carefully. +# +# Internet media types should be registered as described in RFC 4288. +# The registry is at . +# +# This file was retrieved from https://svn.apache.org/viewvc/httpd/httpd/trunk/docs/conf/mime.types?revision=1752884&view=co +# +# MIME type (lowercased) Extensions +# ============================================ ========== +# application/1d-interleaved-parityfec +# application/3gpdash-qoe-report+xml +# application/3gpp-ims+xml +# application/a2l +# application/activemessage +# application/alto-costmap+json +# application/alto-costmapfilter+json +# application/alto-directory+json +# application/alto-endpointcost+json +# application/alto-endpointcostparams+json +# application/alto-endpointprop+json +# application/alto-endpointpropparams+json +# application/alto-error+json +# application/alto-networkmap+json +# application/alto-networkmapfilter+json +# application/aml +application/andrew-inset ez +# application/applefile +application/applixware aw +# application/atf +# application/atfx +application/atom+xml atom +application/atomcat+xml atomcat +# application/atomdeleted+xml +# application/atomicmail +application/atomsvc+xml atomsvc +# application/atxml +# application/auth-policy+xml +# application/bacnet-xdd+zip +# application/batch-smtp +# application/beep+xml +# application/calendar+json +# application/calendar+xml +# application/call-completion +# application/cals-1840 +# application/cbor +# application/ccmp+xml +application/ccxml+xml ccxml +# application/cdfx+xml +application/cdmi-capability cdmia +application/cdmi-container cdmic +application/cdmi-domain cdmid +application/cdmi-object cdmio +application/cdmi-queue cdmiq +# application/cdni +# application/cea +# application/cea-2018+xml +# application/cellml+xml +# application/cfw +# application/cms +# application/cnrp+xml +# application/coap-group+json +# application/commonground +# application/conference-info+xml +# application/cpl+xml +# application/csrattrs +# application/csta+xml +# application/cstadata+xml +# application/csvm+json +application/cu-seeme cu +# application/cybercash +# application/dash+xml +# application/dashdelta +application/davmount+xml davmount +# application/dca-rft +# application/dcd +# application/dec-dx +# application/dialog-info+xml +# application/dicom +# application/dii +# application/dit +# application/dns +application/docbook+xml dbk +# application/dskpp+xml +application/dssc+der dssc +application/dssc+xml xdssc +# application/dvcs +application/ecmascript ecma +# application/edi-consent +# application/edi-x12 +# application/edifact +# application/efi +# application/emergencycalldata.comment+xml +# application/emergencycalldata.deviceinfo+xml +# application/emergencycalldata.providerinfo+xml +# application/emergencycalldata.serviceinfo+xml +# application/emergencycalldata.subscriberinfo+xml +application/emma+xml emma +# application/emotionml+xml +# application/encaprtp +# application/epp+xml +application/epub+zip epub +# application/eshop +# application/example +application/exi exi +# application/fastinfoset +# application/fastsoap +# application/fdt+xml +# application/fits +# application/font-sfnt +application/font-tdpfr pfr +application/font-woff woff +# application/framework-attributes+xml +# application/geo+json +application/gml+xml gml +application/gpx+xml gpx +application/gxf gxf +# application/gzip +# application/h224 +# application/held+xml +# application/http +application/hyperstudio stk +# application/ibe-key-request+xml +# application/ibe-pkg-reply+xml +# application/ibe-pp-data +# application/iges +# application/im-iscomposing+xml +# application/index +# application/index.cmd +# application/index.obj +# application/index.response +# application/index.vnd +application/inkml+xml ink inkml +# application/iotp +application/ipfix ipfix +# application/ipp +# application/isup +# application/its+xml +application/java-archive jar +application/java-serialized-object ser +application/java-vm class +application/javascript js +# application/jose +# application/jose+json +# application/jrd+json +application/json json +# application/json-patch+json +# application/json-seq +application/jsonml+json jsonml +# application/jwk+json +# application/jwk-set+json +# application/jwt +# application/kpml-request+xml +# application/kpml-response+xml +# application/ld+json +# application/lgr+xml +# application/link-format +# application/load-control+xml +application/lost+xml lostxml +# application/lostsync+xml +# application/lxf +application/mac-binhex40 hqx +application/mac-compactpro cpt +# application/macwriteii +application/mads+xml mads +application/marc mrc +application/marcxml+xml mrcx +application/mathematica ma nb mb +application/mathml+xml mathml +# application/mathml-content+xml +# application/mathml-presentation+xml +# application/mbms-associated-procedure-description+xml +# application/mbms-deregister+xml +# application/mbms-envelope+xml +# application/mbms-msk+xml +# application/mbms-msk-response+xml +# application/mbms-protection-description+xml +# application/mbms-reception-report+xml +# application/mbms-register+xml +# application/mbms-register-response+xml +# application/mbms-schedule+xml +# application/mbms-user-service-description+xml +application/mbox mbox +# application/media-policy-dataset+xml +# application/media_control+xml +application/mediaservercontrol+xml mscml +# application/merge-patch+json +application/metalink+xml metalink +application/metalink4+xml meta4 +application/mets+xml mets +# application/mf4 +# application/mikey +application/mods+xml mods +# application/moss-keys +# application/moss-signature +# application/mosskey-data +# application/mosskey-request +application/mp21 m21 mp21 +application/mp4 mp4s +# application/mpeg4-generic +# application/mpeg4-iod +# application/mpeg4-iod-xmt +# application/mrb-consumer+xml +# application/mrb-publish+xml +# application/msc-ivr+xml +# application/msc-mixer+xml +application/msword doc dot +application/mxf mxf +# application/nasdata +# application/news-checkgroups +# application/news-groupinfo +# application/news-transmission +# application/nlsml+xml +# application/nss +# application/ocsp-request +# application/ocsp-response +application/octet-stream bin dms lrf mar so dist distz pkg bpk dump elc deploy +application/oda oda +# application/odx +application/oebps-package+xml opf +application/ogg ogx +application/omdoc+xml omdoc +application/onenote onetoc onetoc2 onetmp onepkg +application/oxps oxps +# application/p2p-overlay+xml +# application/parityfec +application/patch-ops-error+xml xer +application/pdf pdf +# application/pdx +application/pgp-encrypted pgp +# application/pgp-keys +application/pgp-signature asc sig +application/pics-rules prf +# application/pidf+xml +# application/pidf-diff+xml +application/pkcs10 p10 +# application/pkcs12 +application/pkcs7-mime p7m p7c +application/pkcs7-signature p7s +application/pkcs8 p8 +application/pkix-attr-cert ac +application/pkix-cert cer +application/pkix-crl crl +application/pkix-pkipath pkipath +application/pkixcmp pki +application/pls+xml pls +# application/poc-settings+xml +application/postscript ai eps ps +# application/ppsp-tracker+json +# application/problem+json +# application/problem+xml +# application/provenance+xml +# application/prs.alvestrand.titrax-sheet +application/prs.cww cww +# application/prs.hpub+zip +# application/prs.nprend +# application/prs.plucker +# application/prs.rdf-xml-crypt +# application/prs.xsf+xml +application/pskc+xml pskcxml +# application/qsig +# application/raptorfec +# application/rdap+json +application/rdf+xml rdf +application/reginfo+xml rif +application/relax-ng-compact-syntax rnc +# application/remote-printing +# application/reputon+json +application/resource-lists+xml rl +application/resource-lists-diff+xml rld +# application/rfc+xml +# application/riscos +# application/rlmi+xml +application/rls-services+xml rs +application/rpki-ghostbusters gbr +application/rpki-manifest mft +application/rpki-roa roa +# application/rpki-updown +application/rsd+xml rsd +application/rss+xml rss +application/rtf rtf +# application/rtploopback +# application/rtx +# application/samlassertion+xml +# application/samlmetadata+xml +application/sbml+xml sbml +# application/scaip+xml +# application/scim+json +application/scvp-cv-request scq +application/scvp-cv-response scs +application/scvp-vp-request spq +application/scvp-vp-response spp +application/sdp sdp +# application/sep+xml +# application/sep-exi +# application/session-info +# application/set-payment +application/set-payment-initiation setpay +# application/set-registration +application/set-registration-initiation setreg +# application/sgml +# application/sgml-open-catalog +application/shf+xml shf +# application/sieve +# application/simple-filter+xml +# application/simple-message-summary +# application/simplesymbolcontainer +# application/slate +# application/smil +application/smil+xml smi smil +# application/smpte336m +# application/soap+fastinfoset +# application/soap+xml +application/sparql-query rq +application/sparql-results+xml srx +# application/spirits-event+xml +# application/sql +application/srgs gram +application/srgs+xml grxml +application/sru+xml sru +application/ssdl+xml ssdl +application/ssml+xml ssml +# application/tamp-apex-update +# application/tamp-apex-update-confirm +# application/tamp-community-update +# application/tamp-community-update-confirm +# application/tamp-error +# application/tamp-sequence-adjust +# application/tamp-sequence-adjust-confirm +# application/tamp-status-query +# application/tamp-status-response +# application/tamp-update +# application/tamp-update-confirm +application/tei+xml tei teicorpus +application/thraud+xml tfi +# application/timestamp-query +# application/timestamp-reply +application/timestamped-data tsd +# application/ttml+xml +# application/tve-trigger +# application/ulpfec +# application/urc-grpsheet+xml +# application/urc-ressheet+xml +# application/urc-targetdesc+xml +# application/urc-uisocketdesc+xml +# application/vcard+json +# application/vcard+xml +# application/vemmi +# application/vividence.scriptfile +# application/vnd.3gpp-prose+xml +# application/vnd.3gpp-prose-pc3ch+xml +# application/vnd.3gpp.access-transfer-events+xml +# application/vnd.3gpp.bsf+xml +# application/vnd.3gpp.mid-call+xml +application/vnd.3gpp.pic-bw-large plb +application/vnd.3gpp.pic-bw-small psb +application/vnd.3gpp.pic-bw-var pvb +# application/vnd.3gpp.sms +# application/vnd.3gpp.sms+xml +# application/vnd.3gpp.srvcc-ext+xml +# application/vnd.3gpp.srvcc-info+xml +# application/vnd.3gpp.state-and-event-info+xml +# application/vnd.3gpp.ussd+xml +# application/vnd.3gpp2.bcmcsinfo+xml +# application/vnd.3gpp2.sms +application/vnd.3gpp2.tcap tcap +# application/vnd.3lightssoftware.imagescal +application/vnd.3m.post-it-notes pwn +application/vnd.accpac.simply.aso aso +application/vnd.accpac.simply.imp imp +application/vnd.acucobol acu +application/vnd.acucorp atc acutc +application/vnd.adobe.air-application-installer-package+zip air +# application/vnd.adobe.flash.movie +application/vnd.adobe.formscentral.fcdt fcdt +application/vnd.adobe.fxp fxp fxpl +# application/vnd.adobe.partial-upload +application/vnd.adobe.xdp+xml xdp +application/vnd.adobe.xfdf xfdf +# application/vnd.aether.imp +# application/vnd.ah-barcode +application/vnd.ahead.space ahead +application/vnd.airzip.filesecure.azf azf +application/vnd.airzip.filesecure.azs azs +application/vnd.amazon.ebook azw +# application/vnd.amazon.mobi8-ebook +application/vnd.americandynamics.acc acc +application/vnd.amiga.ami ami +# application/vnd.amundsen.maze+xml +application/vnd.android.package-archive apk +# application/vnd.anki +application/vnd.anser-web-certificate-issue-initiation cii +application/vnd.anser-web-funds-transfer-initiation fti +application/vnd.antix.game-component atx +# application/vnd.apache.thrift.binary +# application/vnd.apache.thrift.compact +# application/vnd.apache.thrift.json +# application/vnd.api+json +application/vnd.apple.installer+xml mpkg +application/vnd.apple.mpegurl m3u8 +# application/vnd.arastra.swi +application/vnd.aristanetworks.swi swi +# application/vnd.artsquare +application/vnd.astraea-software.iota iota +application/vnd.audiograph aep +# application/vnd.autopackage +# application/vnd.avistar+xml +# application/vnd.balsamiq.bmml+xml +# application/vnd.balsamiq.bmpr +# application/vnd.bekitzur-stech+json +# application/vnd.biopax.rdf+xml +application/vnd.blueice.multipass mpm +# application/vnd.bluetooth.ep.oob +# application/vnd.bluetooth.le.oob +application/vnd.bmi bmi +application/vnd.businessobjects rep +# application/vnd.cab-jscript +# application/vnd.canon-cpdl +# application/vnd.canon-lips +# application/vnd.cendio.thinlinc.clientconf +# application/vnd.century-systems.tcp_stream +application/vnd.chemdraw+xml cdxml +# application/vnd.chess-pgn +application/vnd.chipnuts.karaoke-mmd mmd +application/vnd.cinderella cdy +# application/vnd.cirpack.isdn-ext +# application/vnd.citationstyles.style+xml +application/vnd.claymore cla +application/vnd.cloanto.rp9 rp9 +application/vnd.clonk.c4group c4g c4d c4f c4p c4u +application/vnd.cluetrust.cartomobile-config c11amc +application/vnd.cluetrust.cartomobile-config-pkg c11amz +# application/vnd.coffeescript +# application/vnd.collection+json +# application/vnd.collection.doc+json +# application/vnd.collection.next+json +# application/vnd.comicbook+zip +# application/vnd.commerce-battelle +application/vnd.commonspace csp +application/vnd.contact.cmsg cdbcmsg +# application/vnd.coreos.ignition+json +application/vnd.cosmocaller cmc +application/vnd.crick.clicker clkx +application/vnd.crick.clicker.keyboard clkk +application/vnd.crick.clicker.palette clkp +application/vnd.crick.clicker.template clkt +application/vnd.crick.clicker.wordbank clkw +application/vnd.criticaltools.wbs+xml wbs +application/vnd.ctc-posml pml +# application/vnd.ctct.ws+xml +# application/vnd.cups-pdf +# application/vnd.cups-postscript +application/vnd.cups-ppd ppd +# application/vnd.cups-raster +# application/vnd.cups-raw +# application/vnd.curl +application/vnd.curl.car car +application/vnd.curl.pcurl pcurl +# application/vnd.cyan.dean.root+xml +# application/vnd.cybank +application/vnd.dart dart +application/vnd.data-vision.rdz rdz +# application/vnd.debian.binary-package +application/vnd.dece.data uvf uvvf uvd uvvd +application/vnd.dece.ttml+xml uvt uvvt +application/vnd.dece.unspecified uvx uvvx +application/vnd.dece.zip uvz uvvz +application/vnd.denovo.fcselayout-link fe_launch +# application/vnd.desmume.movie +# application/vnd.dir-bi.plate-dl-nosuffix +# application/vnd.dm.delegation+xml +application/vnd.dna dna +# application/vnd.document+json +application/vnd.dolby.mlp mlp +# application/vnd.dolby.mobile.1 +# application/vnd.dolby.mobile.2 +# application/vnd.doremir.scorecloud-binary-document +application/vnd.dpgraph dpg +application/vnd.dreamfactory dfac +# application/vnd.drive+json +application/vnd.ds-keypoint kpxx +# application/vnd.dtg.local +# application/vnd.dtg.local.flash +# application/vnd.dtg.local.html +application/vnd.dvb.ait ait +# application/vnd.dvb.dvbj +# application/vnd.dvb.esgcontainer +# application/vnd.dvb.ipdcdftnotifaccess +# application/vnd.dvb.ipdcesgaccess +# application/vnd.dvb.ipdcesgaccess2 +# application/vnd.dvb.ipdcesgpdd +# application/vnd.dvb.ipdcroaming +# application/vnd.dvb.iptv.alfec-base +# application/vnd.dvb.iptv.alfec-enhancement +# application/vnd.dvb.notif-aggregate-root+xml +# application/vnd.dvb.notif-container+xml +# application/vnd.dvb.notif-generic+xml +# application/vnd.dvb.notif-ia-msglist+xml +# application/vnd.dvb.notif-ia-registration-request+xml +# application/vnd.dvb.notif-ia-registration-response+xml +# application/vnd.dvb.notif-init+xml +# application/vnd.dvb.pfr +application/vnd.dvb.service svc +# application/vnd.dxr +application/vnd.dynageo geo +# application/vnd.dzr +# application/vnd.easykaraoke.cdgdownload +# application/vnd.ecdis-update +application/vnd.ecowin.chart mag +# application/vnd.ecowin.filerequest +# application/vnd.ecowin.fileupdate +# application/vnd.ecowin.series +# application/vnd.ecowin.seriesrequest +# application/vnd.ecowin.seriesupdate +# application/vnd.emclient.accessrequest+xml +application/vnd.enliven nml +# application/vnd.enphase.envoy +# application/vnd.eprints.data+xml +application/vnd.epson.esf esf +application/vnd.epson.msf msf +application/vnd.epson.quickanime qam +application/vnd.epson.salt slt +application/vnd.epson.ssf ssf +# application/vnd.ericsson.quickcall +application/vnd.eszigno3+xml es3 et3 +# application/vnd.etsi.aoc+xml +# application/vnd.etsi.asic-e+zip +# application/vnd.etsi.asic-s+zip +# application/vnd.etsi.cug+xml +# application/vnd.etsi.iptvcommand+xml +# application/vnd.etsi.iptvdiscovery+xml +# application/vnd.etsi.iptvprofile+xml +# application/vnd.etsi.iptvsad-bc+xml +# application/vnd.etsi.iptvsad-cod+xml +# application/vnd.etsi.iptvsad-npvr+xml +# application/vnd.etsi.iptvservice+xml +# application/vnd.etsi.iptvsync+xml +# application/vnd.etsi.iptvueprofile+xml +# application/vnd.etsi.mcid+xml +# application/vnd.etsi.mheg5 +# application/vnd.etsi.overload-control-policy-dataset+xml +# application/vnd.etsi.pstn+xml +# application/vnd.etsi.sci+xml +# application/vnd.etsi.simservs+xml +# application/vnd.etsi.timestamp-token +# application/vnd.etsi.tsl+xml +# application/vnd.etsi.tsl.der +# application/vnd.eudora.data +application/vnd.ezpix-album ez2 +application/vnd.ezpix-package ez3 +# application/vnd.f-secure.mobile +# application/vnd.fastcopy-disk-image +application/vnd.fdf fdf +application/vnd.fdsn.mseed mseed +application/vnd.fdsn.seed seed dataless +# application/vnd.ffsns +# application/vnd.filmit.zfc +# application/vnd.fints +# application/vnd.firemonkeys.cloudcell +application/vnd.flographit gph +application/vnd.fluxtime.clip ftc +# application/vnd.font-fontforge-sfd +application/vnd.framemaker fm frame maker book +application/vnd.frogans.fnc fnc +application/vnd.frogans.ltf ltf +application/vnd.fsc.weblaunch fsc +application/vnd.fujitsu.oasys oas +application/vnd.fujitsu.oasys2 oa2 +application/vnd.fujitsu.oasys3 oa3 +application/vnd.fujitsu.oasysgp fg5 +application/vnd.fujitsu.oasysprs bh2 +# application/vnd.fujixerox.art-ex +# application/vnd.fujixerox.art4 +application/vnd.fujixerox.ddd ddd +application/vnd.fujixerox.docuworks xdw +application/vnd.fujixerox.docuworks.binder xbd +# application/vnd.fujixerox.docuworks.container +# application/vnd.fujixerox.hbpl +# application/vnd.fut-misnet +application/vnd.fuzzysheet fzs +application/vnd.genomatix.tuxedo txd +# application/vnd.geo+json +# application/vnd.geocube+xml +application/vnd.geogebra.file ggb +application/vnd.geogebra.tool ggt +application/vnd.geometry-explorer gex gre +application/vnd.geonext gxt +application/vnd.geoplan g2w +application/vnd.geospace g3w +# application/vnd.gerber +# application/vnd.globalplatform.card-content-mgt +# application/vnd.globalplatform.card-content-mgt-response +application/vnd.gmx gmx +application/vnd.google-earth.kml+xml kml +application/vnd.google-earth.kmz kmz +# application/vnd.gov.sk.e-form+xml +# application/vnd.gov.sk.e-form+zip +# application/vnd.gov.sk.xmldatacontainer+xml +application/vnd.grafeq gqf gqs +# application/vnd.gridmp +application/vnd.groove-account gac +application/vnd.groove-help ghf +application/vnd.groove-identity-message gim +application/vnd.groove-injector grv +application/vnd.groove-tool-message gtm +application/vnd.groove-tool-template tpl +application/vnd.groove-vcard vcg +# application/vnd.hal+json +application/vnd.hal+xml hal +application/vnd.handheld-entertainment+xml zmm +application/vnd.hbci hbci +# application/vnd.hcl-bireports +# application/vnd.hdt +# application/vnd.heroku+json +application/vnd.hhe.lesson-player les +application/vnd.hp-hpgl hpgl +application/vnd.hp-hpid hpid +application/vnd.hp-hps hps +application/vnd.hp-jlyt jlt +application/vnd.hp-pcl pcl +application/vnd.hp-pclxl pclxl +# application/vnd.httphone +application/vnd.hydrostatix.sof-data sfd-hdstx +# application/vnd.hyperdrive+json +# application/vnd.hzn-3d-crossword +# application/vnd.ibm.afplinedata +# application/vnd.ibm.electronic-media +application/vnd.ibm.minipay mpy +application/vnd.ibm.modcap afp listafp list3820 +application/vnd.ibm.rights-management irm +application/vnd.ibm.secure-container sc +application/vnd.iccprofile icc icm +# application/vnd.ieee.1905 +application/vnd.igloader igl +application/vnd.immervision-ivp ivp +application/vnd.immervision-ivu ivu +# application/vnd.ims.imsccv1p1 +# application/vnd.ims.imsccv1p2 +# application/vnd.ims.imsccv1p3 +# application/vnd.ims.lis.v2.result+json +# application/vnd.ims.lti.v2.toolconsumerprofile+json +# application/vnd.ims.lti.v2.toolproxy+json +# application/vnd.ims.lti.v2.toolproxy.id+json +# application/vnd.ims.lti.v2.toolsettings+json +# application/vnd.ims.lti.v2.toolsettings.simple+json +# application/vnd.informedcontrol.rms+xml +# application/vnd.informix-visionary +# application/vnd.infotech.project +# application/vnd.infotech.project+xml +# application/vnd.innopath.wamp.notification +application/vnd.insors.igm igm +application/vnd.intercon.formnet xpw xpx +application/vnd.intergeo i2g +# application/vnd.intertrust.digibox +# application/vnd.intertrust.nncp +application/vnd.intu.qbo qbo +application/vnd.intu.qfx qfx +# application/vnd.iptc.g2.catalogitem+xml +# application/vnd.iptc.g2.conceptitem+xml +# application/vnd.iptc.g2.knowledgeitem+xml +# application/vnd.iptc.g2.newsitem+xml +# application/vnd.iptc.g2.newsmessage+xml +# application/vnd.iptc.g2.packageitem+xml +# application/vnd.iptc.g2.planningitem+xml +application/vnd.ipunplugged.rcprofile rcprofile +application/vnd.irepository.package+xml irp +application/vnd.is-xpr xpr +application/vnd.isac.fcs fcs +application/vnd.jam jam +# application/vnd.japannet-directory-service +# application/vnd.japannet-jpnstore-wakeup +# application/vnd.japannet-payment-wakeup +# application/vnd.japannet-registration +# application/vnd.japannet-registration-wakeup +# application/vnd.japannet-setstore-wakeup +# application/vnd.japannet-verification +# application/vnd.japannet-verification-wakeup +application/vnd.jcp.javame.midlet-rms rms +application/vnd.jisp jisp +application/vnd.joost.joda-archive joda +# application/vnd.jsk.isdn-ngn +application/vnd.kahootz ktz ktr +application/vnd.kde.karbon karbon +application/vnd.kde.kchart chrt +application/vnd.kde.kformula kfo +application/vnd.kde.kivio flw +application/vnd.kde.kontour kon +application/vnd.kde.kpresenter kpr kpt +application/vnd.kde.kspread ksp +application/vnd.kde.kword kwd kwt +application/vnd.kenameaapp htke +application/vnd.kidspiration kia +application/vnd.kinar kne knp +application/vnd.koan skp skd skt skm +application/vnd.kodak-descriptor sse +application/vnd.las.las+xml lasxml +# application/vnd.liberty-request+xml +application/vnd.llamagraphics.life-balance.desktop lbd +application/vnd.llamagraphics.life-balance.exchange+xml lbe +application/vnd.lotus-1-2-3 123 +application/vnd.lotus-approach apr +application/vnd.lotus-freelance pre +application/vnd.lotus-notes nsf +application/vnd.lotus-organizer org +application/vnd.lotus-screencam scm +application/vnd.lotus-wordpro lwp +application/vnd.macports.portpkg portpkg +# application/vnd.mapbox-vector-tile +# application/vnd.marlin.drm.actiontoken+xml +# application/vnd.marlin.drm.conftoken+xml +# application/vnd.marlin.drm.license+xml +# application/vnd.marlin.drm.mdcf +# application/vnd.mason+json +# application/vnd.maxmind.maxmind-db +application/vnd.mcd mcd +application/vnd.medcalcdata mc1 +application/vnd.mediastation.cdkey cdkey +# application/vnd.meridian-slingshot +application/vnd.mfer mwf +application/vnd.mfmp mfm +# application/vnd.micro+json +application/vnd.micrografx.flo flo +application/vnd.micrografx.igx igx +# application/vnd.microsoft.portable-executable +# application/vnd.miele+json +application/vnd.mif mif +# application/vnd.minisoft-hp3000-save +# application/vnd.mitsubishi.misty-guard.trustweb +application/vnd.mobius.daf daf +application/vnd.mobius.dis dis +application/vnd.mobius.mbk mbk +application/vnd.mobius.mqy mqy +application/vnd.mobius.msl msl +application/vnd.mobius.plc plc +application/vnd.mobius.txf txf +application/vnd.mophun.application mpn +application/vnd.mophun.certificate mpc +# application/vnd.motorola.flexsuite +# application/vnd.motorola.flexsuite.adsi +# application/vnd.motorola.flexsuite.fis +# application/vnd.motorola.flexsuite.gotap +# application/vnd.motorola.flexsuite.kmr +# application/vnd.motorola.flexsuite.ttc +# application/vnd.motorola.flexsuite.wem +# application/vnd.motorola.iprm +application/vnd.mozilla.xul+xml xul +# application/vnd.ms-3mfdocument +application/vnd.ms-artgalry cil +# application/vnd.ms-asf +application/vnd.ms-cab-compressed cab +# application/vnd.ms-color.iccprofile +application/vnd.ms-excel xls xlm xla xlc xlt xlw +application/vnd.ms-excel.addin.macroenabled.12 xlam +application/vnd.ms-excel.sheet.binary.macroenabled.12 xlsb +application/vnd.ms-excel.sheet.macroenabled.12 xlsm +application/vnd.ms-excel.template.macroenabled.12 xltm +application/vnd.ms-fontobject eot +application/vnd.ms-htmlhelp chm +application/vnd.ms-ims ims +application/vnd.ms-lrm lrm +# application/vnd.ms-office.activex+xml +application/vnd.ms-officetheme thmx +# application/vnd.ms-opentype +# application/vnd.ms-package.obfuscated-opentype +application/vnd.ms-pki.seccat cat +application/vnd.ms-pki.stl stl +# application/vnd.ms-playready.initiator+xml +application/vnd.ms-powerpoint ppt pps pot +application/vnd.ms-powerpoint.addin.macroenabled.12 ppam +application/vnd.ms-powerpoint.presentation.macroenabled.12 pptm +application/vnd.ms-powerpoint.slide.macroenabled.12 sldm +application/vnd.ms-powerpoint.slideshow.macroenabled.12 ppsm +application/vnd.ms-powerpoint.template.macroenabled.12 potm +# application/vnd.ms-printdevicecapabilities+xml +# application/vnd.ms-printing.printticket+xml +# application/vnd.ms-printschematicket+xml +application/vnd.ms-project mpp mpt +# application/vnd.ms-tnef +# application/vnd.ms-windows.devicepairing +# application/vnd.ms-windows.nwprinting.oob +# application/vnd.ms-windows.printerpairing +# application/vnd.ms-windows.wsd.oob +# application/vnd.ms-wmdrm.lic-chlg-req +# application/vnd.ms-wmdrm.lic-resp +# application/vnd.ms-wmdrm.meter-chlg-req +# application/vnd.ms-wmdrm.meter-resp +application/vnd.ms-word.document.macroenabled.12 docm +application/vnd.ms-word.template.macroenabled.12 dotm +application/vnd.ms-works wps wks wcm wdb +application/vnd.ms-wpl wpl +application/vnd.ms-xpsdocument xps +# application/vnd.msa-disk-image +application/vnd.mseq mseq +# application/vnd.msign +# application/vnd.multiad.creator +# application/vnd.multiad.creator.cif +# application/vnd.music-niff +application/vnd.musician mus +application/vnd.muvee.style msty +application/vnd.mynfc taglet +# application/vnd.ncd.control +# application/vnd.ncd.reference +# application/vnd.nervana +# application/vnd.netfpx +application/vnd.neurolanguage.nlu nlu +# application/vnd.nintendo.nitro.rom +# application/vnd.nintendo.snes.rom +application/vnd.nitf ntf nitf +application/vnd.noblenet-directory nnd +application/vnd.noblenet-sealer nns +application/vnd.noblenet-web nnw +# application/vnd.nokia.catalogs +# application/vnd.nokia.conml+wbxml +# application/vnd.nokia.conml+xml +# application/vnd.nokia.iptv.config+xml +# application/vnd.nokia.isds-radio-presets +# application/vnd.nokia.landmark+wbxml +# application/vnd.nokia.landmark+xml +# application/vnd.nokia.landmarkcollection+xml +# application/vnd.nokia.n-gage.ac+xml +application/vnd.nokia.n-gage.data ngdat +application/vnd.nokia.n-gage.symbian.install n-gage +# application/vnd.nokia.ncd +# application/vnd.nokia.pcd+wbxml +# application/vnd.nokia.pcd+xml +application/vnd.nokia.radio-preset rpst +application/vnd.nokia.radio-presets rpss +application/vnd.novadigm.edm edm +application/vnd.novadigm.edx edx +application/vnd.novadigm.ext ext +# application/vnd.ntt-local.content-share +# application/vnd.ntt-local.file-transfer +# application/vnd.ntt-local.ogw_remote-access +# application/vnd.ntt-local.sip-ta_remote +# application/vnd.ntt-local.sip-ta_tcp_stream +application/vnd.oasis.opendocument.chart odc +application/vnd.oasis.opendocument.chart-template otc +application/vnd.oasis.opendocument.database odb +application/vnd.oasis.opendocument.formula odf +application/vnd.oasis.opendocument.formula-template odft +application/vnd.oasis.opendocument.graphics odg +application/vnd.oasis.opendocument.graphics-template otg +application/vnd.oasis.opendocument.image odi +application/vnd.oasis.opendocument.image-template oti +application/vnd.oasis.opendocument.presentation odp +application/vnd.oasis.opendocument.presentation-template otp +application/vnd.oasis.opendocument.spreadsheet ods +application/vnd.oasis.opendocument.spreadsheet-template ots +application/vnd.oasis.opendocument.text odt +application/vnd.oasis.opendocument.text-master odm +application/vnd.oasis.opendocument.text-template ott +application/vnd.oasis.opendocument.text-web oth +# application/vnd.obn +# application/vnd.oftn.l10n+json +# application/vnd.oipf.contentaccessdownload+xml +# application/vnd.oipf.contentaccessstreaming+xml +# application/vnd.oipf.cspg-hexbinary +# application/vnd.oipf.dae.svg+xml +# application/vnd.oipf.dae.xhtml+xml +# application/vnd.oipf.mippvcontrolmessage+xml +# application/vnd.oipf.pae.gem +# application/vnd.oipf.spdiscovery+xml +# application/vnd.oipf.spdlist+xml +# application/vnd.oipf.ueprofile+xml +# application/vnd.oipf.userprofile+xml +application/vnd.olpc-sugar xo +# application/vnd.oma-scws-config +# application/vnd.oma-scws-http-request +# application/vnd.oma-scws-http-response +# application/vnd.oma.bcast.associated-procedure-parameter+xml +# application/vnd.oma.bcast.drm-trigger+xml +# application/vnd.oma.bcast.imd+xml +# application/vnd.oma.bcast.ltkm +# application/vnd.oma.bcast.notification+xml +# application/vnd.oma.bcast.provisioningtrigger +# application/vnd.oma.bcast.sgboot +# application/vnd.oma.bcast.sgdd+xml +# application/vnd.oma.bcast.sgdu +# application/vnd.oma.bcast.simple-symbol-container +# application/vnd.oma.bcast.smartcard-trigger+xml +# application/vnd.oma.bcast.sprov+xml +# application/vnd.oma.bcast.stkm +# application/vnd.oma.cab-address-book+xml +# application/vnd.oma.cab-feature-handler+xml +# application/vnd.oma.cab-pcc+xml +# application/vnd.oma.cab-subs-invite+xml +# application/vnd.oma.cab-user-prefs+xml +# application/vnd.oma.dcd +# application/vnd.oma.dcdc +application/vnd.oma.dd2+xml dd2 +# application/vnd.oma.drm.risd+xml +# application/vnd.oma.group-usage-list+xml +# application/vnd.oma.lwm2m+json +# application/vnd.oma.lwm2m+tlv +# application/vnd.oma.pal+xml +# application/vnd.oma.poc.detailed-progress-report+xml +# application/vnd.oma.poc.final-report+xml +# application/vnd.oma.poc.groups+xml +# application/vnd.oma.poc.invocation-descriptor+xml +# application/vnd.oma.poc.optimized-progress-report+xml +# application/vnd.oma.push +# application/vnd.oma.scidm.messages+xml +# application/vnd.oma.xcap-directory+xml +# application/vnd.omads-email+xml +# application/vnd.omads-file+xml +# application/vnd.omads-folder+xml +# application/vnd.omaloc-supl-init +# application/vnd.onepager +# application/vnd.openblox.game+xml +# application/vnd.openblox.game-binary +# application/vnd.openeye.oeb +application/vnd.openofficeorg.extension oxt +# application/vnd.openxmlformats-officedocument.custom-properties+xml +# application/vnd.openxmlformats-officedocument.customxmlproperties+xml +# application/vnd.openxmlformats-officedocument.drawing+xml +# application/vnd.openxmlformats-officedocument.drawingml.chart+xml +# application/vnd.openxmlformats-officedocument.drawingml.chartshapes+xml +# application/vnd.openxmlformats-officedocument.drawingml.diagramcolors+xml +# application/vnd.openxmlformats-officedocument.drawingml.diagramdata+xml +# application/vnd.openxmlformats-officedocument.drawingml.diagramlayout+xml +# application/vnd.openxmlformats-officedocument.drawingml.diagramstyle+xml +# application/vnd.openxmlformats-officedocument.extended-properties+xml +# application/vnd.openxmlformats-officedocument.presentationml.commentauthors+xml +# application/vnd.openxmlformats-officedocument.presentationml.comments+xml +# application/vnd.openxmlformats-officedocument.presentationml.handoutmaster+xml +# application/vnd.openxmlformats-officedocument.presentationml.notesmaster+xml +# application/vnd.openxmlformats-officedocument.presentationml.notesslide+xml +application/vnd.openxmlformats-officedocument.presentationml.presentation pptx +# application/vnd.openxmlformats-officedocument.presentationml.presentation.main+xml +# application/vnd.openxmlformats-officedocument.presentationml.presprops+xml +application/vnd.openxmlformats-officedocument.presentationml.slide sldx +# application/vnd.openxmlformats-officedocument.presentationml.slide+xml +# application/vnd.openxmlformats-officedocument.presentationml.slidelayout+xml +# application/vnd.openxmlformats-officedocument.presentationml.slidemaster+xml +application/vnd.openxmlformats-officedocument.presentationml.slideshow ppsx +# application/vnd.openxmlformats-officedocument.presentationml.slideshow.main+xml +# application/vnd.openxmlformats-officedocument.presentationml.slideupdateinfo+xml +# application/vnd.openxmlformats-officedocument.presentationml.tablestyles+xml +# application/vnd.openxmlformats-officedocument.presentationml.tags+xml +application/vnd.openxmlformats-officedocument.presentationml.template potx +# application/vnd.openxmlformats-officedocument.presentationml.template.main+xml +# application/vnd.openxmlformats-officedocument.presentationml.viewprops+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.calcchain+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.chartsheet+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.comments+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.connections+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.dialogsheet+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.externallink+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.pivotcachedefinition+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.pivotcacherecords+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.pivottable+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.querytable+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.revisionheaders+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.revisionlog+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.sharedstrings+xml +application/vnd.openxmlformats-officedocument.spreadsheetml.sheet xlsx +# application/vnd.openxmlformats-officedocument.spreadsheetml.sheet.main+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.sheetmetadata+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.styles+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.table+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.tablesinglecells+xml +application/vnd.openxmlformats-officedocument.spreadsheetml.template xltx +# application/vnd.openxmlformats-officedocument.spreadsheetml.template.main+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.usernames+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.volatiledependencies+xml +# application/vnd.openxmlformats-officedocument.spreadsheetml.worksheet+xml +# application/vnd.openxmlformats-officedocument.theme+xml +# application/vnd.openxmlformats-officedocument.themeoverride+xml +# application/vnd.openxmlformats-officedocument.vmldrawing +# application/vnd.openxmlformats-officedocument.wordprocessingml.comments+xml +application/vnd.openxmlformats-officedocument.wordprocessingml.document docx +# application/vnd.openxmlformats-officedocument.wordprocessingml.document.glossary+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.endnotes+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.fonttable+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.footer+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.footnotes+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.numbering+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.settings+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.styles+xml +application/vnd.openxmlformats-officedocument.wordprocessingml.template dotx +# application/vnd.openxmlformats-officedocument.wordprocessingml.template.main+xml +# application/vnd.openxmlformats-officedocument.wordprocessingml.websettings+xml +# application/vnd.openxmlformats-package.core-properties+xml +# application/vnd.openxmlformats-package.digital-signature-xmlsignature+xml +# application/vnd.openxmlformats-package.relationships+xml +# application/vnd.oracle.resource+json +# application/vnd.orange.indata +# application/vnd.osa.netdeploy +application/vnd.osgeo.mapguide.package mgp +# application/vnd.osgi.bundle +application/vnd.osgi.dp dp +application/vnd.osgi.subsystem esa +# application/vnd.otps.ct-kip+xml +# application/vnd.oxli.countgraph +# application/vnd.pagerduty+json +application/vnd.palm pdb pqa oprc +# application/vnd.panoply +# application/vnd.paos.xml +application/vnd.pawaafile paw +# application/vnd.pcos +application/vnd.pg.format str +application/vnd.pg.osasli ei6 +# application/vnd.piaccess.application-licence +application/vnd.picsel efif +application/vnd.pmi.widget wg +# application/vnd.poc.group-advertisement+xml +application/vnd.pocketlearn plf +application/vnd.powerbuilder6 pbd +# application/vnd.powerbuilder6-s +# application/vnd.powerbuilder7 +# application/vnd.powerbuilder7-s +# application/vnd.powerbuilder75 +# application/vnd.powerbuilder75-s +# application/vnd.preminet +application/vnd.previewsystems.box box +application/vnd.proteus.magazine mgz +application/vnd.publishare-delta-tree qps +application/vnd.pvi.ptid1 ptid +# application/vnd.pwg-multiplexed +# application/vnd.pwg-xhtml-print+xml +# application/vnd.qualcomm.brew-app-res +# application/vnd.quarantainenet +application/vnd.quark.quarkxpress qxd qxt qwd qwt qxl qxb +# application/vnd.quobject-quoxdocument +# application/vnd.radisys.moml+xml +# application/vnd.radisys.msml+xml +# application/vnd.radisys.msml-audit+xml +# application/vnd.radisys.msml-audit-conf+xml +# application/vnd.radisys.msml-audit-conn+xml +# application/vnd.radisys.msml-audit-dialog+xml +# application/vnd.radisys.msml-audit-stream+xml +# application/vnd.radisys.msml-conf+xml +# application/vnd.radisys.msml-dialog+xml +# application/vnd.radisys.msml-dialog-base+xml +# application/vnd.radisys.msml-dialog-fax-detect+xml +# application/vnd.radisys.msml-dialog-fax-sendrecv+xml +# application/vnd.radisys.msml-dialog-group+xml +# application/vnd.radisys.msml-dialog-speech+xml +# application/vnd.radisys.msml-dialog-transform+xml +# application/vnd.rainstor.data +# application/vnd.rapid +# application/vnd.rar +application/vnd.realvnc.bed bed +application/vnd.recordare.musicxml mxl +application/vnd.recordare.musicxml+xml musicxml +# application/vnd.renlearn.rlprint +application/vnd.rig.cryptonote cryptonote +application/vnd.rim.cod cod +application/vnd.rn-realmedia rm +application/vnd.rn-realmedia-vbr rmvb +application/vnd.route66.link66+xml link66 +# application/vnd.rs-274x +# application/vnd.ruckus.download +# application/vnd.s3sms +application/vnd.sailingtracker.track st +# application/vnd.sbm.cid +# application/vnd.sbm.mid2 +# application/vnd.scribus +# application/vnd.sealed.3df +# application/vnd.sealed.csf +# application/vnd.sealed.doc +# application/vnd.sealed.eml +# application/vnd.sealed.mht +# application/vnd.sealed.net +# application/vnd.sealed.ppt +# application/vnd.sealed.tiff +# application/vnd.sealed.xls +# application/vnd.sealedmedia.softseal.html +# application/vnd.sealedmedia.softseal.pdf +application/vnd.seemail see +application/vnd.sema sema +application/vnd.semd semd +application/vnd.semf semf +application/vnd.shana.informed.formdata ifm +application/vnd.shana.informed.formtemplate itp +application/vnd.shana.informed.interchange iif +application/vnd.shana.informed.package ipk +application/vnd.simtech-mindmapper twd twds +# application/vnd.siren+json +application/vnd.smaf mmf +# application/vnd.smart.notebook +application/vnd.smart.teacher teacher +# application/vnd.software602.filler.form+xml +# application/vnd.software602.filler.form-xml-zip +application/vnd.solent.sdkm+xml sdkm sdkd +application/vnd.spotfire.dxp dxp +application/vnd.spotfire.sfs sfs +# application/vnd.sss-cod +# application/vnd.sss-dtf +# application/vnd.sss-ntf +application/vnd.stardivision.calc sdc +application/vnd.stardivision.draw sda +application/vnd.stardivision.impress sdd +application/vnd.stardivision.math smf +application/vnd.stardivision.writer sdw vor +application/vnd.stardivision.writer-global sgl +application/vnd.stepmania.package smzip +application/vnd.stepmania.stepchart sm +# application/vnd.street-stream +# application/vnd.sun.wadl+xml +application/vnd.sun.xml.calc sxc +application/vnd.sun.xml.calc.template stc +application/vnd.sun.xml.draw sxd +application/vnd.sun.xml.draw.template std +application/vnd.sun.xml.impress sxi +application/vnd.sun.xml.impress.template sti +application/vnd.sun.xml.math sxm +application/vnd.sun.xml.writer sxw +application/vnd.sun.xml.writer.global sxg +application/vnd.sun.xml.writer.template stw +application/vnd.sus-calendar sus susp +application/vnd.svd svd +# application/vnd.swiftview-ics +application/vnd.symbian.install sis sisx +application/vnd.syncml+xml xsm +application/vnd.syncml.dm+wbxml bdm +application/vnd.syncml.dm+xml xdm +# application/vnd.syncml.dm.notification +# application/vnd.syncml.dmddf+wbxml +# application/vnd.syncml.dmddf+xml +# application/vnd.syncml.dmtnds+wbxml +# application/vnd.syncml.dmtnds+xml +# application/vnd.syncml.ds.notification +application/vnd.tao.intent-module-archive tao +application/vnd.tcpdump.pcap pcap cap dmp +# application/vnd.tmd.mediaflex.api+xml +# application/vnd.tml +application/vnd.tmobile-livetv tmo +application/vnd.trid.tpt tpt +application/vnd.triscape.mxs mxs +application/vnd.trueapp tra +# application/vnd.truedoc +# application/vnd.ubisoft.webplayer +application/vnd.ufdl ufd ufdl +application/vnd.uiq.theme utz +application/vnd.umajin umj +application/vnd.unity unityweb +application/vnd.uoml+xml uoml +# application/vnd.uplanet.alert +# application/vnd.uplanet.alert-wbxml +# application/vnd.uplanet.bearer-choice +# application/vnd.uplanet.bearer-choice-wbxml +# application/vnd.uplanet.cacheop +# application/vnd.uplanet.cacheop-wbxml +# application/vnd.uplanet.channel +# application/vnd.uplanet.channel-wbxml +# application/vnd.uplanet.list +# application/vnd.uplanet.list-wbxml +# application/vnd.uplanet.listcmd +# application/vnd.uplanet.listcmd-wbxml +# application/vnd.uplanet.signal +# application/vnd.uri-map +# application/vnd.valve.source.material +application/vnd.vcx vcx +# application/vnd.vd-study +# application/vnd.vectorworks +# application/vnd.vel+json +# application/vnd.verimatrix.vcas +# application/vnd.vidsoft.vidconference +application/vnd.visio vsd vst vss vsw +application/vnd.visionary vis +# application/vnd.vividence.scriptfile +application/vnd.vsf vsf +# application/vnd.wap.sic +# application/vnd.wap.slc +application/vnd.wap.wbxml wbxml +application/vnd.wap.wmlc wmlc +application/vnd.wap.wmlscriptc wmlsc +application/vnd.webturbo wtb +# application/vnd.wfa.p2p +# application/vnd.wfa.wsc +# application/vnd.windows.devicepairing +# application/vnd.wmc +# application/vnd.wmf.bootstrap +# application/vnd.wolfram.mathematica +# application/vnd.wolfram.mathematica.package +application/vnd.wolfram.player nbp +application/vnd.wordperfect wpd +application/vnd.wqd wqd +# application/vnd.wrq-hp3000-labelled +application/vnd.wt.stf stf +# application/vnd.wv.csp+wbxml +# application/vnd.wv.csp+xml +# application/vnd.wv.ssp+xml +# application/vnd.xacml+json +application/vnd.xara xar +application/vnd.xfdl xfdl +# application/vnd.xfdl.webform +# application/vnd.xmi+xml +# application/vnd.xmpie.cpkg +# application/vnd.xmpie.dpkg +# application/vnd.xmpie.plan +# application/vnd.xmpie.ppkg +# application/vnd.xmpie.xlim +application/vnd.yamaha.hv-dic hvd +application/vnd.yamaha.hv-script hvs +application/vnd.yamaha.hv-voice hvp +application/vnd.yamaha.openscoreformat osf +application/vnd.yamaha.openscoreformat.osfpvg+xml osfpvg +# application/vnd.yamaha.remote-setup +application/vnd.yamaha.smaf-audio saf +application/vnd.yamaha.smaf-phrase spf +# application/vnd.yamaha.through-ngn +# application/vnd.yamaha.tunnel-udpencap +# application/vnd.yaoweme +application/vnd.yellowriver-custom-menu cmp +application/vnd.zul zir zirz +application/vnd.zzazz.deck+xml zaz +application/voicexml+xml vxml +# application/vq-rtcpxr +# application/watcherinfo+xml +# application/whoispp-query +# application/whoispp-response +application/widget wgt +application/winhlp hlp +# application/wita +# application/wordperfect5.1 +application/wsdl+xml wsdl +application/wspolicy+xml wspolicy +application/x-7z-compressed 7z +application/x-abiword abw +application/x-ace-compressed ace +# application/x-amf +application/x-apple-diskimage dmg +application/x-authorware-bin aab x32 u32 vox +application/x-authorware-map aam +application/x-authorware-seg aas +application/x-bcpio bcpio +application/x-bittorrent torrent +application/x-blorb blb blorb +application/x-bzip bz +application/x-bzip2 bz2 boz +application/x-cbr cbr cba cbt cbz cb7 +application/x-cdlink vcd +application/x-cfs-compressed cfs +application/x-chat chat +application/x-chess-pgn pgn +# application/x-compress +application/x-conference nsc +application/x-cpio cpio +application/x-csh csh +application/x-debian-package deb udeb +application/x-dgc-compressed dgc +application/x-director dir dcr dxr cst cct cxt w3d fgd swa +application/x-doom wad +application/x-dtbncx+xml ncx +application/x-dtbook+xml dtb +application/x-dtbresource+xml res +application/x-dvi dvi +application/x-envoy evy +application/x-eva eva +application/x-font-bdf bdf +# application/x-font-dos +# application/x-font-framemaker +application/x-font-ghostscript gsf +# application/x-font-libgrx +application/x-font-linux-psf psf +application/x-font-otf otf +application/x-font-pcf pcf +application/x-font-snf snf +# application/x-font-speedo +# application/x-font-sunos-news +application/x-font-ttf ttf ttc +application/x-font-type1 pfa pfb pfm afm +# application/x-font-vfont +application/x-freearc arc +application/x-futuresplash spl +application/x-gca-compressed gca +application/x-glulx ulx +application/x-gnumeric gnumeric +application/x-gramps-xml gramps +application/x-gtar gtar +# application/x-gzip +application/x-hdf hdf +application/x-install-instructions install +application/x-iso9660-image iso +application/x-java-jnlp-file jnlp +application/x-latex latex +application/x-lzh-compressed lzh lha +application/x-mie mie +application/x-mobipocket-ebook prc mobi +application/x-ms-application application +application/x-ms-shortcut lnk +application/x-ms-wmd wmd +application/x-ms-wmz wmz +application/x-ms-xbap xbap +application/x-msaccess mdb +application/x-msbinder obd +application/x-mscardfile crd +application/x-msclip clp +application/x-msdownload exe dll com bat msi +application/x-msmediaview mvb m13 m14 +application/x-msmetafile wmf wmz emf emz +application/x-msmoney mny +application/x-mspublisher pub +application/x-msschedule scd +application/x-msterminal trm +application/x-mswrite wri +application/x-netcdf nc cdf +application/x-nzb nzb +application/x-pkcs12 p12 pfx +application/x-pkcs7-certificates p7b spc +application/x-pkcs7-certreqresp p7r +application/x-rar-compressed rar +application/x-research-info-systems ris +application/x-sh sh +application/x-shar shar +application/x-shockwave-flash swf +application/x-silverlight-app xap +application/x-sql sql +application/x-stuffit sit +application/x-stuffitx sitx +application/x-subrip srt +application/x-sv4cpio sv4cpio +application/x-sv4crc sv4crc +application/x-t3vm-image t3 +application/x-tads gam +application/x-tar tar +application/x-tcl tcl +application/x-tex tex +application/x-tex-tfm tfm +application/x-texinfo texinfo texi +application/x-tgif obj +application/x-ustar ustar +application/x-wais-source src +# application/x-www-form-urlencoded +application/x-x509-ca-cert der crt +application/x-xfig fig +application/x-xliff+xml xlf +application/x-xpinstall xpi +application/x-xz xz +application/x-zmachine z1 z2 z3 z4 z5 z6 z7 z8 +# application/x400-bp +# application/xacml+xml +application/xaml+xml xaml +# application/xcap-att+xml +# application/xcap-caps+xml +application/xcap-diff+xml xdf +# application/xcap-el+xml +# application/xcap-error+xml +# application/xcap-ns+xml +# application/xcon-conference-info+xml +# application/xcon-conference-info-diff+xml +application/xenc+xml xenc +application/xhtml+xml xhtml xht +# application/xhtml-voice+xml +application/xml xml xsl +application/xml-dtd dtd +# application/xml-external-parsed-entity +# application/xml-patch+xml +# application/xmpp+xml +application/xop+xml xop +application/xproc+xml xpl +application/xslt+xml xslt +application/xspf+xml xspf +application/xv+xml mxml xhvml xvml xvm +application/yang yang +application/yin+xml yin +application/zip zip +# application/zlib +# audio/1d-interleaved-parityfec +# audio/32kadpcm +# audio/3gpp +# audio/3gpp2 +# audio/ac3 +audio/adpcm adp +# audio/amr +# audio/amr-wb +# audio/amr-wb+ +# audio/aptx +# audio/asc +# audio/atrac-advanced-lossless +# audio/atrac-x +# audio/atrac3 +audio/basic au snd +# audio/bv16 +# audio/bv32 +# audio/clearmode +# audio/cn +# audio/dat12 +# audio/dls +# audio/dsr-es201108 +# audio/dsr-es202050 +# audio/dsr-es202211 +# audio/dsr-es202212 +# audio/dv +# audio/dvi4 +# audio/eac3 +# audio/encaprtp +# audio/evrc +# audio/evrc-qcp +# audio/evrc0 +# audio/evrc1 +# audio/evrcb +# audio/evrcb0 +# audio/evrcb1 +# audio/evrcnw +# audio/evrcnw0 +# audio/evrcnw1 +# audio/evrcwb +# audio/evrcwb0 +# audio/evrcwb1 +# audio/evs +# audio/example +# audio/fwdred +# audio/g711-0 +# audio/g719 +# audio/g722 +# audio/g7221 +# audio/g723 +# audio/g726-16 +# audio/g726-24 +# audio/g726-32 +# audio/g726-40 +# audio/g728 +# audio/g729 +# audio/g7291 +# audio/g729d +# audio/g729e +# audio/gsm +# audio/gsm-efr +# audio/gsm-hr-08 +# audio/ilbc +# audio/ip-mr_v2.5 +# audio/isac +# audio/l16 +# audio/l20 +# audio/l24 +# audio/l8 +# audio/lpc +audio/midi mid midi kar rmi +# audio/mobile-xmf +audio/mp4 m4a mp4a +# audio/mp4a-latm +# audio/mpa +# audio/mpa-robust +audio/mpeg mpga mp2 mp2a mp3 m2a m3a +# audio/mpeg4-generic +# audio/musepack +audio/ogg oga ogg spx +# audio/opus +# audio/parityfec +# audio/pcma +# audio/pcma-wb +# audio/pcmu +# audio/pcmu-wb +# audio/prs.sid +# audio/qcelp +# audio/raptorfec +# audio/red +# audio/rtp-enc-aescm128 +# audio/rtp-midi +# audio/rtploopback +# audio/rtx +audio/s3m s3m +audio/silk sil +# audio/smv +# audio/smv-qcp +# audio/smv0 +# audio/sp-midi +# audio/speex +# audio/t140c +# audio/t38 +# audio/telephone-event +# audio/tone +# audio/uemclip +# audio/ulpfec +# audio/vdvi +# audio/vmr-wb +# audio/vnd.3gpp.iufp +# audio/vnd.4sb +# audio/vnd.audiokoz +# audio/vnd.celp +# audio/vnd.cisco.nse +# audio/vnd.cmles.radio-events +# audio/vnd.cns.anp1 +# audio/vnd.cns.inf1 +audio/vnd.dece.audio uva uvva +audio/vnd.digital-winds eol +# audio/vnd.dlna.adts +# audio/vnd.dolby.heaac.1 +# audio/vnd.dolby.heaac.2 +# audio/vnd.dolby.mlp +# audio/vnd.dolby.mps +# audio/vnd.dolby.pl2 +# audio/vnd.dolby.pl2x +# audio/vnd.dolby.pl2z +# audio/vnd.dolby.pulse.1 +audio/vnd.dra dra +audio/vnd.dts dts +audio/vnd.dts.hd dtshd +# audio/vnd.dvb.file +# audio/vnd.everad.plj +# audio/vnd.hns.audio +audio/vnd.lucent.voice lvp +audio/vnd.ms-playready.media.pya pya +# audio/vnd.nokia.mobile-xmf +# audio/vnd.nortel.vbk +audio/vnd.nuera.ecelp4800 ecelp4800 +audio/vnd.nuera.ecelp7470 ecelp7470 +audio/vnd.nuera.ecelp9600 ecelp9600 +# audio/vnd.octel.sbc +# audio/vnd.qcelp +# audio/vnd.rhetorex.32kadpcm +audio/vnd.rip rip +# audio/vnd.sealedmedia.softseal.mpeg +# audio/vnd.vmx.cvsd +# audio/vorbis +# audio/vorbis-config +audio/webm weba +audio/x-aac aac +audio/x-aiff aif aiff aifc +audio/x-caf caf +audio/x-flac flac +audio/x-matroska mka +audio/x-mpegurl m3u +audio/x-ms-wax wax +audio/x-ms-wma wma +audio/x-pn-realaudio ram ra +audio/x-pn-realaudio-plugin rmp +# audio/x-tta +audio/x-wav wav +audio/xm xm +chemical/x-cdx cdx +chemical/x-cif cif +chemical/x-cmdf cmdf +chemical/x-cml cml +chemical/x-csml csml +# chemical/x-pdb +chemical/x-xyz xyz +image/bmp bmp +image/cgm cgm +# image/dicom-rle +# image/emf +# image/example +# image/fits +image/g3fax g3 +image/gif gif +image/ief ief +# image/jls +# image/jp2 +image/jpeg jpeg jpg jpe +# image/jpm +# image/jpx +image/ktx ktx +# image/naplps +image/png png +image/prs.btif btif +# image/prs.pti +# image/pwg-raster +image/sgi sgi +image/svg+xml svg svgz +# image/t38 +image/tiff tiff tif +# image/tiff-fx +image/vnd.adobe.photoshop psd +# image/vnd.airzip.accelerator.azv +# image/vnd.cns.inf2 +image/vnd.dece.graphic uvi uvvi uvg uvvg +image/vnd.djvu djvu djv +image/vnd.dvb.subtitle sub +image/vnd.dwg dwg +image/vnd.dxf dxf +image/vnd.fastbidsheet fbs +image/vnd.fpx fpx +image/vnd.fst fst +image/vnd.fujixerox.edmics-mmr mmr +image/vnd.fujixerox.edmics-rlc rlc +# image/vnd.globalgraphics.pgb +# image/vnd.microsoft.icon +# image/vnd.mix +# image/vnd.mozilla.apng +image/vnd.ms-modi mdi +image/vnd.ms-photo wdp +image/vnd.net-fpx npx +# image/vnd.radiance +# image/vnd.sealed.png +# image/vnd.sealedmedia.softseal.gif +# image/vnd.sealedmedia.softseal.jpg +# image/vnd.svf +# image/vnd.tencent.tap +# image/vnd.valve.source.texture +image/vnd.wap.wbmp wbmp +image/vnd.xiff xif +# image/vnd.zbrush.pcx +image/webp webp +# image/wmf +image/x-3ds 3ds +image/x-cmu-raster ras +image/x-cmx cmx +image/x-freehand fh fhc fh4 fh5 fh7 +image/x-icon ico +image/x-mrsid-image sid +image/x-pcx pcx +image/x-pict pic pct +image/x-portable-anymap pnm +image/x-portable-bitmap pbm +image/x-portable-graymap pgm +image/x-portable-pixmap ppm +image/x-rgb rgb +image/x-tga tga +image/x-xbitmap xbm +image/x-xpixmap xpm +image/x-xwindowdump xwd +# message/cpim +# message/delivery-status +# message/disposition-notification +# message/example +# message/external-body +# message/feedback-report +# message/global +# message/global-delivery-status +# message/global-disposition-notification +# message/global-headers +# message/http +# message/imdn+xml +# message/news +# message/partial +message/rfc822 eml mime +# message/s-http +# message/sip +# message/sipfrag +# message/tracking-status +# message/vnd.si.simp +# message/vnd.wfa.wsc +# model/example +# model/gltf+json +model/iges igs iges +model/mesh msh mesh silo +model/vnd.collada+xml dae +model/vnd.dwf dwf +# model/vnd.flatland.3dml +model/vnd.gdl gdl +# model/vnd.gs-gdl +# model/vnd.gs.gdl +model/vnd.gtw gtw +# model/vnd.moml+xml +model/vnd.mts mts +# model/vnd.opengex +# model/vnd.parasolid.transmit.binary +# model/vnd.parasolid.transmit.text +# model/vnd.rosette.annotated-data-model +# model/vnd.valve.source.compiled-map +model/vnd.vtu vtu +model/vrml wrl vrml +model/x3d+binary x3db x3dbz +# model/x3d+fastinfoset +model/x3d+vrml x3dv x3dvz +model/x3d+xml x3d x3dz +# model/x3d-vrml +# multipart/alternative +# multipart/appledouble +# multipart/byteranges +# multipart/digest +# multipart/encrypted +# multipart/example +# multipart/form-data +# multipart/header-set +# multipart/mixed +# multipart/parallel +# multipart/related +# multipart/report +# multipart/signed +# multipart/voice-message +# multipart/x-mixed-replace +# text/1d-interleaved-parityfec +text/cache-manifest appcache +text/calendar ics ifb +text/css css +text/csv csv +# text/csv-schema +# text/directory +# text/dns +# text/ecmascript +# text/encaprtp +# text/enriched +# text/example +# text/fwdred +# text/grammar-ref-list +text/html html htm +# text/javascript +# text/jcr-cnd +# text/markdown +# text/mizar +text/n3 n3 +# text/parameters +# text/parityfec +text/plain txt text conf def list log in +# text/provenance-notation +# text/prs.fallenstein.rst +text/prs.lines.tag dsc +# text/prs.prop.logic +# text/raptorfec +# text/red +# text/rfc822-headers +text/richtext rtx +# text/rtf +# text/rtp-enc-aescm128 +# text/rtploopback +# text/rtx +text/sgml sgml sgm +# text/t140 +text/tab-separated-values tsv +text/troff t tr roff man me ms +text/turtle ttl +# text/ulpfec +text/uri-list uri uris urls +text/vcard vcard +# text/vnd.a +# text/vnd.abc +text/vnd.curl curl +text/vnd.curl.dcurl dcurl +text/vnd.curl.mcurl mcurl +text/vnd.curl.scurl scurl +# text/vnd.debian.copyright +# text/vnd.dmclientscript +text/vnd.dvb.subtitle sub +# text/vnd.esmertec.theme-descriptor +text/vnd.fly fly +text/vnd.fmi.flexstor flx +text/vnd.graphviz gv +text/vnd.in3d.3dml 3dml +text/vnd.in3d.spot spot +# text/vnd.iptc.newsml +# text/vnd.iptc.nitf +# text/vnd.latex-z +# text/vnd.motorola.reflex +# text/vnd.ms-mediapackage +# text/vnd.net2phone.commcenter.command +# text/vnd.radisys.msml-basic-layout +# text/vnd.si.uricatalogue +text/vnd.sun.j2me.app-descriptor jad +# text/vnd.trolltech.linguist +# text/vnd.wap.si +# text/vnd.wap.sl +text/vnd.wap.wml wml +text/vnd.wap.wmlscript wmls +text/x-asm s asm +text/x-c c cc cxx cpp h hh dic +text/x-fortran f for f77 f90 +text/x-java-source java +text/x-nfo nfo +text/x-opml opml +text/x-pascal p pas +text/x-setext etx +text/x-sfv sfv +text/x-uuencode uu +text/x-vcalendar vcs +text/x-vcard vcf +# text/xml +# text/xml-external-parsed-entity +# video/1d-interleaved-parityfec +video/3gpp 3gp +# video/3gpp-tt +video/3gpp2 3g2 +# video/bmpeg +# video/bt656 +# video/celb +# video/dv +# video/encaprtp +# video/example +video/h261 h261 +video/h263 h263 +# video/h263-1998 +# video/h263-2000 +video/h264 h264 +# video/h264-rcdo +# video/h264-svc +# video/h265 +# video/iso.segment +video/jpeg jpgv +# video/jpeg2000 +video/jpm jpm jpgm +video/mj2 mj2 mjp2 +# video/mp1s +# video/mp2p +# video/mp2t +video/mp4 mp4 mp4v mpg4 +# video/mp4v-es +video/mpeg mpeg mpg mpe m1v m2v +# video/mpeg4-generic +# video/mpv +# video/nv +video/ogg ogv +# video/parityfec +# video/pointer +video/quicktime qt mov +# video/raptorfec +# video/raw +# video/rtp-enc-aescm128 +# video/rtploopback +# video/rtx +# video/smpte292m +# video/ulpfec +# video/vc1 +# video/vnd.cctv +video/vnd.dece.hd uvh uvvh +video/vnd.dece.mobile uvm uvvm +# video/vnd.dece.mp4 +video/vnd.dece.pd uvp uvvp +video/vnd.dece.sd uvs uvvs +video/vnd.dece.video uvv uvvv +# video/vnd.directv.mpeg +# video/vnd.directv.mpeg-tts +# video/vnd.dlna.mpeg-tts +video/vnd.dvb.file dvb +video/vnd.fvt fvt +# video/vnd.hns.video +# video/vnd.iptvforum.1dparityfec-1010 +# video/vnd.iptvforum.1dparityfec-2005 +# video/vnd.iptvforum.2dparityfec-1010 +# video/vnd.iptvforum.2dparityfec-2005 +# video/vnd.iptvforum.ttsavc +# video/vnd.iptvforum.ttsmpeg2 +# video/vnd.motorola.video +# video/vnd.motorola.videop +video/vnd.mpegurl mxu m4u +video/vnd.ms-playready.media.pyv pyv +# video/vnd.nokia.interleaved-multimedia +# video/vnd.nokia.videovoip +# video/vnd.objectvideo +# video/vnd.radgamettools.bink +# video/vnd.radgamettools.smacker +# video/vnd.sealed.mpeg1 +# video/vnd.sealed.mpeg4 +# video/vnd.sealed.swf +# video/vnd.sealedmedia.softseal.mov +video/vnd.uvvu.mp4 uvu uvvu +video/vnd.vivo viv +# video/vp8 +video/webm webm +video/x-f4v f4v +video/x-fli fli +video/x-flv flv +video/x-m4v m4v +video/x-matroska mkv mk3d mks +video/x-mng mng +video/x-ms-asf asf asx +video/x-ms-vob vob +video/x-ms-wm wm +video/x-ms-wmv wmv +video/x-ms-wmx wmx +video/x-ms-wvx wvx +video/x-msvideo avi +video/x-sgi-movie movie +video/x-smv smv +x-conference/x-cooltalk ice diff --git a/spring-web/src/main/resources/org/springframework/web/context/ContextLoader.properties b/spring-web/src/main/resources/org/springframework/web/context/ContextLoader.properties new file mode 100644 index 0000000000000000000000000000000000000000..6cd24b294afb12aac7fcc2be1893c42d59cc8b7a --- /dev/null +++ b/spring-web/src/main/resources/org/springframework/web/context/ContextLoader.properties @@ -0,0 +1,5 @@ +# Default WebApplicationContext implementation class for ContextLoader. +# Used as fallback when no explicit context implementation has been specified as context-param. +# Not meant to be customized by application developers. + +org.springframework.web.context.WebApplicationContext=org.springframework.web.context.support.XmlWebApplicationContext diff --git a/spring-web/src/main/resources/org/springframework/web/util/HtmlCharacterEntityReferences.properties b/spring-web/src/main/resources/org/springframework/web/util/HtmlCharacterEntityReferences.properties new file mode 100644 index 0000000000000000000000000000000000000000..75d3015792ab1fb9646ab5f0728f1ece4fe1d301 --- /dev/null +++ b/spring-web/src/main/resources/org/springframework/web/util/HtmlCharacterEntityReferences.properties @@ -0,0 +1,265 @@ +# Character Entity References defined by the HTML 4.0 standard. +# A complete description of the HTML 4.0 character set can be found at: +# http://www.w3.org/TR/html4/charset.html + +# Character entity references for ISO 8859-1 characters + +160 = nbsp +161 = iexcl +162 = cent +163 = pound +164 = curren +165 = yen +166 = brvbar +167 = sect +168 = uml +169 = copy +170 = ordf +171 = laquo +172 = not +173 = shy +174 = reg +175 = macr +176 = deg +177 = plusmn +178 = sup2 +179 = sup3 +180 = acute +181 = micro +182 = para +183 = middot +184 = cedil +185 = sup1 +186 = ordm +187 = raquo +188 = frac14 +189 = frac12 +190 = frac34 +191 = iquest +192 = Agrave +193 = Aacute +194 = Acirc +195 = Atilde +196 = Auml +197 = Aring +198 = AElig +199 = Ccedil +200 = Egrave +201 = Eacute +202 = Ecirc +203 = Euml +204 = Igrave +205 = Iacute +206 = Icirc +207 = Iuml +208 = ETH +209 = Ntilde +210 = Ograve +211 = Oacute +212 = Ocirc +213 = Otilde +214 = Ouml +215 = times +216 = Oslash +217 = Ugrave +218 = Uacute +219 = Ucirc +220 = Uuml +221 = Yacute +222 = THORN +223 = szlig +224 = agrave +225 = aacute +226 = acirc +227 = atilde +228 = auml +229 = aring +230 = aelig +231 = ccedil +232 = egrave +233 = eacute +234 = ecirc +235 = euml +236 = igrave +237 = iacute +238 = icirc +239 = iuml +240 = eth +241 = ntilde +242 = ograve +243 = oacute +244 = ocirc +245 = otilde +246 = ouml +247 = divide +248 = oslash +249 = ugrave +250 = uacute +251 = ucirc +252 = uuml +253 = yacute +254 = thorn +255 = yuml + +# Character entity references for symbols, mathematical symbols, and Greek letters + +402 = fnof +913 = Alpha +914 = Beta +915 = Gamma +916 = Delta +917 = Epsilon +918 = Zeta +919 = Eta +920 = Theta +921 = Iota +922 = Kappa +923 = Lambda +924 = Mu +925 = Nu +926 = Xi +927 = Omicron +928 = Pi +929 = Rho +931 = Sigma +932 = Tau +933 = Upsilon +934 = Phi +935 = Chi +936 = Psi +937 = Omega +945 = alpha +946 = beta +947 = gamma +948 = delta +949 = epsilon +950 = zeta +951 = eta +952 = theta +953 = iota +954 = kappa +955 = lambda +956 = mu +957 = nu +958 = xi +959 = omicron +960 = pi +961 = rho +962 = sigmaf +963 = sigma +964 = tau +965 = upsilon +966 = phi +967 = chi +968 = psi +969 = omega +977 = thetasym +978 = upsih +982 = piv +8226 = bull +8230 = hellip +8242 = prime +8243 = Prime +8254 = oline +8260 = frasl +8472 = weierp +8465 = image +8476 = real +8482 = trade +8501 = alefsym +8592 = larr +8593 = uarr +8594 = rarr +8595 = darr +8596 = harr +8629 = crarr +8656 = lArr +8657 = uArr +8658 = rArr +8659 = dArr +8660 = hArr +8704 = forall +8706 = part +8707 = exist +8709 = empty +8711 = nabla +8712 = isin +8713 = notin +8715 = ni +8719 = prod +8721 = sum +8722 = minus +8727 = lowast +8730 = radic +8733 = prop +8734 = infin +8736 = ang +8743 = and +8744 = or +8745 = cap +8746 = cup +8747 = int +8756 = there4 +8764 = sim +8773 = cong +8776 = asymp +8800 = ne +8801 = equiv +8804 = le +8805 = ge +8834 = sub +8835 = sup +8836 = nsub +8838 = sube +8839 = supe +8853 = oplus +8855 = otimes +8869 = perp +8901 = sdot +8968 = lceil +8969 = rceil +8970 = lfloor +8971 = rfloor +9001 = lang +9002 = rang +9674 = loz +9824 = spades +9827 = clubs +9829 = hearts +9830 = diams + +# Character entity references for markup-significant and internationalization characters + +34 = quot +38 = amp +39 = #39 +60 = lt +62 = gt +338 = OElig +339 = oelig +352 = Scaron +353 = scaron +376 = Yuml +710 = circ +732 = tilde +8194 = ensp +8195 = emsp +8201 = thinsp +8204 = zwnj +8205 = zwj +8206 = lrm +8207 = rlm +8211 = ndash +8212 = mdash +8216 = lsquo +8217 = rsquo +8218 = sbquo +8220 = ldquo +8221 = rdquo +8222 = bdquo +8224 = dagger +8225 = Dagger +8240 = permil +8249 = lsaquo +8250 = rsaquo +8364 = euro diff --git a/spring-web/src/test/java/org/springframework/core/task/MockRunnable.java b/spring-web/src/test/java/org/springframework/core/task/MockRunnable.java new file mode 100644 index 0000000000000000000000000000000000000000..c0433762822899ee379c0c7c87a973b50c653a2d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/core/task/MockRunnable.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.task; + +/** + * @author Juergen Hoeller + */ +public class MockRunnable implements Runnable { + + private boolean executed = false; + + @Override + public void run() { + this.executed = true; + } + + public boolean wasExecuted() { + return this.executed; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/CacheControlTests.java b/spring-web/src/test/java/org/springframework/http/CacheControlTests.java new file mode 100644 index 0000000000000000000000000000000000000000..34a3a9349d3f83fce908a605b39785644c28af9e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/CacheControlTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import static org.junit.Assert.*; + +import java.util.concurrent.TimeUnit; + +/** + * @author Brian Clozel + */ +public class CacheControlTests { + + @Test + public void emptyCacheControl() throws Exception { + CacheControl cc = CacheControl.empty(); + assertThat(cc.getHeaderValue(), Matchers.nullValue()); + } + + @Test + public void maxAge() throws Exception { + CacheControl cc = CacheControl.maxAge(1, TimeUnit.HOURS); + assertThat(cc.getHeaderValue(), Matchers.equalTo("max-age=3600")); + } + + @Test + public void maxAgeAndDirectives() throws Exception { + CacheControl cc = CacheControl.maxAge(3600, TimeUnit.SECONDS).cachePublic().noTransform(); + assertThat(cc.getHeaderValue(), Matchers.equalTo("max-age=3600, no-transform, public")); + } + + @Test + public void maxAgeAndSMaxAge() throws Exception { + CacheControl cc = CacheControl.maxAge(1, TimeUnit.HOURS).sMaxAge(30, TimeUnit.MINUTES); + assertThat(cc.getHeaderValue(), Matchers.equalTo("max-age=3600, s-maxage=1800")); + } + + @Test + public void noCachePrivate() throws Exception { + CacheControl cc = CacheControl.noCache().cachePrivate(); + assertThat(cc.getHeaderValue(), Matchers.equalTo("no-cache, private")); + } + + @Test + public void noStore() throws Exception { + CacheControl cc = CacheControl.noStore(); + assertThat(cc.getHeaderValue(), Matchers.equalTo("no-store")); + } + + @Test + public void staleIfError() throws Exception { + CacheControl cc = CacheControl.maxAge(1, TimeUnit.HOURS).staleIfError(2, TimeUnit.HOURS); + assertThat(cc.getHeaderValue(), Matchers.equalTo("max-age=3600, stale-if-error=7200")); + } + + @Test + public void staleWhileRevalidate() throws Exception { + CacheControl cc = CacheControl.maxAge(1, TimeUnit.HOURS).staleWhileRevalidate(2, TimeUnit.HOURS); + assertThat(cc.getHeaderValue(), Matchers.equalTo("max-age=3600, stale-while-revalidate=7200")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/ContentDispositionTests.java b/spring-web/src/test/java/org/springframework/http/ContentDispositionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e0072e09e464bb357530f18296f52279380b44d5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/ContentDispositionTests.java @@ -0,0 +1,234 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.junit.Test; + +import org.springframework.util.ReflectionUtils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.springframework.http.ContentDisposition.builder; + +/** + * Unit tests for {@link ContentDisposition} + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class ContentDispositionTests { + + private static DateTimeFormatter formatter = DateTimeFormatter.RFC_1123_DATE_TIME; + + + @Test + public void parse() { + assertEquals(builder("form-data").name("foo").filename("foo.txt").size(123L).build(), + parse("form-data; name=\"foo\"; filename=\"foo.txt\"; size=123")); + } + + @Test + public void parseFilenameUnquoted() { + assertEquals(builder("form-data").filename("unquoted").build(), + parse("form-data; filename=unquoted")); + } + + @Test // SPR-16091 + public void parseFilenameWithSemicolon() { + assertEquals(builder("attachment").filename("filename with ; semicolon.txt").build(), + parse("attachment; filename=\"filename with ; semicolon.txt\"")); + } + + @Test + public void parseEncodedFilename() { + assertEquals(builder("form-data").name("name").filename("中文.txt", StandardCharsets.UTF_8).build(), + parse("form-data; name=\"name\"; filename*=UTF-8''%E4%B8%AD%E6%96%87.txt")); + } + + @Test // gh-24112 + public void parseEncodedFilenameWithPaddedCharset() { + assertEquals(builder("attachment").filename("some-file.zip", StandardCharsets.UTF_8).build(), + parse("attachment; filename*= UTF-8''some-file.zip")); + } + + @Test + public void parseEncodedFilenameWithoutCharset() { + assertEquals(builder("form-data").name("name").filename("test.txt").build(), + parse("form-data; name=\"name\"; filename*=test.txt")); + } + + @Test(expected = IllegalArgumentException.class) + public void parseEncodedFilenameWithInvalidCharset() { + parse("form-data; name=\"name\"; filename*=UTF-16''test.txt"); + } + + @Test + public void parseEncodedFilenameWithInvalidName() { + + Consumer tester = input -> { + try { + parse(input); + fail(); + } + catch (IllegalArgumentException ex) { + // expected + } + }; + + tester.accept("form-data; name=\"name\"; filename*=UTF-8''%A"); + tester.accept("form-data; name=\"name\"; filename*=UTF-8''%A.txt"); + } + + @Test // gh-23077 + public void parseWithEscapedQuote() { + + BiConsumer tester = (description, filename) -> + assertEquals(description, + builder("form-data").name("file").filename(filename).size(123L).build(), + parse("form-data; name=\"file\"; filename=\"" + filename + "\"; size=123")); + + tester.accept("Escaped quotes should be ignored", + "\\\"The Twilight Zone\\\".txt"); + + tester.accept("Escaped quotes preceded by escaped backslashes should be ignored", + "\\\\\\\"The Twilight Zone\\\\\\\".txt"); + + tester.accept("Escaped backslashes should not suppress quote", + "The Twilight Zone \\\\"); + + tester.accept("Escaped backslashes should not suppress quote", + "The Twilight Zone \\\\\\\\"); + } + + @Test + public void parseWithExtraSemicolons() { + assertEquals(builder("form-data").name("foo").filename("foo.txt").size(123L).build(), + parse("form-data; name=\"foo\";; ; filename=\"foo.txt\"; size=123")); + } + + @Test + public void parseDates() { + assertEquals( + builder("attachment") + .creationDate(ZonedDateTime.parse("Mon, 12 Feb 2007 10:15:30 -0500", formatter)) + .modificationDate(ZonedDateTime.parse("Tue, 13 Feb 2007 10:15:30 -0500", formatter)) + .readDate(ZonedDateTime.parse("Wed, 14 Feb 2007 10:15:30 -0500", formatter)).build(), + parse("attachment; creation-date=\"Mon, 12 Feb 2007 10:15:30 -0500\"; " + + "modification-date=\"Tue, 13 Feb 2007 10:15:30 -0500\"; " + + "read-date=\"Wed, 14 Feb 2007 10:15:30 -0500\"")); + } + + @Test + public void parseIgnoresInvalidDates() { + assertEquals( + builder("attachment") + .readDate(ZonedDateTime.parse("Wed, 14 Feb 2007 10:15:30 -0500", formatter)) + .build(), + parse("attachment; creation-date=\"-1\"; " + + "modification-date=\"-1\"; " + + "read-date=\"Wed, 14 Feb 2007 10:15:30 -0500\"")); + } + + @Test(expected = IllegalArgumentException.class) + public void parseEmpty() { + parse(""); + } + + @Test(expected = IllegalArgumentException.class) + public void parseNoType() { + parse(";"); + } + + @Test(expected = IllegalArgumentException.class) + public void parseInvalidParameter() { + parse("foo;bar"); + } + + private static ContentDisposition parse(String input) { + return ContentDisposition.parse(input); + } + + + @Test + public void format() { + assertEquals("form-data; name=\"foo\"; filename=\"foo.txt\"; size=123", + builder("form-data").name("foo").filename("foo.txt").size(123L).build().toString()); + } + + @Test + public void formatWithEncodedFilename() { + assertEquals("form-data; name=\"name\"; filename*=UTF-8''%E4%B8%AD%E6%96%87.txt", + builder("form-data").name("name").filename("中文.txt", StandardCharsets.UTF_8).build().toString()); + } + + @Test + public void formatWithEncodedFilenameUsingUsAscii() { + assertEquals("form-data; name=\"name\"; filename=\"test.txt\"", + builder("form-data") + .name("name") + .filename("test.txt", StandardCharsets.US_ASCII) + .build() + .toString()); + } + + @Test // gh-24220 + public void formatWithFilenameWithQuotes() { + + BiConsumer tester = (input, output) -> { + + assertEquals("form-data; filename=\"" + output + "\"", + builder("form-data").filename(input).build().toString()); + + assertEquals("form-data; filename=\"" + output + "\"", + builder("form-data").filename(input, StandardCharsets.US_ASCII).build().toString()); + }; + + String filename = "\"foo.txt"; + tester.accept(filename, "\\" + filename); + + filename = "\\\"foo.txt"; + tester.accept(filename, filename); + + filename = "\\\\\"foo.txt"; + tester.accept(filename, "\\" + filename); + + filename = "\\\\\\\"foo.txt"; + tester.accept(filename, filename); + + filename = "\\\\\\\\\"foo.txt"; + tester.accept(filename, "\\" + filename); + + tester.accept("\"\"foo.txt", "\\\"\\\"foo.txt"); + tester.accept("\"\"\"foo.txt", "\\\"\\\"\\\"foo.txt"); + + tester.accept("foo.txt\\", "foo.txt"); + tester.accept("foo.txt\\\\", "foo.txt\\\\"); + tester.accept("foo.txt\\\\\\", "foo.txt\\\\"); + } + + @Test(expected = IllegalArgumentException.class) + public void formatWithEncodedFilenameUsingInvalidCharset() { + builder("form-data").name("name").filename("test.txt", StandardCharsets.UTF_16).build().toString(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/HttpEntityTests.java b/spring-web/src/test/java/org/springframework/http/HttpEntityTests.java new file mode 100644 index 0000000000000000000000000000000000000000..69b35cd81b7f252a458b001f884786893c5db4e1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/HttpEntityTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.net.URI; + +import org.junit.Test; + +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class HttpEntityTests { + + @Test + public void noHeaders() { + String body = "foo"; + HttpEntity entity = new HttpEntity<>(body); + assertSame(body, entity.getBody()); + assertTrue(entity.getHeaders().isEmpty()); + } + + @Test + public void httpHeaders() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + String body = "foo"; + HttpEntity entity = new HttpEntity<>(body, headers); + assertEquals(body, entity.getBody()); + assertEquals(MediaType.TEXT_PLAIN, entity.getHeaders().getContentType()); + assertEquals("text/plain", entity.getHeaders().getFirst("Content-Type")); + } + + @Test + public void multiValueMap() { + MultiValueMap map = new LinkedMultiValueMap<>(); + map.set("Content-Type", "text/plain"); + String body = "foo"; + HttpEntity entity = new HttpEntity<>(body, map); + assertEquals(body, entity.getBody()); + assertEquals(MediaType.TEXT_PLAIN, entity.getHeaders().getContentType()); + assertEquals("text/plain", entity.getHeaders().getFirst("Content-Type")); + } + + @Test + public void testEquals() { + MultiValueMap map1 = new LinkedMultiValueMap<>(); + map1.set("Content-Type", "text/plain"); + + MultiValueMap map2 = new LinkedMultiValueMap<>(); + map2.set("Content-Type", "application/json"); + + assertTrue(new HttpEntity<>().equals(new HttpEntity())); + assertFalse(new HttpEntity<>(map1).equals(new HttpEntity())); + assertFalse(new HttpEntity<>().equals(new HttpEntity(map2))); + + assertTrue(new HttpEntity<>(map1).equals(new HttpEntity(map1))); + assertFalse(new HttpEntity<>(map1).equals(new HttpEntity(map2))); + + assertTrue(new HttpEntity(null, null).equals(new HttpEntity(null, null))); + assertFalse(new HttpEntity<>("foo", null).equals(new HttpEntity(null, null))); + assertFalse(new HttpEntity(null, null).equals(new HttpEntity<>("bar", null))); + + assertTrue(new HttpEntity<>("foo", map1).equals(new HttpEntity("foo", map1))); + assertFalse(new HttpEntity<>("foo", map1).equals(new HttpEntity("bar", map1))); + } + + @Test + public void responseEntity() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + String body = "foo"; + HttpEntity httpEntity = new HttpEntity<>(body, headers); + ResponseEntity responseEntity = new ResponseEntity<>(body, headers, HttpStatus.OK); + ResponseEntity responseEntity2 = new ResponseEntity<>(body, headers, HttpStatus.OK); + + assertEquals(body, responseEntity.getBody()); + assertEquals(MediaType.TEXT_PLAIN, responseEntity.getHeaders().getContentType()); + assertEquals("text/plain", responseEntity.getHeaders().getFirst("Content-Type")); + assertEquals("text/plain", responseEntity.getHeaders().getFirst("Content-Type")); + + assertFalse(httpEntity.equals(responseEntity)); + assertFalse(responseEntity.equals(httpEntity)); + assertTrue(responseEntity.equals(responseEntity2)); + assertTrue(responseEntity2.equals(responseEntity)); + } + + @Test + public void requestEntity() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + String body = "foo"; + HttpEntity httpEntity = new HttpEntity<>(body, headers); + RequestEntity requestEntity = new RequestEntity<>(body, headers, HttpMethod.GET, new URI("/")); + RequestEntity requestEntity2 = new RequestEntity<>(body, headers, HttpMethod.GET, new URI("/")); + + assertEquals(body, requestEntity.getBody()); + assertEquals(MediaType.TEXT_PLAIN, requestEntity.getHeaders().getContentType()); + assertEquals("text/plain", requestEntity.getHeaders().getFirst("Content-Type")); + assertEquals("text/plain", requestEntity.getHeaders().getFirst("Content-Type")); + + assertFalse(httpEntity.equals(requestEntity)); + assertFalse(requestEntity.equals(httpEntity)); + assertTrue(requestEntity.equals(requestEntity2)); + assertTrue(requestEntity2.equals(requestEntity)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java b/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c4cfc59044401ea08989e8eade7dff43969d97e9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java @@ -0,0 +1,689 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Calendar; +import java.util.Collections; +import java.util.EnumSet; +import java.util.GregorianCalendar; +import java.util.List; +import java.util.Locale; +import java.util.Map.Entry; +import java.util.TimeZone; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import static java.util.stream.Collectors.toList; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link org.springframework.http.HttpHeaders}. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Brian Clozel + * @author Juergen Hoeller + * @author Sam Brannen + */ +public class HttpHeadersTests { + + private final HttpHeaders headers = new HttpHeaders(); + + + @Test + public void getFirst() { + headers.add(HttpHeaders.CACHE_CONTROL, "max-age=1000, public"); + headers.add(HttpHeaders.CACHE_CONTROL, "s-maxage=1000"); + assertThat(headers.getFirst(HttpHeaders.CACHE_CONTROL), is("max-age=1000, public")); + } + + @Test + public void accept() { + MediaType mediaType1 = new MediaType("text", "html"); + MediaType mediaType2 = new MediaType("text", "plain"); + List mediaTypes = new ArrayList<>(2); + mediaTypes.add(mediaType1); + mediaTypes.add(mediaType2); + headers.setAccept(mediaTypes); + assertEquals("Invalid Accept header", mediaTypes, headers.getAccept()); + assertEquals("Invalid Accept header", "text/html, text/plain", headers.getFirst("Accept")); + } + + @Test // SPR-9655 + public void acceptWithMultipleHeaderValues() { + headers.add("Accept", "text/html"); + headers.add("Accept", "text/plain"); + List expected = Arrays.asList(new MediaType("text", "html"), new MediaType("text", "plain")); + assertEquals("Invalid Accept header", expected, headers.getAccept()); + } + + @Test // SPR-14506 + public void acceptWithMultipleCommaSeparatedHeaderValues() { + headers.add("Accept", "text/html,text/pdf"); + headers.add("Accept", "text/plain,text/csv"); + List expected = Arrays.asList(new MediaType("text", "html"), new MediaType("text", "pdf"), + new MediaType("text", "plain"), new MediaType("text", "csv")); + assertEquals("Invalid Accept header", expected, headers.getAccept()); + } + + @Test + public void acceptCharsets() { + Charset charset1 = StandardCharsets.UTF_8; + Charset charset2 = StandardCharsets.ISO_8859_1; + List charsets = new ArrayList<>(2); + charsets.add(charset1); + charsets.add(charset2); + headers.setAcceptCharset(charsets); + assertEquals("Invalid Accept header", charsets, headers.getAcceptCharset()); + assertEquals("Invalid Accept header", "utf-8, iso-8859-1", headers.getFirst("Accept-Charset")); + } + + @Test + public void acceptCharsetWildcard() { + headers.set("Accept-Charset", "ISO-8859-1,utf-8;q=0.7,*;q=0.7"); + assertEquals("Invalid Accept header", Arrays.asList(StandardCharsets.ISO_8859_1, StandardCharsets.UTF_8), + headers.getAcceptCharset()); + } + + @Test + public void allow() { + EnumSet methods = EnumSet.of(HttpMethod.GET, HttpMethod.POST); + headers.setAllow(methods); + assertEquals("Invalid Allow header", methods, headers.getAllow()); + assertEquals("Invalid Allow header", "GET,POST", headers.getFirst("Allow")); + } + + @Test + public void contentLength() { + long length = 42L; + headers.setContentLength(length); + assertEquals("Invalid Content-Length header", length, headers.getContentLength()); + assertEquals("Invalid Content-Length header", "42", headers.getFirst("Content-Length")); + } + + @Test + public void contentType() { + MediaType contentType = new MediaType("text", "html", StandardCharsets.UTF_8); + headers.setContentType(contentType); + assertEquals("Invalid Content-Type header", contentType, headers.getContentType()); + assertEquals("Invalid Content-Type header", "text/html;charset=UTF-8", headers.getFirst("Content-Type")); + } + + @Test + public void location() throws URISyntaxException { + URI location = new URI("https://www.example.com/hotels"); + headers.setLocation(location); + assertEquals("Invalid Location header", location, headers.getLocation()); + assertEquals("Invalid Location header", "https://www.example.com/hotels", headers.getFirst("Location")); + } + + @Test + public void eTag() { + String eTag = "\"v2.6\""; + headers.setETag(eTag); + assertEquals("Invalid ETag header", eTag, headers.getETag()); + assertEquals("Invalid ETag header", "\"v2.6\"", headers.getFirst("ETag")); + } + + @Test + public void host() { + InetSocketAddress host = InetSocketAddress.createUnresolved("localhost", 8080); + headers.setHost(host); + assertEquals("Invalid Host header", host, headers.getHost()); + assertEquals("Invalid Host header", "localhost:8080", headers.getFirst("Host")); + } + + @Test + public void hostNoPort() { + InetSocketAddress host = InetSocketAddress.createUnresolved("localhost", 0); + headers.setHost(host); + assertEquals("Invalid Host header", host, headers.getHost()); + assertEquals("Invalid Host header", "localhost", headers.getFirst("Host")); + } + + @Test + public void ipv6Host() { + InetSocketAddress host = InetSocketAddress.createUnresolved("[::1]", 0); + headers.setHost(host); + assertEquals("Invalid Host header", host, headers.getHost()); + assertEquals("Invalid Host header", "[::1]", headers.getFirst("Host")); + } + + @Test(expected = IllegalArgumentException.class) + public void illegalETag() { + String eTag = "v2.6"; + headers.setETag(eTag); + assertEquals("Invalid ETag header", eTag, headers.getETag()); + assertEquals("Invalid ETag header", "\"v2.6\"", headers.getFirst("ETag")); + } + + @Test + public void ifMatch() { + String ifMatch = "\"v2.6\""; + headers.setIfMatch(ifMatch); + assertEquals("Invalid If-Match header", ifMatch, headers.getIfMatch().get(0)); + assertEquals("Invalid If-Match header", "\"v2.6\"", headers.getFirst("If-Match")); + } + + @Test(expected = IllegalArgumentException.class) + public void ifMatchIllegalHeader() { + headers.setIfMatch("Illegal"); + headers.getIfMatch(); + } + + @Test + public void ifMatchMultipleHeaders() { + headers.add(HttpHeaders.IF_MATCH, "\"v2,0\""); + headers.add(HttpHeaders.IF_MATCH, "W/\"v2,1\", \"v2,2\""); + assertEquals("Invalid If-Match header", "\"v2,0\"", headers.get(HttpHeaders.IF_MATCH).get(0)); + assertEquals("Invalid If-Match header", "W/\"v2,1\", \"v2,2\"", headers.get(HttpHeaders.IF_MATCH).get(1)); + assertThat(headers.getIfMatch(), Matchers.contains("\"v2,0\"", "W/\"v2,1\"", "\"v2,2\"")); + } + + @Test + public void ifNoneMatch() { + String ifNoneMatch = "\"v2.6\""; + headers.setIfNoneMatch(ifNoneMatch); + assertEquals("Invalid If-None-Match header", ifNoneMatch, headers.getIfNoneMatch().get(0)); + assertEquals("Invalid If-None-Match header", "\"v2.6\"", headers.getFirst("If-None-Match")); + } + + @Test + public void ifNoneMatchWildCard() { + String ifNoneMatch = "*"; + headers.setIfNoneMatch(ifNoneMatch); + assertEquals("Invalid If-None-Match header", ifNoneMatch, headers.getIfNoneMatch().get(0)); + assertEquals("Invalid If-None-Match header", "*", headers.getFirst("If-None-Match")); + } + + @Test + public void ifNoneMatchList() { + String ifNoneMatch1 = "\"v2.6\""; + String ifNoneMatch2 = "\"v2.7\", \"v2.8\""; + List ifNoneMatchList = new ArrayList<>(2); + ifNoneMatchList.add(ifNoneMatch1); + ifNoneMatchList.add(ifNoneMatch2); + headers.setIfNoneMatch(ifNoneMatchList); + assertThat(headers.getIfNoneMatch(), Matchers.contains("\"v2.6\"", "\"v2.7\"", "\"v2.8\"")); + assertEquals("Invalid If-None-Match header", "\"v2.6\", \"v2.7\", \"v2.8\"", headers.getFirst("If-None-Match")); + } + + @Test + public void date() { + Calendar calendar = new GregorianCalendar(2008, 11, 18, 11, 20); + calendar.setTimeZone(TimeZone.getTimeZone("CET")); + long date = calendar.getTimeInMillis(); + headers.setDate(date); + assertEquals("Invalid Date header", date, headers.getDate()); + assertEquals("Invalid Date header", "Thu, 18 Dec 2008 10:20:00 GMT", headers.getFirst("date")); + + // RFC 850 + headers.set("Date", "Thu, 18 Dec 2008 10:20:00 GMT"); + assertEquals("Invalid Date header", date, headers.getDate()); + } + + @Test(expected = IllegalArgumentException.class) + public void dateInvalid() { + headers.set("Date", "Foo Bar Baz"); + headers.getDate(); + } + + @Test + public void dateOtherLocale() { + Locale defaultLocale = Locale.getDefault(); + try { + Locale.setDefault(new Locale("nl", "nl")); + Calendar calendar = new GregorianCalendar(2008, 11, 18, 11, 20); + calendar.setTimeZone(TimeZone.getTimeZone("CET")); + long date = calendar.getTimeInMillis(); + headers.setDate(date); + assertEquals("Invalid Date header", "Thu, 18 Dec 2008 10:20:00 GMT", headers.getFirst("date")); + assertEquals("Invalid Date header", date, headers.getDate()); + } + finally { + Locale.setDefault(defaultLocale); + } + } + + @Test + public void lastModified() { + Calendar calendar = new GregorianCalendar(2008, 11, 18, 11, 20); + calendar.setTimeZone(TimeZone.getTimeZone("CET")); + long date = calendar.getTimeInMillis(); + headers.setLastModified(date); + assertEquals("Invalid Last-Modified header", date, headers.getLastModified()); + assertEquals("Invalid Last-Modified header", "Thu, 18 Dec 2008 10:20:00 GMT", + headers.getFirst("last-modified")); + } + + @Test + public void expiresLong() { + Calendar calendar = new GregorianCalendar(2008, 11, 18, 11, 20); + calendar.setTimeZone(TimeZone.getTimeZone("CET")); + long date = calendar.getTimeInMillis(); + headers.setExpires(date); + assertEquals("Invalid Expires header", date, headers.getExpires()); + assertEquals("Invalid Expires header", "Thu, 18 Dec 2008 10:20:00 GMT", headers.getFirst("expires")); + } + + @Test + public void expiresZonedDateTime() { + ZonedDateTime zonedDateTime = ZonedDateTime.of(2008, 12, 18, 10, 20, 0, 0, ZoneId.of("GMT")); + headers.setExpires(zonedDateTime); + assertEquals("Invalid Expires header", zonedDateTime.toInstant().toEpochMilli(), headers.getExpires()); + assertEquals("Invalid Expires header", "Thu, 18 Dec 2008 10:20:00 GMT", headers.getFirst("expires")); + } + + @Test // SPR-10648 (example is from INT-3063) + public void expiresInvalidDate() { + headers.set("Expires", "-1"); + assertEquals(-1, headers.getExpires()); + } + + @Test + public void ifModifiedSince() { + Calendar calendar = new GregorianCalendar(2008, 11, 18, 11, 20); + calendar.setTimeZone(TimeZone.getTimeZone("CET")); + long date = calendar.getTimeInMillis(); + headers.setIfModifiedSince(date); + assertEquals("Invalid If-Modified-Since header", date, headers.getIfModifiedSince()); + assertEquals("Invalid If-Modified-Since header", "Thu, 18 Dec 2008 10:20:00 GMT", + headers.getFirst("if-modified-since")); + } + + @Test // SPR-14144 + public void invalidIfModifiedSinceHeader() { + headers.set(HttpHeaders.IF_MODIFIED_SINCE, "0"); + assertEquals(-1, headers.getIfModifiedSince()); + + headers.set(HttpHeaders.IF_MODIFIED_SINCE, "-1"); + assertEquals(-1, headers.getIfModifiedSince()); + + headers.set(HttpHeaders.IF_MODIFIED_SINCE, "XXX"); + assertEquals(-1, headers.getIfModifiedSince()); + } + + @Test + public void pragma() { + String pragma = "no-cache"; + headers.setPragma(pragma); + assertEquals("Invalid Pragma header", pragma, headers.getPragma()); + assertEquals("Invalid Pragma header", "no-cache", headers.getFirst("pragma")); + } + + @Test + public void cacheControl() { + headers.setCacheControl("no-cache"); + assertEquals("Invalid Cache-Control header", "no-cache", headers.getCacheControl()); + assertEquals("Invalid Cache-Control header", "no-cache", headers.getFirst("cache-control")); + } + + @Test + public void cacheControlBuilder() { + headers.setCacheControl(CacheControl.noCache()); + assertEquals("Invalid Cache-Control header", "no-cache", headers.getCacheControl()); + assertEquals("Invalid Cache-Control header", "no-cache", headers.getFirst("cache-control")); + } + + @Test + public void cacheControlEmpty() { + headers.setCacheControl(CacheControl.empty()); + assertNull("Invalid Cache-Control header", headers.getCacheControl()); + assertNull("Invalid Cache-Control header", headers.getFirst("cache-control")); + } + + @Test + public void cacheControlAllValues() { + headers.add(HttpHeaders.CACHE_CONTROL, "max-age=1000, public"); + headers.add(HttpHeaders.CACHE_CONTROL, "s-maxage=1000"); + assertEquals("max-age=1000, public, s-maxage=1000", headers.getCacheControl()); + } + + @Test + public void contentDisposition() { + ContentDisposition disposition = headers.getContentDisposition(); + assertNotNull(disposition); + assertEquals("Invalid Content-Disposition header", ContentDisposition.empty(), headers.getContentDisposition()); + + disposition = ContentDisposition.builder("attachment").name("foo").filename("foo.txt").size(123L).build(); + headers.setContentDisposition(disposition); + assertEquals("Invalid Content-Disposition header", disposition, headers.getContentDisposition()); + } + + @Test // SPR-11917 + public void getAllowEmptySet() { + headers.setAllow(Collections.emptySet()); + assertThat(headers.getAllow(), Matchers.emptyCollectionOf(HttpMethod.class)); + } + + @Test + public void accessControlAllowCredentials() { + assertFalse(headers.getAccessControlAllowCredentials()); + headers.setAccessControlAllowCredentials(false); + assertFalse(headers.getAccessControlAllowCredentials()); + headers.setAccessControlAllowCredentials(true); + assertTrue(headers.getAccessControlAllowCredentials()); + } + + @Test + public void accessControlAllowHeaders() { + List allowedHeaders = headers.getAccessControlAllowHeaders(); + assertThat(allowedHeaders, Matchers.emptyCollectionOf(String.class)); + headers.setAccessControlAllowHeaders(Arrays.asList("header1", "header2")); + allowedHeaders = headers.getAccessControlAllowHeaders(); + assertEquals(allowedHeaders, Arrays.asList("header1", "header2")); + } + + @Test + public void accessControlAllowHeadersMultipleValues() { + List allowedHeaders = headers.getAccessControlAllowHeaders(); + assertThat(allowedHeaders, Matchers.emptyCollectionOf(String.class)); + headers.add(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS, "header1, header2"); + headers.add(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS, "header3"); + allowedHeaders = headers.getAccessControlAllowHeaders(); + assertEquals(Arrays.asList("header1", "header2", "header3"), allowedHeaders); + } + + @Test + public void accessControlAllowMethods() { + List allowedMethods = headers.getAccessControlAllowMethods(); + assertThat(allowedMethods, Matchers.emptyCollectionOf(HttpMethod.class)); + headers.setAccessControlAllowMethods(Arrays.asList(HttpMethod.GET, HttpMethod.POST)); + allowedMethods = headers.getAccessControlAllowMethods(); + assertEquals(allowedMethods, Arrays.asList(HttpMethod.GET, HttpMethod.POST)); + } + + @Test + public void accessControlAllowOrigin() { + assertNull(headers.getAccessControlAllowOrigin()); + headers.setAccessControlAllowOrigin("*"); + assertEquals("*", headers.getAccessControlAllowOrigin()); + } + + @Test + public void accessControlExposeHeaders() { + List exposedHeaders = headers.getAccessControlExposeHeaders(); + assertThat(exposedHeaders, Matchers.emptyCollectionOf(String.class)); + headers.setAccessControlExposeHeaders(Arrays.asList("header1", "header2")); + exposedHeaders = headers.getAccessControlExposeHeaders(); + assertEquals(exposedHeaders, Arrays.asList("header1", "header2")); + } + + @Test + public void accessControlMaxAge() { + assertEquals(-1, headers.getAccessControlMaxAge()); + headers.setAccessControlMaxAge(3600); + assertEquals(3600, headers.getAccessControlMaxAge()); + } + + @Test + public void accessControlRequestHeaders() { + List requestHeaders = headers.getAccessControlRequestHeaders(); + assertThat(requestHeaders, Matchers.emptyCollectionOf(String.class)); + headers.setAccessControlRequestHeaders(Arrays.asList("header1", "header2")); + requestHeaders = headers.getAccessControlRequestHeaders(); + assertEquals(requestHeaders, Arrays.asList("header1", "header2")); + } + + @Test + public void accessControlRequestMethod() { + assertNull(headers.getAccessControlRequestMethod()); + headers.setAccessControlRequestMethod(HttpMethod.POST); + assertEquals(HttpMethod.POST, headers.getAccessControlRequestMethod()); + } + + @Test + public void acceptLanguage() { + String headerValue = "fr-ch, fr;q=0.9, en-*;q=0.8, de;q=0.7, *;q=0.5"; + headers.setAcceptLanguage(Locale.LanguageRange.parse(headerValue)); + assertEquals(headerValue, headers.getFirst(HttpHeaders.ACCEPT_LANGUAGE)); + + List expectedRanges = Arrays.asList( + new Locale.LanguageRange("fr-ch"), + new Locale.LanguageRange("fr", 0.9), + new Locale.LanguageRange("en-*", 0.8), + new Locale.LanguageRange("de", 0.7), + new Locale.LanguageRange("*", 0.5) + ); + assertEquals(expectedRanges, headers.getAcceptLanguage()); + assertEquals(Locale.forLanguageTag("fr-ch"), headers.getAcceptLanguageAsLocales().get(0)); + + headers.setAcceptLanguageAsLocales(Collections.singletonList(Locale.FRANCE)); + assertEquals(Locale.FRANCE, headers.getAcceptLanguageAsLocales().get(0)); + } + + @Test // SPR-15603 + public void acceptLanguageWithEmptyValue() throws Exception { + this.headers.set(HttpHeaders.ACCEPT_LANGUAGE, ""); + assertEquals(Collections.emptyList(), this.headers.getAcceptLanguageAsLocales()); + } + + @Test + public void contentLanguage() { + headers.setContentLanguage(Locale.FRANCE); + assertEquals(Locale.FRANCE, headers.getContentLanguage()); + assertEquals("fr-FR", headers.getFirst(HttpHeaders.CONTENT_LANGUAGE)); + } + + @Test + public void contentLanguageSerialized() { + headers.set(HttpHeaders.CONTENT_LANGUAGE, "de, en_CA"); + assertEquals("Expected one (first) locale", Locale.GERMAN, headers.getContentLanguage()); + } + + @Test + public void firstDate() { + headers.setDate(HttpHeaders.DATE, 1496370120000L); + assertThat(headers.getFirstDate(HttpHeaders.DATE), is(1496370120000L)); + + headers.clear(); + + headers.add(HttpHeaders.DATE, "Fri, 02 Jun 2017 02:22:00 GMT"); + headers.add(HttpHeaders.DATE, "Sat, 18 Dec 2010 10:20:00 GMT"); + assertThat(headers.getFirstDate(HttpHeaders.DATE), is(1496370120000L)); + } + + @Test + public void firstZonedDateTime() { + ZonedDateTime date = ZonedDateTime.of(2017, 6, 2, 2, 22, 0, 0, ZoneId.of("GMT")); + headers.setZonedDateTime(HttpHeaders.DATE, date); + assertThat(headers.getFirst(HttpHeaders.DATE), is("Fri, 02 Jun 2017 02:22:00 GMT")); + assertTrue(headers.getFirstZonedDateTime(HttpHeaders.DATE).isEqual(date)); + + headers.clear(); + headers.add(HttpHeaders.DATE, "Fri, 02 Jun 2017 02:22:00 GMT"); + headers.add(HttpHeaders.DATE, "Sat, 18 Dec 2010 10:20:00 GMT"); + assertTrue(headers.getFirstZonedDateTime(HttpHeaders.DATE).isEqual(date)); + + // obsolete RFC 850 format + headers.clear(); + headers.set(HttpHeaders.DATE, "Friday, 02-Jun-17 02:22:00 GMT"); + assertTrue(headers.getFirstZonedDateTime(HttpHeaders.DATE).isEqual(date)); + + // ANSI C's asctime() format + headers.clear(); + headers.set(HttpHeaders.DATE, "Fri Jun 02 02:22:00 2017"); + assertTrue(headers.getFirstZonedDateTime(HttpHeaders.DATE).isEqual(date)); + } + + @Test + public void basicAuth() { + String username = "foo"; + String password = "bar"; + headers.setBasicAuth(username, password); + String authorization = headers.getFirst(HttpHeaders.AUTHORIZATION); + assertNotNull(authorization); + assertTrue(authorization.startsWith("Basic ")); + byte[] result = Base64.getDecoder().decode(authorization.substring(6).getBytes(StandardCharsets.ISO_8859_1)); + assertEquals("foo:bar", new String(result, StandardCharsets.ISO_8859_1)); + } + + @Test(expected = IllegalArgumentException.class) + public void basicAuthIllegalChar() { + String username = "foo"; + String password = "\u03BB"; + headers.setBasicAuth(username, password); + } + + @Test + public void bearerAuth() { + String token = "foo"; + + headers.setBearerAuth(token); + String authorization = headers.getFirst(HttpHeaders.AUTHORIZATION); + assertEquals("Bearer foo", authorization); + } + + @Test // https://github.com/spring-projects/spring-framework/issues/23633 + public void keySetRemove() { + // Given + headers.add("Alpha", "apple"); + headers.add("Bravo", "banana"); + assertEquals(2, headers.size()); + assertTrue("Alpha should be present", headers.containsKey("Alpha")); + assertTrue("Bravo should be present", headers.containsKey("Bravo")); + assertArrayEquals(new String[] {"Alpha", "Bravo"}, headers.keySet().toArray()); + + // When + boolean removed = headers.keySet().remove("Alpha"); + + // Then + assertTrue(removed); + assertFalse(headers.keySet().remove("Alpha")); + assertEquals(1, headers.size()); + assertFalse("Alpha should have been removed", headers.containsKey("Alpha")); + assertTrue("Bravo should be present", headers.containsKey("Bravo")); + assertArrayEquals(new String[] {"Bravo"}, headers.keySet().toArray()); + assertEquals(Collections.singletonMap("Bravo", Arrays.asList("banana")).entrySet(), headers.entrySet()); + } + + @Test + public void keySetOperations() { + headers.add("Alpha", "apple"); + headers.add("Bravo", "banana"); + assertEquals(2, headers.size()); + + // size() + assertEquals(2, headers.keySet().size()); + + // contains() + assertTrue("Alpha should be present", headers.keySet().contains("Alpha")); + assertTrue("alpha should be present", headers.keySet().contains("alpha")); + assertTrue("Bravo should be present", headers.keySet().contains("Bravo")); + assertTrue("BRAVO should be present", headers.keySet().contains("BRAVO")); + assertFalse("Charlie should not be present", headers.keySet().contains("Charlie")); + + // toArray() + assertArrayEquals(new String[] {"Alpha", "Bravo"}, headers.keySet().toArray()); + + // spliterator() via stream() + assertEquals(Arrays.asList("Alpha", "Bravo"), headers.keySet().stream().collect(toList())); + + // iterator() + List results = new ArrayList<>(); + headers.keySet().iterator().forEachRemaining(results::add); + assertEquals(Arrays.asList("Alpha", "Bravo"), results); + + // remove() + assertTrue(headers.keySet().remove("Alpha")); + assertEquals(1, headers.size()); + assertFalse(headers.keySet().remove("Alpha")); + + // clear() + headers.keySet().clear(); + assertEquals(0, headers.size()); + + // Unsupported operations + unsupported(() -> headers.keySet().add("x")); + unsupported(() -> headers.keySet().addAll(Collections.singleton("enigma"))); + } + + private static void unsupported(Runnable runnable) { + try { + runnable.run(); + fail("should have thrown an UnsupportedOperationException"); + } + catch (UnsupportedOperationException e) { + // expected + } + } + + @Test + public void removalFromKeySetRemovesEntryFromUnderlyingMap() { + String headerName = "MyHeader"; + String headerValue = "value"; + + assertTrue(headers.isEmpty()); + headers.add(headerName, headerValue); + assertTrue(headers.containsKey(headerName)); + headers.keySet().removeIf(key -> key.equals(headerName)); + assertTrue(headers.isEmpty()); + headers.add(headerName, headerValue); + assertEquals(headerValue, headers.get(headerName).get(0)); + } + + @Test + public void removalFromEntrySetRemovesEntryFromUnderlyingMap() { + String headerName = "MyHeader"; + String headerValue = "value"; + + assertTrue(headers.isEmpty()); + headers.add(headerName, headerValue); + assertTrue(headers.containsKey(headerName)); + headers.entrySet().removeIf(entry -> entry.getKey().equals(headerName)); + assertTrue(headers.isEmpty()); + headers.add(headerName, headerValue); + assertEquals(headerValue, headers.get(headerName).get(0)); + } + + @Test + public void readOnlyHttpHeadersRetainEntrySetOrder() { + headers.add("aardvark", "enigma"); + headers.add("beaver", "enigma"); + headers.add("cat", "enigma"); + headers.add("dog", "enigma"); + headers.add("elephant", "enigma"); + + String[] expectedKeys = new String[] { "aardvark", "beaver", "cat", "dog", "elephant" }; + + assertArrayEquals(expectedKeys, headers.entrySet().stream().map(Entry::getKey).toArray()); + + HttpHeaders readOnlyHttpHeaders = HttpHeaders.readOnlyHttpHeaders(headers); + assertArrayEquals(expectedKeys, readOnlyHttpHeaders.entrySet().stream().map(Entry::getKey).toArray()); + } + + @Test // gh-25034 + public void equalsUnwrapsHttpHeaders() { + HttpHeaders headers1 = new HttpHeaders(); + HttpHeaders headers2 = new HttpHeaders(new HttpHeaders(headers1)); + + assertEquals(headers1, headers2); + assertEquals(headers2, headers1); + } +} diff --git a/spring-web/src/test/java/org/springframework/http/HttpRangeTests.java b/spring-web/src/test/java/org/springframework/http/HttpRangeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..163faa94eeeafaee0be8b961de9d265b9a8bc5ce --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/HttpRangeTests.java @@ -0,0 +1,195 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import org.junit.Test; + +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.support.ResourceRegion; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Unit tests for {@link HttpRange}. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + */ +public class HttpRangeTests { + + @Test(expected = IllegalArgumentException.class) + public void invalidFirstPosition() { + HttpRange.createByteRange(-1); + } + + @Test(expected = IllegalArgumentException.class) + public void invalidLastLessThanFirst() { + HttpRange.createByteRange(10, 9); + } + + @Test(expected = IllegalArgumentException.class) + public void invalidSuffixLength() { + HttpRange.createSuffixRange(-1); + } + + @Test + public void byteRange() { + HttpRange range = HttpRange.createByteRange(0, 499); + assertEquals(0, range.getRangeStart(1000)); + assertEquals(499, range.getRangeEnd(1000)); + } + + @Test + public void byteRangeWithoutLastPosition() { + HttpRange range = HttpRange.createByteRange(9500); + assertEquals(9500, range.getRangeStart(10000)); + assertEquals(9999, range.getRangeEnd(10000)); + } + + @Test + public void byteRangeOfZeroLength() { + HttpRange range = HttpRange.createByteRange(9500, 9500); + assertEquals(9500, range.getRangeStart(10000)); + assertEquals(9500, range.getRangeEnd(10000)); + } + + @Test + public void suffixRange() { + HttpRange range = HttpRange.createSuffixRange(500); + assertEquals(500, range.getRangeStart(1000)); + assertEquals(999, range.getRangeEnd(1000)); + } + + @Test + public void suffixRangeShorterThanRepresentation() { + HttpRange range = HttpRange.createSuffixRange(500); + assertEquals(0, range.getRangeStart(350)); + assertEquals(349, range.getRangeEnd(350)); + } + + @Test + public void parseRanges() { + List ranges = HttpRange.parseRanges("bytes=0-0,500-,-1"); + assertEquals(3, ranges.size()); + assertEquals(0, ranges.get(0).getRangeStart(1000)); + assertEquals(0, ranges.get(0).getRangeEnd(1000)); + assertEquals(500, ranges.get(1).getRangeStart(1000)); + assertEquals(999, ranges.get(1).getRangeEnd(1000)); + assertEquals(999, ranges.get(2).getRangeStart(1000)); + assertEquals(999, ranges.get(2).getRangeEnd(1000)); + } + + @Test + public void parseRangesValidations() { + + // 1. At limit.. + StringBuilder sb = new StringBuilder("bytes=0-0"); + for (int i=0; i < 99; i++) { + sb.append(",").append(i).append("-").append(i + 1); + } + List ranges = HttpRange.parseRanges(sb.toString()); + assertEquals(100, ranges.size()); + + // 2. Above limit.. + sb = new StringBuilder("bytes=0-0"); + for (int i=0; i < 100; i++) { + sb.append(",").append(i).append("-").append(i + 1); + } + try { + HttpRange.parseRanges(sb.toString()); + fail(); + } + catch (IllegalArgumentException ex) { + // Expected + } + } + + @Test + public void rangeToString() { + List ranges = new ArrayList<>(); + ranges.add(HttpRange.createByteRange(0, 499)); + ranges.add(HttpRange.createByteRange(9500)); + ranges.add(HttpRange.createSuffixRange(500)); + assertEquals("Invalid Range header", "bytes=0-499, 9500-, -500", HttpRange.toString(ranges)); + } + + @Test + public void toResourceRegion() { + byte[] bytes = "Spring Framework".getBytes(StandardCharsets.UTF_8); + ByteArrayResource resource = new ByteArrayResource(bytes); + HttpRange range = HttpRange.createByteRange(0, 5); + ResourceRegion region = range.toResourceRegion(resource); + assertEquals(resource, region.getResource()); + assertEquals(0L, region.getPosition()); + assertEquals(6L, region.getCount()); + } + + @Test(expected = IllegalArgumentException.class) + public void toResourceRegionInputStreamResource() { + InputStreamResource resource = mock(InputStreamResource.class); + HttpRange range = HttpRange.createByteRange(0, 9); + range.toResourceRegion(resource); + } + + @Test(expected = IllegalArgumentException.class) + public void toResourceRegionIllegalLength() { + ByteArrayResource resource = mock(ByteArrayResource.class); + given(resource.contentLength()).willReturn(-1L); + HttpRange range = HttpRange.createByteRange(0, 9); + range.toResourceRegion(resource); + } + + @Test(expected = IllegalArgumentException.class) + @SuppressWarnings("unchecked") + public void toResourceRegionExceptionLength() throws IOException { + InputStreamResource resource = mock(InputStreamResource.class); + given(resource.contentLength()).willThrow(IOException.class); + HttpRange range = HttpRange.createByteRange(0, 9); + range.toResourceRegion(resource); + } + + @Test + public void toResourceRegionsValidations() { + byte[] bytes = "12345".getBytes(StandardCharsets.UTF_8); + ByteArrayResource resource = new ByteArrayResource(bytes); + + // 1. Below length + List ranges = HttpRange.parseRanges("bytes=0-1,2-3"); + List regions = HttpRange.toResourceRegions(ranges, resource); + assertEquals(2, regions.size()); + + // 2. At length + ranges = HttpRange.parseRanges("bytes=0-1,2-4"); + try { + HttpRange.toResourceRegions(ranges, resource); + fail(); + } + catch (IllegalArgumentException ex) { + // Expected.. + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/HttpStatusTests.java b/spring-web/src/test/java/org/springframework/http/HttpStatusTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e64413442acb90980fc3f6b6e94f84c62314a0cc --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/HttpStatusTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** @author Arjen Poutsma */ +public class HttpStatusTests { + + private Map statusCodes = new LinkedHashMap<>(); + + @Before + public void createStatusCodes() { + statusCodes.put(100, "CONTINUE"); + statusCodes.put(101, "SWITCHING_PROTOCOLS"); + statusCodes.put(102, "PROCESSING"); + statusCodes.put(103, "CHECKPOINT"); + + statusCodes.put(200, "OK"); + statusCodes.put(201, "CREATED"); + statusCodes.put(202, "ACCEPTED"); + statusCodes.put(203, "NON_AUTHORITATIVE_INFORMATION"); + statusCodes.put(204, "NO_CONTENT"); + statusCodes.put(205, "RESET_CONTENT"); + statusCodes.put(206, "PARTIAL_CONTENT"); + statusCodes.put(207, "MULTI_STATUS"); + statusCodes.put(208, "ALREADY_REPORTED"); + statusCodes.put(226, "IM_USED"); + + statusCodes.put(300, "MULTIPLE_CHOICES"); + statusCodes.put(301, "MOVED_PERMANENTLY"); + statusCodes.put(302, "FOUND"); + statusCodes.put(303, "SEE_OTHER"); + statusCodes.put(304, "NOT_MODIFIED"); + statusCodes.put(305, "USE_PROXY"); + statusCodes.put(307, "TEMPORARY_REDIRECT"); + statusCodes.put(308, "PERMANENT_REDIRECT"); + + statusCodes.put(400, "BAD_REQUEST"); + statusCodes.put(401, "UNAUTHORIZED"); + statusCodes.put(402, "PAYMENT_REQUIRED"); + statusCodes.put(403, "FORBIDDEN"); + statusCodes.put(404, "NOT_FOUND"); + statusCodes.put(405, "METHOD_NOT_ALLOWED"); + statusCodes.put(406, "NOT_ACCEPTABLE"); + statusCodes.put(407, "PROXY_AUTHENTICATION_REQUIRED"); + statusCodes.put(408, "REQUEST_TIMEOUT"); + statusCodes.put(409, "CONFLICT"); + statusCodes.put(410, "GONE"); + statusCodes.put(411, "LENGTH_REQUIRED"); + statusCodes.put(412, "PRECONDITION_FAILED"); + statusCodes.put(413, "PAYLOAD_TOO_LARGE"); + statusCodes.put(414, "URI_TOO_LONG"); + statusCodes.put(415, "UNSUPPORTED_MEDIA_TYPE"); + statusCodes.put(416, "REQUESTED_RANGE_NOT_SATISFIABLE"); + statusCodes.put(417, "EXPECTATION_FAILED"); + statusCodes.put(418, "I_AM_A_TEAPOT"); + statusCodes.put(419, "INSUFFICIENT_SPACE_ON_RESOURCE"); + statusCodes.put(420, "METHOD_FAILURE"); + statusCodes.put(421, "DESTINATION_LOCKED"); + statusCodes.put(422, "UNPROCESSABLE_ENTITY"); + statusCodes.put(423, "LOCKED"); + statusCodes.put(424, "FAILED_DEPENDENCY"); + statusCodes.put(426, "UPGRADE_REQUIRED"); + statusCodes.put(428, "PRECONDITION_REQUIRED"); + statusCodes.put(429, "TOO_MANY_REQUESTS"); + statusCodes.put(431, "REQUEST_HEADER_FIELDS_TOO_LARGE"); + statusCodes.put(451, "UNAVAILABLE_FOR_LEGAL_REASONS"); + + statusCodes.put(500, "INTERNAL_SERVER_ERROR"); + statusCodes.put(501, "NOT_IMPLEMENTED"); + statusCodes.put(502, "BAD_GATEWAY"); + statusCodes.put(503, "SERVICE_UNAVAILABLE"); + statusCodes.put(504, "GATEWAY_TIMEOUT"); + statusCodes.put(505, "HTTP_VERSION_NOT_SUPPORTED"); + statusCodes.put(506, "VARIANT_ALSO_NEGOTIATES"); + statusCodes.put(507, "INSUFFICIENT_STORAGE"); + statusCodes.put(508, "LOOP_DETECTED"); + statusCodes.put(509, "BANDWIDTH_LIMIT_EXCEEDED"); + statusCodes.put(510, "NOT_EXTENDED"); + statusCodes.put(511, "NETWORK_AUTHENTICATION_REQUIRED"); + } + + @Test + public void fromMapToEnum() { + for (Map.Entry entry : statusCodes.entrySet()) { + int value = entry.getKey(); + HttpStatus status = HttpStatus.valueOf(value); + assertEquals("Invalid value", value, status.value()); + assertEquals("Invalid name for [" + value + "]", entry.getValue(), status.name()); + } + } + + @Test + public void fromEnumToMap() { + + for (HttpStatus status : HttpStatus.values()) { + int value = status.value(); + if (value == 302 || value == 413 || value == 414) { + continue; + } + assertTrue("Map has no value for [" + value + "]", statusCodes.containsKey(value)); + assertEquals("Invalid name for [" + value + "]", statusCodes.get(value), status.name()); + } + } +} diff --git a/spring-web/src/test/java/org/springframework/http/MediaTypeFactoryTests.java b/spring-web/src/test/java/org/springframework/http/MediaTypeFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..56eae3bdfa7d746c8b2e28d61a8f7c58cb359e0a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/MediaTypeFactoryTests.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import org.junit.Test; + +import org.springframework.core.io.Resource; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class MediaTypeFactoryTests { + + @Test + public void getMediaType() { + assertEquals(MediaType.APPLICATION_XML, MediaTypeFactory.getMediaType("file.xml").get()); + assertEquals(MediaType.parseMediaType("application/javascript"), MediaTypeFactory.getMediaType("file.js").get()); + assertEquals(MediaType.parseMediaType("text/css"), MediaTypeFactory.getMediaType("file.css").get()); + assertFalse(MediaTypeFactory.getMediaType("file.foobar").isPresent()); + } + + @Test + public void nullParameter() { + assertFalse(MediaTypeFactory.getMediaType((String) null).isPresent()); + assertFalse(MediaTypeFactory.getMediaType((Resource) null).isPresent()); + assertTrue(MediaTypeFactory.getMediaTypes(null).isEmpty()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/MediaTypeTests.java b/spring-web/src/test/java/org/springframework/http/MediaTypeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..142b01440f35dca22b4ec1cd70b478325f1e7937 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/MediaTypeTests.java @@ -0,0 +1,446 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Random; + +import org.junit.Test; + +import org.springframework.core.convert.ConversionService; +import org.springframework.core.convert.support.DefaultConversionService; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Juergen Hoeller + */ +public class MediaTypeTests { + + @Test + public void testToString() throws Exception { + MediaType mediaType = new MediaType("text", "plain", 0.7); + String result = mediaType.toString(); + assertEquals("Invalid toString() returned", "text/plain;q=0.7", result); + } + + @Test(expected = IllegalArgumentException.class) + public void slashInType() { + new MediaType("text/plain"); + } + + @Test(expected = IllegalArgumentException.class) + public void slashInSubtype() { + new MediaType("text", "/"); + } + + @Test + public void getDefaultQualityValue() { + MediaType mediaType = new MediaType("text", "plain"); + assertEquals("Invalid quality value", 1, mediaType.getQualityValue(), 0D); + } + + @Test + public void parseMediaType() throws Exception { + String s = "audio/*; q=0.2"; + MediaType mediaType = MediaType.parseMediaType(s); + assertEquals("Invalid type", "audio", mediaType.getType()); + assertEquals("Invalid subtype", "*", mediaType.getSubtype()); + assertEquals("Invalid quality factor", 0.2D, mediaType.getQualityValue(), 0D); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeNoSubtype() { + MediaType.parseMediaType("audio"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeNoSubtypeSlash() { + MediaType.parseMediaType("audio/"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeTypeRange() { + MediaType.parseMediaType("*/json"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeIllegalType() { + MediaType.parseMediaType("audio(/basic"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeIllegalSubtype() { + MediaType.parseMediaType("audio/basic)"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeEmptyParameterAttribute() { + MediaType.parseMediaType("audio/*;=value"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeEmptyParameterValue() { + MediaType.parseMediaType("audio/*;attr="); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeIllegalParameterAttribute() { + MediaType.parseMediaType("audio/*;attr<=value"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeIllegalParameterValue() { + MediaType.parseMediaType("audio/*;attr=v>alue"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeIllegalQualityFactor() { + MediaType.parseMediaType("audio/basic;q=1.1"); + } + + @Test(expected = InvalidMediaTypeException.class) + public void parseMediaTypeIllegalCharset() { + MediaType.parseMediaType("text/html; charset=foo-bar"); + } + + @Test + public void parseURLConnectionMediaType() throws Exception { + String s = "*; q=.2"; + MediaType mediaType = MediaType.parseMediaType(s); + assertEquals("Invalid type", "*", mediaType.getType()); + assertEquals("Invalid subtype", "*", mediaType.getSubtype()); + assertEquals("Invalid quality factor", 0.2D, mediaType.getQualityValue(), 0D); + } + + @Test + public void parseMediaTypes() throws Exception { + String s = "text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"; + List mediaTypes = MediaType.parseMediaTypes(s); + assertNotNull("No media types returned", mediaTypes); + assertEquals("Invalid amount of media types", 4, mediaTypes.size()); + + mediaTypes = MediaType.parseMediaTypes(""); + assertNotNull("No media types returned", mediaTypes); + assertEquals("Invalid amount of media types", 0, mediaTypes.size()); + } + + @Test // gh-23241 + public void parseMediaTypesWithTrailingComma() { + List mediaTypes = MediaType.parseMediaTypes("text/plain, text/html, "); + assertNotNull("No media types returned", mediaTypes); + assertEquals("Incorrect number of media types", 2, mediaTypes.size()); + } + + @Test + public void compareTo() { + MediaType audioBasic = new MediaType("audio", "basic"); + MediaType audio = new MediaType("audio"); + MediaType audioWave = new MediaType("audio", "wave"); + MediaType audioBasicLevel = new MediaType("audio", "basic", Collections.singletonMap("level", "1")); + MediaType audioBasic07 = new MediaType("audio", "basic", 0.7); + + // equal + assertEquals("Invalid comparison result", 0, audioBasic.compareTo(audioBasic)); + assertEquals("Invalid comparison result", 0, audio.compareTo(audio)); + assertEquals("Invalid comparison result", 0, audioBasicLevel.compareTo(audioBasicLevel)); + + assertTrue("Invalid comparison result", audioBasicLevel.compareTo(audio) > 0); + + List expected = new ArrayList<>(); + expected.add(audio); + expected.add(audioBasic); + expected.add(audioBasicLevel); + expected.add(audioBasic07); + expected.add(audioWave); + + List result = new ArrayList<>(expected); + Random rnd = new Random(); + // shuffle & sort 10 times + for (int i = 0; i < 10; i++) { + Collections.shuffle(result, rnd); + Collections.sort(result); + + for (int j = 0; j < result.size(); j++) { + assertSame("Invalid media type at " + j + ", run " + i, expected.get(j), result.get(j)); + } + } + } + + @Test + public void compareToConsistentWithEquals() { + MediaType m1 = MediaType.parseMediaType("text/html; q=0.7; charset=iso-8859-1"); + MediaType m2 = MediaType.parseMediaType("text/html; charset=iso-8859-1; q=0.7"); + + assertEquals("Media types not equal", m1, m2); + assertEquals("compareTo() not consistent with equals", 0, m1.compareTo(m2)); + assertEquals("compareTo() not consistent with equals", 0, m2.compareTo(m1)); + + m1 = MediaType.parseMediaType("text/html; q=0.7; charset=iso-8859-1"); + m2 = MediaType.parseMediaType("text/html; Q=0.7; charset=iso-8859-1"); + assertEquals("Media types not equal", m1, m2); + assertEquals("compareTo() not consistent with equals", 0, m1.compareTo(m2)); + assertEquals("compareTo() not consistent with equals", 0, m2.compareTo(m1)); + } + + @Test + public void compareToCaseSensitivity() { + MediaType m1 = new MediaType("audio", "basic"); + MediaType m2 = new MediaType("Audio", "Basic"); + assertEquals("Invalid comparison result", 0, m1.compareTo(m2)); + assertEquals("Invalid comparison result", 0, m2.compareTo(m1)); + + m1 = new MediaType("audio", "basic", Collections.singletonMap("foo", "bar")); + m2 = new MediaType("audio", "basic", Collections.singletonMap("Foo", "bar")); + assertEquals("Invalid comparison result", 0, m1.compareTo(m2)); + assertEquals("Invalid comparison result", 0, m2.compareTo(m1)); + + m1 = new MediaType("audio", "basic", Collections.singletonMap("foo", "bar")); + m2 = new MediaType("audio", "basic", Collections.singletonMap("foo", "Bar")); + assertTrue("Invalid comparison result", m1.compareTo(m2) != 0); + assertTrue("Invalid comparison result", m2.compareTo(m1) != 0); + + + } + + @Test + public void specificityComparator() throws Exception { + MediaType audioBasic = new MediaType("audio", "basic"); + MediaType audioWave = new MediaType("audio", "wave"); + MediaType audio = new MediaType("audio"); + MediaType audio03 = new MediaType("audio", "*", 0.3); + MediaType audio07 = new MediaType("audio", "*", 0.7); + MediaType audioBasicLevel = new MediaType("audio", "basic", Collections.singletonMap("level", "1")); + MediaType textHtml = new MediaType("text", "html"); + MediaType allXml = new MediaType("application", "*+xml"); + MediaType all = MediaType.ALL; + + Comparator comp = MediaType.SPECIFICITY_COMPARATOR; + + // equal + assertEquals("Invalid comparison result", 0, comp.compare(audioBasic,audioBasic)); + assertEquals("Invalid comparison result", 0, comp.compare(audio, audio)); + assertEquals("Invalid comparison result", 0, comp.compare(audio07, audio07)); + assertEquals("Invalid comparison result", 0, comp.compare(audio03, audio03)); + assertEquals("Invalid comparison result", 0, comp.compare(audioBasicLevel, audioBasicLevel)); + + // specific to unspecific + assertTrue("Invalid comparison result", comp.compare(audioBasic, audio) < 0); + assertTrue("Invalid comparison result", comp.compare(audioBasic, all) < 0); + assertTrue("Invalid comparison result", comp.compare(audio, all) < 0); + assertTrue("Invalid comparison result", comp.compare(MediaType.APPLICATION_XHTML_XML, allXml) < 0); + + // unspecific to specific + assertTrue("Invalid comparison result", comp.compare(audio, audioBasic) > 0); + assertTrue("Invalid comparison result", comp.compare(allXml, MediaType.APPLICATION_XHTML_XML) > 0); + assertTrue("Invalid comparison result", comp.compare(all, audioBasic) > 0); + assertTrue("Invalid comparison result", comp.compare(all, audio) > 0); + + // qualifiers + assertTrue("Invalid comparison result", comp.compare(audio, audio07) < 0); + assertTrue("Invalid comparison result", comp.compare(audio07, audio) > 0); + assertTrue("Invalid comparison result", comp.compare(audio07, audio03) < 0); + assertTrue("Invalid comparison result", comp.compare(audio03, audio07) > 0); + assertTrue("Invalid comparison result", comp.compare(audio03, all) < 0); + assertTrue("Invalid comparison result", comp.compare(all, audio03) > 0); + + // other parameters + assertTrue("Invalid comparison result", comp.compare(audioBasic, audioBasicLevel) > 0); + assertTrue("Invalid comparison result", comp.compare(audioBasicLevel, audioBasic) < 0); + + // different types + assertEquals("Invalid comparison result", 0, comp.compare(audioBasic, textHtml)); + assertEquals("Invalid comparison result", 0, comp.compare(textHtml, audioBasic)); + + // different subtypes + assertEquals("Invalid comparison result", 0, comp.compare(audioBasic, audioWave)); + assertEquals("Invalid comparison result", 0, comp.compare(audioWave, audioBasic)); + } + + @Test + public void sortBySpecificityRelated() { + MediaType audioBasic = new MediaType("audio", "basic"); + MediaType audio = new MediaType("audio"); + MediaType audio03 = new MediaType("audio", "*", 0.3); + MediaType audio07 = new MediaType("audio", "*", 0.7); + MediaType audioBasicLevel = new MediaType("audio", "basic", Collections.singletonMap("level", "1")); + MediaType all = MediaType.ALL; + + List expected = new ArrayList<>(); + expected.add(audioBasicLevel); + expected.add(audioBasic); + expected.add(audio); + expected.add(audio07); + expected.add(audio03); + expected.add(all); + + List result = new ArrayList<>(expected); + Random rnd = new Random(); + // shuffle & sort 10 times + for (int i = 0; i < 10; i++) { + Collections.shuffle(result, rnd); + MediaType.sortBySpecificity(result); + + for (int j = 0; j < result.size(); j++) { + assertSame("Invalid media type at " + j, expected.get(j), result.get(j)); + } + } + } + + @Test + public void sortBySpecificityUnrelated() { + MediaType audioBasic = new MediaType("audio", "basic"); + MediaType audioWave = new MediaType("audio", "wave"); + MediaType textHtml = new MediaType("text", "html"); + + List expected = new ArrayList<>(); + expected.add(textHtml); + expected.add(audioBasic); + expected.add(audioWave); + + List result = new ArrayList<>(expected); + MediaType.sortBySpecificity(result); + + for (int i = 0; i < result.size(); i++) { + assertSame("Invalid media type at " + i, expected.get(i), result.get(i)); + } + + } + + @Test + public void qualityComparator() throws Exception { + MediaType audioBasic = new MediaType("audio", "basic"); + MediaType audioWave = new MediaType("audio", "wave"); + MediaType audio = new MediaType("audio"); + MediaType audio03 = new MediaType("audio", "*", 0.3); + MediaType audio07 = new MediaType("audio", "*", 0.7); + MediaType audioBasicLevel = new MediaType("audio", "basic", Collections.singletonMap("level", "1")); + MediaType textHtml = new MediaType("text", "html"); + MediaType allXml = new MediaType("application", "*+xml"); + MediaType all = MediaType.ALL; + + Comparator comp = MediaType.QUALITY_VALUE_COMPARATOR; + + // equal + assertEquals("Invalid comparison result", 0, comp.compare(audioBasic,audioBasic)); + assertEquals("Invalid comparison result", 0, comp.compare(audio, audio)); + assertEquals("Invalid comparison result", 0, comp.compare(audio07, audio07)); + assertEquals("Invalid comparison result", 0, comp.compare(audio03, audio03)); + assertEquals("Invalid comparison result", 0, comp.compare(audioBasicLevel, audioBasicLevel)); + + // specific to unspecific + assertTrue("Invalid comparison result", comp.compare(audioBasic, audio) < 0); + assertTrue("Invalid comparison result", comp.compare(audioBasic, all) < 0); + assertTrue("Invalid comparison result", comp.compare(audio, all) < 0); + assertTrue("Invalid comparison result", comp.compare(MediaType.APPLICATION_XHTML_XML, allXml) < 0); + + // unspecific to specific + assertTrue("Invalid comparison result", comp.compare(audio, audioBasic) > 0); + assertTrue("Invalid comparison result", comp.compare(all, audioBasic) > 0); + assertTrue("Invalid comparison result", comp.compare(all, audio) > 0); + assertTrue("Invalid comparison result", comp.compare(allXml, MediaType.APPLICATION_XHTML_XML) > 0); + + // qualifiers + assertTrue("Invalid comparison result", comp.compare(audio, audio07) < 0); + assertTrue("Invalid comparison result", comp.compare(audio07, audio) > 0); + assertTrue("Invalid comparison result", comp.compare(audio07, audio03) < 0); + assertTrue("Invalid comparison result", comp.compare(audio03, audio07) > 0); + assertTrue("Invalid comparison result", comp.compare(audio03, all) > 0); + assertTrue("Invalid comparison result", comp.compare(all, audio03) < 0); + + // other parameters + assertTrue("Invalid comparison result", comp.compare(audioBasic, audioBasicLevel) > 0); + assertTrue("Invalid comparison result", comp.compare(audioBasicLevel, audioBasic) < 0); + + // different types + assertEquals("Invalid comparison result", 0, comp.compare(audioBasic, textHtml)); + assertEquals("Invalid comparison result", 0, comp.compare(textHtml, audioBasic)); + + // different subtypes + assertEquals("Invalid comparison result", 0, comp.compare(audioBasic, audioWave)); + assertEquals("Invalid comparison result", 0, comp.compare(audioWave, audioBasic)); + } + + @Test + public void sortByQualityRelated() { + MediaType audioBasic = new MediaType("audio", "basic"); + MediaType audio = new MediaType("audio"); + MediaType audio03 = new MediaType("audio", "*", 0.3); + MediaType audio07 = new MediaType("audio", "*", 0.7); + MediaType audioBasicLevel = new MediaType("audio", "basic", Collections.singletonMap("level", "1")); + MediaType all = MediaType.ALL; + + List expected = new ArrayList<>(); + expected.add(audioBasicLevel); + expected.add(audioBasic); + expected.add(audio); + expected.add(all); + expected.add(audio07); + expected.add(audio03); + + List result = new ArrayList<>(expected); + Random rnd = new Random(); + // shuffle & sort 10 times + for (int i = 0; i < 10; i++) { + Collections.shuffle(result, rnd); + MediaType.sortByQualityValue(result); + + for (int j = 0; j < result.size(); j++) { + assertSame("Invalid media type at " + j, expected.get(j), result.get(j)); + } + } + } + + @Test + public void sortByQualityUnrelated() { + MediaType audioBasic = new MediaType("audio", "basic"); + MediaType audioWave = new MediaType("audio", "wave"); + MediaType textHtml = new MediaType("text", "html"); + + List expected = new ArrayList<>(); + expected.add(textHtml); + expected.add(audioBasic); + expected.add(audioWave); + + List result = new ArrayList<>(expected); + MediaType.sortBySpecificity(result); + + for (int i = 0; i < result.size(); i++) { + assertSame("Invalid media type at " + i, expected.get(i), result.get(i)); + } + } + + @Test + public void testWithConversionService() { + ConversionService conversionService = new DefaultConversionService(); + assertTrue(conversionService.canConvert(String.class, MediaType.class)); + MediaType mediaType = MediaType.parseMediaType("application/xml"); + assertEquals(mediaType, conversionService.convert("application/xml", MediaType.class)); + } + + @Test + public void isConcrete() { + assertTrue("text/plain not concrete", MediaType.TEXT_PLAIN.isConcrete()); + assertFalse("*/* concrete", MediaType.ALL.isConcrete()); + assertFalse("text/* concrete", new MediaType("text", "*").isConcrete()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/MockHttpInputMessage.java b/spring-web/src/test/java/org/springframework/http/MockHttpInputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..784ef50cd2457d05a1b5c9fc25647c235e599ee1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/MockHttpInputMessage.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +import org.springframework.util.Assert; + +/** + * Mock implementation of {@link HttpInputMessage}. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class MockHttpInputMessage implements HttpInputMessage { + + private final HttpHeaders headers = new HttpHeaders(); + + private final InputStream body; + + + public MockHttpInputMessage(byte[] contents) { + Assert.notNull(contents, "'contents' must not be null"); + this.body = new ByteArrayInputStream(contents); + } + + public MockHttpInputMessage(InputStream body) { + Assert.notNull(body, "'body' must not be null"); + this.body = body; + } + + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + @Override + public InputStream getBody() throws IOException { + return body; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/MockHttpOutputMessage.java b/spring-web/src/test/java/org/springframework/http/MockHttpOutputMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..a30b8d3fcc651018dffe9831385767dd0e7bbe05 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/MockHttpOutputMessage.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.Charset; + +import static org.mockito.Mockito.spy; + +/** + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class MockHttpOutputMessage implements HttpOutputMessage { + + private final HttpHeaders headers = new HttpHeaders(); + + private final ByteArrayOutputStream body = spy(new ByteArrayOutputStream()); + + private boolean headersWritten = false; + + private final HttpHeaders writtenHeaders = new HttpHeaders(); + + + @Override + public HttpHeaders getHeaders() { + return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + /** + * Return a copy of the actual headers written at the time of the call to + * getResponseBody, i.e. ignoring any further changes that may have been made to + * the underlying headers, e.g. via a previously obtained instance. + */ + public HttpHeaders getWrittenHeaders() { + return writtenHeaders; + } + + @Override + public OutputStream getBody() throws IOException { + writeHeaders(); + return body; + } + + public byte[] getBodyAsBytes() { + writeHeaders(); + return body.toByteArray(); + } + + public String getBodyAsString(Charset charset) { + byte[] bytes = getBodyAsBytes(); + return new String(bytes, charset); + } + + private void writeHeaders() { + if (this.headersWritten) { + return; + } + this.headersWritten = true; + this.writtenHeaders.putAll(this.headers); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/RequestEntityTests.java b/spring-web/src/test/java/org/springframework/http/RequestEntityTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3a2aa2b65c8c8eea023b63ecc97da9d06b3bec8c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/RequestEntityTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.web.util.UriTemplate; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link org.springframework.http.RequestEntity}. + * + * @author Arjen Poutsma + */ +public class RequestEntityTests { + + @Test + public void normal() throws URISyntaxException { + String headerName = "My-Custom-Header"; + String headerValue = "HeaderValue"; + URI url = new URI("https://example.com"); + Integer entity = 42; + + RequestEntity requestEntity = + RequestEntity.method(HttpMethod.GET, url) + .header(headerName, headerValue).body(entity); + + assertNotNull(requestEntity); + assertEquals(HttpMethod.GET, requestEntity.getMethod()); + assertTrue(requestEntity.getHeaders().containsKey(headerName)); + assertEquals(headerValue, requestEntity.getHeaders().getFirst(headerName)); + assertEquals(entity, requestEntity.getBody()); + } + + @Test + public void uriVariablesExpansion() throws URISyntaxException { + URI uri = new UriTemplate("https://example.com/{foo}").expand("bar"); + RequestEntity.get(uri).accept(MediaType.TEXT_PLAIN).build(); + + String url = "https://www.{host}.com/{path}"; + String host = "example"; + String path = "foo/bar"; + URI expected = new URI("https://www.example.com/foo/bar"); + + uri = new UriTemplate(url).expand(host, path); + RequestEntity entity = RequestEntity.get(uri).build(); + assertEquals(expected, entity.getUrl()); + + Map uriVariables = new HashMap<>(2); + uriVariables.put("host", host); + uriVariables.put("path", path); + + uri = new UriTemplate(url).expand(uriVariables); + entity = RequestEntity.get(uri).build(); + assertEquals(expected, entity.getUrl()); + } + + @Test + public void get() { + RequestEntity requestEntity = RequestEntity.get(URI.create("https://example.com")).accept( + MediaType.IMAGE_GIF, MediaType.IMAGE_JPEG, MediaType.IMAGE_PNG).build(); + + assertNotNull(requestEntity); + assertEquals(HttpMethod.GET, requestEntity.getMethod()); + assertTrue(requestEntity.getHeaders().containsKey("Accept")); + assertEquals("image/gif, image/jpeg, image/png", requestEntity.getHeaders().getFirst("Accept")); + assertNull(requestEntity.getBody()); + } + + @Test + public void headers() throws URISyntaxException { + MediaType accept = MediaType.TEXT_PLAIN; + long ifModifiedSince = 12345L; + String ifNoneMatch = "\"foo\""; + long contentLength = 67890; + MediaType contentType = MediaType.TEXT_PLAIN; + + RequestEntity responseEntity = RequestEntity.post(new URI("https://example.com")). + accept(accept). + acceptCharset(StandardCharsets.UTF_8). + ifModifiedSince(ifModifiedSince). + ifNoneMatch(ifNoneMatch). + contentLength(contentLength). + contentType(contentType). + build(); + + assertNotNull(responseEntity); + assertEquals(HttpMethod.POST, responseEntity.getMethod()); + assertEquals(new URI("https://example.com"), responseEntity.getUrl()); + HttpHeaders responseHeaders = responseEntity.getHeaders(); + + assertEquals("text/plain", responseHeaders.getFirst("Accept")); + assertEquals("utf-8", responseHeaders.getFirst("Accept-Charset")); + assertEquals("Thu, 01 Jan 1970 00:00:12 GMT", responseHeaders.getFirst("If-Modified-Since")); + assertEquals(ifNoneMatch, responseHeaders.getFirst("If-None-Match")); + assertEquals(String.valueOf(contentLength), responseHeaders.getFirst("Content-Length")); + assertEquals(contentType.toString(), responseHeaders.getFirst("Content-Type")); + + assertNull(responseEntity.getBody()); + } + + @Test + public void methods() throws URISyntaxException { + URI url = new URI("https://example.com"); + + RequestEntity entity = RequestEntity.get(url).build(); + assertEquals(HttpMethod.GET, entity.getMethod()); + + entity = RequestEntity.post(url).build(); + assertEquals(HttpMethod.POST, entity.getMethod()); + + entity = RequestEntity.head(url).build(); + assertEquals(HttpMethod.HEAD, entity.getMethod()); + + entity = RequestEntity.options(url).build(); + assertEquals(HttpMethod.OPTIONS, entity.getMethod()); + + entity = RequestEntity.put(url).build(); + assertEquals(HttpMethod.PUT, entity.getMethod()); + + entity = RequestEntity.patch(url).build(); + assertEquals(HttpMethod.PATCH, entity.getMethod()); + + entity = RequestEntity.delete(url).build(); + assertEquals(HttpMethod.DELETE, entity.getMethod()); + + } + + @Test // SPR-13154 + public void types() throws URISyntaxException { + URI url = new URI("https://example.com"); + List body = Arrays.asList("foo", "bar"); + ParameterizedTypeReference typeReference = new ParameterizedTypeReference>() {}; + + RequestEntity entity = RequestEntity.post(url).body(body, typeReference.getType()); + assertEquals(typeReference.getType(), entity.getType()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/ResponseCookieTests.java b/spring-web/src/test/java/org/springframework/http/ResponseCookieTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a1a021f6cb2c94d6337300ba6ab57c18b2386e67 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/ResponseCookieTests.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.Arrays; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +/** + * Unit tests for {@link ResponseCookie}. + * @author Rossen Stoyanchev + */ +public class ResponseCookieTests { + + @Test + public void basic() { + + assertEquals("id=", ResponseCookie.from("id", null).build().toString()); + assertEquals("id=1fWa", ResponseCookie.from("id", "1fWa").build().toString()); + + assertEquals( + "id=1fWa; Path=/path; Domain=abc; " + + "Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; " + + "Secure; HttpOnly; SameSite=None", + ResponseCookie.from("id", "1fWa") + .domain("abc").path("/path").maxAge(0).httpOnly(true).secure(true).sameSite("None") + .build().toString()); + } + + @Test + public void nameChecks() { + + Arrays.asList("id", "i.d.", "i-d", "+id", "i*d", "i$d", "#id") + .forEach(name -> { + ResponseCookie.from(name, "value").build(); + // no exception.. + }); + + Arrays.asList("\"id\"", "id\t", "i\td", "i d", "i;d", "{id}", "[id]", "\"", "id\u0091") + .forEach(name -> { + try { + ResponseCookie.from(name, "value").build(); + } + catch (IllegalArgumentException ex) { + assertThat(ex.getMessage(), Matchers.containsString("RFC2616 token")); + } + }); + } + + @Test + public void valueChecks() { + + Arrays.asList("1fWa", "", null, "1f=Wa", "1f-Wa", "1f/Wa", "1.f.W.a.") + .forEach(value -> { + ResponseCookie.from("id", value).build(); + // no exception.. + }); + + Arrays.asList("1f\tWa", "\t", "1f Wa", "1f;Wa", "\"1fWa", "1f\\Wa", "1f\"Wa", "\"", "1fWa\u0005", "1f\u0091Wa") + .forEach(value -> { + try { + ResponseCookie.from("id", value).build(); + } + catch (IllegalArgumentException ex) { + assertThat(ex.getMessage(), Matchers.containsString("RFC2616 cookie value")); + } + }); + } + + @Test + public void domainChecks() { + + Arrays.asList("abc", "abc.org", "abc-def.org", "abc3.org", ".abc.org") + .forEach(domain -> ResponseCookie.from("n", "v").domain(domain).build()); + + Arrays.asList("-abc.org", "abc.org.", "abc.org-", "-abc.org", "abc.org-") + .forEach(domain -> { + try { + ResponseCookie.from("n", "v").domain(domain).build(); + } + catch (IllegalArgumentException ex) { + assertThat(ex.getMessage(), Matchers.containsString("Invalid first/last char")); + } + }); + + Arrays.asList("abc..org", "abc.-org", "abc-.org") + .forEach(domain -> { + try { + ResponseCookie.from("n", "v").domain(domain).build(); + } + catch (IllegalArgumentException ex) { + assertThat(ex.getMessage(), Matchers.containsString("invalid cookie domain char")); + } + }); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/ResponseEntityTests.java b/spring-web/src/test/java/org/springframework/http/ResponseEntityTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c87cdcf9ae3d0b01ab1205a6b15c9ec07bd0b2cb --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/ResponseEntityTests.java @@ -0,0 +1,303 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Marcel Overdijk + * @author Kazuki Shimizu + */ +public class ResponseEntityTests { + + @Test + public void normal() { + String headerName = "My-Custom-Header"; + String headerValue1 = "HeaderValue1"; + String headerValue2 = "HeaderValue2"; + Integer entity = 42; + + ResponseEntity responseEntity = + ResponseEntity.status(HttpStatus.OK).header(headerName, headerValue1, headerValue2).body(entity); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertTrue(responseEntity.getHeaders().containsKey(headerName)); + List list = responseEntity.getHeaders().get(headerName); + assertEquals(2, list.size()); + assertEquals(headerValue1, list.get(0)); + assertEquals(headerValue2, list.get(1)); + assertEquals(entity, responseEntity.getBody()); + } + + @Test + public void okNoBody() { + ResponseEntity responseEntity = ResponseEntity.ok().build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertNull(responseEntity.getBody()); + } + + @Test + public void okEntity() { + Integer entity = 42; + ResponseEntity responseEntity = ResponseEntity.ok(entity); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertEquals(entity, responseEntity.getBody()); + } + + @Test + public void ofOptional() { + Integer entity = 42; + ResponseEntity responseEntity = ResponseEntity.of(Optional.of(entity)); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertEquals(entity, responseEntity.getBody()); + } + + @Test + public void ofEmptyOptional() { + ResponseEntity responseEntity = ResponseEntity.of(Optional.empty()); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.NOT_FOUND, responseEntity.getStatusCode()); + assertNull(responseEntity.getBody()); + } + + @Test + public void createdLocation() throws URISyntaxException { + URI location = new URI("location"); + ResponseEntity responseEntity = ResponseEntity.created(location).build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.CREATED, responseEntity.getStatusCode()); + assertTrue(responseEntity.getHeaders().containsKey("Location")); + assertEquals(location.toString(), + responseEntity.getHeaders().getFirst("Location")); + assertNull(responseEntity.getBody()); + + ResponseEntity.created(location).header("MyResponseHeader", "MyValue").body("Hello World"); + } + + @Test + public void acceptedNoBody() throws URISyntaxException { + ResponseEntity responseEntity = ResponseEntity.accepted().build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.ACCEPTED, responseEntity.getStatusCode()); + assertNull(responseEntity.getBody()); + } + + @Test // SPR-14939 + public void acceptedNoBodyWithAlternativeBodyType() throws URISyntaxException { + ResponseEntity responseEntity = ResponseEntity.accepted().build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.ACCEPTED, responseEntity.getStatusCode()); + assertNull(responseEntity.getBody()); + } + + @Test + public void noContent() throws URISyntaxException { + ResponseEntity responseEntity = ResponseEntity.noContent().build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.NO_CONTENT, responseEntity.getStatusCode()); + assertNull(responseEntity.getBody()); + } + + @Test + public void badRequest() throws URISyntaxException { + ResponseEntity responseEntity = ResponseEntity.badRequest().build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.BAD_REQUEST, responseEntity.getStatusCode()); + assertNull(responseEntity.getBody()); + } + + @Test + public void notFound() throws URISyntaxException { + ResponseEntity responseEntity = ResponseEntity.notFound().build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.NOT_FOUND, responseEntity.getStatusCode()); + assertNull(responseEntity.getBody()); + } + + @Test + public void unprocessableEntity() throws URISyntaxException { + ResponseEntity responseEntity = ResponseEntity.unprocessableEntity().body("error"); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.UNPROCESSABLE_ENTITY, responseEntity.getStatusCode()); + assertEquals("error", responseEntity.getBody()); + } + + @Test + public void headers() throws URISyntaxException { + URI location = new URI("location"); + long contentLength = 67890; + MediaType contentType = MediaType.TEXT_PLAIN; + + ResponseEntity responseEntity = ResponseEntity.ok(). + allow(HttpMethod.GET). + lastModified(12345L). + location(location). + contentLength(contentLength). + contentType(contentType). + build(); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + HttpHeaders responseHeaders = responseEntity.getHeaders(); + + assertEquals("GET", responseHeaders.getFirst("Allow")); + assertEquals("Thu, 01 Jan 1970 00:00:12 GMT", + responseHeaders.getFirst("Last-Modified")); + assertEquals(location.toASCIIString(), + responseHeaders.getFirst("Location")); + assertEquals(String.valueOf(contentLength), responseHeaders.getFirst("Content-Length")); + assertEquals(contentType.toString(), responseHeaders.getFirst("Content-Type")); + + assertNull(responseEntity.getBody()); + } + + @Test + public void Etagheader() throws URISyntaxException { + + ResponseEntity responseEntity = ResponseEntity.ok().eTag("\"foo\"").build(); + assertEquals("\"foo\"", responseEntity.getHeaders().getETag()); + + responseEntity = ResponseEntity.ok().eTag("foo").build(); + assertEquals("\"foo\"", responseEntity.getHeaders().getETag()); + + responseEntity = ResponseEntity.ok().eTag("W/\"foo\"").build(); + assertEquals("W/\"foo\"", responseEntity.getHeaders().getETag()); + } + + @Test + public void headersCopy() { + HttpHeaders customHeaders = new HttpHeaders(); + customHeaders.set("X-CustomHeader", "vale"); + + ResponseEntity responseEntity = ResponseEntity.ok().headers(customHeaders).build(); + HttpHeaders responseHeaders = responseEntity.getHeaders(); + + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertEquals(1, responseHeaders.size()); + assertEquals(1, responseHeaders.get("X-CustomHeader").size()); + assertEquals("vale", responseHeaders.getFirst("X-CustomHeader")); + + } + + @Test // SPR-12792 + public void headersCopyWithEmptyAndNull() { + ResponseEntity responseEntityWithEmptyHeaders = + ResponseEntity.ok().headers(new HttpHeaders()).build(); + ResponseEntity responseEntityWithNullHeaders = + ResponseEntity.ok().headers(null).build(); + + assertEquals(HttpStatus.OK, responseEntityWithEmptyHeaders.getStatusCode()); + assertTrue(responseEntityWithEmptyHeaders.getHeaders().isEmpty()); + assertEquals(responseEntityWithEmptyHeaders.toString(), responseEntityWithNullHeaders.toString()); + } + + @Test + public void emptyCacheControl() { + Integer entity = 42; + + ResponseEntity responseEntity = + ResponseEntity.status(HttpStatus.OK) + .cacheControl(CacheControl.empty()) + .body(entity); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertFalse(responseEntity.getHeaders().containsKey(HttpHeaders.CACHE_CONTROL)); + assertEquals(entity, responseEntity.getBody()); + } + + @Test + public void cacheControl() { + Integer entity = 42; + + ResponseEntity responseEntity = + ResponseEntity.status(HttpStatus.OK) + .cacheControl(CacheControl.maxAge(1, TimeUnit.HOURS).cachePrivate(). + mustRevalidate().proxyRevalidate().sMaxAge(30, TimeUnit.MINUTES)) + .body(entity); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertTrue(responseEntity.getHeaders().containsKey(HttpHeaders.CACHE_CONTROL)); + assertEquals(entity, responseEntity.getBody()); + String cacheControlHeader = responseEntity.getHeaders().getCacheControl(); + assertThat(cacheControlHeader, + Matchers.equalTo("max-age=3600, must-revalidate, private, proxy-revalidate, s-maxage=1800")); + } + + @Test + public void cacheControlNoCache() { + Integer entity = 42; + + ResponseEntity responseEntity = + ResponseEntity.status(HttpStatus.OK) + .cacheControl(CacheControl.noStore()) + .body(entity); + + assertNotNull(responseEntity); + assertEquals(HttpStatus.OK, responseEntity.getStatusCode()); + assertTrue(responseEntity.getHeaders().containsKey(HttpHeaders.CACHE_CONTROL)); + assertEquals(entity, responseEntity.getBody()); + + String cacheControlHeader = responseEntity.getHeaders().getCacheControl(); + assertThat(cacheControlHeader, Matchers.equalTo("no-store")); + } + + @Test + public void statusCodeAsInt() { + Integer entity = 42; + ResponseEntity responseEntity = ResponseEntity.status(200).body(entity); + + assertEquals(200, responseEntity.getStatusCode().value()); + assertEquals(entity, responseEntity.getBody()); + } + + @Test + public void customStatusCode() { + Integer entity = 42; + ResponseEntity responseEntity = ResponseEntity.status(299).body(entity); + + assertEquals(299, responseEntity.getStatusCodeValue()); + assertEquals(entity, responseEntity.getBody()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractAsyncHttpRequestFactoryTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractAsyncHttpRequestFactoryTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..8ec3a06678ebd001442d761b1d98da1a9e738658 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractAsyncHttpRequestFactoryTestCase.java @@ -0,0 +1,236 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Locale; +import java.util.concurrent.Future; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.StreamUtils; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; + +@SuppressWarnings("deprecation") +public abstract class AbstractAsyncHttpRequestFactoryTestCase extends AbstractMockWebServerTestCase { + + protected AsyncClientHttpRequestFactory factory; + + + @Before + public final void createFactory() throws Exception { + this.factory = createRequestFactory(); + if (this.factory instanceof InitializingBean) { + ((InitializingBean) this.factory).afterPropertiesSet(); + } + } + + @After + public final void destroyFactory() throws Exception { + if (this.factory instanceof DisposableBean) { + ((DisposableBean) this.factory).destroy(); + } + } + + protected abstract AsyncClientHttpRequestFactory createRequestFactory(); + + + @Test + public void status() throws Exception { + URI uri = new URI(baseUrl + "/status/notfound"); + AsyncClientHttpRequest request = this.factory.createAsyncRequest(uri, HttpMethod.GET); + assertEquals("Invalid HTTP method", HttpMethod.GET, request.getMethod()); + assertEquals("Invalid HTTP URI", uri, request.getURI()); + Future futureResponse = request.executeAsync(); + ClientHttpResponse response = futureResponse.get(); + try { + assertEquals("Invalid status code", HttpStatus.NOT_FOUND, response.getStatusCode()); + } + finally { + response.close(); + } + } + + @Test + public void statusCallback() throws Exception { + URI uri = new URI(baseUrl + "/status/notfound"); + AsyncClientHttpRequest request = this.factory.createAsyncRequest(uri, HttpMethod.GET); + assertEquals("Invalid HTTP method", HttpMethod.GET, request.getMethod()); + assertEquals("Invalid HTTP URI", uri, request.getURI()); + ListenableFuture listenableFuture = request.executeAsync(); + listenableFuture.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(ClientHttpResponse result) { + try { + assertEquals("Invalid status code", HttpStatus.NOT_FOUND, result.getStatusCode()); + } + catch (IOException ex) { + fail(ex.getMessage()); + } + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + ClientHttpResponse response = listenableFuture.get(); + try { + assertEquals("Invalid status code", HttpStatus.NOT_FOUND, response.getStatusCode()); + } + finally { + response.close(); + } + } + + @Test + public void echo() throws Exception { + AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.PUT); + assertEquals("Invalid HTTP method", HttpMethod.PUT, request.getMethod()); + String headerName = "MyHeader"; + String headerValue1 = "value1"; + request.getHeaders().add(headerName, headerValue1); + String headerValue2 = "value2"; + request.getHeaders().add(headerName, headerValue2); + final byte[] body = "Hello World".getBytes("UTF-8"); + request.getHeaders().setContentLength(body.length); + + if (request instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingRequest = (StreamingHttpOutputMessage) request; + streamingRequest.setBody(outputStream -> StreamUtils.copy(body, outputStream)); + } + else { + StreamUtils.copy(body, request.getBody()); + } + + Future futureResponse = request.executeAsync(); + ClientHttpResponse response = futureResponse.get(); + try { + assertEquals("Invalid status code", HttpStatus.OK, response.getStatusCode()); + assertTrue("Header not found", response.getHeaders().containsKey(headerName)); + assertEquals("Header value not found", Arrays.asList(headerValue1, headerValue2), + response.getHeaders().get(headerName)); + byte[] result = FileCopyUtils.copyToByteArray(response.getBody()); + assertTrue("Invalid body", Arrays.equals(body, result)); + } + finally { + response.close(); + } + } + + @Test + public void multipleWrites() throws Exception { + AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.POST); + final byte[] body = "Hello World".getBytes("UTF-8"); + + if (request instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingRequest = (StreamingHttpOutputMessage) request; + streamingRequest.setBody(outputStream -> StreamUtils.copy(body, outputStream)); + } + else { + StreamUtils.copy(body, request.getBody()); + } + + Future futureResponse = request.executeAsync(); + ClientHttpResponse response = futureResponse.get(); + try { + FileCopyUtils.copy(body, request.getBody()); + fail("IllegalStateException expected"); + } + catch (IllegalStateException ex) { + // expected + } + finally { + response.close(); + } + } + + @Test + public void headersAfterExecute() throws Exception { + AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.POST); + request.getHeaders().add("MyHeader", "value"); + byte[] body = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(body, request.getBody()); + + Future futureResponse = request.executeAsync(); + ClientHttpResponse response = futureResponse.get(); + try { + request.getHeaders().add("MyHeader", "value"); + fail("UnsupportedOperationException expected"); + } + catch (UnsupportedOperationException ex) { + // expected + } + finally { + response.close(); + } + } + + @Test + public void httpMethods() throws Exception { + assertHttpMethod("get", HttpMethod.GET); + assertHttpMethod("head", HttpMethod.HEAD); + assertHttpMethod("post", HttpMethod.POST); + assertHttpMethod("put", HttpMethod.PUT); + assertHttpMethod("options", HttpMethod.OPTIONS); + assertHttpMethod("delete", HttpMethod.DELETE); + } + + protected void assertHttpMethod(String path, HttpMethod method) throws Exception { + ClientHttpResponse response = null; + try { + AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/methods/" + path), method); + if (method == HttpMethod.POST || method == HttpMethod.PUT || method == HttpMethod.PATCH) { + // requires a body + request.getBody().write(32); + } + Future futureResponse = request.executeAsync(); + response = futureResponse.get(); + assertEquals("Invalid response status", HttpStatus.OK, response.getStatusCode()); + assertEquals("Invalid method", path.toUpperCase(Locale.ENGLISH), request.getMethod().name()); + } + finally { + if (response != null) { + response.close(); + } + } + } + + @Test + public void cancel() throws Exception { + URI uri = new URI(baseUrl + "/status/notfound"); + AsyncClientHttpRequest request = this.factory.createAsyncRequest(uri, HttpMethod.GET); + Future futureResponse = request.executeAsync(); + futureResponse.cancel(true); + assertTrue(futureResponse.isCancelled()); + } + + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..852c74039b538ce0938d4ac06880863d0d6a3f29 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java @@ -0,0 +1,201 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.URI; +import java.util.Arrays; +import java.util.Locale; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.StreamUtils; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public abstract class AbstractHttpRequestFactoryTestCase extends AbstractMockWebServerTestCase { + + protected ClientHttpRequestFactory factory; + + + @Before + public final void createFactory() throws Exception { + factory = createRequestFactory(); + if (factory instanceof InitializingBean) { + ((InitializingBean) factory).afterPropertiesSet(); + } + } + + @After + public final void destroyFactory() throws Exception { + if (factory instanceof DisposableBean) { + ((DisposableBean) factory).destroy(); + } + } + + + protected abstract ClientHttpRequestFactory createRequestFactory(); + + + @Test + public void status() throws Exception { + URI uri = new URI(baseUrl + "/status/notfound"); + ClientHttpRequest request = factory.createRequest(uri, HttpMethod.GET); + assertEquals("Invalid HTTP method", HttpMethod.GET, request.getMethod()); + assertEquals("Invalid HTTP URI", uri, request.getURI()); + + ClientHttpResponse response = request.execute(); + try { + assertEquals("Invalid status code", HttpStatus.NOT_FOUND, response.getStatusCode()); + } + finally { + response.close(); + } + } + + @Test + public void echo() throws Exception { + ClientHttpRequest request = factory.createRequest(new URI(baseUrl + "/echo"), HttpMethod.PUT); + assertEquals("Invalid HTTP method", HttpMethod.PUT, request.getMethod()); + + String headerName = "MyHeader"; + String headerValue1 = "value1"; + request.getHeaders().add(headerName, headerValue1); + String headerValue2 = "value2"; + request.getHeaders().add(headerName, headerValue2); + final byte[] body = "Hello World".getBytes("UTF-8"); + request.getHeaders().setContentLength(body.length); + + if (request instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingRequest = (StreamingHttpOutputMessage) request; + streamingRequest.setBody(outputStream -> StreamUtils.copy(body, outputStream)); + } + else { + StreamUtils.copy(body, request.getBody()); + } + + ClientHttpResponse response = request.execute(); + try { + assertEquals("Invalid status code", HttpStatus.OK, response.getStatusCode()); + assertTrue("Header not found", response.getHeaders().containsKey(headerName)); + assertEquals("Header value not found", Arrays.asList(headerValue1, headerValue2), + response.getHeaders().get(headerName)); + byte[] result = FileCopyUtils.copyToByteArray(response.getBody()); + assertTrue("Invalid body", Arrays.equals(body, result)); + } + finally { + response.close(); + } + } + + @Test(expected = IllegalStateException.class) + public void multipleWrites() throws Exception { + ClientHttpRequest request = factory.createRequest(new URI(baseUrl + "/echo"), HttpMethod.POST); + + final byte[] body = "Hello World".getBytes("UTF-8"); + if (request instanceof StreamingHttpOutputMessage) { + StreamingHttpOutputMessage streamingRequest = (StreamingHttpOutputMessage) request; + streamingRequest.setBody(outputStream -> { + StreamUtils.copy(body, outputStream); + outputStream.flush(); + outputStream.close(); + }); + } + else { + StreamUtils.copy(body, request.getBody()); + } + + request.execute(); + FileCopyUtils.copy(body, request.getBody()); + } + + @Test(expected = UnsupportedOperationException.class) + public void headersAfterExecute() throws Exception { + ClientHttpRequest request = factory.createRequest(new URI(baseUrl + "/status/ok"), HttpMethod.POST); + + request.getHeaders().add("MyHeader", "value"); + byte[] body = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(body, request.getBody()); + + ClientHttpResponse response = request.execute(); + try { + request.getHeaders().add("MyHeader", "value"); + } + finally { + response.close(); + } + } + + @Test + public void httpMethods() throws Exception { + assertHttpMethod("get", HttpMethod.GET); + assertHttpMethod("head", HttpMethod.HEAD); + assertHttpMethod("post", HttpMethod.POST); + assertHttpMethod("put", HttpMethod.PUT); + assertHttpMethod("options", HttpMethod.OPTIONS); + assertHttpMethod("delete", HttpMethod.DELETE); + } + + protected void assertHttpMethod(String path, HttpMethod method) throws Exception { + ClientHttpResponse response = null; + try { + ClientHttpRequest request = factory.createRequest(new URI(baseUrl + "/methods/" + path), method); + if (method == HttpMethod.POST || method == HttpMethod.PUT || method == HttpMethod.PATCH) { + // requires a body + try { + request.getBody().write(32); + } + catch (UnsupportedOperationException ex) { + // probably a streaming request - let's simply ignore it + } + } + response = request.execute(); + assertEquals("Invalid response status", HttpStatus.OK, response.getStatusCode()); + assertEquals("Invalid method", path.toUpperCase(Locale.ENGLISH), request.getMethod().name()); + } + finally { + if (response != null) { + response.close(); + } + } + } + + @Test + public void queryParameters() throws Exception { + URI uri = new URI(baseUrl + "/params?param1=value¶m2=value1¶m2=value2"); + ClientHttpRequest request = factory.createRequest(uri, HttpMethod.GET); + + ClientHttpResponse response = request.execute(); + try { + assertEquals("Invalid status code", HttpStatus.OK, response.getStatusCode()); + } + finally { + response.close(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..0a496e871a78ba0df1de72f5c0eac44b4bc303ea --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTestCase.java @@ -0,0 +1,97 @@ +package org.springframework.http.client; + +import java.util.Collections; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import org.springframework.http.MediaType; +import org.springframework.util.StringUtils; + +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * @author Brian Clozel + */ +public class AbstractMockWebServerTestCase { + + private MockWebServer server; + + protected int port; + + protected String baseUrl; + + protected static final MediaType textContentType = + new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8")); + + @Before + public void setUp() throws Exception { + this.server = new MockWebServer(); + this.server.setDispatcher(new TestDispatcher()); + this.server.start(); + this.port = this.server.getPort(); + this.baseUrl = "http://localhost:" + this.port; + } + + @After + public void tearDown() throws Exception { + this.server.shutdown(); + } + + protected class TestDispatcher extends Dispatcher { + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + try { + if (request.getPath().equals("/echo")) { + assertThat(request.getHeader("Host"), + Matchers.containsString("localhost:" + port)); + MockResponse response = new MockResponse() + .setHeaders(request.getHeaders()) + .setHeader("Content-Length", request.getBody().size()) + .setResponseCode(200) + .setBody(request.getBody()); + request.getBody().flush(); + return response; + } + else if(request.getPath().equals("/status/ok")) { + return new MockResponse(); + } + else if(request.getPath().equals("/status/notfound")) { + return new MockResponse().setResponseCode(404); + } + else if(request.getPath().startsWith("/params")) { + assertThat(request.getPath(), Matchers.containsString("param1=value")); + assertThat(request.getPath(), Matchers.containsString("param2=value1¶m2=value2")); + return new MockResponse(); + } + else if(request.getPath().equals("/methods/post")) { + assertThat(request.getMethod(), Matchers.is("POST")); + String transferEncoding = request.getHeader("Transfer-Encoding"); + if(StringUtils.hasLength(transferEncoding)) { + assertThat(transferEncoding, Matchers.is("chunked")); + } + else { + long contentLength = Long.parseLong(request.getHeader("Content-Length")); + assertThat("Invalid content-length", + request.getBody().size(), Matchers.is(contentLength)); + } + return new MockResponse().setResponseCode(200); + } + else if(request.getPath().startsWith("/methods/")) { + String expectedMethod = request.getPath().replace("/methods/","").toUpperCase(); + assertThat(request.getMethod(), Matchers.is(expectedMethod)); + return new MockResponse(); + } + return new MockResponse().setResponseCode(404); + } + catch (Throwable exc) { + return new MockResponse().setResponseCode(500).setBody(exc.toString()); + } + } + } +} diff --git a/spring-web/src/test/java/org/springframework/http/client/BufferedSimpleAsyncHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/BufferedSimpleAsyncHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d0de70b83cb59132cd5ad24b9ff87ad78ca11805 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/BufferedSimpleAsyncHttpRequestFactoryTests.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.ProtocolException; + +import org.junit.Test; + +import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.http.HttpMethod; + +public class BufferedSimpleAsyncHttpRequestFactoryTests extends AbstractAsyncHttpRequestFactoryTestCase { + + @SuppressWarnings("deprecation") + @Override + protected AsyncClientHttpRequestFactory createRequestFactory() { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); + requestFactory.setTaskExecutor(taskExecutor); + return requestFactory; + } + + @Override + @Test + public void httpMethods() throws Exception { + super.httpMethods(); + try { + assertHttpMethod("patch", HttpMethod.PATCH); + } + catch (ProtocolException ex) { + // Currently HttpURLConnection does not support HTTP PATCH + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/BufferedSimpleHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/BufferedSimpleHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..267fbf1f041a0c661776a7403cd8c78cfab93d2c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/BufferedSimpleHttpRequestFactoryTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.ProtocolException; +import java.net.URL; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +import static org.junit.Assert.*; + +public class BufferedSimpleHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + return new SimpleClientHttpRequestFactory(); + } + + @Override + @Test + public void httpMethods() throws Exception { + try { + assertHttpMethod("patch", HttpMethod.PATCH); + } + catch (ProtocolException ex) { + // Currently HttpURLConnection does not support HTTP PATCH + } + } + + @Test + public void prepareConnectionWithRequestBody() throws Exception { + URL uri = new URL("https://example.com"); + testRequestBodyAllowed(uri, "GET", false); + testRequestBodyAllowed(uri, "HEAD", false); + testRequestBodyAllowed(uri, "OPTIONS", false); + testRequestBodyAllowed(uri, "TRACE", false); + testRequestBodyAllowed(uri, "PUT", true); + testRequestBodyAllowed(uri, "POST", true); + testRequestBodyAllowed(uri, "DELETE", true); + } + + @Test + public void deleteWithoutBodyDoesNotRaiseException() throws Exception { + HttpURLConnection connection = new TestHttpURLConnection(new URL("https://example.com")); + ((SimpleClientHttpRequestFactory) this.factory).prepareConnection(connection, "DELETE"); + SimpleBufferingClientHttpRequest request = new SimpleBufferingClientHttpRequest(connection, false); + request.execute(); + } + + private void testRequestBodyAllowed(URL uri, String httpMethod, boolean allowed) throws IOException { + HttpURLConnection connection = new TestHttpURLConnection(uri); + ((SimpleClientHttpRequestFactory) this.factory).prepareConnection(connection, httpMethod); + assertEquals(allowed, connection.getDoOutput()); + } + + + private static class TestHttpURLConnection extends HttpURLConnection { + + public TestHttpURLConnection(URL uri) { + super(uri); + } + + @Override + public void connect() throws IOException { + } + + @Override + public void disconnect() { + } + + @Override + public boolean usingProxy() { + return false; + } + + @Override + public InputStream getInputStream() throws IOException { + return new ByteArrayInputStream(new byte[0]); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/BufferingClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/BufferingClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6dc785627250a0cfde9b1afe8a1979039a8f7bcb --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/BufferingClientHttpRequestFactoryTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.URI; +import java.util.Arrays; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.util.FileCopyUtils; + +import static org.junit.Assert.*; + +public class BufferingClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + return new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory()); + } + + @Test + public void repeatableRead() throws Exception { + ClientHttpRequest request = factory.createRequest(new URI(baseUrl + "/echo"), HttpMethod.PUT); + assertEquals("Invalid HTTP method", HttpMethod.PUT, request.getMethod()); + String headerName = "MyHeader"; + String headerValue1 = "value1"; + request.getHeaders().add(headerName, headerValue1); + String headerValue2 = "value2"; + request.getHeaders().add(headerName, headerValue2); + byte[] body = "Hello World".getBytes("UTF-8"); + request.getHeaders().setContentLength(body.length); + FileCopyUtils.copy(body, request.getBody()); + ClientHttpResponse response = request.execute(); + try { + assertEquals("Invalid status code", HttpStatus.OK, response.getStatusCode()); + assertEquals("Invalid status code", HttpStatus.OK, response.getStatusCode()); + + assertTrue("Header not found", response.getHeaders().containsKey(headerName)); + assertTrue("Header not found", response.getHeaders().containsKey(headerName)); + + assertEquals("Header value not found", Arrays.asList(headerValue1, headerValue2), + response.getHeaders().get(headerName)); + assertEquals("Header value not found", Arrays.asList(headerValue1, headerValue2), + response.getHeaders().get(headerName)); + + byte[] result = FileCopyUtils.copyToByteArray(response.getBody()); + assertTrue("Invalid body", Arrays.equals(body, result)); + FileCopyUtils.copyToByteArray(response.getBody()); + assertTrue("Invalid body", Arrays.equals(body, result)); + } + finally { + response.close(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fd2f577152a18e1dbf736a6a7897fc55d87b086e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/HttpComponentsAsyncClientHttpRequestFactoryTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.URI; + +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; +import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Stephane Nicoll + */ +@SuppressWarnings("deprecation") +public class HttpComponentsAsyncClientHttpRequestFactoryTests extends AbstractAsyncHttpRequestFactoryTestCase { + + @Override + protected AsyncClientHttpRequestFactory createRequestFactory() { + return new HttpComponentsAsyncClientHttpRequestFactory(); + } + + + @Override + @Test + public void httpMethods() throws Exception { + super.httpMethods(); + assertHttpMethod("patch", HttpMethod.PATCH); + } + + @Test + public void customHttpAsyncClientUsesItsDefault() throws Exception { + HttpComponentsAsyncClientHttpRequestFactory factory = + new HttpComponentsAsyncClientHttpRequestFactory(); + + URI uri = new URI(baseUrl + "/status/ok"); + HttpComponentsAsyncClientHttpRequest request = (HttpComponentsAsyncClientHttpRequest) + factory.createAsyncRequest(uri, HttpMethod.GET); + + assertNull("No custom config should be set with a custom HttpAsyncClient", + request.getHttpContext().getAttribute(HttpClientContext.REQUEST_CONFIG)); + } + + @Test + public void defaultSettingsOfHttpAsyncClientLostOnExecutorCustomization() throws Exception { + CloseableHttpAsyncClient client = HttpAsyncClientBuilder.create() + .setDefaultRequestConfig(RequestConfig.custom().setConnectTimeout(1234).build()) + .build(); + HttpComponentsAsyncClientHttpRequestFactory factory = new HttpComponentsAsyncClientHttpRequestFactory(client); + + URI uri = new URI(baseUrl + "/status/ok"); + HttpComponentsAsyncClientHttpRequest request = (HttpComponentsAsyncClientHttpRequest) + factory.createAsyncRequest(uri, HttpMethod.GET); + + assertNull("No custom config should be set with a custom HttpClient", + request.getHttpContext().getAttribute(HttpClientContext.REQUEST_CONFIG)); + + factory.setConnectionRequestTimeout(4567); + HttpComponentsAsyncClientHttpRequest request2 = (HttpComponentsAsyncClientHttpRequest) + factory.createAsyncRequest(uri, HttpMethod.GET); + Object requestConfigAttribute = request2.getHttpContext().getAttribute(HttpClientContext.REQUEST_CONFIG); + assertNotNull(requestConfigAttribute); + RequestConfig requestConfig = (RequestConfig) requestConfigAttribute; + + assertEquals(4567, requestConfig.getConnectionRequestTimeout()); + // No way to access the request config of the HTTP client so no way to "merge" our customizations + assertEquals(-1, requestConfig.getConnectTimeout()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..749d1a0f6f5a1de09412641f848349c06e53b0a9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactoryTests.java @@ -0,0 +1,170 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.URI; + +import org.apache.http.HttpEntityEnclosingRequest; +import org.apache.http.client.HttpClient; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.Configurable; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.junit.Test; +import org.springframework.http.HttpMethod; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * @author Stephane Nicoll + */ +public class HttpComponentsClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + return new HttpComponentsClientHttpRequestFactory(); + } + + @Override + @Test + public void httpMethods() throws Exception { + super.httpMethods(); + assertHttpMethod("patch", HttpMethod.PATCH); + } + + @Test + public void assertCustomConfig() throws Exception { + HttpClient httpClient = HttpClientBuilder.create().build(); + HttpComponentsClientHttpRequestFactory hrf = new HttpComponentsClientHttpRequestFactory(httpClient); + hrf.setConnectTimeout(1234); + hrf.setConnectionRequestTimeout(4321); + hrf.setReadTimeout(4567); + + URI uri = new URI(baseUrl + "/status/ok"); + HttpComponentsClientHttpRequest request = (HttpComponentsClientHttpRequest) + hrf.createRequest(uri, HttpMethod.GET); + + Object config = request.getHttpContext().getAttribute(HttpClientContext.REQUEST_CONFIG); + assertNotNull("Request config should be set", config); + assertTrue("Wrong request config type" + config.getClass().getName(), + RequestConfig.class.isInstance(config)); + RequestConfig requestConfig = (RequestConfig) config; + assertEquals("Wrong custom connection timeout", 1234, requestConfig.getConnectTimeout()); + assertEquals("Wrong custom connection request timeout", 4321, requestConfig.getConnectionRequestTimeout()); + assertEquals("Wrong custom socket timeout", 4567, requestConfig.getSocketTimeout()); + } + + @Test + public void defaultSettingsOfHttpClientMergedOnExecutorCustomization() throws Exception { + RequestConfig defaultConfig = RequestConfig.custom().setConnectTimeout(1234).build(); + CloseableHttpClient client = mock(CloseableHttpClient.class, + withSettings().extraInterfaces(Configurable.class)); + Configurable configurable = (Configurable) client; + when(configurable.getConfig()).thenReturn(defaultConfig); + + HttpComponentsClientHttpRequestFactory hrf = new HttpComponentsClientHttpRequestFactory(client); + assertSame("Default client configuration is expected", defaultConfig, retrieveRequestConfig(hrf)); + + hrf.setConnectionRequestTimeout(4567); + RequestConfig requestConfig = retrieveRequestConfig(hrf); + assertNotNull(requestConfig); + assertEquals(4567, requestConfig.getConnectionRequestTimeout()); + // Default connection timeout merged + assertEquals(1234, requestConfig.getConnectTimeout()); + } + + @Test + public void localSettingsOverrideClientDefaultSettings() throws Exception { + RequestConfig defaultConfig = RequestConfig.custom() + .setConnectTimeout(1234).setConnectionRequestTimeout(6789).build(); + CloseableHttpClient client = mock(CloseableHttpClient.class, + withSettings().extraInterfaces(Configurable.class)); + Configurable configurable = (Configurable) client; + when(configurable.getConfig()).thenReturn(defaultConfig); + + HttpComponentsClientHttpRequestFactory hrf = new HttpComponentsClientHttpRequestFactory(client); + hrf.setConnectTimeout(5000); + + RequestConfig requestConfig = retrieveRequestConfig(hrf); + assertEquals(5000, requestConfig.getConnectTimeout()); + assertEquals(6789, requestConfig.getConnectionRequestTimeout()); + assertEquals(-1, requestConfig.getSocketTimeout()); + } + + @Test + public void mergeBasedOnCurrentHttpClient() throws Exception { + RequestConfig defaultConfig = RequestConfig.custom() + .setSocketTimeout(1234).build(); + final CloseableHttpClient client = mock(CloseableHttpClient.class, + withSettings().extraInterfaces(Configurable.class)); + Configurable configurable = (Configurable) client; + when(configurable.getConfig()).thenReturn(defaultConfig); + + HttpComponentsClientHttpRequestFactory hrf = new HttpComponentsClientHttpRequestFactory() { + @Override + public HttpClient getHttpClient() { + return client; + } + }; + hrf.setReadTimeout(5000); + + RequestConfig requestConfig = retrieveRequestConfig(hrf); + assertEquals(-1, requestConfig.getConnectTimeout()); + assertEquals(-1, requestConfig.getConnectionRequestTimeout()); + assertEquals(5000, requestConfig.getSocketTimeout()); + + // Update the Http client so that it returns an updated config + RequestConfig updatedDefaultConfig = RequestConfig.custom() + .setConnectTimeout(1234).build(); + when(configurable.getConfig()).thenReturn(updatedDefaultConfig); + hrf.setReadTimeout(7000); + RequestConfig requestConfig2 = retrieveRequestConfig(hrf); + assertEquals(1234, requestConfig2.getConnectTimeout()); + assertEquals(-1, requestConfig2.getConnectionRequestTimeout()); + assertEquals(7000, requestConfig2.getSocketTimeout()); + } + + private RequestConfig retrieveRequestConfig(HttpComponentsClientHttpRequestFactory factory) throws Exception { + URI uri = new URI(baseUrl + "/status/ok"); + HttpComponentsClientHttpRequest request = (HttpComponentsClientHttpRequest) + factory.createRequest(uri, HttpMethod.GET); + return (RequestConfig) request.getHttpContext().getAttribute(HttpClientContext.REQUEST_CONFIG); + } + + @Test + public void createHttpUriRequest() throws Exception { + URI uri = new URI("https://example.com"); + testRequestBodyAllowed(uri, HttpMethod.GET, false); + testRequestBodyAllowed(uri, HttpMethod.HEAD, false); + testRequestBodyAllowed(uri, HttpMethod.OPTIONS, false); + testRequestBodyAllowed(uri, HttpMethod.TRACE, false); + testRequestBodyAllowed(uri, HttpMethod.PUT, true); + testRequestBodyAllowed(uri, HttpMethod.POST, true); + testRequestBodyAllowed(uri, HttpMethod.PATCH, true); + testRequestBodyAllowed(uri, HttpMethod.DELETE, true); + + } + + private void testRequestBodyAllowed(URI uri, HttpMethod method, boolean allowed) { + HttpUriRequest request = ((HttpComponentsClientHttpRequestFactory) this.factory).createHttpUriRequest(method, uri); + assertEquals(allowed, request instanceof HttpEntityEnclosingRequest); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1afa1bcc5b5482d0f2785162058c966769f75376 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java @@ -0,0 +1,335 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.support.HttpRequestWrapper; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Juergen Hoeller + */ +public class InterceptingClientHttpRequestFactoryTests { + + private RequestFactoryMock requestFactoryMock = new RequestFactoryMock(); + + private RequestMock requestMock = new RequestMock(); + + private ResponseMock responseMock = new ResponseMock(); + + private InterceptingClientHttpRequestFactory requestFactory; + + + @Test + public void basic() throws Exception { + List interceptors = new ArrayList<>(); + interceptors.add(new NoOpInterceptor()); + interceptors.add(new NoOpInterceptor()); + interceptors.add(new NoOpInterceptor()); + requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, interceptors); + + ClientHttpRequest request = requestFactory.createRequest(new URI("https://example.com"), HttpMethod.GET); + ClientHttpResponse response = request.execute(); + + assertTrue(((NoOpInterceptor) interceptors.get(0)).invoked); + assertTrue(((NoOpInterceptor) interceptors.get(1)).invoked); + assertTrue(((NoOpInterceptor) interceptors.get(2)).invoked); + assertTrue(requestMock.executed); + assertSame(responseMock, response); + } + + @Test + public void noExecution() throws Exception { + List interceptors = new ArrayList<>(); + interceptors.add(new ClientHttpRequestInterceptor() { + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException { + return responseMock; + } + }); + + interceptors.add(new NoOpInterceptor()); + requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, interceptors); + + ClientHttpRequest request = requestFactory.createRequest(new URI("https://example.com"), HttpMethod.GET); + ClientHttpResponse response = request.execute(); + + assertFalse(((NoOpInterceptor) interceptors.get(1)).invoked); + assertFalse(requestMock.executed); + assertSame(responseMock, response); + } + + @Test + public void changeHeaders() throws Exception { + final String headerName = "Foo"; + final String headerValue = "Bar"; + final String otherValue = "Baz"; + + ClientHttpRequestInterceptor interceptor = new ClientHttpRequestInterceptor() { + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException { + HttpRequestWrapper wrapper = new HttpRequestWrapper(request); + wrapper.getHeaders().add(headerName, otherValue); + return execution.execute(wrapper, body); + } + }; + + requestMock = new RequestMock() { + @Override + public ClientHttpResponse execute() throws IOException { + List headerValues = getHeaders().get(headerName); + assertEquals(2, headerValues.size()); + assertEquals(headerValue, headerValues.get(0)); + assertEquals(otherValue, headerValues.get(1)); + return super.execute(); + } + }; + requestMock.getHeaders().add(headerName, headerValue); + + requestFactory = + new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); + + ClientHttpRequest request = requestFactory.createRequest(new URI("https://example.com"), HttpMethod.GET); + request.execute(); + } + + @Test + public void changeURI() throws Exception { + final URI changedUri = new URI("https://example.com/2"); + + ClientHttpRequestInterceptor interceptor = new ClientHttpRequestInterceptor() { + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException { + return execution.execute(new HttpRequestWrapper(request) { + @Override + public URI getURI() { + return changedUri; + } + + }, body); + } + }; + + requestFactoryMock = new RequestFactoryMock() { + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + assertEquals(changedUri, uri); + return super.createRequest(uri, httpMethod); + } + }; + + requestFactory = + new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); + + ClientHttpRequest request = requestFactory.createRequest(new URI("https://example.com"), HttpMethod.GET); + request.execute(); + } + + @Test + public void changeMethod() throws Exception { + final HttpMethod changedMethod = HttpMethod.POST; + + ClientHttpRequestInterceptor interceptor = new ClientHttpRequestInterceptor() { + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException { + return execution.execute(new HttpRequestWrapper(request) { + @Override + public HttpMethod getMethod() { + return changedMethod; + } + + }, body); + } + }; + + requestFactoryMock = new RequestFactoryMock() { + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + assertEquals(changedMethod, httpMethod); + return super.createRequest(uri, httpMethod); + } + }; + + requestFactory = + new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); + + ClientHttpRequest request = requestFactory.createRequest(new URI("https://example.com"), HttpMethod.GET); + request.execute(); + } + + @Test + public void changeBody() throws Exception { + final byte[] changedBody = "Foo".getBytes(); + + ClientHttpRequestInterceptor interceptor = new ClientHttpRequestInterceptor() { + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException { + return execution.execute(request, changedBody); + } + }; + + requestFactory = + new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); + + ClientHttpRequest request = requestFactory.createRequest(new URI("https://example.com"), HttpMethod.GET); + request.execute(); + assertTrue(Arrays.equals(changedBody, requestMock.body.toByteArray())); + } + + + private static class NoOpInterceptor implements ClientHttpRequestInterceptor { + + private boolean invoked = false; + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException { + invoked = true; + return execution.execute(request, body); + } + } + + + private class RequestFactoryMock implements ClientHttpRequestFactory { + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + requestMock.setURI(uri); + requestMock.setMethod(httpMethod); + return requestMock; + } + + } + + + private class RequestMock implements ClientHttpRequest { + + private URI uri; + + private HttpMethod method; + + private HttpHeaders headers = new HttpHeaders(); + + private ByteArrayOutputStream body = new ByteArrayOutputStream(); + + private boolean executed = false; + + private RequestMock() { + } + + @Override + public URI getURI() { + return uri; + } + + public void setURI(URI uri) { + this.uri = uri; + } + + @Override + public HttpMethod getMethod() { + return method; + } + + @Override + public String getMethodValue() { + return method.name(); + } + + public void setMethod(HttpMethod method) { + this.method = method; + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + @Override + public OutputStream getBody() throws IOException { + return body; + } + + @Override + public ClientHttpResponse execute() throws IOException { + executed = true; + return responseMock; + } + } + + + private static class ResponseMock implements ClientHttpResponse { + + private HttpStatus statusCode = HttpStatus.OK; + + private String statusText = ""; + + private HttpHeaders headers = new HttpHeaders(); + + @Override + public HttpStatus getStatusCode() throws IOException { + return statusCode; + } + + @Override + public int getRawStatusCode() throws IOException { + return statusCode.value(); + } + + @Override + public String getStatusText() throws IOException { + return statusText; + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + @Override + public InputStream getBody() throws IOException { + return null; + } + + @Override + public void close() { + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/InterceptingStreamingHttpComponentsTests.java b/spring-web/src/test/java/org/springframework/http/client/InterceptingStreamingHttpComponentsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..beb407f679cf34d5cc8708883d9303b4838a352a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/InterceptingStreamingHttpComponentsTests.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * @author Juergen Hoeller + */ +public class InterceptingStreamingHttpComponentsTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(); + requestFactory.setBufferRequestBody(false); + return new InterceptingClientHttpRequestFactory(requestFactory, null); + } + + @Override + @Test + public void httpMethods() throws Exception { + assertHttpMethod("patch", HttpMethod.PATCH); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java b/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9393acf38a3259a8067c7861eb73b795bab42282 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.client.MultipartBodyBuilder.PublisherEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class MultipartBodyBuilderTests { + + @Test + public void builder() { + + MultipartBodyBuilder builder = new MultipartBodyBuilder(); + + MultiValueMap multipartData = new LinkedMultiValueMap<>(); + multipartData.add("form field", "form value"); + builder.part("key", multipartData).header("foo", "bar"); + + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + builder.part("logo", logo).header("baz", "qux"); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.add("foo", "bar"); + HttpEntity entity = new HttpEntity<>("body", entityHeaders); + builder.part("entity", entity).header("baz", "qux"); + + Publisher publisher = Flux.just("foo", "bar", "baz"); + builder.asyncPart("publisherClass", publisher, String.class).header("baz", "qux"); + builder.asyncPart("publisherPtr", publisher, new ParameterizedTypeReference() {}).header("baz", "qux"); + + MultiValueMap> result = builder.build(); + + assertEquals(5, result.size()); + HttpEntity resultEntity = result.getFirst("key"); + assertNotNull(resultEntity); + assertEquals(multipartData, resultEntity.getBody()); + assertEquals("bar", resultEntity.getHeaders().getFirst("foo")); + + resultEntity = result.getFirst("logo"); + assertNotNull(resultEntity); + assertEquals(logo, resultEntity.getBody()); + assertEquals("qux", resultEntity.getHeaders().getFirst("baz")); + + resultEntity = result.getFirst("entity"); + assertNotNull(resultEntity); + assertEquals("body", resultEntity.getBody()); + assertEquals("bar", resultEntity.getHeaders().getFirst("foo")); + assertEquals("qux", resultEntity.getHeaders().getFirst("baz")); + + resultEntity = result.getFirst("publisherClass"); + assertNotNull(resultEntity); + assertEquals(publisher, resultEntity.getBody()); + assertEquals(ResolvableType.forClass(String.class), + ((PublisherEntity) resultEntity).getResolvableType()); + assertEquals("qux", resultEntity.getHeaders().getFirst("baz")); + + resultEntity = result.getFirst("publisherPtr"); + assertNotNull(resultEntity); + assertEquals(publisher, resultEntity.getBody()); + assertEquals(ResolvableType.forClass(String.class), + ((PublisherEntity) resultEntity).getResolvableType()); + assertEquals("qux", resultEntity.getHeaders().getFirst("baz")); + } + + @Test // SPR-16601 + public void publisherEntityAcceptedAsInput() { + + Publisher publisher = Flux.just("foo", "bar", "baz"); + MultipartBodyBuilder builder = new MultipartBodyBuilder(); + builder.asyncPart("publisherClass", publisher, String.class).header("baz", "qux"); + HttpEntity entity = builder.build().getFirst("publisherClass"); + + assertNotNull(entity); + assertEquals(PublisherEntity.class, entity.getClass()); + + // Now build a new MultipartBodyBuilder, as BodyInserters.fromMultipartData would do... + + builder = new MultipartBodyBuilder(); + builder.part("publisherClass", entity); + entity = builder.build().getFirst("publisherClass"); + + assertNotNull(entity); + assertEquals(PublisherEntity.class, entity.getClass()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/Netty4AsyncClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/Netty4AsyncClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c6ce0b41eac9371c924fad49c3431b25b92de87c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/Netty4AsyncClientHttpRequestFactoryTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * @author Arjen Poutsma + */ +public class Netty4AsyncClientHttpRequestFactoryTests extends AbstractAsyncHttpRequestFactoryTestCase { + + private static EventLoopGroup eventLoopGroup; + + + @BeforeClass + public static void createEventLoopGroup() { + eventLoopGroup = new NioEventLoopGroup(); + } + + @AfterClass + public static void shutdownEventLoopGroup() throws InterruptedException { + eventLoopGroup.shutdownGracefully().sync(); + } + + @SuppressWarnings("deprecation") + @Override + protected AsyncClientHttpRequestFactory createRequestFactory() { + return new Netty4ClientHttpRequestFactory(eventLoopGroup); + } + + @Override + @Test + public void httpMethods() throws Exception { + super.httpMethods(); + assertHttpMethod("patch", HttpMethod.PATCH); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/Netty4ClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/Netty4ClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3b82f4fb596fed5b06710d10ff79d38f75866bea --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/Netty4ClientHttpRequestFactoryTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * @author Arjen Poutsma + */ +public class Netty4ClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + private static EventLoopGroup eventLoopGroup; + + + @BeforeClass + public static void createEventLoopGroup() { + eventLoopGroup = new NioEventLoopGroup(); + } + + @AfterClass + public static void shutdownEventLoopGroup() throws InterruptedException { + eventLoopGroup.shutdownGracefully().sync(); + } + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + return new Netty4ClientHttpRequestFactory(eventLoopGroup); + } + + @Override + @Test + public void httpMethods() throws Exception { + super.httpMethods(); + assertHttpMethod("patch", HttpMethod.PATCH); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/NoOutputStreamingBufferedSimpleHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/NoOutputStreamingBufferedSimpleHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7bc41ea7d4438ed226016dae952b7dea3631ca60 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/NoOutputStreamingBufferedSimpleHttpRequestFactoryTests.java @@ -0,0 +1,29 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + + +public class NoOutputStreamingBufferedSimpleHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + SimpleClientHttpRequestFactory factory = new SimpleClientHttpRequestFactory(); + factory.setOutputStreaming(false); + return factory; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/NoOutputStreamingStreamingSimpleHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/NoOutputStreamingStreamingSimpleHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b1d0c9368a054901e077d11065ee00334f5981ee --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/NoOutputStreamingStreamingSimpleHttpRequestFactoryTests.java @@ -0,0 +1,29 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + + +public class NoOutputStreamingStreamingSimpleHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + SimpleClientHttpRequestFactory factory = new SimpleClientHttpRequestFactory(); + factory.setBufferRequestBody(false); + factory.setOutputStreaming(false); + return factory; + } +} diff --git a/spring-web/src/test/java/org/springframework/http/client/OkHttp3AsyncClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/OkHttp3AsyncClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c2dfeae1b5770894cbddbe3b8fd4470ece8b857e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/OkHttp3AsyncClientHttpRequestFactoryTests.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * @author Roy Clarkson + */ +public class OkHttp3AsyncClientHttpRequestFactoryTests extends AbstractAsyncHttpRequestFactoryTestCase { + + @SuppressWarnings("deprecation") + @Override + protected AsyncClientHttpRequestFactory createRequestFactory() { + return new OkHttp3ClientHttpRequestFactory(); + } + + @Override + @Test + public void httpMethods() throws Exception { + super.httpMethods(); + assertHttpMethod("patch", HttpMethod.PATCH); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/OkHttp3ClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/OkHttp3ClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d7d51e6e39e53713cfaca6de310e7cca653e350e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/OkHttp3ClientHttpRequestFactoryTests.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * @author Roy Clarkson + */ +public class OkHttp3ClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + return new OkHttp3ClientHttpRequestFactory(); + } + + @Override + @Test + public void httpMethods() throws Exception { + super.httpMethods(); + assertHttpMethod("patch", HttpMethod.PATCH); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..db96dfb7a24b78472bb03d45569b0ae5641a51e3 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.net.HttpURLConnection; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; + +import static org.mockito.Mockito.*; + +/** + * @author Stephane Nicoll + */ +public class SimpleClientHttpRequestFactoryTests { + + @Test // SPR-13225 + public void headerWithNullValue() { + HttpURLConnection urlConnection = mock(HttpURLConnection.class); + HttpHeaders headers = new HttpHeaders(); + headers.set("foo", null); + SimpleBufferingClientHttpRequest.addHeaders(urlConnection, headers); + verify(urlConnection, times(1)).addRequestProperty("foo", ""); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpResponseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..622e36151dbdae31cf1fc1b72beee4b4ffbaea3d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpResponseTests.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.nio.charset.StandardCharsets; + +import org.junit.Test; + +import org.springframework.util.StreamUtils; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.*; + +/** + * @author Brian Clozel + * @author Juergen Hoeller + */ +public class SimpleClientHttpResponseTests { + + private final HttpURLConnection connection = mock(HttpURLConnection.class); + + private final SimpleClientHttpResponse response = new SimpleClientHttpResponse(this.connection); + + + @Test // SPR-14040 + public void shouldNotCloseConnectionWhenResponseClosed() throws Exception { + TestByteArrayInputStream is = new TestByteArrayInputStream("Spring".getBytes(StandardCharsets.UTF_8)); + given(this.connection.getErrorStream()).willReturn(null); + given(this.connection.getInputStream()).willReturn(is); + + InputStream responseStream = this.response.getBody(); + assertThat(StreamUtils.copyToString(responseStream, StandardCharsets.UTF_8), is("Spring")); + + this.response.close(); + assertTrue(is.isClosed()); + verify(this.connection, never()).disconnect(); + } + + @Test // SPR-14040 + public void shouldDrainStreamWhenResponseClosed() throws Exception { + byte[] buf = new byte[6]; + TestByteArrayInputStream is = new TestByteArrayInputStream("SpringSpring".getBytes(StandardCharsets.UTF_8)); + given(this.connection.getErrorStream()).willReturn(null); + given(this.connection.getInputStream()).willReturn(is); + + InputStream responseStream = this.response.getBody(); + responseStream.read(buf); + assertThat(new String(buf, StandardCharsets.UTF_8), is("Spring")); + assertThat(is.available(), is(6)); + + this.response.close(); + assertThat(is.available(), is(0)); + assertTrue(is.isClosed()); + verify(this.connection, never()).disconnect(); + } + + @Test // SPR-14040 + public void shouldDrainErrorStreamWhenResponseClosed() throws Exception { + byte[] buf = new byte[6]; + TestByteArrayInputStream is = new TestByteArrayInputStream("SpringSpring".getBytes(StandardCharsets.UTF_8)); + given(this.connection.getErrorStream()).willReturn(is); + + InputStream responseStream = this.response.getBody(); + responseStream.read(buf); + assertThat(new String(buf, StandardCharsets.UTF_8), is("Spring")); + assertThat(is.available(), is(6)); + + this.response.close(); + assertThat(is.available(), is(0)); + assertTrue(is.isClosed()); + verify(this.connection, never()).disconnect(); + } + + @Test // SPR-16773 + public void shouldNotDrainWhenErrorStreamClosed() throws Exception { + InputStream is = mock(InputStream.class); + given(this.connection.getErrorStream()).willReturn(is); + doNothing().when(is).close(); + given(is.read(any())).willThrow(new NullPointerException("from HttpURLConnection#ErrorStream")); + + InputStream responseStream = this.response.getBody(); + responseStream.close(); + this.response.close(); + + verify(is).close(); + } + + @Test // SPR-17181 + public void shouldDrainResponseEvenIfResponseNotRead() throws Exception { + TestByteArrayInputStream is = new TestByteArrayInputStream("SpringSpring".getBytes(StandardCharsets.UTF_8)); + given(this.connection.getErrorStream()).willReturn(null); + given(this.connection.getInputStream()).willReturn(is); + + this.response.close(); + assertThat(is.available(), is(0)); + assertTrue(is.isClosed()); + verify(this.connection, never()).disconnect(); + } + + + private static class TestByteArrayInputStream extends ByteArrayInputStream { + + private boolean closed; + + public TestByteArrayInputStream(byte[] buf) { + super(buf); + this.closed = false; + } + + public boolean isClosed() { + return closed; + } + + @Override + public void close() throws IOException { + super.close(); + this.closed = true; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/StreamingHttpComponentsClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/StreamingHttpComponentsClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..877098962227f6e2902fe430e1253d2a2ea4279e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/StreamingHttpComponentsClientHttpRequestFactoryTests.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * @author Arjen Poutsma + */ +public class StreamingHttpComponentsClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(); + requestFactory.setBufferRequestBody(false); + return requestFactory; + } + + @Override + @Test + public void httpMethods() throws Exception { + assertHttpMethod("patch", HttpMethod.PATCH); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/StreamingSimpleClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/StreamingSimpleClientHttpRequestFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0c0be4cdb79deed37ca8e69bbab1052b58e52772 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/StreamingSimpleClientHttpRequestFactoryTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.OutputStream; +import java.net.URI; +import java.util.Collections; +import java.util.Random; + +import org.junit.Ignore; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class StreamingSimpleClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTestCase { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + SimpleClientHttpRequestFactory factory = new SimpleClientHttpRequestFactory(); + factory.setBufferRequestBody(false); + return factory; + } + + @Test // SPR-8809 + public void interceptor() throws Exception { + final String headerName = "MyHeader"; + final String headerValue = "MyValue"; + ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { + request.getHeaders().add(headerName, headerValue); + return execution.execute(request, body); + }; + InterceptingClientHttpRequestFactory factory = new InterceptingClientHttpRequestFactory( + createRequestFactory(), Collections.singletonList(interceptor)); + + ClientHttpResponse response = null; + try { + ClientHttpRequest request = factory.createRequest(new URI(baseUrl + "/echo"), HttpMethod.GET); + response = request.execute(); + assertEquals("Invalid response status", HttpStatus.OK, response.getStatusCode()); + HttpHeaders responseHeaders = response.getHeaders(); + assertEquals("Custom header invalid", headerValue, responseHeaders.getFirst(headerName)); + } + finally { + if (response != null) { + response.close(); + } + } + } + + @Test + @Ignore + public void largeFileUpload() throws Exception { + Random rnd = new Random(); + ClientHttpResponse response = null; + try { + ClientHttpRequest request = factory.createRequest(new URI(baseUrl + "/methods/post"), HttpMethod.POST); + final int BUF_SIZE = 4096; + final int ITERATIONS = Integer.MAX_VALUE / BUF_SIZE; + // final int contentLength = ITERATIONS * BUF_SIZE; + // request.getHeaders().setContentLength(contentLength); + OutputStream body = request.getBody(); + for (int i = 0; i < ITERATIONS; i++) { + byte[] buffer = new byte[BUF_SIZE]; + rnd.nextBytes(buffer); + body.write(buffer); + } + response = request.execute(); + assertEquals("Invalid response status", HttpStatus.OK, response.getStatusCode()); + } + finally { + if (response != null) { + response.close(); + } + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/reactive/ReactorResourceFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/reactive/ReactorResourceFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f5e64f871b38319d9e4c4b210b08989bd58783c9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/reactive/ReactorResourceFactoryTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.client.reactive; + +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Test; +import reactor.netty.http.HttpResources; +import reactor.netty.resources.ConnectionProvider; +import reactor.netty.resources.LoopResources; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link ReactorResourceFactory}. + * @author Rossen Stoyanchev + */ +public class ReactorResourceFactoryTests { + + private final ReactorResourceFactory resourceFactory = new ReactorResourceFactory(); + + private final ConnectionProvider connectionProvider = mock(ConnectionProvider.class); + + private final LoopResources loopResources = mock(LoopResources.class); + + + @Test + public void globalResources() throws Exception { + + this.resourceFactory.setUseGlobalResources(true); + this.resourceFactory.afterPropertiesSet(); + + HttpResources globalResources = HttpResources.get(); + assertSame(globalResources, this.resourceFactory.getConnectionProvider()); + assertSame(globalResources, this.resourceFactory.getLoopResources()); + assertFalse(globalResources.isDisposed()); + + this.resourceFactory.destroy(); + + assertTrue(globalResources.isDisposed()); + } + + @Test + public void globalResourcesWithConsumer() throws Exception { + + AtomicBoolean invoked = new AtomicBoolean(false); + + this.resourceFactory.addGlobalResourcesConsumer(httpResources -> invoked.set(true)); + this.resourceFactory.afterPropertiesSet(); + + assertTrue(invoked.get()); + this.resourceFactory.destroy(); + } + + @Test + public void localResources() throws Exception { + + this.resourceFactory.setUseGlobalResources(false); + this.resourceFactory.afterPropertiesSet(); + + ConnectionProvider connectionProvider = this.resourceFactory.getConnectionProvider(); + LoopResources loopResources = this.resourceFactory.getLoopResources(); + + assertNotSame(HttpResources.get(), connectionProvider); + assertNotSame(HttpResources.get(), loopResources); + + // The below does not work since ConnectionPoolProvider simply checks if pool is empty. + // assertFalse(connectionProvider.isDisposed()); + assertFalse(loopResources.isDisposed()); + + this.resourceFactory.destroy(); + + assertTrue(connectionProvider.isDisposed()); + assertTrue(loopResources.isDisposed()); + } + + @Test + public void localResourcesViaSupplier() throws Exception { + + this.resourceFactory.setUseGlobalResources(false); + this.resourceFactory.setConnectionProviderSupplier(() -> this.connectionProvider); + this.resourceFactory.setLoopResourcesSupplier(() -> this.loopResources); + this.resourceFactory.afterPropertiesSet(); + + ConnectionProvider connectionProvider = this.resourceFactory.getConnectionProvider(); + LoopResources loopResources = this.resourceFactory.getLoopResources(); + + assertSame(this.connectionProvider, connectionProvider); + assertSame(this.loopResources, loopResources); + + verifyNoMoreInteractions(this.connectionProvider, this.loopResources); + + this.resourceFactory.destroy(); + + // Managed (destroy disposes).. + verify(this.connectionProvider).disposeLater(); + verify(this.loopResources).disposeLater(); + verifyNoMoreInteractions(this.connectionProvider, this.loopResources); + } + + @Test + public void externalResources() throws Exception { + + this.resourceFactory.setUseGlobalResources(false); + this.resourceFactory.setConnectionProvider(this.connectionProvider); + this.resourceFactory.setLoopResources(this.loopResources); + this.resourceFactory.afterPropertiesSet(); + + ConnectionProvider connectionProvider = this.resourceFactory.getConnectionProvider(); + LoopResources loopResources = this.resourceFactory.getLoopResources(); + + assertSame(this.connectionProvider, connectionProvider); + assertSame(this.loopResources, loopResources); + + verifyNoMoreInteractions(this.connectionProvider, this.loopResources); + + this.resourceFactory.destroy(); + + // Not managed (destroy has no impact).. + verifyNoMoreInteractions(this.connectionProvider, this.loopResources); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/support/BasicAuthorizationInterceptorTests.java b/spring-web/src/test/java/org/springframework/http/client/support/BasicAuthorizationInterceptorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..bee37740d63dac00b3a5b7c05cd1ea87f9195f48 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/support/BasicAuthorizationInterceptorTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.net.URI; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.beans.DirectFieldAccessor; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestExecution; +import org.springframework.http.client.SimpleClientHttpRequestFactory; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link BasicAuthorizationInterceptor}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +public class BasicAuthorizationInterceptorTests { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void createWhenUsernameContainsColonShouldThrowException() { + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Username must not contain a colon"); + new BasicAuthorizationInterceptor("username:", "password"); + } + + @Test + public void createWhenUsernameIsNullShouldUseEmptyUsername() throws Exception { + BasicAuthorizationInterceptor interceptor = new BasicAuthorizationInterceptor( + null, "password"); + assertEquals("", new DirectFieldAccessor(interceptor).getPropertyValue("username")); + } + + @Test + public void createWhenPasswordIsNullShouldUseEmptyPassword() throws Exception { + BasicAuthorizationInterceptor interceptor = new BasicAuthorizationInterceptor( + "username", null); + assertEquals("", new DirectFieldAccessor(interceptor).getPropertyValue("password")); + } + + @Test + public void interceptShouldAddHeader() throws Exception { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + ClientHttpRequest request = requestFactory.createRequest(new URI("https://example.com"), HttpMethod.GET); + ClientHttpRequestExecution execution = mock(ClientHttpRequestExecution.class); + byte[] body = new byte[] {}; + new BasicAuthorizationInterceptor("spring", "boot").intercept(request, body, + execution); + verify(execution).execute(request, body); + assertEquals("Basic c3ByaW5nOmJvb3Q=", request.getHeaders().getFirst("Authorization")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/support/InterceptingHttpAccessorTests.java b/spring-web/src/test/java/org/springframework/http/client/support/InterceptingHttpAccessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..96f45e76db660df9aa3d34d4e68ba70160ea0000 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/support/InterceptingHttpAccessorTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.util.Arrays; +import java.util.List; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.http.HttpRequest; +import org.springframework.http.client.ClientHttpRequestExecution; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; + +import static org.junit.Assert.*; + +/** + * Tests for {@link InterceptingHttpAccessor}. + * + * @author Brian Clozel + */ +public class InterceptingHttpAccessorTests { + + @Test + public void getInterceptors() { + TestInterceptingHttpAccessor accessor = new TestInterceptingHttpAccessor(); + List interceptors = Arrays.asList( + new SecondClientHttpRequestInterceptor(), + new ThirdClientHttpRequestInterceptor(), + new FirstClientHttpRequestInterceptor() + + ); + accessor.setInterceptors(interceptors); + + assertThat(accessor.getInterceptors().get(0), Matchers.instanceOf(FirstClientHttpRequestInterceptor.class)); + assertThat(accessor.getInterceptors().get(1), Matchers.instanceOf(SecondClientHttpRequestInterceptor.class)); + assertThat(accessor.getInterceptors().get(2), Matchers.instanceOf(ThirdClientHttpRequestInterceptor.class)); + } + + + private class TestInterceptingHttpAccessor extends InterceptingHttpAccessor { + } + + + @Order(1) + private class FirstClientHttpRequestInterceptor implements ClientHttpRequestInterceptor { + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) { + return null; + } + } + + + private class SecondClientHttpRequestInterceptor implements ClientHttpRequestInterceptor, Ordered { + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) { + return null; + } + + @Override + public int getOrder() { + return 2; + } + } + + + private class ThirdClientHttpRequestInterceptor implements ClientHttpRequestInterceptor { + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) { + return null; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/support/ProxyFactoryBeanTests.java b/spring-web/src/test/java/org/springframework/http/client/support/ProxyFactoryBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..dfe42243995e9a3123e19a01de98a1784bf146a0 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/support/ProxyFactoryBeanTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client.support; + +import java.net.InetSocketAddress; +import java.net.Proxy; + +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class ProxyFactoryBeanTests { + + ProxyFactoryBean factoryBean; + + @Before + public void setUp() { + factoryBean = new ProxyFactoryBean(); + } + + @Test(expected = IllegalArgumentException.class) + public void noType() { + factoryBean.setType(null); + factoryBean.afterPropertiesSet(); + } + + @Test(expected = IllegalArgumentException.class) + public void noHostname() { + factoryBean.setHostname(""); + factoryBean.afterPropertiesSet(); + } + + @Test(expected = IllegalArgumentException.class) + public void noPort() { + factoryBean.setHostname("example.com"); + factoryBean.afterPropertiesSet(); + } + + @Test + public void normal() { + Proxy.Type type = Proxy.Type.HTTP; + factoryBean.setType(type); + String hostname = "example.com"; + factoryBean.setHostname(hostname); + int port = 8080; + factoryBean.setPort(port); + factoryBean.afterPropertiesSet(); + + Proxy result = factoryBean.getObject(); + + assertEquals(type, result.type()); + InetSocketAddress address = (InetSocketAddress) result.address(); + assertEquals(hostname, address.getHostName()); + assertEquals(port, address.getPort()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/CancelWithoutDemandCodecTests.java b/spring-web/src/test/java/org/springframework/http/codec/CancelWithoutDemandCodecTests.java new file mode 100644 index 0000000000000000000000000000000000000000..509e1355422f88ab24254e4aaa304fa960c888ec --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/CancelWithoutDemandCodecTests.java @@ -0,0 +1,230 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.codec; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; + +import com.google.protobuf.Message; +import org.junit.After; +import org.junit.Test; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.LeakAwareDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.codec.protobuf.ProtobufDecoder; +import org.springframework.http.codec.protobuf.ProtobufEncoder; +import org.springframework.http.codec.xml.Jaxb2XmlEncoder; +import org.springframework.protobuf.Msg; +import org.springframework.protobuf.SecondMsg; +import org.springframework.util.MimeType; + +/** + * Test scenarios for data buffer leaks. + * @author Rossen Stoyanchev + */ +public class CancelWithoutDemandCodecTests { + + private final LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); + + + @After + public void tearDown() throws Exception { + this.bufferFactory.checkForLeaks(); + } + + + @Test // gh-22107 + public void cancelWithEncoderHttpMessageWriterAndSingleValue() { + CharSequenceEncoder encoder = CharSequenceEncoder.allMimeTypes(); + HttpMessageWriter writer = new EncoderHttpMessageWriter<>(encoder); + CancellingOutputMessage outputMessage = new CancellingOutputMessage(this.bufferFactory); + + writer.write(Mono.just("foo"), ResolvableType.forType(String.class), MediaType.TEXT_PLAIN, + outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5)); + } + + @Test // gh-22107 + public void cancelWithJackson() { + Jackson2JsonEncoder encoder = new Jackson2JsonEncoder(); + + Flux flux = encoder.encode(Flux.just(new Pojo("foofoo", "barbar"), new Pojo("bar", "baz")), + this.bufferFactory, ResolvableType.forClass(Pojo.class), + MediaType.APPLICATION_JSON, Collections.emptyMap()); + + BaseSubscriber subscriber = new ZeroDemandSubscriber(); + flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. + subscriber.cancel(); + } + + @Test // gh-22107 + public void cancelWithJaxb2() { + Jaxb2XmlEncoder encoder = new Jaxb2XmlEncoder(); + + Flux flux = encoder.encode(Mono.just(new Pojo("foo", "bar")), + this.bufferFactory, ResolvableType.forClass(Pojo.class), + MediaType.APPLICATION_XML, Collections.emptyMap()); + + BaseSubscriber subscriber = new ZeroDemandSubscriber(); + flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. + subscriber.cancel(); + } + + @Test // gh-22543 + public void cancelWithProtobufEncoder() { + ProtobufEncoder encoder = new ProtobufEncoder(); + Msg msg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build(); + + Flux flux = encoder.encode(Mono.just(msg), + this.bufferFactory, ResolvableType.forClass(Msg.class), + new MimeType("application", "x-protobuf"), Collections.emptyMap()); + + BaseSubscriber subscriber = new ZeroDemandSubscriber(); + flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. + subscriber.cancel(); + } + + @Test // gh-22731 + public void cancelWithProtobufDecoder() throws InterruptedException { + ProtobufDecoder decoder = new ProtobufDecoder(); + + Mono input = Mono.fromCallable(() -> { + Msg msg = Msg.newBuilder().setFoo("Foo").build(); + byte[] bytes = msg.toByteArray(); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + }); + + Flux messages = decoder.decode(input, ResolvableType.forType(Msg.class), + new MimeType("application", "x-protobuf"), Collections.emptyMap()); + ZeroDemandMessageSubscriber subscriber = new ZeroDemandMessageSubscriber(); + messages.subscribe(subscriber); + subscriber.cancel(); + } + + @Test // gh-22107 + public void cancelWithMultipartContent() { + MultipartBodyBuilder builder = new MultipartBodyBuilder(); + builder.part("part1", "value1"); + builder.part("part2", "value2"); + + List> writers = ClientCodecConfigurer.create().getWriters(); + MultipartHttpMessageWriter writer = new MultipartHttpMessageWriter(writers); + CancellingOutputMessage outputMessage = new CancellingOutputMessage(this.bufferFactory); + + writer.write(Mono.just(builder.build()), null, MediaType.MULTIPART_FORM_DATA, + outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5)); + } + + @Test // gh-22107 + public void cancelWithSse() { + ServerSentEvent event = ServerSentEvent.builder().data("bar").id("c42").event("foo").build(); + ServerSentEventHttpMessageWriter writer = new ServerSentEventHttpMessageWriter(new Jackson2JsonEncoder()); + CancellingOutputMessage outputMessage = new CancellingOutputMessage(this.bufferFactory); + + writer.write(Mono.just(event), ResolvableType.forClass(ServerSentEvent.class), MediaType.TEXT_EVENT_STREAM, + outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5)); + } + + + + private static class CancellingOutputMessage implements ReactiveHttpOutputMessage { + + private final DataBufferFactory bufferFactory; + + + public CancellingOutputMessage(DataBufferFactory bufferFactory) { + this.bufferFactory = bufferFactory; + } + + + @Override + public DataBufferFactory bufferFactory() { + return this.bufferFactory; + } + + @Override + public void beforeCommit(Supplier> action) { + } + + @Override + public boolean isCommitted() { + return false; + } + + @Override + public Mono writeWith(Publisher body) { + Flux flux = Flux.from(body); + BaseSubscriber subscriber = new ZeroDemandSubscriber(); + flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. + subscriber.cancel(); + return Mono.empty(); + } + + @Override + public Mono writeAndFlushWith(Publisher> body) { + Flux flux = Flux.from(body).concatMap(Flux::from); + BaseSubscriber subscriber = new ZeroDemandSubscriber(); + flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. + subscriber.cancel(); + return Mono.empty(); + } + + @Override + public Mono setComplete() { + throw new UnsupportedOperationException(); + } + + @Override + public HttpHeaders getHeaders() { + return new HttpHeaders(); + } + } + + + private static class ZeroDemandSubscriber extends BaseSubscriber { + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } + } + + + private static class ZeroDemandMessageSubscriber extends BaseSubscriber { + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } + } +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/EncoderHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/EncoderHttpMessageWriterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8d54ea7157cfa128b55c5838c1db1852093c9532 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/EncoderHttpMessageWriterTests.java @@ -0,0 +1,207 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.ReflectionUtils; + +import static java.nio.charset.StandardCharsets.*; +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.springframework.core.ResolvableType.*; +import static org.springframework.http.MediaType.*; + +/** + * Unit tests for {@link EncoderHttpMessageWriter}. + * @author Rossen Stoyanchev + * @author Brian Clozel + */ +public class EncoderHttpMessageWriterTests { + + private static final Map NO_HINTS = Collections.emptyMap(); + + private static final MediaType TEXT_PLAIN_UTF_8 = new MediaType("text", "plain", UTF_8); + + + @Mock + private HttpMessageEncoder encoder; + + private ArgumentCaptor mediaTypeCaptor; + + private MockServerHttpResponse response; + + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + this.mediaTypeCaptor = ArgumentCaptor.forClass(MediaType.class); + this.response = new MockServerHttpResponse(); + } + + + @Test + public void getWritableMediaTypes() { + HttpMessageWriter writer = getWriter(MimeTypeUtils.TEXT_HTML, MimeTypeUtils.TEXT_XML); + assertEquals(Arrays.asList(TEXT_HTML, TEXT_XML), writer.getWritableMediaTypes()); + } + + @Test + public void canWrite() { + HttpMessageWriter writer = getWriter(MimeTypeUtils.TEXT_HTML); + when(this.encoder.canEncode(forClass(String.class), TEXT_HTML)).thenReturn(true); + + assertTrue(writer.canWrite(forClass(String.class), TEXT_HTML)); + assertFalse(writer.canWrite(forClass(String.class), TEXT_XML)); + } + + @Test + public void useNegotiatedMediaType() { + HttpMessageWriter writer = getWriter(MimeTypeUtils.ALL); + writer.write(Mono.just("body"), forClass(String.class), TEXT_PLAIN, this.response, NO_HINTS); + + assertEquals(TEXT_PLAIN, response.getHeaders().getContentType()); + assertEquals(TEXT_PLAIN, this.mediaTypeCaptor.getValue()); + } + + @Test + public void useDefaultMediaType() { + testDefaultMediaType(null); + testDefaultMediaType(new MediaType("text", "*")); + testDefaultMediaType(new MediaType("*", "*")); + testDefaultMediaType(MediaType.APPLICATION_OCTET_STREAM); + } + + private void testDefaultMediaType(MediaType negotiatedMediaType) { + + this.mediaTypeCaptor = ArgumentCaptor.forClass(MediaType.class); + + MimeType defaultContentType = MimeTypeUtils.TEXT_XML; + HttpMessageWriter writer = getWriter(defaultContentType); + writer.write(Mono.just("body"), forClass(String.class), negotiatedMediaType, this.response, NO_HINTS); + + assertEquals(defaultContentType, this.response.getHeaders().getContentType()); + assertEquals(defaultContentType, this.mediaTypeCaptor.getValue()); + } + + @Test + public void useDefaultMediaTypeCharset() { + HttpMessageWriter writer = getWriter(TEXT_PLAIN_UTF_8, TEXT_HTML); + writer.write(Mono.just("body"), forClass(String.class), TEXT_HTML, response, NO_HINTS); + + assertEquals(new MediaType("text", "html", UTF_8), this.response.getHeaders().getContentType()); + assertEquals(new MediaType("text", "html", UTF_8), this.mediaTypeCaptor.getValue()); + } + + @Test + public void useNegotiatedMediaTypeCharset() { + MediaType negotiatedMediaType = new MediaType("text", "html", ISO_8859_1); + HttpMessageWriter writer = getWriter(TEXT_PLAIN_UTF_8, TEXT_HTML); + writer.write(Mono.just("body"), forClass(String.class), negotiatedMediaType, this.response, NO_HINTS); + + assertEquals(negotiatedMediaType, this.response.getHeaders().getContentType()); + assertEquals(negotiatedMediaType, this.mediaTypeCaptor.getValue()); + } + + @Test + public void useHttpOutputMessageMediaType() { + MediaType outputMessageMediaType = MediaType.TEXT_HTML; + this.response.getHeaders().setContentType(outputMessageMediaType); + + HttpMessageWriter writer = getWriter(TEXT_PLAIN_UTF_8, TEXT_HTML); + writer.write(Mono.just("body"), forClass(String.class), TEXT_PLAIN, this.response, NO_HINTS); + + assertEquals(outputMessageMediaType, this.response.getHeaders().getContentType()); + assertEquals(outputMessageMediaType, this.mediaTypeCaptor.getValue()); + } + + @Test + public void setContentLengthForMonoBody() { + DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); + DataBuffer buffer = factory.wrap("body".getBytes(StandardCharsets.UTF_8)); + HttpMessageWriter writer = getWriter(Flux.just(buffer), MimeTypeUtils.TEXT_PLAIN); + writer.write(Mono.just("body"), forClass(String.class), TEXT_PLAIN, this.response, NO_HINTS).block(); + + assertEquals(4, this.response.getHeaders().getContentLength()); + } + + @Test // gh-22952 + public void monoBodyDoesNotCancelEncodedFlux() { + Mono inputStream = Mono.just("body") + .doOnCancel(() -> { + throw new AssertionError("Cancel signal not expected"); + }); + new EncoderHttpMessageWriter<>(CharSequenceEncoder.allMimeTypes()) + .write(inputStream, forClass(String.class), TEXT_PLAIN, this.response, NO_HINTS) + .block(); + } + + @Test // SPR-17220 + public void emptyBodyWritten() { + HttpMessageWriter writer = getWriter(MimeTypeUtils.TEXT_PLAIN); + writer.write(Mono.empty(), forClass(String.class), TEXT_PLAIN, this.response, NO_HINTS).block(); + StepVerifier.create(this.response.getBody()).expectComplete(); + assertEquals(0, this.response.getHeaders().getContentLength()); + } + + @Test // gh-22936 + public void isStreamingMediaType() throws InvocationTargetException, IllegalAccessException { + HttpMessageWriter writer = getWriter(TEXT_HTML); + MediaType streamingMediaType = new MediaType(TEXT_PLAIN, Collections.singletonMap("streaming", "true")); + when(this.encoder.getStreamingMediaTypes()).thenReturn(Arrays.asList(streamingMediaType)); + Method method = ReflectionUtils.findMethod(writer.getClass(), "isStreamingMediaType", MediaType.class); + ReflectionUtils.makeAccessible(method); + assertTrue((Boolean) method.invoke(writer, streamingMediaType)); + assertFalse((Boolean) method.invoke(writer, new MediaType(TEXT_PLAIN, Collections.singletonMap("streaming", "false")))); + assertFalse((Boolean) method.invoke(writer, TEXT_HTML)); + } + + private HttpMessageWriter getWriter(MimeType... mimeTypes) { + return getWriter(Flux.empty(), mimeTypes); + } + + private HttpMessageWriter getWriter(Flux encodedStream, MimeType... mimeTypes) { + List typeList = Arrays.asList(mimeTypes); + when(this.encoder.getEncodableMimeTypes()).thenReturn(typeList); + when(this.encoder.encode(any(), any(), any(), this.mediaTypeCaptor.capture(), any())).thenReturn(encodedStream); + return new EncoderHttpMessageWriter<>(this.encoder); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/FormHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/FormHttpMessageReaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..10e666a589d1457072884568448b29b993dd3dd8 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/FormHttpMessageReaderTests.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.AbstractLeakCheckingTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + +/** + * @author Sebastien Deleuze + */ +public class FormHttpMessageReaderTests extends AbstractLeakCheckingTestCase { + + private final FormHttpMessageReader reader = new FormHttpMessageReader(); + + + @Test + public void canRead() { + assertTrue(this.reader.canRead( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.APPLICATION_FORM_URLENCODED)); + + assertTrue(this.reader.canRead( + ResolvableType.forInstance(new LinkedMultiValueMap()), + MediaType.APPLICATION_FORM_URLENCODED)); + + assertFalse(this.reader.canRead( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + MediaType.APPLICATION_FORM_URLENCODED)); + + assertFalse(this.reader.canRead( + ResolvableType.forClassWithGenerics(MultiValueMap.class, Object.class, String.class), + MediaType.APPLICATION_FORM_URLENCODED)); + + assertFalse(this.reader.canRead( + ResolvableType.forClassWithGenerics(Map.class, String.class, String.class), + MediaType.APPLICATION_FORM_URLENCODED)); + + assertFalse(this.reader.canRead( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.MULTIPART_FORM_DATA)); + } + + @Test + public void readFormAsMono() { + String body = "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3"; + MockServerHttpRequest request = request(body); + MultiValueMap result = this.reader.readMono(null, request, null).block(); + + assertEquals("Invalid result", 3, result.size()); + assertEquals("Invalid result", "value 1", result.getFirst("name 1")); + List values = result.get("name 2"); + assertEquals("Invalid result", 2, values.size()); + assertEquals("Invalid result", "value 2+1", values.get(0)); + assertEquals("Invalid result", "value 2+2", values.get(1)); + assertNull("Invalid result", result.getFirst("name 3")); + } + + @Test + public void readFormAsFlux() { + String body = "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3"; + MockServerHttpRequest request = request(body); + MultiValueMap result = this.reader.read(null, request, null).single().block(); + + assertEquals("Invalid result", 3, result.size()); + assertEquals("Invalid result", "value 1", result.getFirst("name 1")); + List values = result.get("name 2"); + assertEquals("Invalid result", 2, values.size()); + assertEquals("Invalid result", "value 2+1", values.get(0)); + assertEquals("Invalid result", "value 2+2", values.get(1)); + assertNull("Invalid result", result.getFirst("name 3")); + } + + @Test + public void readFormError() { + DataBuffer fooBuffer = stringBuffer("name=value"); + Flux body = + Flux.just(fooBuffer).concatWith(Flux.error(new RuntimeException())); + MockServerHttpRequest request = request(body); + + Flux> result = this.reader.read(null, request, null); + StepVerifier.create(result) + .expectError() + .verify(); + } + + + private MockServerHttpRequest request(String body) { + return request(Mono.just(stringBuffer(body))); + } + + private MockServerHttpRequest request(Publisher body) { + return MockServerHttpRequest + .method(HttpMethod.GET, "/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .body(body); + } + + private DataBuffer stringBuffer(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/FormHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/FormHttpMessageWriterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7f23f769b3cbaf6c40eceecf912c49d7a5708fc7 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/FormHttpMessageWriterTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.function.Consumer; + +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.AbstractLeakCheckingTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + +/** + * @author Sebastien Deleuze + */ +public class FormHttpMessageWriterTests extends AbstractLeakCheckingTestCase { + + private final FormHttpMessageWriter writer = new FormHttpMessageWriter(); + + + @Test + public void canWrite() { + assertTrue(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.APPLICATION_FORM_URLENCODED)); + + // No generic information + assertTrue(this.writer.canWrite( + ResolvableType.forInstance(new LinkedMultiValueMap()), + MediaType.APPLICATION_FORM_URLENCODED)); + + assertFalse(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + null)); + + assertFalse(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, Object.class, String.class), + null)); + + assertFalse(this.writer.canWrite( + ResolvableType.forClassWithGenerics(Map.class, String.class, String.class), + MediaType.APPLICATION_FORM_URLENCODED)); + + assertFalse(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.MULTIPART_FORM_DATA)); + } + + @Test + public void writeForm() { + MultiValueMap body = new LinkedMultiValueMap<>(); + body.set("name 1", "value 1"); + body.add("name 2", "value 2+1"); + body.add("name 2", "value 2+2"); + body.add("name 3", null); + MockServerHttpResponse response = new MockServerHttpResponse(this.bufferFactory); + this.writer.write(Mono.just(body), null, MediaType.APPLICATION_FORM_URLENCODED, response, null).block(); + + String expected = "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3"; + StepVerifier.create(response.getBody()) + .consumeNextWith(stringConsumer(expected)) + .expectComplete() + .verify(); + HttpHeaders headers = response.getHeaders(); + assertEquals("application/x-www-form-urlencoded;charset=UTF-8", headers.getContentType().toString()); + assertEquals(expected.length(), headers.getContentLength()); + } + + private Consumer stringConsumer(String expected) { + return dataBuffer -> { + String value = DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8); + DataBufferUtils.release(dataBuffer); + assertEquals(expected, value); + }; + } + + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/Pojo.java b/spring-web/src/test/java/org/springframework/http/codec/Pojo.java new file mode 100644 index 0000000000000000000000000000000000000000..4db33fefbbe35bc8e9752442ab90ff3c3219c9bc --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/Pojo.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import javax.xml.bind.annotation.XmlRootElement; + +/** + * @author Sebastien Deleuze + */ +@XmlRootElement +public class Pojo { + + private String foo; + + private String bar; + + public Pojo() { + } + + public Pojo(String foo, String bar) { + this.foo = foo; + this.bar = bar; + } + + public String getFoo() { + return this.foo; + } + + public void setFoo(String foo) { + this.foo = foo; + } + + public String getBar() { + return this.bar; + } + + public void setBar(String bar) { + this.bar = bar; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof Pojo) { + Pojo other = (Pojo) o; + return this.foo.equals(other.foo) && this.bar.equals(other.bar); + } + return false; + } + + @Override + public int hashCode() { + return 31 * foo.hashCode() + bar.hashCode(); + } + + @Override + public String toString() { + return "Pojo[foo='" + this.foo + "\'" + ", bar='" + this.bar + "\']"; + } +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/ResourceHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/ResourceHttpMessageWriterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5b95cc53a10effa3f3dc046ef8efeb63c350b354 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/ResourceHttpMessageWriterTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; + +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRange; +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.startsWith; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThat; +import static org.springframework.http.MediaType.TEXT_PLAIN; +import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.get; + +/** + * Unit tests for {@link ResourceHttpMessageWriter}. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + */ +public class ResourceHttpMessageWriterTests { + + private static final Map HINTS = Collections.emptyMap(); + + + private final ResourceHttpMessageWriter writer = new ResourceHttpMessageWriter(); + + private final MockServerHttpResponse response = new MockServerHttpResponse(); + + private final Mono input = Mono.just(new ByteArrayResource( + "Spring Framework test resource content.".getBytes(StandardCharsets.UTF_8))); + + + @Test + public void getWritableMediaTypes() throws Exception { + assertThat(this.writer.getWritableMediaTypes(), + containsInAnyOrder(MimeTypeUtils.APPLICATION_OCTET_STREAM, MimeTypeUtils.ALL)); + } + + @Test + public void writeResource() throws Exception { + + testWrite(get("/").build()); + + assertThat(this.response.getHeaders().getContentType(), is(TEXT_PLAIN)); + assertThat(this.response.getHeaders().getContentLength(), is(39L)); + assertThat(this.response.getHeaders().getFirst(HttpHeaders.ACCEPT_RANGES), is("bytes")); + + String content = "Spring Framework test resource content."; + StepVerifier.create(this.response.getBodyAsString()).expectNext(content).expectComplete().verify(); + } + + @Test + public void writeSingleRegion() throws Exception { + + testWrite(get("/").range(of(0, 5)).build()); + + assertThat(this.response.getHeaders().getContentType(), is(TEXT_PLAIN)); + assertThat(this.response.getHeaders().getFirst(HttpHeaders.CONTENT_RANGE), is("bytes 0-5/39")); + assertThat(this.response.getHeaders().getContentLength(), is(6L)); + + StepVerifier.create(this.response.getBodyAsString()).expectNext("Spring").expectComplete().verify(); + } + + @Test + public void writeMultipleRegions() throws Exception { + + testWrite(get("/").range(of(0,5), of(7,15), of(17,20), of(22,38)).build()); + + HttpHeaders headers = this.response.getHeaders(); + String contentType = headers.getContentType().toString(); + String boundary = contentType.substring(30); + + assertThat(contentType, startsWith("multipart/byteranges;boundary=")); + + StepVerifier.create(this.response.getBodyAsString()) + .consumeNextWith(content -> { + String[] actualRanges = StringUtils.tokenizeToStringArray(content, "\r\n", false, true); + String[] expected = new String[] { + "--" + boundary, + "Content-Type: text/plain", + "Content-Range: bytes 0-5/39", + "Spring", + "--" + boundary, + "Content-Type: text/plain", + "Content-Range: bytes 7-15/39", + "Framework", + "--" + boundary, + "Content-Type: text/plain", + "Content-Range: bytes 17-20/39", + "test", + "--" + boundary, + "Content-Type: text/plain", + "Content-Range: bytes 22-38/39", + "resource content.", + "--" + boundary + "--" + }; + assertArrayEquals(expected, actualRanges); + }) + .expectComplete() + .verify(); + } + + @Test + public void invalidRange() throws Exception { + + testWrite(get("/").header(HttpHeaders.RANGE, "invalid").build()); + + assertThat(this.response.getHeaders().getFirst(HttpHeaders.ACCEPT_RANGES), is("bytes")); + assertThat(this.response.getStatusCode(), is(HttpStatus.REQUESTED_RANGE_NOT_SATISFIABLE)); + } + + + private void testWrite(MockServerHttpRequest request) { + Mono mono = this.writer.write(this.input, null, null, TEXT_PLAIN, request, this.response, HINTS); + StepVerifier.create(mono).expectComplete().verify(); + } + + private static HttpRange of(int first, int last) { + return HttpRange.createByteRange(first, last); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageReaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3e3a417fc52d36eb52d757607bc5cb1230b87c96 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageReaderTests.java @@ -0,0 +1,200 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; + +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.AbstractLeakCheckingTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.MediaType; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link ServerSentEventHttpMessageReader}. + * + * @author Sebastien Deleuze + */ +public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingTestCase { + + private ServerSentEventHttpMessageReader messageReader = + new ServerSentEventHttpMessageReader(new Jackson2JsonDecoder()); + + + @Test + public void cantRead() { + assertFalse(messageReader.canRead(ResolvableType.forClass(Object.class), new MediaType("foo", "bar"))); + assertFalse(messageReader.canRead(ResolvableType.forClass(Object.class), null)); + } + + @Test + public void canRead() { + assertTrue(messageReader.canRead(ResolvableType.forClass(Object.class), new MediaType("text", "event-stream"))); + assertTrue(messageReader.canRead(ResolvableType.forClass(ServerSentEvent.class), new MediaType("foo", "bar"))); + } + + @Test + public void readServerSentEvents() { + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .body(Mono.just(stringBuffer( + "id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:bar\n\n" + + "id:c43\nevent:bar\nretry:456\ndata:baz\n\n"))); + + Flux events = this.messageReader + .read(ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class), + request, Collections.emptyMap()).cast(ServerSentEvent.class); + + StepVerifier.create(events) + .consumeNextWith(event -> { + assertEquals("c42", event.id()); + assertEquals("foo", event.event()); + assertEquals(Duration.ofMillis(123), event.retry()); + assertEquals("bla\nbla bla\nbla bla bla", event.comment()); + assertEquals("bar", event.data()); + }) + .consumeNextWith(event -> { + assertEquals("c43", event.id()); + assertEquals("bar", event.event()); + assertEquals(Duration.ofMillis(456), event.retry()); + assertNull(event.comment()); + assertEquals("baz", event.data()); + }) + .expectComplete() + .verify(); + } + + @Test + public void readServerSentEventsWithMultipleChunks() { + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .body(Flux.just( + stringBuffer("id:c42\nev"), + stringBuffer("ent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:"), + stringBuffer("bar\n\nid:c43\nevent:bar\nretry:456\ndata:baz\n\n"))); + + Flux events = messageReader + .read(ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class), + request, Collections.emptyMap()).cast(ServerSentEvent.class); + + StepVerifier.create(events) + .consumeNextWith(event -> { + assertEquals("c42", event.id()); + assertEquals("foo", event.event()); + assertEquals(Duration.ofMillis(123), event.retry()); + assertEquals("bla\nbla bla\nbla bla bla", event.comment()); + assertEquals("bar", event.data()); + }) + .consumeNextWith(event -> { + assertEquals("c43", event.id()); + assertEquals("bar", event.event()); + assertEquals(Duration.ofMillis(456), event.retry()); + assertNull(event.comment()); + assertEquals("baz", event.data()); + }) + .expectComplete() + .verify(); + } + + @Test + public void readString() { + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .body(Mono.just(stringBuffer("data:foo\ndata:bar\n\ndata:baz\n\n"))); + + Flux data = messageReader.read(ResolvableType.forClass(String.class), + request, Collections.emptyMap()).cast(String.class); + + StepVerifier.create(data) + .expectNextMatches(elem -> elem.equals("foo\nbar")) + .expectNextMatches(elem -> elem.equals("baz")) + .expectComplete() + .verify(); + } + + @Test + public void readPojo() { + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .body(Mono.just(stringBuffer( + "data:{\"foo\": \"foofoo\", \"bar\": \"barbar\"}\n\n" + + "data:{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}\n\n"))); + + Flux data = messageReader.read(ResolvableType.forClass(Pojo.class), request, + Collections.emptyMap()).cast(Pojo.class); + + StepVerifier.create(data) + .consumeNextWith(pojo -> { + assertEquals("foofoo", pojo.getFoo()); + assertEquals("barbar", pojo.getBar()); + }) + .consumeNextWith(pojo -> { + assertEquals("foofoofoo", pojo.getFoo()); + assertEquals("barbarbar", pojo.getBar()); + }) + .expectComplete() + .verify(); + } + + @Test // SPR-15331 + public void decodeFullContentAsString() { + String body = "data:foo\ndata:bar\n\ndata:baz\n\n"; + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .body(Mono.just(stringBuffer(body))); + + String actual = messageReader + .readMono(ResolvableType.forClass(String.class), request, Collections.emptyMap()) + .cast(String.class) + .block(Duration.ZERO); + + assertEquals(body, actual); + } + + @Test + public void readError() { + Flux body = + Flux.just(stringBuffer("data:foo\ndata:bar\n\ndata:baz\n\n")) + .concatWith(Flux.error(new RuntimeException())); + + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .body(body); + + Flux data = messageReader.read(ResolvableType.forClass(String.class), + request, Collections.emptyMap()).cast(String.class); + + StepVerifier.create(data) + .expectNextMatches(elem -> elem.equals("foo\nbar")) + .expectNextMatches(elem -> elem.equals("baz")) + .expectError() + .verify(); + } + + private DataBuffer stringBuffer(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + } + + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..10eb6bffee8b860c3dabe13ae243aba49798a5ec --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java @@ -0,0 +1,245 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; +import java.util.Map; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.AbstractDataBufferAllocatingTestCase; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; +import org.springframework.http.MediaType; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; + +import static org.junit.Assert.*; +import static org.springframework.core.ResolvableType.forClass; + +/** + * Unit tests for {@link ServerSentEventHttpMessageWriter}. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +@SuppressWarnings("rawtypes") +public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAllocatingTestCase { + + private static final Map HINTS = Collections.emptyMap(); + + private ServerSentEventHttpMessageWriter messageWriter = + new ServerSentEventHttpMessageWriter(new Jackson2JsonEncoder()); + + private MockServerHttpResponse outputMessage; + + + @Before + public void setUp() { + this.outputMessage = new MockServerHttpResponse(this.bufferFactory); + } + + + + @Test + public void canWrite() { + assertTrue(this.messageWriter.canWrite(forClass(Object.class), null)); + assertFalse(this.messageWriter.canWrite(forClass(Object.class), new MediaType("foo", "bar"))); + + assertTrue(this.messageWriter.canWrite(null, MediaType.TEXT_EVENT_STREAM)); + assertTrue(this.messageWriter.canWrite(forClass(ServerSentEvent.class), new MediaType("foo", "bar"))); + + // SPR-15464 + assertTrue(this.messageWriter.canWrite(ResolvableType.NONE, MediaType.TEXT_EVENT_STREAM)); + assertFalse(this.messageWriter.canWrite(ResolvableType.NONE, new MediaType("foo", "bar"))); + } + + @Test + public void writeServerSentEvent() { + ServerSentEvent event = ServerSentEvent.builder().data("bar").id("c42").event("foo") + .comment("bla\nbla bla\nbla bla bla").retry(Duration.ofMillis(123L)).build(); + + Mono source = Mono.just(event); + testWrite(source, outputMessage, ServerSentEvent.class); + + StepVerifier.create(outputMessage.getBody()) + .consumeNextWith(stringConsumer("id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:")) + .consumeNextWith(stringConsumer("bar\n")) + .consumeNextWith(stringConsumer("\n")) + .expectComplete() + .verify(); + } + + @Test + public void writeString() { + Flux source = Flux.just("foo", "bar"); + testWrite(source, outputMessage, String.class); + + StepVerifier.create(outputMessage.getBody()) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("foo\n")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("bar\n")) + .consumeNextWith(stringConsumer("\n")) + .expectComplete() + .verify(); + } + + @Test + public void writeMultiLineString() { + Flux source = Flux.just("foo\nbar", "foo\nbaz"); + testWrite(source, outputMessage, String.class); + + StepVerifier.create(outputMessage.getBody()) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("foo\ndata:bar\n")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("foo\ndata:baz\n")) + .consumeNextWith(stringConsumer("\n")) + .expectComplete() + .verify(); + } + + @Test // SPR-16516 + public void writeStringWithCustomCharset() { + Flux source = Flux.just("\u00A3"); + Charset charset = StandardCharsets.ISO_8859_1; + MediaType mediaType = new MediaType("text", "event-stream", charset); + testWrite(source, mediaType, outputMessage, String.class); + + assertEquals(mediaType, outputMessage.getHeaders().getContentType()); + StepVerifier.create(outputMessage.getBody()) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(dataBuffer -> { + String value = + DataBufferTestUtils.dumpString(dataBuffer, charset); + DataBufferUtils.release(dataBuffer); + assertEquals("\u00A3\n", value); + }) + .consumeNextWith(stringConsumer("\n")) + .expectComplete() + .verify(); + } + + @Test + public void writePojo() { + Flux source = Flux.just(new Pojo("foofoo", "barbar"), new Pojo("foofoofoo", "barbarbar")); + testWrite(source, outputMessage, Pojo.class); + + StepVerifier.create(outputMessage.getBody()) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("\n")) + .expectComplete() + .verify(); + } + + @Test // SPR-14899 + public void writePojoWithPrettyPrint() { + ObjectMapper mapper = Jackson2ObjectMapperBuilder.json().indentOutput(true).build(); + this.messageWriter = new ServerSentEventHttpMessageWriter(new Jackson2JsonEncoder(mapper)); + + Flux source = Flux.just(new Pojo("foofoo", "barbar"), new Pojo("foofoofoo", "barbarbar")); + testWrite(source, outputMessage, Pojo.class); + + StepVerifier.create(outputMessage.getBody()) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("{\n" + + "data: \"foo\" : \"foofoo\",\n" + + "data: \"bar\" : \"barbar\"\n" + "data:}")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("data:")) + .consumeNextWith(stringConsumer("{\n" + + "data: \"foo\" : \"foofoofoo\",\n" + + "data: \"bar\" : \"barbarbar\"\n" + "data:}")) + .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("\n")) + .expectComplete() + .verify(); + } + + @Test // SPR-16516, SPR-16539 + public void writePojoWithCustomEncoding() { + Flux source = Flux.just(new Pojo("foo\uD834\uDD1E", "bar\uD834\uDD1E")); + Charset charset = StandardCharsets.UTF_16LE; + MediaType mediaType = new MediaType("text", "event-stream", charset); + testWrite(source, mediaType, outputMessage, Pojo.class); + + assertEquals(mediaType, outputMessage.getHeaders().getContentType()); + StepVerifier.create(outputMessage.getBody()) + .consumeNextWith(dataBuffer1 -> { + String value1 = + DataBufferTestUtils.dumpString(dataBuffer1, charset); + DataBufferUtils.release(dataBuffer1); + assertEquals("data:", value1); + }) + .consumeNextWith(dataBuffer -> { + String value = DataBufferTestUtils.dumpString(dataBuffer, charset); + DataBufferUtils.release(dataBuffer); + assertEquals("{\"foo\":\"foo\uD834\uDD1E\",\"bar\":\"bar\uD834\uDD1E\"}", value); + }) + .consumeNextWith(dataBuffer2 -> { + String value2 = + DataBufferTestUtils.dumpString(dataBuffer2, charset); + DataBufferUtils.release(dataBuffer2); + assertEquals("\n", value2); + }) + .consumeNextWith(dataBuffer3 -> { + String value3 = + DataBufferTestUtils.dumpString(dataBuffer3, charset); + DataBufferUtils.release(dataBuffer3); + assertEquals("\n", value3); + }) + .expectComplete() + .verify(); + } + + + private void testWrite(Publisher source, MockServerHttpResponse response, Class clazz) { + testWrite(source, MediaType.TEXT_EVENT_STREAM, response, clazz); + } + + private void testWrite( + Publisher source, MediaType mediaType, MockServerHttpResponse response, Class clazz) { + + Mono result = + this.messageWriter.write(source, forClass(clazz), mediaType, response, HINTS); + + StepVerifier.create(result) + .verifyComplete(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2JsonDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2JsonDecoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..eea21cce103c1d7bea1413d04a50643d78c3abfa --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2JsonDecoderTests.java @@ -0,0 +1,361 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDecoderTestCase; +import org.springframework.core.codec.CodecException; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.MediaType; +import org.springframework.http.codec.Pojo; +import org.springframework.util.MimeType; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.springframework.core.ResolvableType.forClass; +import static org.springframework.http.MediaType.APPLICATION_JSON; +import static org.springframework.http.MediaType.APPLICATION_JSON_UTF8; +import static org.springframework.http.MediaType.APPLICATION_STREAM_JSON; +import static org.springframework.http.MediaType.APPLICATION_XML; +import static org.springframework.http.codec.json.Jackson2JsonDecoder.JSON_VIEW_HINT; +import static org.springframework.http.codec.json.JacksonViewBean.MyJacksonView1; +import static org.springframework.http.codec.json.JacksonViewBean.MyJacksonView3; + +/** + * Unit tests for {@link Jackson2JsonDecoder}. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class Jackson2JsonDecoderTests extends AbstractDecoderTestCase { + + private Pojo pojo1 = new Pojo("f1", "b1"); + + private Pojo pojo2 = new Pojo("f2", "b2"); + + + public Jackson2JsonDecoderTests() { + super(new Jackson2JsonDecoder()); + } + + + @Override + @Test + public void canDecode() { + assertTrue(decoder.canDecode(forClass(Pojo.class), APPLICATION_JSON)); + assertTrue(decoder.canDecode(forClass(Pojo.class), APPLICATION_JSON_UTF8)); + assertTrue(decoder.canDecode(forClass(Pojo.class), APPLICATION_STREAM_JSON)); + assertTrue(decoder.canDecode(forClass(Pojo.class), null)); + + assertFalse(decoder.canDecode(forClass(String.class), null)); + assertFalse(decoder.canDecode(forClass(Pojo.class), APPLICATION_XML)); + assertTrue(this.decoder.canDecode(forClass(Pojo.class), + new MediaType("application", "json", StandardCharsets.UTF_8))); + assertTrue(this.decoder.canDecode(forClass(Pojo.class), + new MediaType("application", "json", StandardCharsets.US_ASCII))); + assertTrue(this.decoder.canDecode(forClass(Pojo.class), + new MediaType("application", "json", StandardCharsets.ISO_8859_1))); + + } + + @Test // SPR-15866 + public void canDecodeWithProvidedMimeType() { + MimeType textJavascript = new MimeType("text", "javascript", StandardCharsets.UTF_8); + Jackson2JsonDecoder decoder = new Jackson2JsonDecoder(new ObjectMapper(), textJavascript); + + assertEquals(Collections.singletonList(textJavascript), decoder.getDecodableMimeTypes()); + } + + @Test(expected = UnsupportedOperationException.class) + public void decodableMimeTypesIsImmutable() { + MimeType textJavascript = new MimeType("text", "javascript", StandardCharsets.UTF_8); + Jackson2JsonDecoder decoder = new Jackson2JsonDecoder(new ObjectMapper(), textJavascript); + + decoder.getMimeTypes().add(new MimeType("text", "ecmascript")); + } + + @Override + @Test + public void decode() { + Flux input = Flux.concat( + stringBuffer("[{\"bar\":\"b1\",\"foo\":\"f1\"},"), + stringBuffer("{\"bar\":\"b2\",\"foo\":\"f2\"}]")); + + testDecodeAll(input, Pojo.class, step -> step + .expectNext(pojo1) + .expectNext(pojo2) + .verifyComplete()); + } + + @Override + public void decodeToMono() { + Flux input = Flux.concat( + stringBuffer("[{\"bar\":\"b1\",\"foo\":\"f1\"},"), + stringBuffer("{\"bar\":\"b2\",\"foo\":\"f2\"}]")); + + ResolvableType elementType = ResolvableType.forClassWithGenerics(List.class, Pojo.class); + + testDecodeToMonoAll(input, elementType, step -> step + .expectNext(asList(new Pojo("f1", "b1"), new Pojo("f2", "b2"))) + .expectComplete() + .verify(), null, null); + } + + + @Test + public void decodeEmptyArrayToFlux() { + Flux input = Flux.from(stringBuffer("[]")); + + testDecode(input, Pojo.class, step -> step.verifyComplete()); + } + + @Test + public void fieldLevelJsonView() { + Flux input = Flux.from( + stringBuffer("{\"withView1\" : \"with\", \"withView2\" : \"with\", \"withoutView\" : \"without\"}")); + ResolvableType elementType = forClass(JacksonViewBean.class); + Map hints = singletonMap(JSON_VIEW_HINT, MyJacksonView1.class); + + testDecode(input, elementType, step -> step + .consumeNextWith(o -> { + JacksonViewBean b = (JacksonViewBean) o; + assertEquals("with", b.getWithView1()); + assertNull(b.getWithView2()); + assertNull(b.getWithoutView()); + }), null, hints); + } + + @Test + public void classLevelJsonView() { + Flux input = Flux.from(stringBuffer( + "{\"withView1\" : \"with\", \"withView2\" : \"with\", \"withoutView\" : \"without\"}")); + ResolvableType elementType = forClass(JacksonViewBean.class); + Map hints = singletonMap(JSON_VIEW_HINT, MyJacksonView3.class); + + testDecode(input, elementType, step -> step + .consumeNextWith(o -> { + JacksonViewBean b = (JacksonViewBean) o; + assertEquals("without", b.getWithoutView()); + assertNull(b.getWithView1()); + assertNull(b.getWithView2()); + }) + .verifyComplete(), null, hints); + } + + @Test + public void invalidData() { + Flux input = + Flux.from(stringBuffer("{\"foofoo\": \"foofoo\", \"barbar\": \"barbar\"")); + testDecode(input, Pojo.class, step -> step + .verifyError(DecodingException.class)); + } + + @Test // gh-22042 + public void decodeWithNullLiteral() { + Flux result = this.decoder.decode(Flux.concat(stringBuffer("null")), + ResolvableType.forType(Pojo.class), MediaType.APPLICATION_JSON, Collections.emptyMap()); + + StepVerifier.create(result).expectComplete().verify(); + } + + @Test + public void noDefaultConstructor() { + Flux input = + Flux.from(stringBuffer("{\"property1\":\"foo\",\"property2\":\"bar\"}")); + ResolvableType elementType = forClass(BeanWithNoDefaultConstructor.class); + Flux flux = new Jackson2JsonDecoder().decode(input, elementType, null, emptyMap()); + StepVerifier.create(flux).verifyError(CodecException.class); + } + + @Test // SPR-15975 + public void customDeserializer() { + Mono input = stringBuffer("{\"test\": 1}"); + + testDecode(input, TestObject.class, step -> step + .consumeNextWith(o -> assertEquals(1, o.getTest())) + .verifyComplete() + ); + } + + @Test + public void bigDecimalFlux() { + Flux input = stringBuffer("[ 1E+2 ]").flux(); + + testDecode(input, BigDecimal.class, step -> step + .expectNext(new BigDecimal("1E+2")) + .verifyComplete() + ); + } + + @Test + public void decodeNonUtf8Encoding() { + Mono input = stringBuffer("{\"foo\":\"bar\"}", StandardCharsets.UTF_16); + + testDecode(input, ResolvableType.forType(new ParameterizedTypeReference>() {}), + step -> step.assertNext(o -> { + Map map = (Map) o; + assertEquals("bar", map.get("foo")); + }) + .verifyComplete(), + MediaType.parseMediaType("application/json; charset=utf-16"), + null); + } + + @Test + @SuppressWarnings("unchecked") + public void decodeNonUnicode() { + Flux input = Flux.concat( + stringBuffer("{\"føø\":\"bår\"}", StandardCharsets.ISO_8859_1) + ); + + testDecode(input, ResolvableType.forType(new ParameterizedTypeReference>() {}), + step -> step.assertNext(o -> { + assertTrue(o instanceof Map); + Map map = (Map) o; + assertEquals(1, map.size()); + assertEquals("bår", map.get("føø")); + }) + .verifyComplete(), + MediaType.parseMediaType("application/json; charset=iso-8859-1"), + null); + } + + @Test + public void decodeMonoNonUtf8Encoding() { + Mono input = stringBuffer("{\"foo\":\"bar\"}", StandardCharsets.UTF_16); + + testDecodeToMono(input, ResolvableType.forType(new ParameterizedTypeReference>() {}), + step -> step.assertNext(o -> { + Map map = (Map) o; + assertEquals("bar", map.get("foo")); + }) + .verifyComplete(), + MediaType.parseMediaType("application/json; charset=utf-16"), + null); + } + + @Test + @SuppressWarnings("unchecked") + public void decodeAscii() { + Flux input = Flux.concat( + stringBuffer("{\"foo\":\"bar\"}", StandardCharsets.US_ASCII) + ); + + testDecode(input, ResolvableType.forType(new ParameterizedTypeReference>() {}), + step -> step.assertNext(o -> { + Map map = (Map) o; + assertEquals("bar", map.get("foo")); + }) + .verifyComplete(), + MediaType.parseMediaType("application/json; charset=us-ascii"), + null); + } + + + private Mono stringBuffer(String value) { + return stringBuffer(value, StandardCharsets.UTF_8); + } + + private Mono stringBuffer(String value, Charset charset) { + return Mono.defer(() -> { + byte[] bytes = value.getBytes(charset); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return Mono.just(buffer); + }); + } + + + private static class BeanWithNoDefaultConstructor { + + private final String property1; + + private final String property2; + + public BeanWithNoDefaultConstructor(String property1, String property2) { + this.property1 = property1; + this.property2 = property2; + } + + public String getProperty1() { + return this.property1; + } + + public String getProperty2() { + return this.property2; + } + } + + + @JsonDeserialize(using = Deserializer.class) + public static class TestObject { + + private int test; + + public int getTest() { + return this.test; + } + public void setTest(int test) { + this.test = test; + } + } + + + public static class Deserializer extends StdDeserializer { + + private static final long serialVersionUID = 1L; + + protected Deserializer() { + super(TestObject.class); + } + + @Override + public TestObject deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + JsonNode node = p.readValueAsTree(); + TestObject result = new TestObject(); + result.setTest(node.get("test").asInt()); + return result; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2JsonEncoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2JsonEncoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ca237cb26e6019c5518ceabf83857d8672211246 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2JsonEncoderTests.java @@ -0,0 +1,254 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.JsonTypeName; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractEncoderTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.MediaType; +import org.springframework.http.codec.Pojo; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +import static java.util.Collections.singletonMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.springframework.http.MediaType.APPLICATION_JSON; +import static org.springframework.http.MediaType.APPLICATION_JSON_UTF8; +import static org.springframework.http.MediaType.APPLICATION_OCTET_STREAM; +import static org.springframework.http.MediaType.APPLICATION_STREAM_JSON; +import static org.springframework.http.MediaType.APPLICATION_XML; +import static org.springframework.http.codec.json.Jackson2JsonEncoder.JSON_VIEW_HINT; +import static org.springframework.http.codec.json.JacksonViewBean.MyJacksonView1; +import static org.springframework.http.codec.json.JacksonViewBean.MyJacksonView3; + +/** + * @author Sebastien Deleuze + */ +public class Jackson2JsonEncoderTests extends AbstractEncoderTestCase { + + + public Jackson2JsonEncoderTests() { + super(new Jackson2JsonEncoder()); + } + + @Override + @Test + public void canEncode() { + ResolvableType pojoType = ResolvableType.forClass(Pojo.class); + assertTrue(this.encoder.canEncode(pojoType, APPLICATION_JSON)); + assertTrue(this.encoder.canEncode(pojoType, APPLICATION_JSON_UTF8)); + assertTrue(this.encoder.canEncode(pojoType, APPLICATION_STREAM_JSON)); + assertTrue(this.encoder.canEncode(pojoType, null)); + + assertTrue(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), + new MediaType("application", "json", StandardCharsets.UTF_8))); + assertTrue(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), + new MediaType("application", "json", StandardCharsets.US_ASCII))); + assertFalse(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), + new MediaType("application", "json", StandardCharsets.ISO_8859_1))); + + // SPR-15464 + assertTrue(this.encoder.canEncode(ResolvableType.NONE, null)); + + // SPR-15910 + assertFalse(this.encoder.canEncode(ResolvableType.forClass(Object.class), APPLICATION_OCTET_STREAM)); + } + + @Override + public void encode() throws Exception { + Flux input = Flux.just(new Pojo("foo", "bar"), + new Pojo("foofoo", "barbar"), + new Pojo("foofoofoo", "barbarbar")); + + testEncodeAll(input, ResolvableType.forClass(Pojo.class), step -> step + .consumeNextWith(expectString("{\"foo\":\"foo\",\"bar\":\"bar\"}\n")) + .consumeNextWith(expectString("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}\n")) + .consumeNextWith(expectString("{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}\n")) + .verifyComplete(), + APPLICATION_STREAM_JSON, null); + } + + @Test // SPR-15866 + public void canEncodeWithCustomMimeType() { + MimeType textJavascript = new MimeType("text", "javascript", StandardCharsets.UTF_8); + Jackson2JsonEncoder encoder = new Jackson2JsonEncoder(new ObjectMapper(), textJavascript); + + assertEquals(Collections.singletonList(textJavascript), encoder.getEncodableMimeTypes()); + } + + @Test(expected = UnsupportedOperationException.class) + public void encodableMimeTypesIsImmutable() { + MimeType textJavascript = new MimeType("text", "javascript", StandardCharsets.UTF_8); + Jackson2JsonEncoder encoder = new Jackson2JsonEncoder(new ObjectMapper(), textJavascript); + + encoder.getMimeTypes().add(new MimeType("text", "ecmascript")); + } + + @Test + public void canNotEncode() { + assertFalse(this.encoder.canEncode(ResolvableType.forClass(String.class), null)); + assertFalse(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), APPLICATION_XML)); + + ResolvableType sseType = ResolvableType.forClass(ServerSentEvent.class); + assertFalse(this.encoder.canEncode(sseType, APPLICATION_JSON)); + } + + @Test + public void encodeNonStream() { + Flux input = Flux.just( + new Pojo("foo", "bar"), + new Pojo("foofoo", "barbar"), + new Pojo("foofoofoo", "barbarbar") + ); + + testEncode(input, Pojo.class, step -> step + .consumeNextWith(expectString("[" + + "{\"foo\":\"foo\",\"bar\":\"bar\"}," + + "{\"foo\":\"foofoo\",\"bar\":\"barbar\"}," + + "{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}]") + .andThen(DataBufferUtils::release)) + .verifyComplete()); + } + + @Test + public void encodeWithType() { + Flux input = Flux.just(new Foo(), new Bar()); + + testEncode(input, ParentClass.class, step -> step + .consumeNextWith(expectString("[{\"type\":\"foo\"},{\"type\":\"bar\"}]") + .andThen(DataBufferUtils::release)) + .verifyComplete()); + } + + + @Test // SPR-15727 + public void encodeAsStreamWithCustomStreamingType() { + MediaType fooMediaType = new MediaType("application", "foo"); + MediaType barMediaType = new MediaType("application", "bar"); + this.encoder.setStreamingMediaTypes(Arrays.asList(fooMediaType, barMediaType)); + Flux input = Flux.just( + new Pojo("foo", "bar"), + new Pojo("foofoo", "barbar"), + new Pojo("foofoofoo", "barbarbar") + ); + + testEncode(input, ResolvableType.forClass(Pojo.class), step -> step + .consumeNextWith(expectString("{\"foo\":\"foo\",\"bar\":\"bar\"}\n") + .andThen(DataBufferUtils::release)) + .consumeNextWith(expectString("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}\n") + .andThen(DataBufferUtils::release)) + .consumeNextWith(expectString("{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}\n") + .andThen(DataBufferUtils::release)) + .verifyComplete(), + barMediaType, null); + } + + @Test + public void fieldLevelJsonView() { + JacksonViewBean bean = new JacksonViewBean(); + bean.setWithView1("with"); + bean.setWithView2("with"); + bean.setWithoutView("without"); + Mono input = Mono.just(bean); + + ResolvableType type = ResolvableType.forClass(JacksonViewBean.class); + Map hints = singletonMap(JSON_VIEW_HINT, MyJacksonView1.class); + + testEncode(input, type, step -> step + .consumeNextWith(expectString("{\"withView1\":\"with\"}") + .andThen(DataBufferUtils::release)) + .verifyComplete(), + null, hints); + } + + @Test + public void classLevelJsonView() { + JacksonViewBean bean = new JacksonViewBean(); + bean.setWithView1("with"); + bean.setWithView2("with"); + bean.setWithoutView("without"); + Mono input = Mono.just(bean); + + ResolvableType type = ResolvableType.forClass(JacksonViewBean.class); + Map hints = singletonMap(JSON_VIEW_HINT, MyJacksonView3.class); + + testEncode(input, type, step -> step + .consumeNextWith(expectString("{\"withoutView\":\"without\"}") + .andThen(DataBufferUtils::release)) + .verifyComplete(), + null, hints); + } + + @Test // gh-22771 + public void encodeWithFlushAfterWriteOff() { + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(SerializationFeature.FLUSH_AFTER_WRITE_VALUE, false); + Jackson2JsonEncoder encoder = new Jackson2JsonEncoder(mapper); + + Flux result = encoder.encode(Flux.just(new Pojo("foo", "bar")), this.bufferFactory, + ResolvableType.forClass(Pojo.class), MimeTypeUtils.APPLICATION_JSON, Collections.emptyMap()); + + StepVerifier.create(result) + .consumeNextWith(expectString("[{\"foo\":\"foo\",\"bar\":\"bar\"}]")) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + @Test + public void encodeAscii() { + Mono input = Mono.just(new Pojo("foo", "bar")); + + testEncode(input, ResolvableType.forClass(Pojo.class), step -> step + .consumeNextWith(expectString("{\"foo\":\"foo\",\"bar\":\"bar\"}")) + .verifyComplete(), + new MimeType("application", "json", StandardCharsets.US_ASCII), null); + + } + + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") + private static class ParentClass { + } + + @JsonTypeName("foo") + private static class Foo extends ParentClass { + } + + @JsonTypeName("bar") + private static class Bar extends ParentClass { + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2SmileDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2SmileDecoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..84ec4fcda694a677a95f5c53366b99c66d0b3da1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2SmileDecoderTests.java @@ -0,0 +1,108 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.util.Arrays; +import java.util.List; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Test; +import reactor.core.publisher.Flux; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDecoderTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.codec.Pojo; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.util.MimeType; + +import static org.junit.Assert.*; +import static org.springframework.core.ResolvableType.forClass; +import static org.springframework.http.MediaType.APPLICATION_JSON; + +/** + * Unit tests for {@link Jackson2SmileDecoder}. + * + * @author Sebastien Deleuze + */ +public class Jackson2SmileDecoderTests extends AbstractDecoderTestCase { + + private final static MimeType SMILE_MIME_TYPE = new MimeType("application", "x-jackson-smile"); + private final static MimeType STREAM_SMILE_MIME_TYPE = new MimeType("application", "stream+x-jackson-smile"); + + private Pojo pojo1 = new Pojo("f1", "b1"); + + private Pojo pojo2 = new Pojo("f2", "b2"); + + private ObjectMapper mapper = Jackson2ObjectMapperBuilder.smile().build(); + + public Jackson2SmileDecoderTests() { + super(new Jackson2SmileDecoder()); + } + + @Override + @Test + public void canDecode() { + assertTrue(decoder.canDecode(forClass(Pojo.class), SMILE_MIME_TYPE)); + assertTrue(decoder.canDecode(forClass(Pojo.class), STREAM_SMILE_MIME_TYPE)); + assertTrue(decoder.canDecode(forClass(Pojo.class), null)); + + assertFalse(decoder.canDecode(forClass(String.class), null)); + assertFalse(decoder.canDecode(forClass(Pojo.class), APPLICATION_JSON)); + } + + + @Override + public void decode() { + Flux input = Flux.just(this.pojo1, this.pojo2) + .map(this::writeObject) + .flatMap(this::dataBuffer); + + testDecodeAll(input, Pojo.class, step -> step + .expectNext(pojo1) + .expectNext(pojo2) + .verifyComplete()); + + } + + private byte[] writeObject(Object o) { + try { + return this.mapper.writer().writeValueAsBytes(o); + } + catch (JsonProcessingException e) { + throw new AssertionError(e); + } + + } + + @Override + public void decodeToMono() { + List expected = Arrays.asList(pojo1, pojo2); + + Flux input = Flux.just(expected) + .map(this::writeObject) + .flatMap(this::dataBuffer); + + ResolvableType elementType = ResolvableType.forClassWithGenerics(List.class, Pojo.class); + testDecodeToMono(input, elementType, step -> step + .expectNext(expected) + .expectComplete() + .verify(), null, null); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2SmileEncoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2SmileEncoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..46e19e7bf659e45ae3deb237aaeec06d4e72b513 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2SmileEncoderTests.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.List; + +import com.fasterxml.jackson.databind.MappingIterator; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractEncoderTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.codec.Pojo; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; +import org.springframework.util.MimeType; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.springframework.core.io.buffer.DataBufferUtils.release; +import static org.springframework.http.MediaType.APPLICATION_XML; + +/** + * Unit tests for {@link Jackson2SmileEncoder}. + * + * @author Sebastien Deleuze + */ +public class Jackson2SmileEncoderTests extends AbstractEncoderTestCase { + + private final static MimeType SMILE_MIME_TYPE = new MimeType("application", "x-jackson-smile"); + private final static MimeType STREAM_SMILE_MIME_TYPE = new MimeType("application", "stream+x-jackson-smile"); + + private final Jackson2SmileEncoder encoder = new Jackson2SmileEncoder(); + + private final ObjectMapper mapper = Jackson2ObjectMapperBuilder.smile().build(); + + public Jackson2SmileEncoderTests() { + super(new Jackson2SmileEncoder()); + + } + + @Override + @Test + public void canEncode() { + ResolvableType pojoType = ResolvableType.forClass(Pojo.class); + assertTrue(this.encoder.canEncode(pojoType, SMILE_MIME_TYPE)); + assertTrue(this.encoder.canEncode(pojoType, STREAM_SMILE_MIME_TYPE)); + assertTrue(this.encoder.canEncode(pojoType, null)); + + // SPR-15464 + assertTrue(this.encoder.canEncode(ResolvableType.NONE, null)); + } + + @Test + public void canNotEncode() { + assertFalse(this.encoder.canEncode(ResolvableType.forClass(String.class), null)); + assertFalse(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), APPLICATION_XML)); + + ResolvableType sseType = ResolvableType.forClass(ServerSentEvent.class); + assertFalse(this.encoder.canEncode(sseType, SMILE_MIME_TYPE)); + } + + @Override + @Test + public void encode() { + List list = Arrays.asList( + new Pojo("foo", "bar"), + new Pojo("foofoo", "barbar"), + new Pojo("foofoofoo", "barbarbar")); + + Flux input = Flux.fromIterable(list); + + testEncode(input, Pojo.class, step -> step + .consumeNextWith(dataBuffer -> { + try { + Object actual = this.mapper.reader().forType(List.class) + .readValue(dataBuffer.asInputStream()); + assertEquals(list, actual); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + finally { + release(dataBuffer); + } + })); + } + + @Test + public void encodeError() throws Exception { + Mono input = Mono.error(new InputException()); + + testEncode(input, Pojo.class, step -> step + .expectError(InputException.class) + .verify()); + + } + + @Test + public void encodeAsStream() throws Exception { + Pojo pojo1 = new Pojo("foo", "bar"); + Pojo pojo2 = new Pojo("foofoo", "barbar"); + Pojo pojo3 = new Pojo("foofoofoo", "barbarbar"); + Flux input = Flux.just(pojo1, pojo2, pojo3); + ResolvableType type = ResolvableType.forClass(Pojo.class); + + Flux result = this.encoder + .encode(input, bufferFactory, type, STREAM_SMILE_MIME_TYPE, null); + + Mono> joined = DataBufferUtils.join(result) + .map(buffer -> { + try { + return this.mapper.reader().forType(Pojo.class).readValues(buffer.asInputStream(true)); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + }); + + StepVerifier.create(joined) + .assertNext(iter -> { + assertTrue(iter.hasNext()); + assertEquals(pojo1, iter.next()); + assertTrue(iter.hasNext()); + assertEquals(pojo2, iter.next()); + assertTrue(iter.hasNext()); + assertEquals(pojo3, iter.next()); + assertFalse(iter.hasNext()); + }) + .verifyComplete(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2TokenizerTests.java b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2TokenizerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9b6e90d1a32801a211c7320f80cef6be04f013cc --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2TokenizerTests.java @@ -0,0 +1,324 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.core.TreeNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.util.TokenBuffer; +import org.json.JSONException; +import org.junit.Before; +import org.junit.Test; +import org.skyscreamer.jsonassert.JSONAssert; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import org.springframework.core.codec.DecodingException; +import org.springframework.core.io.buffer.AbstractLeakCheckingTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class Jackson2TokenizerTests extends AbstractLeakCheckingTestCase { + + private JsonFactory jsonFactory; + + private ObjectMapper objectMapper; + + + @Before + public void createParser() { + this.jsonFactory = new JsonFactory(); + this.objectMapper = new ObjectMapper(this.jsonFactory); + } + + + @Test + public void doNotTokenizeArrayElements() { + testTokenize( + singletonList("{\"foo\": \"foofoo\", \"bar\": \"barbar\"}"), + singletonList("{\"foo\": \"foofoo\", \"bar\": \"barbar\"}"), false); + + testTokenize( + asList("{\"foo\": \"foofoo\"", + ", \"bar\": \"barbar\"}"), + singletonList("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}"), false); + + testTokenize( + singletonList("[" + + "{\"foo\": \"foofoo\", \"bar\": \"barbar\"}," + + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}]"), + singletonList("[" + + "{\"foo\": \"foofoo\", \"bar\": \"barbar\"}," + + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}]"), false); + + testTokenize( + singletonList("[{\"foo\": \"bar\"},{\"foo\": \"baz\"}]"), + singletonList("[{\"foo\": \"bar\"},{\"foo\": \"baz\"}]"), false); + + testTokenize( + asList("[" + + "{\"foo\": \"foofoo\", \"bar\"", ": \"barbar\"}," + + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}]"), + singletonList("[" + + "{\"foo\": \"foofoo\", \"bar\": \"barbar\"}," + + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}]"), false); + + testTokenize( + asList("[", + "{\"id\":1,\"name\":\"Robert\"}", ",", + "{\"id\":2,\"name\":\"Raide\"}", ",", + "{\"id\":3,\"name\":\"Ford\"}", "]"), + singletonList("[" + + "{\"id\":1,\"name\":\"Robert\"}," + + "{\"id\":2,\"name\":\"Raide\"}," + + "{\"id\":3,\"name\":\"Ford\"}]"), false); + + // SPR-16166: top-level JSON values + testTokenize(asList("\"foo", "bar\""),singletonList("\"foobar\""), false); + + testTokenize(asList("12", "34"),singletonList("1234"), false); + + testTokenize(asList("12.", "34"),singletonList("12.34"), false); + + // note that we do not test for null, true, or false, which are also valid top-level values, + // but are unsupported by JSONassert + } + + @Test + public void tokenizeArrayElements() { + testTokenize( + singletonList("{\"foo\": \"foofoo\", \"bar\": \"barbar\"}"), + singletonList("{\"foo\": \"foofoo\", \"bar\": \"barbar\"}"), true); + + testTokenize( + asList("{\"foo\": \"foofoo\"", ", \"bar\": \"barbar\"}"), + singletonList("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}"), true); + + testTokenize( + singletonList("[" + + "{\"foo\": \"foofoo\", \"bar\": \"barbar\"}," + + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}]"), + asList( + "{\"foo\": \"foofoo\", \"bar\": \"barbar\"}", + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}"), true); + + testTokenize( + singletonList("[{\"foo\": \"bar\"},{\"foo\": \"baz\"}]"), + asList("{\"foo\": \"bar\"}", "{\"foo\": \"baz\"}"), true); + + // SPR-15803: nested array + testTokenize( + singletonList("[" + + "{\"id\":\"0\",\"start\":[-999999999,1,1],\"end\":[999999999,12,31]}," + + "{\"id\":\"1\",\"start\":[-999999999,1,1],\"end\":[999999999,12,31]}," + + "{\"id\":\"2\",\"start\":[-999999999,1,1],\"end\":[999999999,12,31]}" + + "]"), + asList( + "{\"id\":\"0\",\"start\":[-999999999,1,1],\"end\":[999999999,12,31]}", + "{\"id\":\"1\",\"start\":[-999999999,1,1],\"end\":[999999999,12,31]}", + "{\"id\":\"2\",\"start\":[-999999999,1,1],\"end\":[999999999,12,31]}"), true); + + // SPR-15803: nested array, no top-level array + testTokenize( + singletonList("{\"speakerIds\":[\"tastapod\"],\"language\":\"ENGLISH\"}"), + singletonList("{\"speakerIds\":[\"tastapod\"],\"language\":\"ENGLISH\"}"), true); + + testTokenize( + asList("[" + + "{\"foo\": \"foofoo\", \"bar\"", ": \"barbar\"}," + + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}]"), + asList( + "{\"foo\": \"foofoo\", \"bar\": \"barbar\"}", + "{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}"), true); + + testTokenize( + asList("[", + "{\"id\":1,\"name\":\"Robert\"}", + ",", + "{\"id\":2,\"name\":\"Raide\"}", + ",", + "{\"id\":3,\"name\":\"Ford\"}", + "]"), + asList("{\"id\":1,\"name\":\"Robert\"}", + "{\"id\":2,\"name\":\"Raide\"}", + "{\"id\":3,\"name\":\"Ford\"}"), true); + + // SPR-16166: top-level JSON values + testTokenize(asList("\"foo", "bar\""),singletonList("\"foobar\""), true); + + testTokenize(asList("12", "34"),singletonList("1234"), true); + + testTokenize(asList("12.", "34"),singletonList("12.34"), true); + + // SPR-16407 + testTokenize(asList("[1", ",2,", "3]"), asList("1", "2", "3"), true); + } + + private void testTokenize(List input, List output, boolean tokenize) { + StepVerifier.FirstStep builder = StepVerifier.create(decode(input, tokenize, -1)); + output.forEach(expected -> builder.assertNext(actual -> { + try { + JSONAssert.assertEquals(expected, actual, true); + } + catch (JSONException ex) { + throw new RuntimeException(ex); + } + })); + builder.verifyComplete(); + } + + @Test + public void testLimit() { + + List source = asList("[", + "{", "\"id\":1,\"name\":\"Dan\"", "},", + "{", "\"id\":2,\"name\":\"Ron\"", "},", + "{", "\"id\":3,\"name\":\"Bartholomew\"", "}", + "]"); + + String expected = String.join("", source); + int maxInMemorySize = expected.length(); + + StepVerifier.create(decode(source, false, maxInMemorySize)) + .expectNext(expected) + .verifyComplete(); + + StepVerifier.create(decode(source, false, maxInMemorySize - 2)) + .verifyError(DataBufferLimitException.class); + } + + @Test + public void testLimitTokenized() { + + List source = asList("[", + "{", "\"id\":1, \"name\":\"Dan\"", "},", + "{", "\"id\":2, \"name\":\"Ron\"", "},", + "{", "\"id\":3, \"name\":\"Bartholomew\"", "}", + "]"); + + String expected = "{\"id\":3,\"name\":\"Bartholomew\"}"; + int maxInMemorySize = expected.length(); + + StepVerifier.create(decode(source, true, maxInMemorySize)) + .expectNext("{\"id\":1,\"name\":\"Dan\"}") + .expectNext("{\"id\":2,\"name\":\"Ron\"}") + .expectNext(expected) + .verifyComplete(); + + StepVerifier.create(decode(source, true, maxInMemorySize - 1)) + .expectNext("{\"id\":1,\"name\":\"Dan\"}") + .expectNext("{\"id\":2,\"name\":\"Ron\"}") + .verifyError(DataBufferLimitException.class); + } + + @Test + public void errorInStream() { + DataBuffer buffer = stringBuffer("{\"id\":1,\"name\":"); + Flux source = Flux.just(buffer).concatWith(Flux.error(new RuntimeException())); + Flux result = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, true, + false, -1); + + StepVerifier.create(result) + .expectError(RuntimeException.class) + .verify(); + } + + @Test // SPR-16521 + public void jsonEOFExceptionIsWrappedAsDecodingError() { + Flux source = Flux.just(stringBuffer("{\"status\": \"noClosingQuote}")); + Flux tokens = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, false, + false, -1); + + StepVerifier.create(tokens) + .expectError(DecodingException.class) + .verify(); + } + + @Test + public void useBigDecimalForFloats() { + for (boolean useBigDecimalForFloats : Arrays.asList(false, true)) { + Flux source = Flux.just(stringBuffer("1E+2")); + Flux tokens = + Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, false, + useBigDecimalForFloats, -1); + + StepVerifier.create(tokens) + .assertNext(tokenBuffer -> { + try { + JsonParser parser = tokenBuffer.asParser(); + JsonToken token = parser.nextToken(); + assertEquals(JsonToken.VALUE_NUMBER_FLOAT, token); + JsonParser.NumberType numberType = parser.getNumberType(); + if (useBigDecimalForFloats) { + assertEquals(JsonParser.NumberType.BIG_DECIMAL, numberType); + } + else { + assertEquals(JsonParser.NumberType.DOUBLE, numberType); + } + } + catch (IOException ex) { + fail(ex.getMessage()); + } + }) + .verifyComplete(); + } + } + + private Flux decode(List source, boolean tokenize, int maxInMemorySize) { + + Flux tokens = Jackson2Tokenizer.tokenize( + Flux.fromIterable(source).map(this::stringBuffer), + this.jsonFactory, this.objectMapper, tokenize, false, maxInMemorySize); + + return tokens + .map(tokenBuffer -> { + try { + TreeNode root = this.objectMapper.readTree(tokenBuffer.asParser()); + return this.objectMapper.writeValueAsString(root); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + }); + } + + private DataBuffer stringBuffer(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/json/JacksonViewBean.java b/spring-web/src/test/java/org/springframework/http/codec/json/JacksonViewBean.java new file mode 100644 index 0000000000000000000000000000000000000000..4309eaf5bfb449374f2e66655375f2c048c6ee83 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/json/JacksonViewBean.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.json; + +import com.fasterxml.jackson.annotation.JsonView; + +/** + * @author Sebastien Deleuze + */ +@JsonView(JacksonViewBean.MyJacksonView3.class) +class JacksonViewBean { + + interface MyJacksonView1 {} + interface MyJacksonView2 {} + interface MyJacksonView3 {} + + @JsonView(MyJacksonView1.class) + private String withView1; + + @JsonView(MyJacksonView2.class) + private String withView2; + + private String withoutView; + + public String getWithView1() { + return withView1; + } + + public void setWithView1(String withView1) { + this.withView1 = withView1; + } + + public String getWithView2() { + return withView2; + } + + public void setWithView2(String withView2) { + this.withView2 = withView2; + } + + public String getWithoutView() { + return withoutView; + } + + public void setWithoutView(String withoutView) { + this.withoutView = withoutView; + } +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e82cae764f0ea766cdb590cbf1d3179cbd55eaff --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -0,0 +1,286 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.UnicastProcessor; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.AbstractLeakCheckingTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpEntity; +import org.springframework.http.MediaType; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + +/** + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCase { + + private final MultipartHttpMessageWriter writer = + new MultipartHttpMessageWriter(ClientCodecConfigurer.create().getWriters()); + + private MockServerHttpResponse response; + + + @Before + public void setUp() { + this.response = new MockServerHttpResponse(this.bufferFactory); + } + + + @Test + public void canWrite() { + assertTrue(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + MediaType.MULTIPART_FORM_DATA)); + assertTrue(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.MULTIPART_FORM_DATA)); + + assertFalse(this.writer.canWrite( + ResolvableType.forClassWithGenerics(Map.class, String.class, Object.class), + MediaType.MULTIPART_FORM_DATA)); + assertTrue(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + MediaType.APPLICATION_FORM_URLENCODED)); + } + + @Test + public void writeMultipart() throws Exception { + + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + Resource utf8 = new ClassPathResource("/org/springframework/http/converter/logo.jpg") { + @Override + public String getFilename() { + // SPR-12108 + return "Hall\u00F6le.jpg"; + } + }; + + Publisher publisher = Flux.just("foo", "bar", "baz"); + + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name 1", "value 1"); + bodyBuilder.part("name 2", "value 2+1"); + bodyBuilder.part("name 2", "value 2+2"); + bodyBuilder.part("logo", logo); + bodyBuilder.part("utf8", utf8); + bodyBuilder.part("json", new Foo("bar"), MediaType.APPLICATION_JSON_UTF8); + bodyBuilder.asyncPart("publisher", publisher, String.class); + Mono>> result = Mono.just(bodyBuilder.build()); + + Map hints = Collections.emptyMap(); + this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, hints).block(Duration.ofSeconds(5)); + + MultiValueMap requestParts = parse(hints); + assertEquals(6, requestParts.size()); + + Part part = requestParts.getFirst("name 1"); + assertTrue(part instanceof FormFieldPart); + assertEquals("name 1", part.name()); + assertEquals("value 1", ((FormFieldPart) part).value()); + + List parts2 = requestParts.get("name 2"); + assertEquals(2, parts2.size()); + part = parts2.get(0); + assertTrue(part instanceof FormFieldPart); + assertEquals("name 2", part.name()); + assertEquals("value 2+1", ((FormFieldPart) part).value()); + part = parts2.get(1); + assertTrue(part instanceof FormFieldPart); + assertEquals("name 2", part.name()); + assertEquals("value 2+2", ((FormFieldPart) part).value()); + + part = requestParts.getFirst("logo"); + assertTrue(part instanceof FilePart); + assertEquals("logo", part.name()); + assertEquals("logo.jpg", ((FilePart) part).filename()); + assertEquals(MediaType.IMAGE_JPEG, part.headers().getContentType()); + assertEquals(logo.getFile().length(), part.headers().getContentLength()); + + part = requestParts.getFirst("utf8"); + assertTrue(part instanceof FilePart); + assertEquals("utf8", part.name()); + assertEquals("Hall\u00F6le.jpg", ((FilePart) part).filename()); + assertEquals(MediaType.IMAGE_JPEG, part.headers().getContentType()); + assertEquals(utf8.getFile().length(), part.headers().getContentLength()); + + part = requestParts.getFirst("json"); + assertEquals("json", part.name()); + assertEquals(MediaType.APPLICATION_JSON_UTF8, part.headers().getContentType()); + + String value = StringDecoder.textPlainOnly(false).decodeToMono(part.content(), + ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, + Collections.emptyMap()).block(Duration.ZERO); + + assertEquals("{\"bar\":\"bar\"}", value); + + part = requestParts.getFirst("publisher"); + assertEquals("publisher", part.name()); + + value = StringDecoder.textPlainOnly(false).decodeToMono(part.content(), + ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, + Collections.emptyMap()).block(Duration.ZERO); + + assertEquals("foobarbaz", value); + } + + @Test // SPR-16402 + public void singleSubscriberWithResource() throws IOException { + UnicastProcessor processor = UnicastProcessor.create(); + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + Mono.just(logo).subscribe(processor); + + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.asyncPart("logo", processor, Resource.class); + + Mono>> result = Mono.just(bodyBuilder.build()); + + Map hints = Collections.emptyMap(); + this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, hints).block(); + + MultiValueMap requestParts = parse(hints); + assertEquals(1, requestParts.size()); + + Part part = requestParts.getFirst("logo"); + assertEquals("logo", part.name()); + assertTrue(part instanceof FilePart); + assertEquals("logo.jpg", ((FilePart) part).filename()); + assertEquals(MediaType.IMAGE_JPEG, part.headers().getContentType()); + assertEquals(logo.getFile().length(), part.headers().getContentLength()); + } + + @Test // SPR-16402 + public void singleSubscriberWithStrings() { + UnicastProcessor processor = UnicastProcessor.create(); + Flux.just("foo", "bar", "baz").subscribe(processor); + + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.asyncPart("name", processor, String.class); + + Mono>> result = Mono.just(bodyBuilder.build()); + + this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, Collections.emptyMap()) + .block(Duration.ofSeconds(5)); + + // Make sure body is consumed to avoid leak reports + this.response.getBodyAsString().block(Duration.ofSeconds(5)); + } + + @Test // SPR-16376 + public void customContentDisposition() throws IOException { + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + Flux buffers = DataBufferUtils.read(logo, new DefaultDataBufferFactory(), 1024); + long contentLength = logo.contentLength(); + + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("resource", logo) + .headers(h -> h.setContentDispositionFormData("resource", "spring.jpg")); + bodyBuilder.asyncPart("buffers", buffers, DataBuffer.class) + .headers(h -> { + h.setContentDispositionFormData("buffers", "buffers.jpg"); + h.setContentType(MediaType.IMAGE_JPEG); + h.setContentLength(contentLength); + }); + + MultiValueMap> multipartData = bodyBuilder.build(); + + Map hints = Collections.emptyMap(); + this.writer.write(Mono.just(multipartData), null, MediaType.MULTIPART_FORM_DATA, + this.response, hints).block(); + + MultiValueMap requestParts = parse(hints); + assertEquals(2, requestParts.size()); + + Part part = requestParts.getFirst("resource"); + assertTrue(part instanceof FilePart); + assertEquals("spring.jpg", ((FilePart) part).filename()); + assertEquals(logo.getFile().length(), part.headers().getContentLength()); + + part = requestParts.getFirst("buffers"); + assertTrue(part instanceof FilePart); + assertEquals("buffers.jpg", ((FilePart) part).filename()); + assertEquals(logo.getFile().length(), part.headers().getContentLength()); + } + + private MultiValueMap parse(Map hints) { + MediaType contentType = this.response.getHeaders().getContentType(); + assertNotNull("No boundary found", contentType.getParameter("boundary")); + + // see if Synchronoss NIO Multipart can read what we wrote + SynchronossPartHttpMessageReader synchronossReader = new SynchronossPartHttpMessageReader(); + MultipartHttpMessageReader reader = new MultipartHttpMessageReader(synchronossReader); + + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .contentType(MediaType.parseMediaType(contentType.toString())) + .body(this.response.getBody()); + + ResolvableType elementType = ResolvableType.forClassWithGenerics( + MultiValueMap.class, String.class, Part.class); + + MultiValueMap result = reader.readMono(elementType, request, hints) + .block(Duration.ofSeconds(5)); + + assertNotNull(result); + return result; + } + + + private class Foo { + + private String bar; + + public Foo() { + } + + public Foo(String bar) { + this.bar = bar; + } + + public String getBar() { + return this.bar; + } + + public void setBar(String bar) { + this.bar = bar; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0b2884450aa7a9f23ca9cb5f7be1c0c9bcc6a83d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java @@ -0,0 +1,238 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.io.File; +import java.io.IOException; +import java.nio.channels.ReadableByteChannel; +import java.time.Duration; +import java.util.Map; +import java.util.function.Consumer; + +import org.junit.Test; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.client.reactive.test.MockClientHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.util.MultiValueMap; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.springframework.core.ResolvableType.forClassWithGenerics; +import static org.springframework.http.HttpHeaders.CONTENT_TYPE; +import static org.springframework.http.MediaType.MULTIPART_FORM_DATA; + +/** + * Unit tests for {@link SynchronossPartHttpMessageReader}. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Brian Clozel + */ +public class SynchronossPartHttpMessageReaderTests { + + private final MultipartHttpMessageReader reader = + new MultipartHttpMessageReader(new SynchronossPartHttpMessageReader()); + + private static final ResolvableType PARTS_ELEMENT_TYPE = + forClassWithGenerics(MultiValueMap.class, String.class, Part.class); + + @Test + public void canRead() { + assertTrue(this.reader.canRead( + forClassWithGenerics(MultiValueMap.class, String.class, Part.class), + MediaType.MULTIPART_FORM_DATA)); + + assertFalse(this.reader.canRead( + forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + MediaType.MULTIPART_FORM_DATA)); + + assertFalse(this.reader.canRead( + forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.MULTIPART_FORM_DATA)); + + assertFalse(this.reader.canRead( + forClassWithGenerics(Map.class, String.class, String.class), + MediaType.MULTIPART_FORM_DATA)); + + assertFalse(this.reader.canRead( + forClassWithGenerics(MultiValueMap.class, String.class, Part.class), + MediaType.APPLICATION_FORM_URLENCODED)); + } + + @Test + public void resolveParts() { + ServerHttpRequest request = generateMultipartRequest(); + ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class); + MultiValueMap parts = this.reader.readMono(elementType, request, emptyMap()).block(); + assertEquals(2, parts.size()); + + assertTrue(parts.containsKey("filePart")); + Part part = parts.getFirst("filePart"); + assertTrue(part instanceof FilePart); + assertEquals("filePart", part.name()); + assertEquals("foo.txt", ((FilePart) part).filename()); + DataBuffer buffer = DataBufferUtils.join(part.content()).block(); + assertEquals(12, buffer.readableByteCount()); + byte[] byteContent = new byte[12]; + buffer.read(byteContent); + assertEquals("Lorem Ipsum.", new String(byteContent)); + + assertTrue(parts.containsKey("textPart")); + part = parts.getFirst("textPart"); + assertTrue(part instanceof FormFieldPart); + assertEquals("textPart", part.name()); + assertEquals("sample-text", ((FormFieldPart) part).value()); + } + + @Test // SPR-16545 + public void transferTo() throws IOException { + ServerHttpRequest request = generateMultipartRequest(); + MultiValueMap parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block(); + + assertNotNull(parts); + FilePart part = (FilePart) parts.getFirst("filePart"); + assertNotNull(part); + + File dest = File.createTempFile(part.filename(), "multipart"); + part.transferTo(dest).block(Duration.ofSeconds(5)); + + assertTrue(dest.exists()); + assertEquals(12, dest.length()); + assertTrue(dest.delete()); + } + + @Test + public void bodyError() { + ServerHttpRequest request = generateErrorMultipartRequest(); + StepVerifier.create(this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap())).verifyError(); + } + + @Test + public void readPartsWithoutDemand() { + ServerHttpRequest request = generateMultipartRequest(); + Mono> parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()); + ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber(); + parts.subscribe(subscriber); + subscriber.cancel(); + } + + @Test + public void gh23768() throws IOException { + ReadableByteChannel channel = new ClassPathResource("invalid.multipart", getClass()).readableChannel(); + Flux body = DataBufferUtils.readByteChannel(() -> channel, new DefaultDataBufferFactory(), 1024); + + MediaType contentType = new MediaType("multipart", "form-data", + singletonMap("boundary", "NbjrKgjbsaMLdnMxMfDpD6myWomYc0qNX0w")); + ServerHttpRequest request = MockServerHttpRequest.post("/") + .contentType(contentType) + .body(body); + + Mono> parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()); + + StepVerifier.create(parts) + .assertNext(result -> assertTrue(result.isEmpty())) + .verifyComplete(); + } + + @Test + public void readTooManyParts() { + testMultipartExceptions(reader -> reader.setMaxParts(1), ex -> { + assertEquals(DecodingException.class, ex.getClass()); + assertThat(ex.getMessage(), startsWith("Failure while parsing part[2]")); + assertEquals("Too many parts (2 allowed)", ex.getCause().getMessage()); + }); + } + + @Test + public void readFilePartTooBig() { + testMultipartExceptions(reader -> reader.setMaxDiskUsagePerPart(5), ex -> { + assertEquals(DecodingException.class, ex.getClass()); + assertThat(ex.getMessage(), startsWith("Failure while parsing part[1]")); + assertEquals("Part[1] exceeded the disk usage limit of 5 bytes", ex.getCause().getMessage()); + }); + } + + @Test + public void readPartHeadersTooBig() { + testMultipartExceptions(reader -> reader.setMaxInMemorySize(1), ex -> { + assertEquals(DecodingException.class, ex.getClass()); + assertThat(ex.getMessage(), startsWith("Failure while parsing part[1]")); + assertEquals("Part[1] exceeded the in-memory limit of 1 bytes", ex.getCause().getMessage()); + }); + } + + private void testMultipartExceptions( + Consumer configurer, Consumer assertions) { + + SynchronossPartHttpMessageReader reader = new SynchronossPartHttpMessageReader(); + configurer.accept(reader); + MultipartHttpMessageReader multipartReader = new MultipartHttpMessageReader(reader); + StepVerifier.create(multipartReader.readMono(PARTS_ELEMENT_TYPE, generateMultipartRequest(), emptyMap())) + .consumeErrorWith(assertions) + .verify(); + } + + private ServerHttpRequest generateMultipartRequest() { + MultipartBodyBuilder partsBuilder = new MultipartBodyBuilder(); + partsBuilder.part("filePart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + partsBuilder.part("textPart", "sample-text"); + + MockClientHttpRequest outputMessage = new MockClientHttpRequest(HttpMethod.POST, "/"); + new MultipartHttpMessageWriter() + .write(Mono.just(partsBuilder.build()), null, MediaType.MULTIPART_FORM_DATA, outputMessage, null) + .block(Duration.ofSeconds(5)); + return MockServerHttpRequest.post("/") + .contentType(outputMessage.getHeaders().getContentType()) + .body(outputMessage.getBody()); + } + + private ServerHttpRequest generateErrorMultipartRequest() { + return MockServerHttpRequest.post("/") + .header(CONTENT_TYPE, MULTIPART_FORM_DATA.toString()) + .body(Flux.just(new DefaultDataBufferFactory().wrap("invalid content".getBytes()))); + } + + private static class ZeroDemandSubscriber extends BaseSubscriber> { + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..07e944fc38bfa136102595092865ba85d82b78e4 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java @@ -0,0 +1,235 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.protobuf; + +import java.io.IOException; +import java.util.Arrays; + +import com.google.protobuf.Message; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDecoderTestCase; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.MediaType; +import org.springframework.protobuf.Msg; +import org.springframework.protobuf.SecondMsg; +import org.springframework.util.MimeType; + +import static java.util.Collections.*; +import static org.junit.Assert.*; +import static org.springframework.core.ResolvableType.*; +import static org.springframework.core.io.buffer.DataBufferUtils.*; + +/** + * Unit tests for {@link ProtobufDecoder}. + * + * @author Sebastien Deleuze + */ +public class ProtobufDecoderTests extends AbstractDecoderTestCase { + + private final static MimeType PROTOBUF_MIME_TYPE = new MimeType("application", "x-protobuf"); + + private final SecondMsg secondMsg = SecondMsg.newBuilder().setBlah(123).build(); + + private final Msg testMsg1 = Msg.newBuilder().setFoo("Foo").setBlah(secondMsg).build(); + + private final SecondMsg secondMsg2 = SecondMsg.newBuilder().setBlah(456).build(); + + private final Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(secondMsg2).build(); + + public ProtobufDecoderTests() { + super(new ProtobufDecoder()); + } + + + @Test(expected = IllegalArgumentException.class) + public void extensionRegistryNull() { + new ProtobufDecoder(null); + } + + @Override + @Test + public void canDecode() { + assertTrue(this.decoder.canDecode(forClass(Msg.class), null)); + assertTrue(this.decoder.canDecode(forClass(Msg.class), PROTOBUF_MIME_TYPE)); + assertTrue(this.decoder.canDecode(forClass(Msg.class), MediaType.APPLICATION_OCTET_STREAM)); + assertFalse(this.decoder.canDecode(forClass(Msg.class), MediaType.APPLICATION_JSON)); + assertFalse(this.decoder.canDecode(forClass(Object.class), PROTOBUF_MIME_TYPE)); + } + + @Override + @Test + public void decodeToMono() { + Mono input = dataBuffer(this.testMsg1); + + testDecodeToMonoAll(input, Msg.class, step -> step + .expectNext(this.testMsg1) + .verifyComplete()); + } + + @Test + public void decodeChunksToMono() { + byte[] full = this.testMsg1.toByteArray(); + byte[] chunk1 = Arrays.copyOfRange(full, 0, full.length / 2); + byte[] chunk2 = Arrays.copyOfRange(full, chunk1.length, full.length); + + Flux input = Flux.just(chunk1, chunk2) + .flatMap(bytes -> Mono.defer(() -> { + DataBuffer dataBuffer = this.bufferFactory.allocateBuffer(bytes.length); + dataBuffer.write(bytes); + return Mono.just(dataBuffer); + })); + + testDecodeToMono(input, Msg.class, step -> step + .expectNext(this.testMsg1) + .verifyComplete()); + } + + @Override + @Test + public void decode() { + Flux input = Flux.just(this.testMsg1, this.testMsg2) + .flatMap(msg -> Mono.defer(() -> { + DataBuffer buffer = this.bufferFactory.allocateBuffer(); + try { + msg.writeDelimitedTo(buffer.asOutputStream()); + return Mono.just(buffer); + } + catch (IOException e) { + release(buffer); + return Mono.error(e); + } + })); + + testDecodeAll(input, Msg.class, step -> step + .expectNext(this.testMsg1) + .expectNext(this.testMsg2) + .verifyComplete()); + } + + @Test + public void decodeSplitChunks() { + + + Flux input = Flux.just(this.testMsg1, this.testMsg2) + .flatMap(msg -> Mono.defer(() -> { + DataBuffer buffer = this.bufferFactory.allocateBuffer(); + try { + msg.writeDelimitedTo(buffer.asOutputStream()); + return Mono.just(buffer); + } + catch (IOException e) { + release(buffer); + return Mono.error(e); + } + })) + .flatMap(buffer -> { + int len = buffer.readableByteCount() / 2; + Flux result = Flux.just( + DataBufferUtils.retain(buffer.slice(0, len)), + DataBufferUtils + .retain(buffer.slice(len, buffer.readableByteCount() - len)) + ); + release(buffer); + return result; + }); + + testDecode(input, Msg.class, step -> step + .expectNext(this.testMsg1) + .expectNext(this.testMsg2) + .verifyComplete()); + } + + @Test // SPR-17429 + public void decodeSplitMessageSize() { + this.decoder.setMaxMessageSize(100009); + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < 10000; i++) { + builder.append("azertyuiop"); + } + Msg bigMessage = Msg.newBuilder().setFoo(builder.toString()).setBlah(secondMsg2).build(); + + Flux input = Flux.just(bigMessage, bigMessage) + .flatMap(msg -> Mono.defer(() -> { + DataBuffer buffer = this.bufferFactory.allocateBuffer(); + try { + msg.writeDelimitedTo(buffer.asOutputStream()); + return Mono.just(buffer); + } + catch (IOException e) { + release(buffer); + return Mono.error(e); + } + })) + .flatMap(buffer -> { + int len = 2; + Flux result = Flux.just( + DataBufferUtils.retain(buffer.slice(0, len)), + DataBufferUtils + .retain(buffer.slice(len, buffer.readableByteCount() - len)) + ); + release(buffer); + return result; + }); + + testDecode(input, Msg.class, step -> step + .expectNext(bigMessage) + .expectNext(bigMessage) + .verifyComplete()); + } + + @Test + public void decodeMergedChunks() throws IOException { + DataBuffer buffer = this.bufferFactory.allocateBuffer(); + this.testMsg1.writeDelimitedTo(buffer.asOutputStream()); + this.testMsg1.writeDelimitedTo(buffer.asOutputStream()); + + ResolvableType elementType = forClass(Msg.class); + Flux messages = this.decoder.decode(Mono.just(buffer), elementType, null, emptyMap()); + + StepVerifier.create(messages) + .expectNext(testMsg1) + .expectNext(testMsg1) + .verifyComplete(); + } + + @Test + public void exceedMaxSize() { + this.decoder.setMaxMessageSize(1); + Mono input = dataBuffer(this.testMsg1); + + testDecode(input, Msg.class, step -> step + .verifyError(DecodingException.class)); + } + + private Mono dataBuffer(Msg msg) { + return Mono.fromCallable(() -> { + byte[] bytes = msg.toByteArray(); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + }); + } + + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufEncoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufEncoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4c707a498251ac4568954ae090cafc10011cb67a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufEncoderTests.java @@ -0,0 +1,114 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.protobuf; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.function.Consumer; + +import com.google.protobuf.Message; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.codec.AbstractEncoderTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.MediaType; +import org.springframework.protobuf.Msg; +import org.springframework.protobuf.SecondMsg; +import org.springframework.util.MimeType; + +import static org.junit.Assert.*; +import static org.springframework.core.ResolvableType.forClass; + +/** + * Unit tests for {@link ProtobufEncoder}. + * + * @author Sebastien Deleuze + */ +public class ProtobufEncoderTests extends AbstractEncoderTestCase { + + private final static MimeType PROTOBUF_MIME_TYPE = new MimeType("application", "x-protobuf"); + + private Msg msg1 = + Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build(); + + private Msg msg2 = + Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build(); + + + public ProtobufEncoderTests() { + super(new ProtobufEncoder()); + } + + @Override + @Test + public void canEncode() { + assertTrue(this.encoder.canEncode(forClass(Msg.class), null)); + assertTrue(this.encoder.canEncode(forClass(Msg.class), PROTOBUF_MIME_TYPE)); + assertTrue(this.encoder.canEncode(forClass(Msg.class), MediaType.APPLICATION_OCTET_STREAM)); + assertFalse(this.encoder.canEncode(forClass(Msg.class), MediaType.APPLICATION_JSON)); + assertFalse(this.encoder.canEncode(forClass(Object.class), PROTOBUF_MIME_TYPE)); + } + + @Override + @Test + public void encode() { + Mono input = Mono.just(this.msg1); + + testEncodeAll(input, Msg.class, step -> step + .consumeNextWith(dataBuffer -> { + try { + assertEquals(this.msg1, Msg.parseFrom(dataBuffer.asInputStream())); + + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + finally { + DataBufferUtils.release(dataBuffer); + } + }) + .verifyComplete()); + } + + @Test + public void encodeStream() { + Flux input = Flux.just(this.msg1, this.msg2); + + testEncodeAll(input, Msg.class, step -> step + .consumeNextWith(expect(this.msg1)) + .consumeNextWith(expect(this.msg2)) + .verifyComplete()); + } + + protected final Consumer expect(Msg msg) { + return dataBuffer -> { + try { + assertEquals(msg, Msg.parseDelimitedFrom(dataBuffer.asInputStream())); + + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + finally { + DataBufferUtils.release(dataBuffer); + } + }; + } +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/support/ClientCodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/support/ClientCodecConfigurerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..aed7f71db4e2a1c6f86069fc03db9a775863bc7f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/support/ClientCodecConfigurerTests.java @@ -0,0 +1,261 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; +import reactor.core.publisher.Flux; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.ByteArrayDecoder; +import org.springframework.core.codec.ByteArrayEncoder; +import org.springframework.core.codec.ByteBufferDecoder; +import org.springframework.core.codec.ByteBufferEncoder; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.DataBufferDecoder; +import org.springframework.core.codec.DataBufferEncoder; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.ResourceDecoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.http.codec.DecoderHttpMessageReader; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageReader; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.ServerSentEventHttpMessageReader; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.json.Jackson2SmileDecoder; +import org.springframework.http.codec.json.Jackson2SmileEncoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.codec.protobuf.ProtobufDecoder; +import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter; +import org.springframework.http.codec.xml.Jaxb2XmlDecoder; +import org.springframework.http.codec.xml.Jaxb2XmlEncoder; +import org.springframework.util.MimeTypeUtils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.springframework.core.ResolvableType.forClass; + +/** + * Unit tests for {@link ClientCodecConfigurer}. + * + * @author Rossen Stoyanchev + */ +public class ClientCodecConfigurerTests { + + private final ClientCodecConfigurer configurer = new DefaultClientCodecConfigurer(); + + private final AtomicInteger index = new AtomicInteger(0); + + + @Test + public void defaultReaders() { + List> readers = this.configurer.getReaders(); + assertEquals(12, readers.size()); + assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass()); + assertStringDecoder(getNextDecoder(readers), true); + assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); // SPR-16804 + assertEquals(Jackson2JsonDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jackson2SmileDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jaxb2XmlDecoder.class, getNextDecoder(readers).getClass()); + assertSseReader(readers); + assertStringDecoder(getNextDecoder(readers), false); + } + + @Test + public void defaultWriters() { + List> writers = this.configurer.getWriters(); + assertEquals(11, writers.size()); + assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertStringEncoder(getNextEncoder(writers), true); + assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertEquals(MultipartHttpMessageWriter.class, writers.get(this.index.getAndIncrement()).getClass()); + assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jackson2SmileEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass()); + assertStringEncoder(getNextEncoder(writers), false); + } + + @Test + public void jackson2EncoderOverride() { + Jackson2JsonDecoder decoder = new Jackson2JsonDecoder(); + this.configurer.defaultCodecs().jackson2JsonDecoder(decoder); + + List> readers = this.configurer.getReaders(); + assertSame(decoder, findCodec(readers, ServerSentEventHttpMessageReader.class).getDecoder()); + } + + @Test + public void maxInMemorySize() { + int size = 99; + this.configurer.defaultCodecs().maxInMemorySize(size); + List> readers = this.configurer.getReaders(); + assertEquals(12, readers.size()); + assertEquals(size, ((ByteArrayDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((ByteBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((DataBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((ResourceDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((ProtobufDecoder) getNextDecoder(readers)).getMaxMessageSize()); + assertEquals(size, ((FormHttpMessageReader) nextReader(readers)).getMaxInMemorySize()); + + assertEquals(size, ((Jackson2JsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((Jackson2SmileDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((Jaxb2XmlDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + + ServerSentEventHttpMessageReader reader = (ServerSentEventHttpMessageReader) nextReader(readers); + assertEquals(size, ((Jackson2JsonDecoder) reader.getDecoder()).getMaxInMemorySize()); + + assertEquals(size, ((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + } + + @Test + public void enableLoggingRequestDetails() { + this.configurer.defaultCodecs().enableLoggingRequestDetails(true); + + List> writers = this.configurer.getWriters(); + MultipartHttpMessageWriter multipartWriter = findCodec(writers, MultipartHttpMessageWriter.class); + assertTrue(multipartWriter.isEnableLoggingRequestDetails()); + + FormHttpMessageWriter formWriter = (FormHttpMessageWriter) multipartWriter.getFormWriter(); + assertNotNull(formWriter); + assertTrue(formWriter.isEnableLoggingRequestDetails()); + } + + @Test + public void clonedConfigurer() { + ClientCodecConfigurer clone = this.configurer.clone(); + + Jackson2JsonDecoder jackson2Decoder = new Jackson2JsonDecoder(); + clone.defaultCodecs().serverSentEventDecoder(jackson2Decoder); + clone.defaultCodecs().multipartCodecs().encoder(new Jackson2SmileEncoder()); + clone.defaultCodecs().multipartCodecs().writer(new ResourceHttpMessageWriter()); + + // Clone has the customizations + + Decoder sseDecoder = findCodec(clone.getReaders(), ServerSentEventHttpMessageReader.class).getDecoder(); + List> writers = findCodec(clone.getWriters(), MultipartHttpMessageWriter.class).getPartWriters(); + + assertSame(jackson2Decoder, sseDecoder); + assertEquals(2, writers.size()); + + // Original does not have the customizations + + sseDecoder = findCodec(this.configurer.getReaders(), ServerSentEventHttpMessageReader.class).getDecoder(); + writers = findCodec(this.configurer.getWriters(), MultipartHttpMessageWriter.class).getPartWriters(); + + assertNotSame(jackson2Decoder, sseDecoder); + assertEquals(10, writers.size()); + } + + @Test // gh-24194 + public void cloneShouldNotDropMultipartCodecs() { + + ClientCodecConfigurer clone = this.configurer.clone(); + List> writers = + findCodec(clone.getWriters(), MultipartHttpMessageWriter.class).getPartWriters(); + + assertEquals(10, writers.size()); + } + + @Test + public void cloneShouldNotBeImpactedByChangesToOriginal() { + + ClientCodecConfigurer clone = this.configurer.clone(); + + this.configurer.registerDefaults(false); + this.configurer.customCodecs().register(new Jackson2JsonEncoder()); + + List> writers = + findCodec(clone.getWriters(), MultipartHttpMessageWriter.class).getPartWriters(); + + assertEquals(10, writers.size()); + } + + private Decoder getNextDecoder(List> readers) { + HttpMessageReader reader = readers.get(this.index.getAndIncrement()); + assertEquals(DecoderHttpMessageReader.class, reader.getClass()); + return ((DecoderHttpMessageReader) reader).getDecoder(); + } + + private HttpMessageReader nextReader(List> readers) { + return readers.get(this.index.getAndIncrement()); + } + + private Encoder getNextEncoder(List> writers) { + HttpMessageWriter writer = writers.get(this.index.getAndIncrement()); + assertEquals(EncoderHttpMessageWriter.class, writer.getClass()); + return ((EncoderHttpMessageWriter) writer).getEncoder(); + } + + @SuppressWarnings("unchecked") + private T findCodec(List codecs, Class type) { + return (T) codecs.stream().filter(type::isInstance).findFirst().get(); + } + + @SuppressWarnings("unchecked") + private void assertStringDecoder(Decoder decoder, boolean textOnly) { + assertEquals(StringDecoder.class, decoder.getClass()); + assertTrue(decoder.canDecode(forClass(String.class), MimeTypeUtils.TEXT_PLAIN)); + assertEquals(!textOnly, decoder.canDecode(forClass(String.class), MediaType.TEXT_EVENT_STREAM)); + + Flux decoded = (Flux) decoder.decode( + Flux.just(new DefaultDataBufferFactory().wrap("line1\nline2".getBytes(StandardCharsets.UTF_8))), + ResolvableType.forClass(String.class), MimeTypeUtils.TEXT_PLAIN, Collections.emptyMap()); + + assertEquals(Arrays.asList("line1", "line2"), decoded.collectList().block(Duration.ZERO)); + } + + private void assertStringEncoder(Encoder encoder, boolean textOnly) { + assertEquals(CharSequenceEncoder.class, encoder.getClass()); + assertTrue(encoder.canEncode(forClass(String.class), MimeTypeUtils.TEXT_PLAIN)); + assertEquals(!textOnly, encoder.canEncode(forClass(String.class), MediaType.TEXT_EVENT_STREAM)); + } + + private void assertSseReader(List> readers) { + HttpMessageReader reader = readers.get(this.index.getAndIncrement()); + assertEquals(ServerSentEventHttpMessageReader.class, reader.getClass()); + Decoder decoder = ((ServerSentEventHttpMessageReader) reader).getDecoder(); + assertNotNull(decoder); + assertEquals(Jackson2JsonDecoder.class, decoder.getClass()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/support/CodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/support/CodecConfigurerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..950f3929958f5e72d66f510207474a48ca86e29e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/support/CodecConfigurerTests.java @@ -0,0 +1,439 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import com.google.protobuf.ExtensionRegistry; +import org.junit.Test; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.ByteArrayDecoder; +import org.springframework.core.codec.ByteArrayEncoder; +import org.springframework.core.codec.ByteBufferDecoder; +import org.springframework.core.codec.ByteBufferEncoder; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.DataBufferDecoder; +import org.springframework.core.codec.DataBufferEncoder; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.ResourceDecoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.http.MediaType; +import org.springframework.http.codec.CodecConfigurer; +import org.springframework.http.codec.DecoderHttpMessageReader; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageReader; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.ServerSentEventHttpMessageReader; +import org.springframework.http.codec.ServerSentEventHttpMessageWriter; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.json.Jackson2SmileDecoder; +import org.springframework.http.codec.json.Jackson2SmileEncoder; +import org.springframework.http.codec.protobuf.ProtobufDecoder; +import org.springframework.http.codec.protobuf.ProtobufEncoder; +import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter; +import org.springframework.http.codec.xml.Jaxb2XmlDecoder; +import org.springframework.http.codec.xml.Jaxb2XmlEncoder; +import org.springframework.util.MimeTypeUtils; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link BaseDefaultCodecs}. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + */ +public class CodecConfigurerTests { + + private final CodecConfigurer configurer = new TestCodecConfigurer(); + + private final AtomicInteger index = new AtomicInteger(0); + + + @Test + public void defaultReaders() { + List> readers = this.configurer.getReaders(); + assertEquals(11, readers.size()); + assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass()); + assertStringDecoder(getNextDecoder(readers), true); + assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); + assertEquals(Jackson2JsonDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jackson2SmileDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jaxb2XmlDecoder.class, getNextDecoder(readers).getClass()); + assertStringDecoder(getNextDecoder(readers), false); + } + + @Test + public void defaultWriters() { + List> writers = this.configurer.getWriters(); + assertEquals(10, writers.size()); + assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertStringEncoder(getNextEncoder(writers), true); + assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jackson2SmileEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass()); + assertStringEncoder(getNextEncoder(writers), false); + } + + @Test + public void defaultAndCustomReaders() { + Decoder customDecoder1 = mock(Decoder.class); + Decoder customDecoder2 = mock(Decoder.class); + + when(customDecoder1.canDecode(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customDecoder2.canDecode(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + HttpMessageReader customReader1 = mock(HttpMessageReader.class); + HttpMessageReader customReader2 = mock(HttpMessageReader.class); + + when(customReader1.canRead(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customReader2.canRead(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + this.configurer.customCodecs().register(customDecoder1); + this.configurer.customCodecs().register(customDecoder2); + + this.configurer.customCodecs().register(customReader1); + this.configurer.customCodecs().register(customReader2); + + List> readers = this.configurer.getReaders(); + + assertEquals(15, readers.size()); + assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(StringDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); + assertSame(customDecoder1, getNextDecoder(readers)); + assertSame(customReader1, readers.get(this.index.getAndIncrement())); + assertEquals(Jackson2JsonDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jackson2SmileDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jaxb2XmlDecoder.class, getNextDecoder(readers).getClass()); + assertSame(customDecoder2, getNextDecoder(readers)); + assertSame(customReader2, readers.get(this.index.getAndIncrement())); + assertEquals(StringDecoder.class, getNextDecoder(readers).getClass()); + } + + @Test + public void defaultAndCustomWriters() { + Encoder customEncoder1 = mock(Encoder.class); + Encoder customEncoder2 = mock(Encoder.class); + + when(customEncoder1.canEncode(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customEncoder2.canEncode(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + HttpMessageWriter customWriter1 = mock(HttpMessageWriter.class); + HttpMessageWriter customWriter2 = mock(HttpMessageWriter.class); + + when(customWriter1.canWrite(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customWriter2.canWrite(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + this.configurer.customCodecs().register(customEncoder1); + this.configurer.customCodecs().register(customEncoder2); + + this.configurer.customCodecs().register(customWriter1); + this.configurer.customCodecs().register(customWriter2); + + List> writers = this.configurer.getWriters(); + + assertEquals(14, writers.size()); + assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertEquals(CharSequenceEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertSame(customEncoder1, getNextEncoder(writers)); + assertSame(customWriter1, writers.get(this.index.getAndIncrement())); + assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jackson2SmileEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass()); + assertSame(customEncoder2, getNextEncoder(writers)); + assertSame(customWriter2, writers.get(this.index.getAndIncrement())); + assertEquals(CharSequenceEncoder.class, getNextEncoder(writers).getClass()); + } + + @Test + public void defaultsOffCustomReaders() { + Decoder customDecoder1 = mock(Decoder.class); + Decoder customDecoder2 = mock(Decoder.class); + + when(customDecoder1.canDecode(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customDecoder2.canDecode(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + HttpMessageReader customReader1 = mock(HttpMessageReader.class); + HttpMessageReader customReader2 = mock(HttpMessageReader.class); + + when(customReader1.canRead(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customReader2.canRead(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + this.configurer.customCodecs().register(customDecoder1); + this.configurer.customCodecs().register(customDecoder2); + + this.configurer.customCodecs().register(customReader1); + this.configurer.customCodecs().register(customReader2); + + this.configurer.registerDefaults(false); + + List> readers = this.configurer.getReaders(); + + assertEquals(4, readers.size()); + assertSame(customDecoder1, getNextDecoder(readers)); + assertSame(customReader1, readers.get(this.index.getAndIncrement())); + assertSame(customDecoder2, getNextDecoder(readers)); + assertSame(customReader2, readers.get(this.index.getAndIncrement())); + } + + @Test + public void defaultsOffWithCustomWriters() { + Encoder customEncoder1 = mock(Encoder.class); + Encoder customEncoder2 = mock(Encoder.class); + + when(customEncoder1.canEncode(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customEncoder2.canEncode(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + HttpMessageWriter customWriter1 = mock(HttpMessageWriter.class); + HttpMessageWriter customWriter2 = mock(HttpMessageWriter.class); + + when(customWriter1.canWrite(ResolvableType.forClass(Object.class), null)).thenReturn(false); + when(customWriter2.canWrite(ResolvableType.forClass(Object.class), null)).thenReturn(true); + + this.configurer.customCodecs().register(customEncoder1); + this.configurer.customCodecs().register(customEncoder2); + + this.configurer.customCodecs().register(customWriter1); + this.configurer.customCodecs().register(customWriter2); + + this.configurer.registerDefaults(false); + + List> writers = this.configurer.getWriters(); + + assertEquals(4, writers.size()); + assertSame(customEncoder1, getNextEncoder(writers)); + assertSame(customWriter1, writers.get(this.index.getAndIncrement())); + assertSame(customEncoder2, getNextEncoder(writers)); + assertSame(customWriter2, writers.get(this.index.getAndIncrement())); + } + + @Test + public void encoderDecoderOverrides() { + Jackson2JsonDecoder jacksonDecoder = new Jackson2JsonDecoder(); + Jackson2JsonEncoder jacksonEncoder = new Jackson2JsonEncoder(); + ProtobufDecoder protobufDecoder = new ProtobufDecoder(ExtensionRegistry.newInstance()); + ProtobufEncoder protobufEncoder = new ProtobufEncoder(); + Jaxb2XmlEncoder jaxb2Encoder = new Jaxb2XmlEncoder(); + Jaxb2XmlDecoder jaxb2Decoder = new Jaxb2XmlDecoder(); + + this.configurer.defaultCodecs().jackson2JsonDecoder(jacksonDecoder); + this.configurer.defaultCodecs().jackson2JsonEncoder(jacksonEncoder); + this.configurer.defaultCodecs().protobufDecoder(protobufDecoder); + this.configurer.defaultCodecs().protobufEncoder(protobufEncoder); + this.configurer.defaultCodecs().jaxb2Decoder(jaxb2Decoder); + this.configurer.defaultCodecs().jaxb2Encoder(jaxb2Encoder); + + assertDecoderInstance(jacksonDecoder); + assertDecoderInstance(protobufDecoder); + assertDecoderInstance(jaxb2Decoder); + assertEncoderInstance(jacksonEncoder); + assertEncoderInstance(protobufEncoder); + assertEncoderInstance(jaxb2Encoder); + } + + @Test + public void cloneEmptyCustomCodecs() { + this.configurer.registerDefaults(false); + assertEquals(0, this.configurer.getReaders().size()); + assertEquals(0, this.configurer.getWriters().size()); + + CodecConfigurer clone = this.configurer.clone(); + clone.customCodecs().register(new Jackson2JsonEncoder()); + clone.customCodecs().register(new Jackson2JsonDecoder()); + clone.customCodecs().register(new ServerSentEventHttpMessageReader()); + clone.customCodecs().register(new ServerSentEventHttpMessageWriter()); + + assertEquals(0, this.configurer.getReaders().size()); + assertEquals(0, this.configurer.getWriters().size()); + assertEquals(2, clone.getReaders().size()); + assertEquals(2, clone.getWriters().size()); + } + + @Test + public void cloneCustomCodecs() { + this.configurer.registerDefaults(false); + assertEquals(0, this.configurer.getReaders().size()); + assertEquals(0, this.configurer.getWriters().size()); + + this.configurer.customCodecs().register(new Jackson2JsonEncoder()); + this.configurer.customCodecs().register(new Jackson2JsonDecoder()); + this.configurer.customCodecs().register(new ServerSentEventHttpMessageReader()); + this.configurer.customCodecs().register(new ServerSentEventHttpMessageWriter()); + assertEquals(2, this.configurer.getReaders().size()); + assertEquals(2, this.configurer.getWriters().size()); + + CodecConfigurer clone = this.configurer.clone(); + assertEquals(2, this.configurer.getReaders().size()); + assertEquals(2, this.configurer.getWriters().size()); + assertEquals(2, clone.getReaders().size()); + assertEquals(2, clone.getWriters().size()); + } + + @Test + public void cloneDefaultCodecs() { + CodecConfigurer clone = this.configurer.clone(); + + Jackson2JsonDecoder jacksonDecoder = new Jackson2JsonDecoder(); + Jackson2JsonEncoder jacksonEncoder = new Jackson2JsonEncoder(); + Jaxb2XmlDecoder jaxb2Decoder = new Jaxb2XmlDecoder(); + Jaxb2XmlEncoder jaxb2Encoder = new Jaxb2XmlEncoder(); + ProtobufDecoder protoDecoder = new ProtobufDecoder(); + ProtobufEncoder protoEncoder = new ProtobufEncoder(); + + clone.defaultCodecs().jackson2JsonDecoder(jacksonDecoder); + clone.defaultCodecs().jackson2JsonEncoder(jacksonEncoder); + clone.defaultCodecs().jaxb2Decoder(jaxb2Decoder); + clone.defaultCodecs().jaxb2Encoder(jaxb2Encoder); + clone.defaultCodecs().protobufDecoder(protoDecoder); + clone.defaultCodecs().protobufEncoder(protoEncoder); + + // Clone has the customized the customizations + + List> decoders = clone.getReaders().stream() + .filter(reader -> reader instanceof DecoderHttpMessageReader) + .map(reader -> ((DecoderHttpMessageReader) reader).getDecoder()) + .collect(Collectors.toList()); + + List> encoders = clone.getWriters().stream() + .filter(writer -> writer instanceof EncoderHttpMessageWriter) + .map(reader -> ((EncoderHttpMessageWriter) reader).getEncoder()) + .collect(Collectors.toList()); + + assertTrue(decoders.containsAll(Arrays.asList(jacksonDecoder, jaxb2Decoder, protoDecoder))); + assertTrue(encoders.containsAll(Arrays.asList(jacksonEncoder, jaxb2Encoder, protoEncoder))); + + // Original does not have the customizations + + decoders = this.configurer.getReaders().stream() + .filter(reader -> reader instanceof DecoderHttpMessageReader) + .map(reader -> ((DecoderHttpMessageReader) reader).getDecoder()) + .collect(Collectors.toList()); + + encoders = this.configurer.getWriters().stream() + .filter(writer -> writer instanceof EncoderHttpMessageWriter) + .map(reader -> ((EncoderHttpMessageWriter) reader).getEncoder()) + .collect(Collectors.toList()); + + assertFalse(decoders.containsAll(Arrays.asList(jacksonDecoder, jaxb2Decoder, protoDecoder))); + assertFalse(encoders.containsAll(Arrays.asList(jacksonEncoder, jaxb2Encoder, protoEncoder))); + } + + @SuppressWarnings("deprecation") + @Test + public void withDefaultCodecConfig() { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + this.configurer.defaultCodecs().enableLoggingRequestDetails(true); + this.configurer.customCodecs().withDefaultCodecConfig(config -> { + assertTrue(config.isEnableLoggingRequestDetails()); + callbackCalled.compareAndSet(false, true); + }); + this.configurer.getReaders(); + assertTrue(callbackCalled.get()); + } + + private Decoder getNextDecoder(List> readers) { + HttpMessageReader reader = readers.get(this.index.getAndIncrement()); + assertEquals(DecoderHttpMessageReader.class, reader.getClass()); + return ((DecoderHttpMessageReader) reader).getDecoder(); + } + + private Encoder getNextEncoder(List> writers) { + HttpMessageWriter writer = writers.get(this.index.getAndIncrement()); + assertEquals(EncoderHttpMessageWriter.class, writer.getClass()); + return ((EncoderHttpMessageWriter) writer).getEncoder(); + } + + private void assertStringDecoder(Decoder decoder, boolean textOnly) { + assertEquals(StringDecoder.class, decoder.getClass()); + assertTrue(decoder.canDecode(ResolvableType.forClass(String.class), MimeTypeUtils.TEXT_PLAIN)); + assertEquals(!textOnly, decoder.canDecode(ResolvableType.forClass(String.class), MediaType.TEXT_EVENT_STREAM)); + } + + private void assertStringEncoder(Encoder encoder, boolean textOnly) { + assertEquals(CharSequenceEncoder.class, encoder.getClass()); + assertTrue(encoder.canEncode(ResolvableType.forClass(String.class), MimeTypeUtils.TEXT_PLAIN)); + assertEquals(!textOnly, encoder.canEncode(ResolvableType.forClass(String.class), MediaType.TEXT_EVENT_STREAM)); + } + + private void assertDecoderInstance(Decoder decoder) { + assertSame(decoder, this.configurer.getReaders().stream() + .filter(writer -> writer instanceof DecoderHttpMessageReader) + .map(writer -> ((DecoderHttpMessageReader) writer).getDecoder()) + .filter(e -> decoder.getClass().equals(e.getClass())) + .findFirst() + .filter(e -> e == decoder).orElse(null)); + } + + private void assertEncoderInstance(Encoder encoder) { + assertSame(encoder, this.configurer.getWriters().stream() + .filter(writer -> writer instanceof EncoderHttpMessageWriter) + .map(writer -> ((EncoderHttpMessageWriter) writer).getEncoder()) + .filter(e -> encoder.getClass().equals(e.getClass())) + .findFirst() + .filter(e -> e == encoder).orElse(null)); + } + + + private static class TestCodecConfigurer extends BaseCodecConfigurer { + + TestCodecConfigurer() { + super(new BaseDefaultCodecs()); + } + + TestCodecConfigurer(TestCodecConfigurer other) { + super(other); + } + + @Override + protected BaseDefaultCodecs cloneDefaultCodecs() { + return new BaseDefaultCodecs((BaseDefaultCodecs) defaultCodecs()); + } + + @Override + public CodecConfigurer clone() { + return new TestCodecConfigurer(this); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b5a2c8d6e46ee84ba33d9649694d74b6e69eb9c1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java @@ -0,0 +1,287 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.support; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; +import reactor.core.publisher.Flux; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.ByteArrayDecoder; +import org.springframework.core.codec.ByteArrayEncoder; +import org.springframework.core.codec.ByteBufferDecoder; +import org.springframework.core.codec.ByteBufferEncoder; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.DataBufferDecoder; +import org.springframework.core.codec.DataBufferEncoder; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.ResourceDecoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.MediaType; +import org.springframework.http.codec.CodecConfigurer; +import org.springframework.http.codec.DecoderHttpMessageReader; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageReader; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.codec.ServerSentEventHttpMessageWriter; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.json.Jackson2SmileDecoder; +import org.springframework.http.codec.json.Jackson2SmileEncoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageReader; +import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader; +import org.springframework.http.codec.protobuf.ProtobufDecoder; +import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter; +import org.springframework.http.codec.xml.Jaxb2XmlDecoder; +import org.springframework.http.codec.xml.Jaxb2XmlEncoder; +import org.springframework.util.MimeTypeUtils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.springframework.core.ResolvableType.forClass; + +/** + * Unit tests for {@link ServerCodecConfigurer}. + * + * @author Rossen Stoyanchev + */ +public class ServerCodecConfigurerTests { + + private final ServerCodecConfigurer configurer = new DefaultServerCodecConfigurer(); + + private final AtomicInteger index = new AtomicInteger(0); + + + @Test + public void defaultReaders() { + List> readers = this.configurer.getReaders(); + assertEquals(13, readers.size()); + assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass()); + assertStringDecoder(getNextDecoder(readers), true); + assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); + assertEquals(SynchronossPartHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); + assertEquals(MultipartHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); + assertEquals(Jackson2JsonDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jackson2SmileDecoder.class, getNextDecoder(readers).getClass()); + assertEquals(Jaxb2XmlDecoder.class, getNextDecoder(readers).getClass()); + assertStringDecoder(getNextDecoder(readers), false); + } + + @Test + public void defaultWriters() { + List> writers = this.configurer.getWriters(); + assertEquals(11, writers.size()); + assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertStringEncoder(getNextEncoder(writers), true); + assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); + assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jackson2SmileEncoder.class, getNextEncoder(writers).getClass()); + assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass()); + assertSseWriter(writers); + assertStringEncoder(getNextEncoder(writers), false); + } + + @Test + public void jackson2EncoderOverride() { + Jackson2JsonEncoder encoder = new Jackson2JsonEncoder(); + this.configurer.defaultCodecs().jackson2JsonEncoder(encoder); + + assertSame(encoder, this.configurer.getWriters().stream() + .filter(writer -> ServerSentEventHttpMessageWriter.class.equals(writer.getClass())) + .map(writer -> (ServerSentEventHttpMessageWriter) writer) + .findFirst() + .map(ServerSentEventHttpMessageWriter::getEncoder) + .filter(e -> e == encoder).orElse(null)); + } + + @Test + public void maxInMemorySize() { + int size = 99; + this.configurer.defaultCodecs().maxInMemorySize(size); + + List> readers = this.configurer.getReaders(); + assertEquals(13, readers.size()); + assertEquals(size, ((ByteArrayDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((ByteBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((DataBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((ResourceDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((ProtobufDecoder) getNextDecoder(readers)).getMaxMessageSize()); + assertEquals(size, ((FormHttpMessageReader) nextReader(readers)).getMaxInMemorySize()); + assertEquals(size, ((SynchronossPartHttpMessageReader) nextReader(readers)).getMaxInMemorySize()); + + MultipartHttpMessageReader multipartReader = (MultipartHttpMessageReader) nextReader(readers); + SynchronossPartHttpMessageReader reader = (SynchronossPartHttpMessageReader) multipartReader.getPartReader(); + assertEquals(size, (reader).getMaxInMemorySize()); + + assertEquals(size, ((Jackson2JsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((Jackson2SmileDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((Jaxb2XmlDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + } + + @Test + public void maxInMemorySizeWithCustomCodecs() { + + int size = 99; + this.configurer.defaultCodecs().maxInMemorySize(size); + this.configurer.registerDefaults(false); + + CodecConfigurer.CustomCodecs customCodecs = this.configurer.customCodecs(); + customCodecs.register(new ByteArrayDecoder()); + customCodecs.registerWithDefaultConfig(new ByteArrayDecoder()); + customCodecs.register(new Jackson2JsonDecoder()); + customCodecs.registerWithDefaultConfig(new Jackson2JsonDecoder()); + + this.configurer.defaultCodecs().enableLoggingRequestDetails(true); + + List> readers = this.configurer.getReaders(); + assertEquals(-1, ((ByteArrayDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((ByteArrayDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(-1, ((Jackson2JsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + assertEquals(size, ((Jackson2JsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()); + } + + @Test + public void enableRequestLoggingDetails() { + this.configurer.defaultCodecs().enableLoggingRequestDetails(true); + + List> readers = this.configurer.getReaders(); + assertTrue(findCodec(readers, FormHttpMessageReader.class).isEnableLoggingRequestDetails()); + + MultipartHttpMessageReader multipartReader = findCodec(readers, MultipartHttpMessageReader.class); + assertTrue(multipartReader.isEnableLoggingRequestDetails()); + + SynchronossPartHttpMessageReader reader = (SynchronossPartHttpMessageReader) multipartReader.getPartReader(); + assertTrue(reader.isEnableLoggingRequestDetails()); + } + + @Test + public void enableRequestLoggingDetailsWithCustomCodecs() { + + this.configurer.registerDefaults(false); + this.configurer.defaultCodecs().enableLoggingRequestDetails(true); + + CodecConfigurer.CustomCodecs customCodecs = this.configurer.customCodecs(); + customCodecs.register(new FormHttpMessageReader()); + customCodecs.registerWithDefaultConfig(new FormHttpMessageReader()); + + List> readers = this.configurer.getReaders(); + assertFalse(((FormHttpMessageReader) readers.get(0)).isEnableLoggingRequestDetails()); + assertTrue(((FormHttpMessageReader) readers.get(1)).isEnableLoggingRequestDetails()); + } + + @Test + public void cloneConfigurer() { + ServerCodecConfigurer clone = this.configurer.clone(); + + MultipartHttpMessageReader reader = new MultipartHttpMessageReader(new SynchronossPartHttpMessageReader()); + Jackson2JsonEncoder encoder = new Jackson2JsonEncoder(); + clone.defaultCodecs().multipartReader(reader); + clone.defaultCodecs().serverSentEventEncoder(encoder); + + // Clone has the customizations + + HttpMessageReader actualReader = + findCodec(clone.getReaders(), MultipartHttpMessageReader.class); + + ServerSentEventHttpMessageWriter actualWriter = + findCodec(clone.getWriters(), ServerSentEventHttpMessageWriter.class); + + assertSame(reader, actualReader); + assertSame(encoder, actualWriter.getEncoder()); + + // Original does not have the customizations + + actualReader = findCodec(this.configurer.getReaders(), MultipartHttpMessageReader.class); + actualWriter = findCodec(this.configurer.getWriters(), ServerSentEventHttpMessageWriter.class); + + assertNotSame(reader, actualReader); + assertNotSame(encoder, actualWriter.getEncoder()); + } + + private Decoder getNextDecoder(List> readers) { + HttpMessageReader reader = nextReader(readers); + assertEquals(DecoderHttpMessageReader.class, reader.getClass()); + return ((DecoderHttpMessageReader) reader).getDecoder(); + } + + private HttpMessageReader nextReader(List> readers) { + return readers.get(this.index.getAndIncrement()); + } + + private Encoder getNextEncoder(List> writers) { + HttpMessageWriter writer = writers.get(this.index.getAndIncrement()); + assertEquals(EncoderHttpMessageWriter.class, writer.getClass()); + return ((EncoderHttpMessageWriter) writer).getEncoder(); + } + + @SuppressWarnings("unchecked") + private T findCodec(List codecs, Class type) { + return (T) codecs.stream().filter(type::isInstance).findFirst().get(); + } + + @SuppressWarnings("unchecked") + private void assertStringDecoder(Decoder decoder, boolean textOnly) { + assertEquals(StringDecoder.class, decoder.getClass()); + assertTrue(decoder.canDecode(forClass(String.class), MimeTypeUtils.TEXT_PLAIN)); + assertEquals(!textOnly, decoder.canDecode(forClass(String.class), MediaType.TEXT_EVENT_STREAM)); + + Flux flux = (Flux) decoder.decode( + Flux.just(new DefaultDataBufferFactory().wrap("line1\nline2".getBytes(StandardCharsets.UTF_8))), + ResolvableType.forClass(String.class), MimeTypeUtils.TEXT_PLAIN, Collections.emptyMap()); + + assertEquals(Arrays.asList("line1", "line2"), flux.collectList().block(Duration.ZERO)); + } + + private void assertStringEncoder(Encoder encoder, boolean textOnly) { + assertEquals(CharSequenceEncoder.class, encoder.getClass()); + assertTrue(encoder.canEncode(forClass(String.class), MimeTypeUtils.TEXT_PLAIN)); + assertEquals(!textOnly, encoder.canEncode(forClass(String.class), MediaType.TEXT_EVENT_STREAM)); + } + + private void assertSseWriter(List> writers) { + HttpMessageWriter writer = writers.get(this.index.getAndIncrement()); + assertEquals(ServerSentEventHttpMessageWriter.class, writer.getClass()); + Encoder encoder = ((ServerSentEventHttpMessageWriter) writer).getEncoder(); + assertNotNull(encoder); + assertEquals(Jackson2JsonEncoder.class, encoder.getClass()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/Jaxb2XmlDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/xml/Jaxb2XmlDecoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..494c89377f7fe1dfc3216a04f046ac4ec681da76 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/Jaxb2XmlDecoderTests.java @@ -0,0 +1,305 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; + +import javax.xml.namespace.QName; +import javax.xml.stream.events.XMLEvent; + +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.AbstractLeakCheckingTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.MediaType; +import org.springframework.http.codec.Pojo; +import org.springframework.http.codec.xml.jaxb.XmlRootElement; +import org.springframework.http.codec.xml.jaxb.XmlRootElementWithName; +import org.springframework.http.codec.xml.jaxb.XmlRootElementWithNameAndNamespace; +import org.springframework.http.codec.xml.jaxb.XmlType; +import org.springframework.http.codec.xml.jaxb.XmlTypeWithName; +import org.springframework.http.codec.xml.jaxb.XmlTypeWithNameAndNamespace; + +import static org.junit.Assert.*; + +/** + * @author Sebastien Deleuze + */ +public class Jaxb2XmlDecoderTests extends AbstractLeakCheckingTestCase { + + private static final String POJO_ROOT = "" + + "" + + "foofoo" + + "barbar" + + ""; + + private static final String POJO_CHILD = + "" + + "" + + "" + + "foo" + + "bar" + + "" + + "" + + "foofoo" + + "barbar" + + "" + + ""; + + + private final Jaxb2XmlDecoder decoder = new Jaxb2XmlDecoder(); + + private final XmlEventDecoder xmlEventDecoder = new XmlEventDecoder(); + + + @Test + public void canDecode() { + assertTrue(this.decoder.canDecode(ResolvableType.forClass(Pojo.class), + MediaType.APPLICATION_XML)); + assertTrue(this.decoder.canDecode(ResolvableType.forClass(Pojo.class), + MediaType.TEXT_XML)); + assertFalse(this.decoder.canDecode(ResolvableType.forClass(Pojo.class), + MediaType.APPLICATION_JSON)); + assertTrue(this.decoder.canDecode(ResolvableType.forClass(TypePojo.class), + MediaType.APPLICATION_XML)); + assertFalse(this.decoder.canDecode(ResolvableType.forClass(getClass()), + MediaType.APPLICATION_XML)); + } + + @Test + public void splitOneBranches() { + Flux xmlEvents = this.xmlEventDecoder + .decode(stringBuffer(POJO_ROOT), null, null, Collections.emptyMap()); + Flux> result = this.decoder.split(xmlEvents, new QName("pojo")); + + StepVerifier.create(result) + .consumeNextWith(events -> { + assertEquals(8, events.size()); + assertStartElement(events.get(0), "pojo"); + assertStartElement(events.get(1), "foo"); + assertCharacters(events.get(2), "foofoo"); + assertEndElement(events.get(3), "foo"); + assertStartElement(events.get(4), "bar"); + assertCharacters(events.get(5), "barbar"); + assertEndElement(events.get(6), "bar"); + assertEndElement(events.get(7), "pojo"); + }) + .expectComplete() + .verify(); + } + + @Test + public void splitMultipleBranches() throws Exception { + Flux xmlEvents = this.xmlEventDecoder + .decode(stringBuffer(POJO_CHILD), null, null, Collections.emptyMap()); + Flux> result = this.decoder.split(xmlEvents, new QName("pojo")); + + + StepVerifier.create(result) + .consumeNextWith(events -> { + assertEquals(8, events.size()); + assertStartElement(events.get(0), "pojo"); + assertStartElement(events.get(1), "foo"); + assertCharacters(events.get(2), "foo"); + assertEndElement(events.get(3), "foo"); + assertStartElement(events.get(4), "bar"); + assertCharacters(events.get(5), "bar"); + assertEndElement(events.get(6), "bar"); + assertEndElement(events.get(7), "pojo"); + }) + .consumeNextWith(events -> { + assertEquals(8, events.size()); + assertStartElement(events.get(0), "pojo"); + assertStartElement(events.get(1), "foo"); + assertCharacters(events.get(2), "foofoo"); + assertEndElement(events.get(3), "foo"); + assertStartElement(events.get(4), "bar"); + assertCharacters(events.get(5), "barbar"); + assertEndElement(events.get(6), "bar"); + assertEndElement(events.get(7), "pojo"); + }) + .expectComplete() + .verify(); + } + + private static void assertStartElement(XMLEvent event, String expectedLocalName) { + assertTrue(event.isStartElement()); + assertEquals(expectedLocalName, event.asStartElement().getName().getLocalPart()); + } + + private static void assertEndElement(XMLEvent event, String expectedLocalName) { + assertTrue(event.isEndElement()); + assertEquals(expectedLocalName, event.asEndElement().getName().getLocalPart()); + } + + private static void assertCharacters(XMLEvent event, String expectedData) { + assertTrue(event.isCharacters()); + assertEquals(expectedData, event.asCharacters().getData()); + } + + @Test + public void decodeSingleXmlRootElement() throws Exception { + Mono source = stringBuffer(POJO_ROOT); + Mono output = this.decoder.decodeToMono(source, ResolvableType.forClass(Pojo.class), + null, Collections.emptyMap()); + + StepVerifier.create(output) + .expectNext(new Pojo("foofoo", "barbar")) + .expectComplete() + .verify(); + } + + @Test + public void decodeSingleXmlTypeElement() throws Exception { + Mono source = stringBuffer(POJO_ROOT); + Mono output = this.decoder.decodeToMono(source, ResolvableType.forClass(TypePojo.class), + null, Collections.emptyMap()); + + StepVerifier.create(output) + .expectNext(new TypePojo("foofoo", "barbar")) + .expectComplete() + .verify(); + } + + @Test + public void decodeMultipleXmlRootElement() throws Exception { + Mono source = stringBuffer(POJO_CHILD); + Flux output = this.decoder.decode(source, ResolvableType.forClass(Pojo.class), + null, Collections.emptyMap()); + + StepVerifier.create(output) + .expectNext(new Pojo("foo", "bar")) + .expectNext(new Pojo("foofoo", "barbar")) + .expectComplete() + .verify(); + } + + @Test + public void decodeMultipleXmlTypeElement() throws Exception { + Mono source = stringBuffer(POJO_CHILD); + Flux output = this.decoder.decode(source, ResolvableType.forClass(TypePojo.class), + null, Collections.emptyMap()); + + StepVerifier.create(output) + .expectNext(new TypePojo("foo", "bar")) + .expectNext(new TypePojo("foofoo", "barbar")) + .expectComplete() + .verify(); + } + + @Test + public void decodeError() throws Exception { + Flux source = Flux.concat( + stringBuffer(""), + Flux.error(new RuntimeException())); + + Mono output = this.decoder.decodeToMono(source, ResolvableType.forClass(Pojo.class), + null, Collections.emptyMap()); + + StepVerifier.create(output) + .expectError(RuntimeException.class) + .verify(); + } + + @Test + public void toExpectedQName() { + assertEquals(new QName("pojo"), this.decoder.toQName(Pojo.class)); + assertEquals(new QName("pojo"), this.decoder.toQName(TypePojo.class)); + + assertEquals(new QName("namespace", "name"), + this.decoder.toQName(XmlRootElementWithNameAndNamespace.class)); + assertEquals(new QName("namespace", "name"), + this.decoder.toQName(XmlRootElementWithName.class)); + assertEquals(new QName("namespace", "xmlRootElement"), + this.decoder.toQName(XmlRootElement.class)); + + assertEquals(new QName("namespace", "name"), + this.decoder.toQName(XmlTypeWithNameAndNamespace.class)); + assertEquals(new QName("namespace", "name"), + this.decoder.toQName(XmlTypeWithName.class)); + assertEquals(new QName("namespace", "xmlType"), + this.decoder.toQName(XmlType.class)); + + } + + private Mono stringBuffer(String value) { + return Mono.defer(() -> { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return Mono.just(buffer); + }); + } + + + @javax.xml.bind.annotation.XmlType(name = "pojo") + public static class TypePojo { + + private String foo; + + private String bar; + + public TypePojo() { + } + + public TypePojo(String foo, String bar) { + this.foo = foo; + this.bar = bar; + } + + public String getFoo() { + return this.foo; + } + + public void setFoo(String foo) { + this.foo = foo; + } + + public String getBar() { + return this.bar; + } + + public void setBar(String bar) { + this.bar = bar; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof TypePojo) { + TypePojo other = (TypePojo) o; + return this.foo.equals(other.foo) && this.bar.equals(other.bar); + } + return false; + } + + @Override + public int hashCode() { + int result = this.foo.hashCode(); + result = 31 * result + this.bar.hashCode(); + return result; + } + } +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/Jaxb2XmlEncoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/xml/Jaxb2XmlEncoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..af1c7f4efce691f797452891368b9d951e49aad4 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/Jaxb2XmlEncoderTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElements; +import javax.xml.bind.annotation.XmlRootElement; + +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractEncoderTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.MediaType; +import org.springframework.http.codec.Pojo; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.*; +import static org.springframework.core.io.buffer.DataBufferUtils.release; +import static org.xmlunit.matchers.CompareMatcher.isSimilarTo; + +/** + * @author Sebastien Deleuze + * @author Arjen Poutsma + */ +public class Jaxb2XmlEncoderTests extends AbstractEncoderTestCase { + + public Jaxb2XmlEncoderTests() { + super(new Jaxb2XmlEncoder()); + } + + @Override + @Test + public void canEncode() { + assertTrue(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), + MediaType.APPLICATION_XML)); + assertTrue(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), + MediaType.TEXT_XML)); + assertFalse(this.encoder.canEncode(ResolvableType.forClass(Pojo.class), + MediaType.APPLICATION_JSON)); + + assertTrue(this.encoder.canEncode( + ResolvableType.forClass(Jaxb2XmlDecoderTests.TypePojo.class), + MediaType.APPLICATION_XML)); + + assertFalse(this.encoder.canEncode(ResolvableType.forClass(getClass()), + MediaType.APPLICATION_XML)); + + // SPR-15464 + assertFalse(this.encoder.canEncode(ResolvableType.NONE, null)); + } + + @Override + @Test + public void encode() { + Mono input = Mono.just(new Pojo("foofoo", "barbar")); + + testEncode(input, Pojo.class, step -> step + .consumeNextWith( + expectXml("" + + "barbarfoofoo")) + .verifyComplete()); + } + + @Test + public void encodeError() { + Flux input = Flux.error(RuntimeException::new); + + testEncode(input, Pojo.class, step -> step + .expectError(RuntimeException.class) + .verify()); + } + + @Test + public void encodeElementsWithCommonType() { + Mono input = Mono.just(new Container()); + + testEncode(input, Pojo.class, step -> step + .consumeNextWith( + expectXml("" + + "name1title1")) + .verifyComplete()); + } + + protected Consumer expectXml(String expected) { + return dataBuffer -> { + byte[] resultBytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(resultBytes); + release(dataBuffer); + String actual = new String(resultBytes, UTF_8); + assertThat(actual, isSimilarTo(expected)); + }; + } + + public static class Model {} + + public static class Foo extends Model { + + private String name; + + public Foo(String name) { + this.name = name; + } + + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + } + + public static class Bar extends Model { + + private String title; + + public Bar(String title) { + this.title = title; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + } + + @XmlRootElement + public static class Container { + + @XmlElements({ + @XmlElement(name="foo", type=Foo.class), + @XmlElement(name="bar", type=Bar.class) + }) + public List getElements() { + return Arrays.asList(new Foo("name1"), new Bar("title1")); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/XmlEventDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/xml/XmlEventDecoderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5e4c394145303a0c1ef46bc34e075dbd10643f63 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/XmlEventDecoderTests.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import javax.xml.stream.events.XMLEvent; + +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.io.buffer.AbstractLeakCheckingTestCase; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author Arjen Poutsma + */ +public class XmlEventDecoderTests extends AbstractLeakCheckingTestCase { + + private static final String XML = "" + + "" + + "foofoo" + + "barbar" + + ""; + + private XmlEventDecoder decoder = new XmlEventDecoder(); + + + @Test + public void toXMLEventsAalto() { + + Flux events = + this.decoder.decode(stringBufferMono(XML), null, null, Collections.emptyMap()); + + StepVerifier.create(events) + .consumeNextWith(e -> assertTrue(e.isStartDocument())) + .consumeNextWith(e -> assertStartElement(e, "pojo")) + .consumeNextWith(e -> assertStartElement(e, "foo")) + .consumeNextWith(e -> assertCharacters(e, "foofoo")) + .consumeNextWith(e -> assertEndElement(e, "foo")) + .consumeNextWith(e -> assertStartElement(e, "bar")) + .consumeNextWith(e -> assertCharacters(e, "barbar")) + .consumeNextWith(e -> assertEndElement(e, "bar")) + .consumeNextWith(e -> assertEndElement(e, "pojo")) + .expectComplete() + .verify(); + } + + @Test + public void toXMLEventsNonAalto() { + decoder.useAalto = false; + + Flux events = + this.decoder.decode(stringBufferMono(XML), null, null, Collections.emptyMap()); + + StepVerifier.create(events) + .consumeNextWith(e -> assertTrue(e.isStartDocument())) + .consumeNextWith(e -> assertStartElement(e, "pojo")) + .consumeNextWith(e -> assertStartElement(e, "foo")) + .consumeNextWith(e -> assertCharacters(e, "foofoo")) + .consumeNextWith(e -> assertEndElement(e, "foo")) + .consumeNextWith(e -> assertStartElement(e, "bar")) + .consumeNextWith(e -> assertCharacters(e, "barbar")) + .consumeNextWith(e -> assertEndElement(e, "bar")) + .consumeNextWith(e -> assertEndElement(e, "pojo")) + .consumeNextWith(e -> assertTrue(e.isEndDocument())) + .expectComplete() + .verify(); + } + + @Test + public void toXMLEventsWithLimit() { + + this.decoder.setMaxInMemorySize(6); + + Flux source = Flux.just( + "", "", "foofoo", "", "", "barbarbar", "", ""); + + Flux events = this.decoder.decode( + source.map(this::stringBuffer), null, null, Collections.emptyMap()); + + StepVerifier.create(events) + .consumeNextWith(e -> assertTrue(e.isStartDocument())) + .consumeNextWith(e -> assertStartElement(e, "pojo")) + .consumeNextWith(e -> assertStartElement(e, "foo")) + .consumeNextWith(e -> assertCharacters(e, "foofoo")) + .consumeNextWith(e -> assertEndElement(e, "foo")) + .consumeNextWith(e -> assertStartElement(e, "bar")) + .expectError(DataBufferLimitException.class) + .verify(); + } + + @Test + public void decodeErrorAalto() { + Flux source = Flux.concat( + stringBufferMono(""), + Flux.error(new RuntimeException())); + + Flux events = + this.decoder.decode(source, null, null, Collections.emptyMap()); + + StepVerifier.create(events) + .consumeNextWith(e -> assertTrue(e.isStartDocument())) + .consumeNextWith(e -> assertStartElement(e, "pojo")) + .expectError(RuntimeException.class) + .verify(); + } + + @Test + public void decodeErrorNonAalto() { + decoder.useAalto = false; + + Flux source = Flux.concat( + stringBufferMono(""), + Flux.error(new RuntimeException())); + + Flux events = + this.decoder.decode(source, null, null, Collections.emptyMap()); + + StepVerifier.create(events) + .expectError(RuntimeException.class) + .verify(); + } + + private static void assertStartElement(XMLEvent event, String expectedLocalName) { + assertTrue(event.isStartElement()); + assertEquals(expectedLocalName, event.asStartElement().getName().getLocalPart()); + } + + private static void assertEndElement(XMLEvent event, String expectedLocalName) { + assertTrue(event + " is no end element", event.isEndElement()); + assertEquals(expectedLocalName, event.asEndElement().getName().getLocalPart()); + } + + private static void assertCharacters(XMLEvent event, String expectedData) { + assertTrue(event.isCharacters()); + assertEquals(expectedData, event.asCharacters().getData()); + } + + private DataBuffer stringBuffer(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + } + + private Mono stringBufferMono(String value) { + return Mono.defer(() -> Mono.just(stringBuffer(value))); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElement.java b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElement.java new file mode 100644 index 0000000000000000000000000000000000000000..7a2dae1aca4c505a0c79c5442d7e62d741c3b12a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElement.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml.jaxb; + +/** + * @author Arjen Poutsma + */ +@javax.xml.bind.annotation.XmlRootElement +public class XmlRootElement { + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElementWithName.java b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElementWithName.java new file mode 100644 index 0000000000000000000000000000000000000000..88de86671ac3913172811ee1b397a994faf8af66 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElementWithName.java @@ -0,0 +1,27 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml.jaxb; + +import javax.xml.bind.annotation.XmlRootElement; + +/** + * @author Arjen Poutsma + */ +@XmlRootElement(name = "name") +public class XmlRootElementWithName { + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElementWithNameAndNamespace.java b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElementWithNameAndNamespace.java new file mode 100644 index 0000000000000000000000000000000000000000..dfb765bd0981def06970b485eec63aaac7130822 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlRootElementWithNameAndNamespace.java @@ -0,0 +1,27 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml.jaxb; + +import javax.xml.bind.annotation.XmlRootElement; + +/** + * @author Arjen Poutsma + */ +@XmlRootElement(name = "name", namespace = "namespace") +public class XmlRootElementWithNameAndNamespace { + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlType.java b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlType.java new file mode 100644 index 0000000000000000000000000000000000000000..cff0d9f4509138c5ccc3d0e7937795d2dbd25050 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlType.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml.jaxb; + +/** + * @author Arjen Poutsma + */ +@javax.xml.bind.annotation.XmlType +public class XmlType { + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlTypeWithName.java b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlTypeWithName.java new file mode 100644 index 0000000000000000000000000000000000000000..60c42519a8db920840cd4baddffe515ff8856528 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlTypeWithName.java @@ -0,0 +1,27 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml.jaxb; + +import javax.xml.bind.annotation.XmlType; + +/** + * @author Arjen Poutsma + */ +@XmlType(name = "name") +public class XmlTypeWithName { + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlTypeWithNameAndNamespace.java b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlTypeWithNameAndNamespace.java new file mode 100644 index 0000000000000000000000000000000000000000..f1ea796fbaa2019e165dad5df19db76410791c5e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/XmlTypeWithNameAndNamespace.java @@ -0,0 +1,27 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.xml.jaxb; + +import javax.xml.bind.annotation.XmlType; + +/** + * @author Arjen Poutsma + */ +@XmlType(name = "name", namespace = "namespace") +public class XmlTypeWithNameAndNamespace { + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/package-info.java b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..f3b3a59496b2292ff3b1068095986343e85046d5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/jaxb/package-info.java @@ -0,0 +1,2 @@ +@javax.xml.bind.annotation.XmlSchema(namespace = "namespace") +package org.springframework.http.codec.xml.jaxb; diff --git a/spring-web/src/test/java/org/springframework/http/converter/BufferedImageHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/BufferedImageHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..001ea18a8b859103c9c8f3f68373a396813f0948 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/BufferedImageHttpMessageConverterTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.awt.image.BufferedImage; +import java.io.ByteArrayInputStream; +import java.io.IOException; + +import javax.imageio.ImageIO; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.util.FileCopyUtils; + +import static org.junit.Assert.*; + +/** + * Unit tests for BufferedImageHttpMessageConverter. + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class BufferedImageHttpMessageConverterTests { + + private BufferedImageHttpMessageConverter converter; + + @Before + public void setUp() { + converter = new BufferedImageHttpMessageConverter(); + } + + @Test + public void canRead() { + assertTrue("Image not supported", converter.canRead(BufferedImage.class, null)); + assertTrue("Image not supported", converter.canRead(BufferedImage.class, new MediaType("image", "png"))); + } + + @Test + public void canWrite() { + assertTrue("Image not supported", converter.canWrite(BufferedImage.class, null)); + assertTrue("Image not supported", converter.canWrite(BufferedImage.class, new MediaType("image", "png"))); + assertTrue("Image not supported", converter.canWrite(BufferedImage.class, new MediaType("*", "*"))); + } + + @Test + public void read() throws IOException { + Resource logo = new ClassPathResource("logo.jpg", BufferedImageHttpMessageConverterTests.class); + byte[] body = FileCopyUtils.copyToByteArray(logo.getInputStream()); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + inputMessage.getHeaders().setContentType(new MediaType("image", "jpeg")); + BufferedImage result = converter.read(BufferedImage.class, inputMessage); + assertEquals("Invalid height", 500, result.getHeight()); + assertEquals("Invalid width", 750, result.getWidth()); + } + + @Test + public void write() throws IOException { + Resource logo = new ClassPathResource("logo.jpg", BufferedImageHttpMessageConverterTests.class); + BufferedImage body = ImageIO.read(logo.getFile()); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MediaType contentType = new MediaType("image", "png"); + converter.write(body, contentType, outputMessage); + assertEquals("Invalid content type", contentType, outputMessage.getWrittenHeaders().getContentType()); + assertTrue("Invalid size", outputMessage.getBodyAsBytes().length > 0); + BufferedImage result = ImageIO.read(new ByteArrayInputStream(outputMessage.getBodyAsBytes())); + assertEquals("Invalid height", 500, result.getHeight()); + assertEquals("Invalid width", 750, result.getWidth()); + } + + @Test + public void writeDefaultContentType() throws IOException { + Resource logo = new ClassPathResource("logo.jpg", BufferedImageHttpMessageConverterTests.class); + MediaType contentType = new MediaType("image", "png"); + converter.setDefaultContentType(contentType); + BufferedImage body = ImageIO.read(logo.getFile()); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(body, new MediaType("*", "*"), outputMessage); + assertEquals("Invalid content type", contentType, outputMessage.getWrittenHeaders().getContentType()); + assertTrue("Invalid size", outputMessage.getBodyAsBytes().length > 0); + BufferedImage result = ImageIO.read(new ByteArrayInputStream(outputMessage.getBodyAsBytes())); + assertEquals("Invalid height", 500, result.getHeight()); + assertEquals("Invalid width", 750, result.getWidth()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/ByteArrayHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/ByteArrayHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b0bd8b15c15ff43d3fe5bc8f46206d9e63ca468d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/ByteArrayHttpMessageConverterTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2009 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; + +import static org.junit.Assert.*; + +/** @author Arjen Poutsma */ +public class ByteArrayHttpMessageConverterTests { + + private ByteArrayHttpMessageConverter converter; + + @Before + public void setUp() { + converter = new ByteArrayHttpMessageConverter(); + } + + @Test + public void canRead() { + assertTrue(converter.canRead(byte[].class, new MediaType("application", "octet-stream"))); + } + + @Test + public void canWrite() { + assertTrue(converter.canWrite(byte[].class, new MediaType("application", "octet-stream"))); + assertTrue(converter.canWrite(byte[].class, MediaType.ALL)); + } + + @Test + public void read() throws IOException { + byte[] body = new byte[]{0x1, 0x2}; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + inputMessage.getHeaders().setContentType(new MediaType("application", "octet-stream")); + byte[] result = converter.read(byte[].class, inputMessage); + assertArrayEquals("Invalid result", body, result); + } + + @Test + public void write() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + byte[] body = new byte[]{0x1, 0x2}; + converter.write(body, null, outputMessage); + assertArrayEquals("Invalid result", body, outputMessage.getBodyAsBytes()); + assertEquals("Invalid content-type", new MediaType("application", "octet-stream"), + outputMessage.getHeaders().getContentType()); + assertEquals("Invalid content-length", 2, outputMessage.getHeaders().getContentLength()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..27cbc8087cbabf75db92c0adf10eef75e14f2211 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java @@ -0,0 +1,290 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringReader; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import javax.xml.transform.Source; +import javax.xml.transform.stream.StreamSource; + +import org.apache.commons.fileupload.FileItem; +import org.apache.commons.fileupload.FileItemFactory; +import org.apache.commons.fileupload.FileUpload; +import org.apache.commons.fileupload.RequestContext; +import org.apache.commons.fileupload.disk.DiskFileItemFactory; +import org.hamcrest.Matchers; +import org.junit.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.endsWith; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.BDDMockito.never; +import static org.mockito.BDDMockito.verify; + +/** + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class FormHttpMessageConverterTests { + + private final FormHttpMessageConverter converter = new AllEncompassingFormHttpMessageConverter(); + + + @Test + public void canRead() { + assertTrue(this.converter.canRead(MultiValueMap.class, + new MediaType("application", "x-www-form-urlencoded"))); + assertFalse(this.converter.canRead(MultiValueMap.class, + new MediaType("multipart", "form-data"))); + } + + @Test + public void canWrite() { + assertTrue(this.converter.canWrite(MultiValueMap.class, + new MediaType("application", "x-www-form-urlencoded"))); + assertTrue(this.converter.canWrite(MultiValueMap.class, + new MediaType("multipart", "form-data"))); + assertTrue(this.converter.canWrite(MultiValueMap.class, + new MediaType("multipart", "form-data", StandardCharsets.UTF_8))); + assertTrue(this.converter.canWrite(MultiValueMap.class, MediaType.ALL)); + } + + @Test + public void readForm() throws Exception { + String body = "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.ISO_8859_1)); + inputMessage.getHeaders().setContentType( + new MediaType("application", "x-www-form-urlencoded", StandardCharsets.ISO_8859_1)); + MultiValueMap result = this.converter.read(null, inputMessage); + + assertEquals("Invalid result", 3, result.size()); + assertEquals("Invalid result", "value 1", result.getFirst("name 1")); + List values = result.get("name 2"); + assertEquals("Invalid result", 2, values.size()); + assertEquals("Invalid result", "value 2+1", values.get(0)); + assertEquals("Invalid result", "value 2+2", values.get(1)); + assertNull("Invalid result", result.getFirst("name 3")); + } + + @Test + public void writeForm() throws IOException { + MultiValueMap body = new LinkedMultiValueMap<>(); + body.set("name 1", "value 1"); + body.add("name 2", "value 2+1"); + body.add("name 2", "value 2+2"); + body.add("name 3", null); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.write(body, MediaType.APPLICATION_FORM_URLENCODED, outputMessage); + + assertEquals("Invalid result", "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3", + outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + assertEquals("Invalid content-type", "application/x-www-form-urlencoded;charset=UTF-8", + outputMessage.getHeaders().getContentType().toString()); + assertEquals("Invalid content-length", outputMessage.getBodyAsBytes().length, + outputMessage.getHeaders().getContentLength()); + } + + @Test + public void writeMultipart() throws Exception { + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("name 1", "value 1"); + parts.add("name 2", "value 2+1"); + parts.add("name 2", "value 2+2"); + parts.add("name 3", null); + + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("logo", logo); + + // SPR-12108 + Resource utf8 = new ClassPathResource("/org/springframework/http/converter/logo.jpg") { + @Override + public String getFilename() { + return "Hall\u00F6le.jpg"; + } + }; + parts.add("utf8", utf8); + + Source xml = new StreamSource(new StringReader("")); + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(MediaType.TEXT_XML); + HttpEntity entity = new HttpEntity<>(xml, entityHeaders); + parts.add("xml", entity); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.write(parts, new MediaType("multipart", "form-data", StandardCharsets.UTF_8), outputMessage); + + final MediaType contentType = outputMessage.getHeaders().getContentType(); + // SPR-17030 + assertThat(contentType.getParameters().keySet(), Matchers.contains("charset", "boundary")); + + // see if Commons FileUpload can read what we wrote + FileItemFactory fileItemFactory = new DiskFileItemFactory(); + FileUpload fileUpload = new FileUpload(fileItemFactory); + RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); + List items = fileUpload.parseRequest(requestContext); + assertEquals(6, items.size()); + FileItem item = items.get(0); + assertTrue(item.isFormField()); + assertEquals("name 1", item.getFieldName()); + assertEquals("value 1", item.getString()); + + item = items.get(1); + assertTrue(item.isFormField()); + assertEquals("name 2", item.getFieldName()); + assertEquals("value 2+1", item.getString()); + + item = items.get(2); + assertTrue(item.isFormField()); + assertEquals("name 2", item.getFieldName()); + assertEquals("value 2+2", item.getString()); + + item = items.get(3); + assertFalse(item.isFormField()); + assertEquals("logo", item.getFieldName()); + assertEquals("logo.jpg", item.getName()); + assertEquals("image/jpeg", item.getContentType()); + assertEquals(logo.getFile().length(), item.getSize()); + + item = items.get(4); + assertFalse(item.isFormField()); + assertEquals("utf8", item.getFieldName()); + assertEquals("Hall\u00F6le.jpg", item.getName()); + assertEquals("image/jpeg", item.getContentType()); + assertEquals(logo.getFile().length(), item.getSize()); + + item = items.get(5); + assertEquals("xml", item.getFieldName()); + assertEquals("text/xml", item.getContentType()); + verify(outputMessage.getBody(), never()).close(); + } + + // SPR-13309 + + @Test + public void writeMultipartOrder() throws Exception { + MyBean myBean = new MyBean(); + myBean.setString("foo"); + + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("part1", myBean); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(MediaType.TEXT_XML); + HttpEntity entity = new HttpEntity<>(myBean, entityHeaders); + parts.add("part2", entity); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.setMultipartCharset(StandardCharsets.UTF_8); + this.converter.write(parts, new MediaType("multipart", "form-data", StandardCharsets.UTF_8), outputMessage); + + final MediaType contentType = outputMessage.getHeaders().getContentType(); + assertNotNull("No boundary found", contentType.getParameter("boundary")); + + // see if Commons FileUpload can read what we wrote + FileItemFactory fileItemFactory = new DiskFileItemFactory(); + FileUpload fileUpload = new FileUpload(fileItemFactory); + RequestContext requestContext = new MockHttpOutputMessageRequestContext(outputMessage); + List items = fileUpload.parseRequest(requestContext); + assertEquals(2, items.size()); + + FileItem item = items.get(0); + assertTrue(item.isFormField()); + assertEquals("part1", item.getFieldName()); + assertEquals("{\"string\":\"foo\"}", item.getString()); + + item = items.get(1); + assertTrue(item.isFormField()); + assertEquals("part2", item.getFieldName()); + + // With developer builds we get: foo + // But on CI server we get: foo + // So... we make a compromise: + assertThat(item.getString(), + allOf(startsWith("foo"))); + } + + + private static class MockHttpOutputMessageRequestContext implements RequestContext { + + private final MockHttpOutputMessage outputMessage; + + + private MockHttpOutputMessageRequestContext(MockHttpOutputMessage outputMessage) { + this.outputMessage = outputMessage; + } + + + @Override + public String getCharacterEncoding() { + MediaType type = this.outputMessage.getHeaders().getContentType(); + return (type != null && type.getCharset() != null ? type.getCharset().name() : null); + } + + @Override + public String getContentType() { + MediaType type = this.outputMessage.getHeaders().getContentType(); + return (type != null ? type.toString() : null); + } + + @Override + @Deprecated + public int getContentLength() { + return this.outputMessage.getBodyAsBytes().length; + } + + @Override + public InputStream getInputStream() throws IOException { + return new ByteArrayInputStream(this.outputMessage.getBodyAsBytes()); + } + } + + public static class MyBean { + + private String string; + + public String getString() { + return this.string; + } + + public void setString(String string) { + this.string = string; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/HttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/HttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e3623277fb4bc0ed78b2845aa9e2cc3d8f058edf --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/HttpMessageConverterTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; + +import org.junit.Test; + +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Test-case for AbstractHttpMessageConverter. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class HttpMessageConverterTests { + + + @Test + public void canRead() { + MediaType mediaType = new MediaType("foo", "bar"); + HttpMessageConverter converter = new MyHttpMessageConverter<>(mediaType); + + assertTrue(converter.canRead(MyType.class, mediaType)); + assertFalse(converter.canRead(MyType.class, new MediaType("foo", "*"))); + assertFalse(converter.canRead(MyType.class, MediaType.ALL)); + } + + @Test + public void canReadWithWildcardSubtype() { + MediaType mediaType = new MediaType("foo"); + HttpMessageConverter converter = new MyHttpMessageConverter<>(mediaType); + + assertTrue(converter.canRead(MyType.class, new MediaType("foo", "bar"))); + assertTrue(converter.canRead(MyType.class, new MediaType("foo", "*"))); + assertFalse(converter.canRead(MyType.class, MediaType.ALL)); + } + + @Test + public void canWrite() { + MediaType mediaType = new MediaType("foo", "bar"); + HttpMessageConverter converter = new MyHttpMessageConverter<>(mediaType); + + assertTrue(converter.canWrite(MyType.class, mediaType)); + assertTrue(converter.canWrite(MyType.class, new MediaType("foo", "*"))); + assertTrue(converter.canWrite(MyType.class, MediaType.ALL)); + } + + @Test + public void canWriteWithWildcardInSupportedSubtype() { + MediaType mediaType = new MediaType("foo"); + HttpMessageConverter converter = new MyHttpMessageConverter<>(mediaType); + + assertTrue(converter.canWrite(MyType.class, new MediaType("foo", "bar"))); + assertTrue(converter.canWrite(MyType.class, new MediaType("foo", "*"))); + assertTrue(converter.canWrite(MyType.class, MediaType.ALL)); + } + + + private static class MyHttpMessageConverter extends AbstractHttpMessageConverter { + + private MyHttpMessageConverter(MediaType supportedMediaType) { + super(supportedMediaType); + } + + @Override + protected boolean supports(Class clazz) { + return MyType.class.equals(clazz); + } + + @Override + protected T readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + fail("Not expected"); + return null; + } + + @Override + protected void writeInternal(T t, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + fail("Not expected"); + } + } + + private static class MyType { + + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/ObjectToStringHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/ObjectToStringHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1c56cab10a98f00995c398b10af3fdbb0545b8ee --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/ObjectToStringHttpMessageConverterTests.java @@ -0,0 +1,173 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.Date; +import java.util.Locale; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.convert.ConversionService; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static org.junit.Assert.*; + +/** + * Test cases for {@link ObjectToStringHttpMessageConverter} class. + * + * @author Dmitry Katsubo + * @author Rossen Stoyanchev + */ +public class ObjectToStringHttpMessageConverterTests { + + private ObjectToStringHttpMessageConverter converter; + + private MockHttpServletResponse servletResponse; + + private ServletServerHttpResponse response; + + + @Before + public void setup() { + ConversionService conversionService = new DefaultConversionService(); + this.converter = new ObjectToStringHttpMessageConverter(conversionService); + + this.servletResponse = new MockHttpServletResponse(); + this.response = new ServletServerHttpResponse(this.servletResponse); + } + + + @Test + public void canRead() { + assertFalse(this.converter.canRead(Math.class, null)); + assertFalse(this.converter.canRead(Resource.class, null)); + + assertTrue(this.converter.canRead(Locale.class, null)); + assertTrue(this.converter.canRead(BigInteger.class, null)); + + assertFalse(this.converter.canRead(BigInteger.class, MediaType.TEXT_HTML)); + assertFalse(this.converter.canRead(BigInteger.class, MediaType.TEXT_XML)); + assertFalse(this.converter.canRead(BigInteger.class, MediaType.APPLICATION_XML)); + } + + @Test + public void canWrite() { + assertFalse(this.converter.canWrite(Math.class, null)); + assertFalse(this.converter.canWrite(Resource.class, null)); + + assertTrue(this.converter.canWrite(Locale.class, null)); + assertTrue(this.converter.canWrite(Double.class, null)); + + assertFalse(this.converter.canWrite(BigInteger.class, MediaType.TEXT_HTML)); + assertFalse(this.converter.canWrite(BigInteger.class, MediaType.TEXT_XML)); + assertFalse(this.converter.canWrite(BigInteger.class, MediaType.APPLICATION_XML)); + + assertTrue(this.converter.canWrite(BigInteger.class, MediaType.valueOf("text/*"))); + } + + @Test + public void defaultCharset() throws IOException { + this.converter.write(Integer.valueOf(5), null, response); + + assertEquals("ISO-8859-1", servletResponse.getCharacterEncoding()); + } + + @Test + public void defaultCharsetModified() throws IOException { + ConversionService cs = new DefaultConversionService(); + ObjectToStringHttpMessageConverter converter = new ObjectToStringHttpMessageConverter(cs, StandardCharsets.UTF_16); + converter.write((byte) 31, null, this.response); + + assertEquals("UTF-16", this.servletResponse.getCharacterEncoding()); + } + + @Test + public void writeAcceptCharset() throws IOException { + this.converter.write(new Date(), null, this.response); + + assertNotNull(this.servletResponse.getHeader("Accept-Charset")); + } + + @Test + public void writeAcceptCharsetTurnedOff() throws IOException { + this.converter.setWriteAcceptCharset(false); + this.converter.write(new Date(), null, this.response); + + assertNull(this.servletResponse.getHeader("Accept-Charset")); + } + + @Test + public void read() throws IOException { + Short shortValue = Short.valueOf((short) 781); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContentType(MediaType.TEXT_PLAIN_VALUE); + request.setContent(shortValue.toString().getBytes(StringHttpMessageConverter.DEFAULT_CHARSET)); + assertEquals(shortValue, this.converter.read(Short.class, new ServletServerHttpRequest(request))); + + Float floatValue = Float.valueOf(123); + request = new MockHttpServletRequest(); + request.setContentType(MediaType.TEXT_PLAIN_VALUE); + request.setCharacterEncoding("UTF-16"); + request.setContent(floatValue.toString().getBytes("UTF-16")); + assertEquals(floatValue, this.converter.read(Float.class, new ServletServerHttpRequest(request))); + + Long longValue = Long.valueOf(55819182821331L); + request = new MockHttpServletRequest(); + request.setContentType(MediaType.TEXT_PLAIN_VALUE); + request.setCharacterEncoding("UTF-8"); + request.setContent(longValue.toString().getBytes("UTF-8")); + assertEquals(longValue, this.converter.read(Long.class, new ServletServerHttpRequest(request))); + } + + @Test + public void write() throws IOException { + this.converter.write((byte) -8, null, this.response); + + assertEquals("ISO-8859-1", this.servletResponse.getCharacterEncoding()); + assertTrue(this.servletResponse.getContentType().startsWith(MediaType.TEXT_PLAIN_VALUE)); + assertEquals(2, this.servletResponse.getContentLength()); + assertArrayEquals(new byte[] { '-', '8' }, this.servletResponse.getContentAsByteArray()); + } + + @Test + public void writeUtf16() throws IOException { + MediaType contentType = new MediaType("text", "plain", StandardCharsets.UTF_16); + this.converter.write(Integer.valueOf(958), contentType, this.response); + + assertEquals("UTF-16", this.servletResponse.getCharacterEncoding()); + assertTrue(this.servletResponse.getContentType().startsWith(MediaType.TEXT_PLAIN_VALUE)); + assertEquals(8, this.servletResponse.getContentLength()); + // First two bytes: byte order mark + assertArrayEquals(new byte[] { -2, -1, 0, '9', 0, '5', 0, '8' }, this.servletResponse.getContentAsByteArray()); + } + + @Test(expected = IllegalArgumentException.class) + public void testConversionServiceRequired() { + new ObjectToStringHttpMessageConverter(null); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/ResourceHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/ResourceHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5aef1ca4074ea45a9fe2a05e03bd6e2f8065e931 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/ResourceHttpMessageConverterTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; +import org.springframework.http.ContentDisposition; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.util.FileCopyUtils; + +import static org.hamcrest.core.Is.*; +import static org.hamcrest.core.IsInstanceOf.*; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.*; + +/** + * @author Arjen Poutsma + * @author Kazuki Shimizu + * @author Brian Clozel + */ +public class ResourceHttpMessageConverterTests { + + private final ResourceHttpMessageConverter converter = new ResourceHttpMessageConverter(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Test + public void canReadResource() { + assertTrue(converter.canRead(Resource.class, new MediaType("application", "octet-stream"))); + } + + @Test + public void canWriteResource() { + assertTrue(converter.canWrite(Resource.class, new MediaType("application", "octet-stream"))); + assertTrue(converter.canWrite(Resource.class, MediaType.ALL)); + } + + @Test + public void shouldReadImageResource() throws IOException { + byte[] body = FileCopyUtils.copyToByteArray(getClass().getResourceAsStream("logo.jpg")); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + inputMessage.getHeaders().setContentType(MediaType.IMAGE_JPEG); + inputMessage.getHeaders().setContentDisposition( + ContentDisposition.builder("attachment").filename("yourlogo.jpg").build()); + Resource actualResource = converter.read(Resource.class, inputMessage); + assertThat(FileCopyUtils.copyToByteArray(actualResource.getInputStream()), is(body)); + assertEquals("yourlogo.jpg", actualResource.getFilename()); + } + + @Test // SPR-13443 + public void shouldReadInputStreamResource() throws IOException { + try (InputStream body = getClass().getResourceAsStream("logo.jpg") ) { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + inputMessage.getHeaders().setContentType(MediaType.IMAGE_JPEG); + inputMessage.getHeaders().setContentDisposition( + ContentDisposition.builder("attachment").filename("yourlogo.jpg").build()); + Resource actualResource = converter.read(InputStreamResource.class, inputMessage); + assertThat(actualResource, instanceOf(InputStreamResource.class)); + assertThat(actualResource.getInputStream(), is(body)); + assertEquals("yourlogo.jpg", actualResource.getFilename()); + } + } + + @Test // SPR-14882 + public void shouldNotReadInputStreamResource() throws IOException { + ResourceHttpMessageConverter noStreamConverter = new ResourceHttpMessageConverter(false); + try (InputStream body = getClass().getResourceAsStream("logo.jpg") ) { + this.thrown.expect(HttpMessageNotReadableException.class); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + inputMessage.getHeaders().setContentType(MediaType.IMAGE_JPEG); + noStreamConverter.read(InputStreamResource.class, inputMessage); + } + } + + @Test + public void shouldWriteImageResource() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource body = new ClassPathResource("logo.jpg", getClass()); + converter.write(body, null, outputMessage); + + assertEquals("Invalid content-type", MediaType.IMAGE_JPEG, + outputMessage.getHeaders().getContentType()); + assertEquals("Invalid content-length", body.getFile().length(), outputMessage.getHeaders().getContentLength()); + } + + @Test // SPR-10848 + public void writeByteArrayNullMediaType() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + byte[] byteArray = {1, 2, 3}; + Resource body = new ByteArrayResource(byteArray); + converter.write(body, null, outputMessage); + + assertTrue(Arrays.equals(byteArray, outputMessage.getBodyAsBytes())); + } + + @Test // SPR-12999 + @SuppressWarnings("unchecked") + public void writeContentNotGettingInputStream() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = mock(Resource.class); + given(resource.getInputStream()).willThrow(FileNotFoundException.class); + converter.write(resource, MediaType.APPLICATION_OCTET_STREAM, outputMessage); + + assertEquals(0, outputMessage.getHeaders().getContentLength()); + } + + @Test // SPR-12999 + public void writeContentNotClosingInputStream() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = mock(Resource.class); + InputStream inputStream = mock(InputStream.class); + given(resource.getInputStream()).willReturn(inputStream); + given(inputStream.read(any())).willReturn(-1); + doThrow(new NullPointerException()).when(inputStream).close(); + converter.write(resource, MediaType.APPLICATION_OCTET_STREAM, outputMessage); + + assertEquals(0, outputMessage.getHeaders().getContentLength()); + } + + @Test // SPR-13620 + @SuppressWarnings("unchecked") + public void writeContentInputStreamThrowingNullPointerException() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource resource = mock(Resource.class); + InputStream in = mock(InputStream.class); + given(resource.getInputStream()).willReturn(in); + given(in.read(any())).willThrow(NullPointerException.class); + converter.write(resource, MediaType.APPLICATION_OCTET_STREAM, outputMessage); + + assertEquals(0, outputMessage.getHeaders().getContentLength()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/ResourceRegionHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/ResourceRegionHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b5556a19599686a4f41cbb55ab406a01c3517941 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/ResourceRegionHttpMessageConverterTests.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.ByteArrayInputStream; +import java.lang.reflect.Type; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.hamcrest.Matchers; +import org.junit.Test; +import org.mockito.BDDMockito; +import org.mockito.Mockito; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourceRegion; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRange; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.util.StringUtils; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +/** + * Test cases for {@link ResourceRegionHttpMessageConverter} class. + * + * @author Brian Clozel + */ +public class ResourceRegionHttpMessageConverterTests { + + private final ResourceRegionHttpMessageConverter converter = new ResourceRegionHttpMessageConverter(); + + @Test + public void canReadResource() { + assertFalse(converter.canRead(Resource.class, MediaType.APPLICATION_OCTET_STREAM)); + assertFalse(converter.canRead(Resource.class, MediaType.ALL)); + assertFalse(converter.canRead(List.class, MediaType.APPLICATION_OCTET_STREAM)); + assertFalse(converter.canRead(List.class, MediaType.ALL)); + } + + @Test + public void canWriteResource() { + assertTrue(converter.canWrite(ResourceRegion.class, null, MediaType.APPLICATION_OCTET_STREAM)); + assertTrue(converter.canWrite(ResourceRegion.class, null, MediaType.ALL)); + assertFalse(converter.canWrite(Object.class, null, MediaType.ALL)); + } + + @Test + public void canWriteResourceCollection() { + Type resourceRegionList = new ParameterizedTypeReference>() {}.getType(); + assertTrue(converter.canWrite(resourceRegionList, null, MediaType.APPLICATION_OCTET_STREAM)); + assertTrue(converter.canWrite(resourceRegionList, null, MediaType.ALL)); + + assertFalse(converter.canWrite(List.class, MediaType.APPLICATION_OCTET_STREAM)); + assertFalse(converter.canWrite(List.class, MediaType.ALL)); + Type resourceObjectList = new ParameterizedTypeReference>() {}.getType(); + assertFalse(converter.canWrite(resourceObjectList, null, MediaType.ALL)); + } + + @Test + public void shouldWritePartialContentByteRange() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource body = new ClassPathResource("byterangeresource.txt", getClass()); + ResourceRegion region = HttpRange.createByteRange(0, 5).toResourceRegion(body); + converter.write(region, MediaType.TEXT_PLAIN, outputMessage); + + HttpHeaders headers = outputMessage.getHeaders(); + assertThat(headers.getContentType(), is(MediaType.TEXT_PLAIN)); + assertThat(headers.getContentLength(), is(6L)); + assertThat(headers.get(HttpHeaders.CONTENT_RANGE).size(), is(1)); + assertThat(headers.get(HttpHeaders.CONTENT_RANGE).get(0), is("bytes 0-5/39")); + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8), is("Spring")); + } + + @Test + public void shouldWritePartialContentByteRangeNoEnd() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource body = new ClassPathResource("byterangeresource.txt", getClass()); + ResourceRegion region = HttpRange.createByteRange(7).toResourceRegion(body); + converter.write(region, MediaType.TEXT_PLAIN, outputMessage); + + HttpHeaders headers = outputMessage.getHeaders(); + assertThat(headers.getContentType(), is(MediaType.TEXT_PLAIN)); + assertThat(headers.getContentLength(), is(32L)); + assertThat(headers.get(HttpHeaders.CONTENT_RANGE).size(), is(1)); + assertThat(headers.get(HttpHeaders.CONTENT_RANGE).get(0), is("bytes 7-38/39")); + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8), is("Framework test resource content.")); + } + + @Test + public void partialContentMultipleByteRanges() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Resource body = new ClassPathResource("byterangeresource.txt", getClass()); + List rangeList = HttpRange.parseRanges("bytes=0-5,7-15,17-20,22-38"); + List regions = new ArrayList<>(); + for(HttpRange range : rangeList) { + regions.add(range.toResourceRegion(body)); + } + + converter.write(regions, MediaType.TEXT_PLAIN, outputMessage); + + HttpHeaders headers = outputMessage.getHeaders(); + assertThat(headers.getContentType().toString(), Matchers.startsWith("multipart/byteranges;boundary=")); + String boundary = "--" + headers.getContentType().toString().substring(30); + String content = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + String[] ranges = StringUtils.tokenizeToStringArray(content, "\r\n", false, true); + + assertThat(ranges[0], is(boundary)); + assertThat(ranges[1], is("Content-Type: text/plain")); + assertThat(ranges[2], is("Content-Range: bytes 0-5/39")); + assertThat(ranges[3], is("Spring")); + + assertThat(ranges[4], is(boundary)); + assertThat(ranges[5], is("Content-Type: text/plain")); + assertThat(ranges[6], is("Content-Range: bytes 7-15/39")); + assertThat(ranges[7], is("Framework")); + + assertThat(ranges[8], is(boundary)); + assertThat(ranges[9], is("Content-Type: text/plain")); + assertThat(ranges[10], is("Content-Range: bytes 17-20/39")); + assertThat(ranges[11], is("test")); + + assertThat(ranges[12], is(boundary)); + assertThat(ranges[13], is("Content-Type: text/plain")); + assertThat(ranges[14], is("Content-Range: bytes 22-38/39")); + assertThat(ranges[15], is("resource content.")); + } + + @Test // SPR-15041 + public void applicationOctetStreamDefaultContentType() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + ClassPathResource body = Mockito.mock(ClassPathResource.class); + BDDMockito.given(body.getFilename()).willReturn("spring.dat"); + BDDMockito.given(body.contentLength()).willReturn(12L); + BDDMockito.given(body.getInputStream()).willReturn(new ByteArrayInputStream("Spring Framework".getBytes())); + HttpRange range = HttpRange.createByteRange(0, 5); + ResourceRegion resourceRegion = range.toResourceRegion(body); + + converter.write(Collections.singletonList(resourceRegion), null, outputMessage); + + assertThat(outputMessage.getHeaders().getContentType(), is(MediaType.APPLICATION_OCTET_STREAM)); + assertThat(outputMessage.getHeaders().getFirst(HttpHeaders.CONTENT_RANGE), is("bytes 0-5/12")); + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8), is("Spring")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/StringHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/StringHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4a4803b36819527e67ad66835df296c9c1e9d721 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/StringHttpMessageConverterTests.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class StringHttpMessageConverterTests { + + public static final MediaType TEXT_PLAIN_UTF_8 = new MediaType("text", "plain", StandardCharsets.UTF_8); + + private StringHttpMessageConverter converter; + + private MockHttpOutputMessage outputMessage; + + + @Before + public void setUp() { + this.converter = new StringHttpMessageConverter(); + this.outputMessage = new MockHttpOutputMessage(); + } + + + @Test + public void canRead() { + assertTrue(this.converter.canRead(String.class, MediaType.TEXT_PLAIN)); + } + + @Test + public void canWrite() { + assertTrue(this.converter.canWrite(String.class, MediaType.TEXT_PLAIN)); + assertTrue(this.converter.canWrite(String.class, MediaType.ALL)); + } + + @Test + public void read() throws IOException { + String body = "Hello World"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(TEXT_PLAIN_UTF_8); + String result = this.converter.read(String.class, inputMessage); + + assertEquals("Invalid result", body, result); + } + + @Test + public void writeDefaultCharset() throws IOException { + String body = "H\u00e9llo W\u00f6rld"; + this.converter.write(body, null, this.outputMessage); + + HttpHeaders headers = this.outputMessage.getHeaders(); + assertEquals(body, this.outputMessage.getBodyAsString(StandardCharsets.ISO_8859_1)); + assertEquals(new MediaType("text", "plain", StandardCharsets.ISO_8859_1), headers.getContentType()); + assertEquals(body.getBytes(StandardCharsets.ISO_8859_1).length, headers.getContentLength()); + assertFalse(headers.getAcceptCharset().isEmpty()); + } + + @Test + public void writeUTF8() throws IOException { + String body = "H\u00e9llo W\u00f6rld"; + this.converter.write(body, TEXT_PLAIN_UTF_8, this.outputMessage); + + HttpHeaders headers = this.outputMessage.getHeaders(); + assertEquals(body, this.outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + assertEquals(TEXT_PLAIN_UTF_8, headers.getContentType()); + assertEquals(body.getBytes(StandardCharsets.UTF_8).length, headers.getContentLength()); + assertFalse(headers.getAcceptCharset().isEmpty()); + } + + @Test // SPR-8867 + public void writeOverrideRequestedContentType() throws IOException { + String body = "H\u00e9llo W\u00f6rld"; + MediaType requestedContentType = new MediaType("text", "html"); + + HttpHeaders headers = this.outputMessage.getHeaders(); + headers.setContentType(TEXT_PLAIN_UTF_8); + this.converter.write(body, requestedContentType, this.outputMessage); + + assertEquals(body, this.outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + assertEquals(TEXT_PLAIN_UTF_8, headers.getContentType()); + assertEquals(body.getBytes(StandardCharsets.UTF_8).length, headers.getContentLength()); + assertFalse(headers.getAcceptCharset().isEmpty()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/feed/AtomFeedHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/feed/AtomFeedHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d412524f45c6c19dbebfa3103bb072d780b639ee --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/feed/AtomFeedHttpMessageConverterTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.feed; + +import static org.junit.Assert.*; +import static org.xmlunit.matchers.CompareMatcher.*; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import com.rometools.rome.feed.atom.Entry; +import com.rometools.rome.feed.atom.Feed; +import org.junit.Before; +import org.junit.Test; +import org.xml.sax.SAXException; +import org.xmlunit.diff.DefaultNodeMatcher; +import org.xmlunit.diff.ElementSelectors; +import org.xmlunit.diff.NodeMatcher; + +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; + +/** + * @author Arjen Poutsma + */ +public class AtomFeedHttpMessageConverterTests { + + private AtomFeedHttpMessageConverter converter; + + + @Before + public void setUp() { + converter = new AtomFeedHttpMessageConverter(); + } + + + @Test + public void canRead() { + assertTrue(converter.canRead(Feed.class, new MediaType("application", "atom+xml"))); + assertTrue(converter.canRead(Feed.class, new MediaType("application", "atom+xml", StandardCharsets.UTF_8))); + } + + @Test + public void canWrite() { + assertTrue(converter.canWrite(Feed.class, new MediaType("application", "atom+xml"))); + assertTrue(converter.canWrite(Feed.class, new MediaType("application", "atom+xml", StandardCharsets.UTF_8))); + } + + @Test + public void read() throws IOException { + InputStream is = getClass().getResourceAsStream("atom.xml"); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(is); + inputMessage.getHeaders().setContentType(new MediaType("application", "atom+xml", StandardCharsets.UTF_8)); + Feed result = converter.read(Feed.class, inputMessage); + assertEquals("title", result.getTitle()); + assertEquals("subtitle", result.getSubtitle().getValue()); + List entries = result.getEntries(); + assertEquals(2, entries.size()); + + Entry entry1 = (Entry) entries.get(0); + assertEquals("id1", entry1.getId()); + assertEquals("title1", entry1.getTitle()); + + Entry entry2 = (Entry) entries.get(1); + assertEquals("id2", entry2.getId()); + assertEquals("title2", entry2.getTitle()); + } + + @Test + public void write() throws IOException, SAXException { + Feed feed = new Feed("atom_1.0"); + feed.setTitle("title"); + + Entry entry1 = new Entry(); + entry1.setId("id1"); + entry1.setTitle("title1"); + + Entry entry2 = new Entry(); + entry2.setId("id2"); + entry2.setTitle("title2"); + + List entries = new ArrayList<>(2); + entries.add(entry1); + entries.add(entry2); + feed.setEntries(entries); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(feed, null, outputMessage); + + assertEquals("Invalid content-type", new MediaType("application", "atom+xml", StandardCharsets.UTF_8), + outputMessage.getHeaders().getContentType()); + String expected = "" + "title" + + "id1title1" + + "id2title2"; + NodeMatcher nm = new DefaultNodeMatcher(ElementSelectors.byName); + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8), + isSimilarTo(expected).ignoreWhitespace().withNodeMatcher(nm)); + } + + @Test + public void writeOtherCharset() throws IOException, SAXException { + Feed feed = new Feed("atom_1.0"); + feed.setTitle("title"); + String encoding = "ISO-8859-1"; + feed.setEncoding(encoding); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(feed, null, outputMessage); + + assertEquals("Invalid content-type", new MediaType("application", "atom+xml", Charset.forName(encoding)), + outputMessage.getHeaders().getContentType()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/feed/RssChannelHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/feed/RssChannelHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..aa674b4d80bb08931b7ed008b071e7e5bb1d97f2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/feed/RssChannelHttpMessageConverterTests.java @@ -0,0 +1,140 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.feed; + +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import com.rometools.rome.feed.rss.Channel; +import com.rometools.rome.feed.rss.Item; +import org.junit.Before; +import org.junit.Test; +import org.xml.sax.SAXException; +import org.xmlunit.matchers.CompareMatcher; + +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; + +/** + * @author Arjen Poutsma + */ +public class RssChannelHttpMessageConverterTests { + + private RssChannelHttpMessageConverter converter; + + + @Before + public void setUp() { + converter = new RssChannelHttpMessageConverter(); + } + + + @Test + public void canRead() { + assertTrue(converter.canRead(Channel.class, new MediaType("application", "rss+xml"))); + assertTrue(converter.canRead(Channel.class, new MediaType("application", "rss+xml", StandardCharsets.UTF_8))); + } + + @Test + public void canWrite() { + assertTrue(converter.canWrite(Channel.class, new MediaType("application", "rss+xml"))); + assertTrue(converter.canWrite(Channel.class, new MediaType("application", "rss+xml", StandardCharsets.UTF_8))); + } + + @Test + public void read() throws IOException { + InputStream is = getClass().getResourceAsStream("rss.xml"); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(is); + inputMessage.getHeaders().setContentType(new MediaType("application", "rss+xml", StandardCharsets.UTF_8)); + Channel result = converter.read(Channel.class, inputMessage); + assertEquals("title", result.getTitle()); + assertEquals("https://example.com", result.getLink()); + assertEquals("description", result.getDescription()); + + List items = result.getItems(); + assertEquals(2, items.size()); + + Item item1 = (Item) items.get(0); + assertEquals("title1", item1.getTitle()); + + Item item2 = (Item) items.get(1); + assertEquals("title2", item2.getTitle()); + } + + @Test + public void write() throws IOException, SAXException { + Channel channel = new Channel("rss_2.0"); + channel.setTitle("title"); + channel.setLink("https://example.com"); + channel.setDescription("description"); + + Item item1 = new Item(); + item1.setTitle("title1"); + + Item item2 = new Item(); + item2.setTitle("title2"); + + List items = new ArrayList<>(2); + items.add(item1); + items.add(item2); + channel.setItems(items); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(channel, null, outputMessage); + + assertEquals("Invalid content-type", new MediaType("application", "rss+xml", StandardCharsets.UTF_8), + outputMessage.getHeaders().getContentType()); + String expected = "" + + "titlehttps://example.comdescription" + + "title1" + + "title2" + + ""; + assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8), isSimilarTo(expected)); + } + + @Test + public void writeOtherCharset() throws IOException, SAXException { + Channel channel = new Channel("rss_2.0"); + channel.setTitle("title"); + channel.setLink("https://example.com"); + channel.setDescription("description"); + + String encoding = "ISO-8859-1"; + channel.setEncoding(encoding); + + Item item1 = new Item(); + item1.setTitle("title1"); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(channel, null, outputMessage); + + assertEquals("Invalid content-type", new MediaType("application", "rss+xml", Charset.forName(encoding)), + outputMessage.getHeaders().getContentType()); + } + + private static CompareMatcher isSimilarTo(final String content) { + return CompareMatcher.isSimilarTo(content) + .ignoreWhitespace(); + } +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/json/GsonFactoryBeanTests.java b/spring-web/src/test/java/org/springframework/http/converter/json/GsonFactoryBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d1c8fda1d2bf08e04561a79a617dbdd057acd3db --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/json/GsonFactoryBeanTests.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.util.Calendar; +import java.util.Date; + +import com.google.gson.Gson; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * {@link GsonFactoryBean} tests. + * + * @author Roy Clarkson + */ +public class GsonFactoryBeanTests { + + private static final String DATE_FORMAT = "yyyy-MM-dd"; + + private GsonFactoryBean factory = new GsonFactoryBean(); + + + @Test + public void serializeNulls() throws Exception { + this.factory.setSerializeNulls(true); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + StringBean bean = new StringBean(); + String result = gson.toJson(bean); + assertEquals("{\"name\":null}", result); + } + + @Test + public void serializeNullsFalse() throws Exception { + this.factory.setSerializeNulls(false); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + StringBean bean = new StringBean(); + String result = gson.toJson(bean); + assertEquals("{}", result); + } + + @Test + public void prettyPrinting() throws Exception { + this.factory.setPrettyPrinting(true); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + StringBean bean = new StringBean(); + bean.setName("Jason"); + String result = gson.toJson(bean); + assertTrue(result.contains(" \"name\": \"Jason\"")); + } + + @Test + public void prettyPrintingFalse() throws Exception { + this.factory.setPrettyPrinting(false); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + StringBean bean = new StringBean(); + bean.setName("Jason"); + String result = gson.toJson(bean); + assertEquals("{\"name\":\"Jason\"}", result); + } + + @Test + public void disableHtmlEscaping() throws Exception { + this.factory.setDisableHtmlEscaping(true); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + StringBean bean = new StringBean(); + bean.setName("Bob=Bob"); + String result = gson.toJson(bean); + assertEquals("{\"name\":\"Bob=Bob\"}", result); + } + + @Test + public void disableHtmlEscapingFalse() throws Exception { + this.factory.setDisableHtmlEscaping(false); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + StringBean bean = new StringBean(); + bean.setName("Bob=Bob"); + String result = gson.toJson(bean); + assertEquals("{\"name\":\"Bob\\u003dBob\"}", result); + } + + @Test + public void customizeDateFormatPattern() throws Exception { + this.factory.setDateFormatPattern(DATE_FORMAT); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + DateBean bean = new DateBean(); + Calendar cal = Calendar.getInstance(); + cal.clear(); + cal.set(Calendar.YEAR, 2014); + cal.set(Calendar.MONTH, Calendar.JANUARY); + cal.set(Calendar.DATE, 1); + Date date = cal.getTime(); + bean.setDate(date); + String result = gson.toJson(bean); + assertEquals("{\"date\":\"2014-01-01\"}", result); + } + + @Test + public void customizeDateFormatNone() throws Exception { + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + DateBean bean = new DateBean(); + Calendar cal = Calendar.getInstance(); + cal.clear(); + cal.set(Calendar.YEAR, 2014); + cal.set(Calendar.MONTH, Calendar.JANUARY); + cal.set(Calendar.DATE, 1); + Date date = cal.getTime(); + bean.setDate(date); + String result = gson.toJson(bean); + assertTrue(result.startsWith("{\"date\":\"Jan 1, 2014")); + assertTrue(result.endsWith("12:00:00 AM\"}")); + } + + @Test + public void base64EncodeByteArrays() throws Exception { + this.factory.setBase64EncodeByteArrays(true); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + ByteArrayBean bean = new ByteArrayBean(); + bean.setBytes(new byte[] {0x1, 0x2}); + String result = gson.toJson(bean); + assertEquals("{\"bytes\":\"AQI\\u003d\"}", result); + } + + @Test + public void base64EncodeByteArraysDisableHtmlEscaping() throws Exception { + this.factory.setBase64EncodeByteArrays(true); + this.factory.setDisableHtmlEscaping(true); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + ByteArrayBean bean = new ByteArrayBean(); + bean.setBytes(new byte[] {0x1, 0x2}); + String result = gson.toJson(bean); + assertEquals("{\"bytes\":\"AQI=\"}", result); + } + + @Test + public void base64EncodeByteArraysFalse() throws Exception { + this.factory.setBase64EncodeByteArrays(false); + this.factory.afterPropertiesSet(); + Gson gson = this.factory.getObject(); + ByteArrayBean bean = new ByteArrayBean(); + bean.setBytes(new byte[] {0x1, 0x2}); + String result = gson.toJson(bean); + assertEquals("{\"bytes\":[1,2]}", result); + } + + + private static class StringBean { + + private String name; + + @SuppressWarnings("unused") + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + } + + + private static class DateBean { + + private Date date; + + @SuppressWarnings("unused") + public Date getDate() { + return this.date; + } + + public void setDate(Date date) { + this.date = date; + } + } + + + private static class ByteArrayBean { + + private byte[] bytes; + + @SuppressWarnings("unused") + public byte[] getBytes() { + return this.bytes; + } + + public void setBytes(byte[] bytes) { + this.bytes = bytes; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/json/GsonHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/json/GsonHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2dfb358a52b9f6b2863adfddd6a36ffd53135583 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/json/GsonHttpMessageConverterTests.java @@ -0,0 +1,347 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Type; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.skyscreamer.jsonassert.JSONAssert; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; + +import static org.junit.Assert.*; + +/** + * Gson 2.x converter tests. + * + * @author Roy Clarkson + * @author Juergen Hoeller + */ +public class GsonHttpMessageConverterTests { + + private final GsonHttpMessageConverter converter = new GsonHttpMessageConverter(); + + + @Test + public void canRead() { + assertTrue(this.converter.canRead(MyBean.class, new MediaType("application", "json"))); + assertTrue(this.converter.canRead(Map.class, new MediaType("application", "json"))); + } + + @Test + public void canWrite() { + assertTrue(this.converter.canWrite(MyBean.class, new MediaType("application", "json"))); + assertTrue(this.converter.canWrite(Map.class, new MediaType("application", "json"))); + } + + @Test + public void canReadAndWriteMicroformats() { + assertTrue(this.converter.canRead(MyBean.class, new MediaType("application", "vnd.test-micro-type+json"))); + assertTrue(this.converter.canWrite(MyBean.class, new MediaType("application", "vnd.test-micro-type+json"))); + } + + @Test + public void readTyped() throws IOException { + String body = "{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + MyBean result = (MyBean) this.converter.read(MyBean.class, inputMessage); + + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + } + + @Test + @SuppressWarnings("unchecked") + public void readUntyped() throws IOException { + String body = "{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + HashMap result = (HashMap) this.converter.read(HashMap.class, inputMessage); + assertEquals("Foo", result.get("string")); + Number n = (Number) result.get("number"); + assertEquals(42, n.longValue()); + n = (Number) result.get("fraction"); + assertEquals(42D, n.doubleValue(), 0D); + List array = new ArrayList<>(); + array.add("Foo"); + array.add("Bar"); + assertEquals(array, result.get("array")); + assertEquals(Boolean.TRUE, result.get("bool")); + byte[] bytes = new byte[2]; + List resultBytes = (ArrayList)result.get("bytes"); + for (int i = 0; i < 2; i++) { + bytes[i] = resultBytes.get(i).byteValue(); + } + assertArrayEquals(new byte[] {0x1, 0x2}, bytes); + } + + @Test + public void write() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[] {"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[] {0x1, 0x2}); + this.converter.write(body, null, outputMessage); + Charset utf8 = StandardCharsets.UTF_8; + String result = outputMessage.getBodyAsString(utf8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + assertTrue(result.contains("fraction\":42.0")); + assertTrue(result.contains("\"array\":[\"Foo\",\"Bar\"]")); + assertTrue(result.contains("\"bool\":true")); + assertTrue(result.contains("\"bytes\":[1,2]")); + assertEquals("Invalid content-type", new MediaType("application", "json", utf8), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeWithBaseType() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[] {"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[] {0x1, 0x2}); + this.converter.write(body, MyBase.class, null, outputMessage); + Charset utf8 = StandardCharsets.UTF_8; + String result = outputMessage.getBodyAsString(utf8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + assertTrue(result.contains("fraction\":42.0")); + assertTrue(result.contains("\"array\":[\"Foo\",\"Bar\"]")); + assertTrue(result.contains("\"bool\":true")); + assertTrue(result.contains("\"bytes\":[1,2]")); + assertEquals("Invalid content-type", new MediaType("application", "json", utf8), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeUTF16() throws IOException { + MediaType contentType = new MediaType("application", "json", StandardCharsets.UTF_16BE); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + String body = "H\u00e9llo W\u00f6rld"; + this.converter.write(body, contentType, outputMessage); + assertEquals("Invalid result", "\"" + body + "\"", outputMessage.getBodyAsString(StandardCharsets.UTF_16BE)); + assertEquals("Invalid content-type", contentType, outputMessage.getHeaders().getContentType()); + } + + @Test(expected = HttpMessageNotReadableException.class) + public void readInvalidJson() throws IOException { + String body = "FooBar"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + this.converter.read(MyBean.class, inputMessage); + } + + @Test + @SuppressWarnings("unchecked") + public void readAndWriteGenerics() throws Exception { + Field beansList = ListHolder.class.getField("listField"); + + String body = "[{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + Type genericType = beansList.getGenericType(); + List results = (List) converter.read(genericType, MyBeanListHolder.class, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, genericType, new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + @SuppressWarnings("unchecked") + public void readAndWriteParameterizedType() throws Exception { + ParameterizedTypeReference> beansList = new ParameterizedTypeReference>() { + }; + + String body = "[{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + List results = (List) converter.read(beansList.getType(), null, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, beansList.getType(), new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + @SuppressWarnings("unchecked") + public void writeParameterizedBaseType() throws Exception { + ParameterizedTypeReference> beansList = new ParameterizedTypeReference>() {}; + ParameterizedTypeReference> baseList = new ParameterizedTypeReference>() {}; + + String body = "[{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + List results = (List) converter.read(beansList.getType(), null, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, baseList.getType(), new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + public void prefixJson() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.setPrefixJson(true); + this.converter.writeInternal("foo", null, outputMessage); + assertEquals(")]}', \"foo\"", outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + } + + @Test + public void prefixJsonCustom() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.setJsonPrefix(")))"); + this.converter.writeInternal("foo", null, outputMessage); + assertEquals(")))\"foo\"", outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + } + + + public static class MyBase { + + private String string; + + public String getString() { + return string; + } + + public void setString(String string) { + this.string = string; + } + } + + + public static class MyBean extends MyBase { + + private int number; + + private float fraction; + + private String[] array; + + private boolean bool; + + private byte[] bytes; + + public int getNumber() { + return number; + } + + public void setNumber(int number) { + this.number = number; + } + + public float getFraction() { + return fraction; + } + + public void setFraction(float fraction) { + this.fraction = fraction; + } + + public String[] getArray() { + return array; + } + + public void setArray(String[] array) { + this.array = array; + } + + public boolean isBool() { + return bool; + } + + public void setBool(boolean bool) { + this.bool = bool; + } + + public byte[] getBytes() { + return bytes; + } + + public void setBytes(byte[] bytes) { + this.bytes = bytes; + } + } + + + public static class ListHolder { + + public List listField; + } + + + public static class MyBeanListHolder extends ListHolder { + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/json/Jackson2ObjectMapperBuilderTests.java b/spring-web/src/test/java/org/springframework/http/converter/json/Jackson2ObjectMapperBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cc38beb839e838d694099846bd24e2fdeffd56ca --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/json/Jackson2ObjectMapperBuilderTests.java @@ -0,0 +1,723 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.text.SimpleDateFormat; +import java.time.OffsetDateTime; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.TimeZone; +import java.util.stream.StreamSupport; + +import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility; +import com.fasterxml.jackson.annotation.JsonFilter; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.PropertyAccessor; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.Version; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.Module; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.cfg.DeserializerFactoryConfig; +import com.fasterxml.jackson.databind.cfg.SerializerFactoryConfig; +import com.fasterxml.jackson.databind.deser.BasicDeserializerFactory; +import com.fasterxml.jackson.databind.deser.Deserializers; +import com.fasterxml.jackson.databind.deser.std.DateDeserializers; +import com.fasterxml.jackson.databind.introspect.NopAnnotationIntrospector; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.module.SimpleSerializers; +import com.fasterxml.jackson.databind.ser.BasicSerializerFactory; +import com.fasterxml.jackson.databind.ser.Serializers; +import com.fasterxml.jackson.databind.ser.impl.SimpleBeanPropertyFilter; +import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider; +import com.fasterxml.jackson.databind.ser.std.ClassSerializer; +import com.fasterxml.jackson.databind.ser.std.NumberSerializer; +import com.fasterxml.jackson.databind.type.SimpleType; +import com.fasterxml.jackson.dataformat.cbor.CBORFactory; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import com.fasterxml.jackson.dataformat.xml.XmlFactory; +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import kotlin.ranges.IntRange; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.junit.Test; + +import org.springframework.beans.FatalBeanException; +import org.springframework.util.StringUtils; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + +/** + * Test class for {@link Jackson2ObjectMapperBuilder}. + * + * @author Sebastien Deleuze + * @author Eddú Meléndez + */ +@SuppressWarnings("deprecation") +public class Jackson2ObjectMapperBuilderTests { + + private static final String DATE_FORMAT = "yyyy-MM-dd"; + + private static final String DATA = "{\"offsetDateTime\": \"2020-01-01T00:00:00\"}"; + + + @Test(expected = FatalBeanException.class) + public void unknownFeature() { + Jackson2ObjectMapperBuilder.json().featuresToEnable(Boolean.TRUE).build(); + } + + @Test + public void defaultProperties() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().build(); + assertNotNull(objectMapper); + assertFalse(objectMapper.isEnabled(MapperFeature.DEFAULT_VIEW_INCLUSION)); + assertFalse(objectMapper.isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)); + assertTrue(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_FIELDS)); + assertTrue(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_GETTERS)); + assertTrue(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_IS_GETTERS)); + assertTrue(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_SETTERS)); + assertFalse(objectMapper.isEnabled(SerializationFeature.INDENT_OUTPUT)); + assertTrue(objectMapper.isEnabled(SerializationFeature.FAIL_ON_EMPTY_BEANS)); + } + + @Test + public void propertiesShortcut() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().autoDetectFields(false) + .defaultViewInclusion(true).failOnUnknownProperties(true).failOnEmptyBeans(false) + .autoDetectGettersSetters(false).indentOutput(true).build(); + assertNotNull(objectMapper); + assertTrue(objectMapper.isEnabled(MapperFeature.DEFAULT_VIEW_INCLUSION)); + assertTrue(objectMapper.isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)); + assertFalse(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_FIELDS)); + assertFalse(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_GETTERS)); + assertFalse(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_IS_GETTERS)); + assertFalse(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_SETTERS)); + assertTrue(objectMapper.isEnabled(SerializationFeature.INDENT_OUTPUT)); + assertFalse(objectMapper.isEnabled(SerializationFeature.FAIL_ON_EMPTY_BEANS)); + } + + @Test + public void booleanSetters() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .featuresToEnable(MapperFeature.DEFAULT_VIEW_INCLUSION, + DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, + SerializationFeature.INDENT_OUTPUT) + .featuresToDisable(MapperFeature.AUTO_DETECT_FIELDS, + MapperFeature.AUTO_DETECT_GETTERS, + MapperFeature.AUTO_DETECT_SETTERS, + SerializationFeature.FAIL_ON_EMPTY_BEANS).build(); + assertNotNull(objectMapper); + assertTrue(objectMapper.isEnabled(MapperFeature.DEFAULT_VIEW_INCLUSION)); + assertTrue(objectMapper.isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)); + assertFalse(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_FIELDS)); + assertFalse(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_GETTERS)); + assertFalse(objectMapper.isEnabled(MapperFeature.AUTO_DETECT_SETTERS)); + assertTrue(objectMapper.isEnabled(SerializationFeature.INDENT_OUTPUT)); + assertFalse(objectMapper.isEnabled(SerializationFeature.FAIL_ON_EMPTY_BEANS)); + } + + @Test + public void setNotNullSerializationInclusion() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().build(); + assertSame(JsonInclude.Include.ALWAYS, objectMapper.getSerializationConfig().getSerializationInclusion()); + objectMapper = Jackson2ObjectMapperBuilder.json().serializationInclusion(JsonInclude.Include.NON_NULL).build(); + assertSame(JsonInclude.Include.NON_NULL, objectMapper.getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void setNotDefaultSerializationInclusion() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().build(); + assertSame(JsonInclude.Include.ALWAYS, objectMapper.getSerializationConfig().getSerializationInclusion()); + objectMapper = Jackson2ObjectMapperBuilder.json().serializationInclusion(JsonInclude.Include.NON_DEFAULT).build(); + assertSame(JsonInclude.Include.NON_DEFAULT, objectMapper.getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void setNotEmptySerializationInclusion() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().build(); + assertSame(JsonInclude.Include.ALWAYS, objectMapper.getSerializationConfig().getSerializationInclusion()); + objectMapper = Jackson2ObjectMapperBuilder.json().serializationInclusion(JsonInclude.Include.NON_EMPTY).build(); + assertSame(JsonInclude.Include.NON_EMPTY, objectMapper.getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void dateTimeFormatSetter() { + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().dateFormat(dateFormat).build(); + assertEquals(dateFormat, objectMapper.getSerializationConfig().getDateFormat()); + assertEquals(dateFormat, objectMapper.getDeserializationConfig().getDateFormat()); + } + + @Test + public void simpleDateFormatStringSetter() { + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().simpleDateFormat(DATE_FORMAT).build(); + assertEquals(dateFormat, objectMapper.getSerializationConfig().getDateFormat()); + assertEquals(dateFormat, objectMapper.getDeserializationConfig().getDateFormat()); + } + + @Test + public void localeSetter() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().locale(Locale.FRENCH).build(); + assertEquals(Locale.FRENCH, objectMapper.getSerializationConfig().getLocale()); + assertEquals(Locale.FRENCH, objectMapper.getDeserializationConfig().getLocale()); + } + + @Test + public void timeZoneSetter() { + TimeZone timeZone = TimeZone.getTimeZone("Europe/Paris"); + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().timeZone(timeZone).build(); + assertEquals(timeZone, objectMapper.getSerializationConfig().getTimeZone()); + assertEquals(timeZone, objectMapper.getDeserializationConfig().getTimeZone()); + } + + @Test + public void timeZoneStringSetter() { + String zoneId = "Europe/Paris"; + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().timeZone(zoneId).build(); + TimeZone timeZone = TimeZone.getTimeZone(zoneId); + assertEquals(timeZone, objectMapper.getSerializationConfig().getTimeZone()); + assertEquals(timeZone, objectMapper.getDeserializationConfig().getTimeZone()); + } + + @Test(expected = IllegalArgumentException.class) + public void wrongTimeZoneStringSetter() { + String zoneId = "foo"; + Jackson2ObjectMapperBuilder.json().timeZone(zoneId).build(); + } + + @Test + public void modules() { + NumberSerializer serializer1 = new NumberSerializer(Integer.class); + SimpleModule module = new SimpleModule(); + module.addSerializer(Integer.class, serializer1); + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().modules(module).build(); + Serializers serializers = getSerializerFactoryConfig(objectMapper).serializers().iterator().next(); + assertSame(serializer1, serializers.findSerializer(null, SimpleType.construct(Integer.class), null)); + } + + @Test + @SuppressWarnings("unchecked") + public void modulesToInstallByClass() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modulesToInstall(CustomIntegerModule.class) + .build(); + Serializers serializers = getSerializerFactoryConfig(objectMapper).serializers().iterator().next(); + assertSame(CustomIntegerSerializer.class, + serializers.findSerializer(null, SimpleType.construct(Integer.class), null).getClass()); + } + + @Test + public void modulesToInstallByInstance() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modulesToInstall(new CustomIntegerModule()) + .build(); + Serializers serializers = getSerializerFactoryConfig(objectMapper).serializers().iterator().next(); + assertSame(CustomIntegerSerializer.class, + serializers.findSerializer(null, SimpleType.construct(Integer.class), null).getClass()); + } + + @Test + public void wellKnownModules() throws JsonProcessingException, UnsupportedEncodingException { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().build(); + + Long timestamp = 1322903730000L; + DateTime dateTime = new DateTime(timestamp, DateTimeZone.UTC); + assertEquals(timestamp.toString(), new String(objectMapper.writeValueAsBytes(dateTime), "UTF-8")); + + Path file = Paths.get("foo"); + assertTrue(new String(objectMapper.writeValueAsBytes(file), "UTF-8").endsWith("foo\"")); + + Optional optional = Optional.of("test"); + assertEquals("\"test\"", new String(objectMapper.writeValueAsBytes(optional), "UTF-8")); + + // Kotlin module + IntRange range = new IntRange(1, 3); + assertEquals("{\"start\":1,\"end\":3}", new String(objectMapper.writeValueAsBytes(range), "UTF-8")); + } + + @Test // SPR-12634 + public void customizeWellKnownModulesWithModule() + throws JsonProcessingException, UnsupportedEncodingException { + + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modulesToInstall(new CustomIntegerModule()) + .build(); + DateTime dateTime = new DateTime(1322903730000L, DateTimeZone.UTC); + assertEquals("1322903730000", new String(objectMapper.writeValueAsBytes(dateTime), "UTF-8")); + assertThat(new String(objectMapper.writeValueAsBytes(new Integer(4)), "UTF-8"), containsString("customid")); + } + + @Test // SPR-12634 + @SuppressWarnings("unchecked") + public void customizeWellKnownModulesWithModuleClass() + throws JsonProcessingException, UnsupportedEncodingException { + + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modulesToInstall(CustomIntegerModule.class) + .build(); + DateTime dateTime = new DateTime(1322903730000L, DateTimeZone.UTC); + assertEquals("1322903730000", new String(objectMapper.writeValueAsBytes(dateTime), "UTF-8")); + assertThat(new String(objectMapper.writeValueAsBytes(new Integer(4)), "UTF-8"), containsString("customid")); + } + + @Test // SPR-12634 + public void customizeWellKnownModulesWithSerializer() + throws JsonProcessingException, UnsupportedEncodingException { + + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .serializerByType(Integer.class, new CustomIntegerSerializer()).build(); + DateTime dateTime = new DateTime(1322903730000L, DateTimeZone.UTC); + assertEquals("1322903730000", new String(objectMapper.writeValueAsBytes(dateTime), "UTF-8")); + assertThat(new String(objectMapper.writeValueAsBytes(new Integer(4)), "UTF-8"), containsString("customid")); + } + + @Test // gh-22576 + public void overrideWellKnownModuleWithModule() throws IOException { + Jackson2ObjectMapperBuilder builder = new Jackson2ObjectMapperBuilder(); + JavaTimeModule javaTimeModule = new JavaTimeModule(); + javaTimeModule.addDeserializer(OffsetDateTime.class, new OffsetDateTimeDeserializer()); + builder.modulesToInstall(javaTimeModule); + builder.featuresToDisable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); + ObjectMapper objectMapper = builder.build(); + DemoPojo demoPojo = objectMapper.readValue(DATA, DemoPojo.class); + assertNotNull(demoPojo.getOffsetDateTime()); + } + + @Test // gh-22740 + public void registerMultipleModulesWithNullTypeId() { + Jackson2ObjectMapperBuilder builder = new Jackson2ObjectMapperBuilder(); + SimpleModule fooModule = new SimpleModule(); + fooModule.addSerializer(new FooSerializer()); + SimpleModule barModule = new SimpleModule(); + barModule.addSerializer(new BarSerializer()); + builder.modulesToInstall(fooModule, barModule); + ObjectMapper objectMapper = builder.build(); + assertEquals(1, StreamSupport + .stream(getSerializerFactoryConfig(objectMapper).serializers().spliterator(), false) + .filter(s -> s.findSerializer(null, SimpleType.construct(Foo.class), null) != null) + .count()); + assertEquals(1, StreamSupport + .stream(getSerializerFactoryConfig(objectMapper).serializers().spliterator(), false) + .filter(s -> s.findSerializer(null, SimpleType.construct(Bar.class), null) != null) + .count()); + } + + private static SerializerFactoryConfig getSerializerFactoryConfig(ObjectMapper objectMapper) { + return ((BasicSerializerFactory) objectMapper.getSerializerFactory()).getFactoryConfig(); + } + + private static DeserializerFactoryConfig getDeserializerFactoryConfig(ObjectMapper objectMapper) { + return ((BasicDeserializerFactory) objectMapper.getDeserializationContext().getFactory()).getFactoryConfig(); + } + + @Test + public void propertyNamingStrategy() { + PropertyNamingStrategy strategy = new PropertyNamingStrategy.SnakeCaseStrategy(); + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json().propertyNamingStrategy(strategy).build(); + assertSame(strategy, objectMapper.getSerializationConfig().getPropertyNamingStrategy()); + assertSame(strategy, objectMapper.getDeserializationConfig().getPropertyNamingStrategy()); + } + + @Test + public void serializerByType() { + JsonSerializer serializer = new NumberSerializer(Integer.class); + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modules(new ArrayList<>()) // Disable well-known modules detection + .serializerByType(Boolean.class, serializer) + .build(); + assertTrue(getSerializerFactoryConfig(objectMapper).hasSerializers()); + Serializers serializers = getSerializerFactoryConfig(objectMapper).serializers().iterator().next(); + assertSame(serializer, serializers.findSerializer(null, SimpleType.construct(Boolean.class), null)); + } + + @Test + public void deserializerByType() throws JsonMappingException { + JsonDeserializer deserializer = new DateDeserializers.DateDeserializer(); + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modules(new ArrayList<>()) // Disable well-known modules detection + .deserializerByType(Date.class, deserializer) + .build(); + assertTrue(getDeserializerFactoryConfig(objectMapper).hasDeserializers()); + Deserializers deserializers = getDeserializerFactoryConfig(objectMapper).deserializers().iterator().next(); + assertSame(deserializer, deserializers.findBeanDeserializer(SimpleType.construct(Date.class), null, null)); + } + + @Test + public void mixIn() { + Class target = String.class; + Class mixInSource = Object.class; + + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modules().mixIn(target, mixInSource) + .build(); + + assertEquals(1, objectMapper.mixInCount()); + assertSame(mixInSource, objectMapper.findMixInClassFor(target)); + } + + @Test + public void mixIns() { + Class target = String.class; + Class mixInSource = Object.class; + Map, Class> mixIns = new HashMap<>(); + mixIns.put(target, mixInSource); + + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .modules().mixIns(mixIns) + .build(); + + assertEquals(1, objectMapper.mixInCount()); + assertSame(mixInSource, objectMapper.findMixInClassFor(target)); + } + + @Test + public void filters() throws JsonProcessingException { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .filters(new SimpleFilterProvider().setFailOnUnknownId(false)).build(); + JacksonFilteredBean bean = new JacksonFilteredBean("value1", "value2"); + String output = objectMapper.writeValueAsString(bean); + assertThat(output, containsString("value1")); + assertThat(output, containsString("value2")); + + SimpleFilterProvider provider = new SimpleFilterProvider() + .setFailOnUnknownId(false) + .setDefaultFilter(SimpleBeanPropertyFilter.serializeAllExcept("property2")); + objectMapper = Jackson2ObjectMapperBuilder.json().filters(provider).build(); + output = objectMapper.writeValueAsString(bean); + assertThat(output, containsString("value1")); + assertThat(output, not(containsString("value2"))); + } + + @Test + public void completeSetup() throws JsonMappingException { + NopAnnotationIntrospector annotationIntrospector = NopAnnotationIntrospector.instance; + + Map, JsonDeserializer> deserializerMap = new HashMap<>(); + JsonDeserializer deserializer = new DateDeserializers.DateDeserializer(); + deserializerMap.put(Date.class, deserializer); + + JsonSerializer> serializer1 = new ClassSerializer(); + JsonSerializer serializer2 = new NumberSerializer(Integer.class); + + Jackson2ObjectMapperBuilder builder = Jackson2ObjectMapperBuilder.json() + .modules(new ArrayList<>()) // Disable well-known modules detection + .serializers(serializer1) + .serializersByType(Collections.singletonMap(Boolean.class, serializer2)) + .deserializersByType(deserializerMap) + .annotationIntrospector(annotationIntrospector) + .featuresToEnable(SerializationFeature.FAIL_ON_EMPTY_BEANS, + DeserializationFeature.UNWRAP_ROOT_VALUE, + JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, + JsonGenerator.Feature.WRITE_NUMBERS_AS_STRINGS) + .featuresToDisable(MapperFeature.AUTO_DETECT_GETTERS, + MapperFeature.AUTO_DETECT_FIELDS, + JsonParser.Feature.AUTO_CLOSE_SOURCE, + JsonGenerator.Feature.QUOTE_FIELD_NAMES) + .serializationInclusion(JsonInclude.Include.NON_NULL); + + ObjectMapper mapper = new ObjectMapper(); + builder.configure(mapper); + + assertTrue(getSerializerFactoryConfig(mapper).hasSerializers()); + assertTrue(getDeserializerFactoryConfig(mapper).hasDeserializers()); + + Serializers serializers = getSerializerFactoryConfig(mapper).serializers().iterator().next(); + assertSame(serializer1, serializers.findSerializer(null, SimpleType.construct(Class.class), null)); + assertSame(serializer2, serializers.findSerializer(null, SimpleType.construct(Boolean.class), null)); + assertNull(serializers.findSerializer(null, SimpleType.construct(Number.class), null)); + + Deserializers deserializers = getDeserializerFactoryConfig(mapper).deserializers().iterator().next(); + assertSame(deserializer, deserializers.findBeanDeserializer(SimpleType.construct(Date.class), null, null)); + + assertSame(annotationIntrospector, mapper.getSerializationConfig().getAnnotationIntrospector()); + assertSame(annotationIntrospector, mapper.getDeserializationConfig().getAnnotationIntrospector()); + + assertTrue(mapper.getSerializationConfig().isEnabled(SerializationFeature.FAIL_ON_EMPTY_BEANS)); + assertTrue(mapper.getDeserializationConfig().isEnabled(DeserializationFeature.UNWRAP_ROOT_VALUE)); + assertTrue(mapper.getFactory().isEnabled(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER)); + assertTrue(mapper.getFactory().isEnabled(JsonGenerator.Feature.WRITE_NUMBERS_AS_STRINGS)); + + assertFalse(mapper.getSerializationConfig().isEnabled(MapperFeature.AUTO_DETECT_GETTERS)); + assertFalse(mapper.getDeserializationConfig().isEnabled(MapperFeature.DEFAULT_VIEW_INCLUSION)); + assertFalse(mapper.getDeserializationConfig().isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)); + assertFalse(mapper.getDeserializationConfig().isEnabled(MapperFeature.AUTO_DETECT_FIELDS)); + assertFalse(mapper.getFactory().isEnabled(JsonParser.Feature.AUTO_CLOSE_SOURCE)); + assertFalse(mapper.getFactory().isEnabled(JsonGenerator.Feature.QUOTE_FIELD_NAMES)); + assertSame(JsonInclude.Include.NON_NULL, mapper.getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void xmlMapper() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.xml().build(); + assertNotNull(objectMapper); + assertEquals(XmlMapper.class, objectMapper.getClass()); + } + + @Test // gh-22428 + public void xmlMapperAndCustomFactory() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.xml().factory(new MyXmlFactory()).build(); + assertNotNull(objectMapper); + assertEquals(XmlMapper.class, objectMapper.getClass()); + assertEquals(MyXmlFactory.class, objectMapper.getFactory().getClass()); + } + + @Test + public void createXmlMapper() { + Jackson2ObjectMapperBuilder builder = Jackson2ObjectMapperBuilder.json().indentOutput(true); + ObjectMapper jsonObjectMapper = builder.build(); + ObjectMapper xmlObjectMapper = builder.createXmlMapper(true).build(); + assertTrue(jsonObjectMapper.isEnabled(SerializationFeature.INDENT_OUTPUT)); + assertTrue(xmlObjectMapper.isEnabled(SerializationFeature.INDENT_OUTPUT)); + assertTrue(xmlObjectMapper.getClass().isAssignableFrom(XmlMapper.class)); + } + + @Test // SPR-13975 + public void defaultUseWrapper() throws JsonProcessingException { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.xml().defaultUseWrapper(false).build(); + assertNotNull(objectMapper); + assertEquals(XmlMapper.class, objectMapper.getClass()); + ListContainer container = new ListContainer<>(Arrays.asList("foo", "bar")); + String output = objectMapper.writeValueAsString(container); + assertThat(output, containsString("foobar")); + } + + @Test // SPR-14435 + public void smile() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.smile().build(); + assertNotNull(objectMapper); + assertEquals(SmileFactory.class, objectMapper.getFactory().getClass()); + } + + @Test // SPR-14435 + public void cbor() { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.cbor().build(); + assertNotNull(objectMapper); + assertEquals(CBORFactory.class, objectMapper.getFactory().getClass()); + } + + @Test // SPR-14435 + public void factory() { + ObjectMapper objectMapper = new Jackson2ObjectMapperBuilder().factory(new SmileFactory()).build(); + assertNotNull(objectMapper); + assertEquals(SmileFactory.class, objectMapper.getFactory().getClass()); + } + + @Test + public void visibility() throws JsonProcessingException { + ObjectMapper objectMapper = Jackson2ObjectMapperBuilder.json() + .visibility(PropertyAccessor.GETTER, Visibility.NONE) + .visibility(PropertyAccessor.FIELD, Visibility.ANY) + .build(); + + String json = objectMapper.writeValueAsString(new JacksonVisibilityBean()); + assertThat(json, containsString("property1")); + assertThat(json, containsString("property2")); + assertThat(json, not(containsString("property3"))); + } + + + public static class CustomIntegerModule extends Module { + + @Override + public String getModuleName() { + return this.getClass().getSimpleName(); + } + + @Override + public Version version() { + return Version.unknownVersion(); + } + + @Override + public void setupModule(SetupContext context) { + SimpleSerializers serializers = new SimpleSerializers(); + serializers.addSerializer(Integer.class, new CustomIntegerSerializer()); + context.addSerializers(serializers); + } + } + + + public static class CustomIntegerSerializer extends JsonSerializer { + + @Override + public void serialize(Integer value, JsonGenerator gen, SerializerProvider serializers) + throws IOException { + + gen.writeStartObject(); + gen.writeNumberField("customid", value); + gen.writeEndObject(); + } + } + + + @JsonFilter("myJacksonFilter") + public static class JacksonFilteredBean { + + public JacksonFilteredBean() { + } + + public JacksonFilteredBean(String property1, String property2) { + this.property1 = property1; + this.property2 = property2; + } + + private String property1; + private String property2; + + public String getProperty1() { + return property1; + } + + public void setProperty1(String property1) { + this.property1 = property1; + } + + public String getProperty2() { + return property2; + } + + public void setProperty2(String property2) { + this.property2 = property2; + } + } + + + public static class ListContainer { + + private List list; + + public ListContainer() { + } + + public ListContainer(List list) { + this.list = list; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + } + + + public static class JacksonVisibilityBean { + + private String property1; + + public String property2; + + public String getProperty3() { + return null; + } + } + + + static class OffsetDateTimeDeserializer extends JsonDeserializer { + + private static final String CURRENT_ZONE_OFFSET = OffsetDateTime.now().getOffset().toString(); + + @Override + public OffsetDateTime deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { + final String value = jsonParser.getValueAsString(); + if (StringUtils.isEmpty(value)) { + return null; + } + try { + return OffsetDateTime.parse(value); + + } catch (DateTimeParseException exception) { + return OffsetDateTime.parse(value + CURRENT_ZONE_OFFSET); + } + } + } + + + @JsonDeserialize + static class DemoPojo { + + private OffsetDateTime offsetDateTime; + + public OffsetDateTime getOffsetDateTime() { + return offsetDateTime; + } + + public void setOffsetDateTime(OffsetDateTime offsetDateTime) { + this.offsetDateTime = offsetDateTime; + } + } + + + @SuppressWarnings("serial") + public static class MyXmlFactory extends XmlFactory { + } + + + static class Foo {} + + static class Bar {} + + static class FooSerializer extends JsonSerializer { + @Override + public void serialize(Foo value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + } + + @Override + public Class handledType() { + return Foo.class; + } + } + + static class BarSerializer extends JsonSerializer { + @Override + public void serialize(Bar value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + } + @Override + public Class handledType() { + return Bar.class; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/json/Jackson2ObjectMapperFactoryBeanTests.java b/spring-web/src/test/java/org/springframework/http/converter/json/Jackson2ObjectMapperFactoryBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6e391cc7cbe88c0e4e1ffa20b9f2586b313bc910 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/json/Jackson2ObjectMapperFactoryBeanTests.java @@ -0,0 +1,457 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; + +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; + +import org.junit.Test; + +import org.springframework.beans.FatalBeanException; + +import com.fasterxml.jackson.annotation.JsonFilter; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.Version; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.Module; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.cfg.DeserializerFactoryConfig; +import com.fasterxml.jackson.databind.cfg.SerializerFactoryConfig; +import com.fasterxml.jackson.databind.deser.BasicDeserializerFactory; +import com.fasterxml.jackson.databind.deser.std.DateDeserializers.DateDeserializer; +import com.fasterxml.jackson.databind.introspect.NopAnnotationIntrospector; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.module.SimpleSerializers; +import com.fasterxml.jackson.databind.ser.BasicSerializerFactory; +import com.fasterxml.jackson.databind.ser.Serializers; +import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider; +import com.fasterxml.jackson.databind.ser.std.ClassSerializer; +import com.fasterxml.jackson.databind.ser.std.NumberSerializer; +import com.fasterxml.jackson.databind.type.SimpleType; +import com.fasterxml.jackson.dataformat.xml.XmlMapper; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + +/** + * Test cases for {@link Jackson2ObjectMapperFactoryBean}. + * + * @author Dmitry Katsubo + * @author Brian Clozel + * @author Sebastien Deleuze + * @author Sam Brannen + */ +@SuppressWarnings("deprecation") +public class Jackson2ObjectMapperFactoryBeanTests { + + private static final String DATE_FORMAT = "yyyy-MM-dd"; + + private final SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); + + private final Jackson2ObjectMapperFactoryBean factory = new Jackson2ObjectMapperFactoryBean(); + + + @Test(expected = FatalBeanException.class) + public void unknownFeature() { + this.factory.setFeaturesToEnable(Boolean.TRUE); + this.factory.afterPropertiesSet(); + } + + @Test + public void booleanSetters() { + this.factory.setAutoDetectFields(false); + this.factory.setAutoDetectGettersSetters(false); + this.factory.setDefaultViewInclusion(false); + this.factory.setFailOnEmptyBeans(false); + this.factory.setIndentOutput(true); + this.factory.afterPropertiesSet(); + + ObjectMapper objectMapper = this.factory.getObject(); + + assertFalse(objectMapper.getSerializationConfig().isEnabled(MapperFeature.AUTO_DETECT_FIELDS)); + assertFalse(objectMapper.getDeserializationConfig().isEnabled(MapperFeature.AUTO_DETECT_FIELDS)); + assertFalse(objectMapper.getSerializationConfig().isEnabled(MapperFeature.AUTO_DETECT_GETTERS)); + assertFalse(objectMapper.getDeserializationConfig().isEnabled(MapperFeature.AUTO_DETECT_SETTERS)); + assertFalse(objectMapper.getDeserializationConfig().isEnabled(MapperFeature.DEFAULT_VIEW_INCLUSION)); + assertFalse(objectMapper.getSerializationConfig().isEnabled(SerializationFeature.FAIL_ON_EMPTY_BEANS)); + assertTrue(objectMapper.getSerializationConfig().isEnabled(SerializationFeature.INDENT_OUTPUT)); + assertSame(Include.ALWAYS, objectMapper.getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void defaultSerializationInclusion() { + this.factory.afterPropertiesSet(); + assertSame(Include.ALWAYS, this.factory.getObject().getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void nonNullSerializationInclusion() { + this.factory.setSerializationInclusion(Include.NON_NULL); + this.factory.afterPropertiesSet(); + assertSame(Include.NON_NULL, this.factory.getObject().getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void nonDefaultSerializationInclusion() { + this.factory.setSerializationInclusion(Include.NON_DEFAULT); + this.factory.afterPropertiesSet(); + assertSame(Include.NON_DEFAULT, this.factory.getObject().getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void nonEmptySerializationInclusion() { + this.factory.setSerializationInclusion(Include.NON_EMPTY); + this.factory.afterPropertiesSet(); + assertSame(Include.NON_EMPTY, this.factory.getObject().getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void setDateFormat() { + this.factory.setDateFormat(this.dateFormat); + this.factory.afterPropertiesSet(); + + assertEquals(this.dateFormat, this.factory.getObject().getSerializationConfig().getDateFormat()); + assertEquals(this.dateFormat, this.factory.getObject().getDeserializationConfig().getDateFormat()); + } + + @Test + public void setSimpleDateFormat() { + this.factory.setSimpleDateFormat(DATE_FORMAT); + this.factory.afterPropertiesSet(); + + assertEquals(this.dateFormat, this.factory.getObject().getSerializationConfig().getDateFormat()); + assertEquals(this.dateFormat, this.factory.getObject().getDeserializationConfig().getDateFormat()); + } + + @Test + public void setLocale() { + this.factory.setLocale(Locale.FRENCH); + this.factory.afterPropertiesSet(); + + assertEquals(Locale.FRENCH, this.factory.getObject().getSerializationConfig().getLocale()); + assertEquals(Locale.FRENCH, this.factory.getObject().getDeserializationConfig().getLocale()); + } + + @Test + public void setTimeZone() { + TimeZone timeZone = TimeZone.getTimeZone("Europe/Paris"); + + this.factory.setTimeZone(timeZone); + this.factory.afterPropertiesSet(); + + assertEquals(timeZone, this.factory.getObject().getSerializationConfig().getTimeZone()); + assertEquals(timeZone, this.factory.getObject().getDeserializationConfig().getTimeZone()); + } + + @Test + public void setTimeZoneWithInvalidZoneId() { + this.factory.setTimeZone(TimeZone.getTimeZone("bogusZoneId")); + this.factory.afterPropertiesSet(); + + TimeZone timeZone = TimeZone.getTimeZone("GMT"); + assertEquals(timeZone, this.factory.getObject().getSerializationConfig().getTimeZone()); + assertEquals(timeZone, this.factory.getObject().getDeserializationConfig().getTimeZone()); + } + + @Test + public void setModules() { + NumberSerializer serializer = new NumberSerializer(Integer.class); + SimpleModule module = new SimpleModule(); + module.addSerializer(Integer.class, serializer); + + this.factory.setModules(Arrays.asList(new Module[]{module})); + this.factory.afterPropertiesSet(); + ObjectMapper objectMapper = this.factory.getObject(); + + Serializers serializers = getSerializerFactoryConfig(objectMapper).serializers().iterator().next(); + assertSame(serializer, serializers.findSerializer(null, SimpleType.construct(Integer.class), null)); + } + + @Test + public void defaultModules() throws JsonProcessingException, UnsupportedEncodingException { + this.factory.afterPropertiesSet(); + ObjectMapper objectMapper = this.factory.getObject(); + + Long timestamp = 1322903730000L; + DateTime dateTime = new DateTime(timestamp, DateTimeZone.UTC); + assertEquals(timestamp.toString(), new String(objectMapper.writeValueAsBytes(dateTime), "UTF-8")); + } + + @Test // SPR-12634 + public void customizeDefaultModulesWithModuleClass() throws JsonProcessingException, UnsupportedEncodingException { + this.factory.setModulesToInstall(CustomIntegerModule.class); + this.factory.afterPropertiesSet(); + ObjectMapper objectMapper = this.factory.getObject(); + + DateTime dateTime = new DateTime(1322903730000L, DateTimeZone.UTC); + assertEquals("1322903730000", new String(objectMapper.writeValueAsBytes(dateTime), "UTF-8")); + assertThat(new String(objectMapper.writeValueAsBytes(new Integer(4)), "UTF-8"), containsString("customid")); + } + + @Test // SPR-12634 + public void customizeDefaultModulesWithSerializer() throws JsonProcessingException, UnsupportedEncodingException { + Map, JsonSerializer> serializers = new HashMap<>(); + serializers.put(Integer.class, new CustomIntegerSerializer()); + + this.factory.setSerializersByType(serializers); + this.factory.afterPropertiesSet(); + ObjectMapper objectMapper = this.factory.getObject(); + + DateTime dateTime = new DateTime(1322903730000L, DateTimeZone.UTC); + assertEquals("1322903730000", new String(objectMapper.writeValueAsBytes(dateTime), "UTF-8")); + assertThat(new String(objectMapper.writeValueAsBytes(new Integer(4)), "UTF-8"), containsString("customid")); + } + + @Test + public void simpleSetup() { + this.factory.afterPropertiesSet(); + + assertNotNull(this.factory.getObject()); + assertTrue(this.factory.isSingleton()); + assertEquals(ObjectMapper.class, this.factory.getObjectType()); + } + + @Test + public void undefinedObjectType() { + assertNull(this.factory.getObjectType()); + } + + private static SerializerFactoryConfig getSerializerFactoryConfig(ObjectMapper objectMapper) { + return ((BasicSerializerFactory) objectMapper.getSerializerFactory()).getFactoryConfig(); + } + + private static DeserializerFactoryConfig getDeserializerFactoryConfig(ObjectMapper objectMapper) { + return ((BasicDeserializerFactory) objectMapper.getDeserializationContext().getFactory()).getFactoryConfig(); + } + + @Test + public void propertyNamingStrategy() { + PropertyNamingStrategy strategy = new PropertyNamingStrategy.SnakeCaseStrategy(); + this.factory.setPropertyNamingStrategy(strategy); + this.factory.afterPropertiesSet(); + + assertSame(strategy, this.factory.getObject().getSerializationConfig().getPropertyNamingStrategy()); + assertSame(strategy, this.factory.getObject().getDeserializationConfig().getPropertyNamingStrategy()); + } + + @Test + public void setMixIns() { + Class target = String.class; + Class mixinSource = Object.class; + Map, Class> mixIns = new HashMap<>(); + mixIns.put(target, mixinSource); + + this.factory.setModules(Collections.emptyList()); + this.factory.setMixIns(mixIns); + this.factory.afterPropertiesSet(); + ObjectMapper objectMapper = this.factory.getObject(); + + assertEquals(1, objectMapper.mixInCount()); + assertSame(mixinSource, objectMapper.findMixInClassFor(target)); + } + + @Test + public void setFilters() throws JsonProcessingException { + this.factory.setFilters(new SimpleFilterProvider().setFailOnUnknownId(false)); + this.factory.afterPropertiesSet(); + ObjectMapper objectMapper = this.factory.getObject(); + + JacksonFilteredBean bean = new JacksonFilteredBean("value1", "value2"); + String output = objectMapper.writeValueAsString(bean); + assertThat(output, containsString("value1")); + assertThat(output, containsString("value2")); + } + + @Test + public void completeSetup() { + NopAnnotationIntrospector annotationIntrospector = NopAnnotationIntrospector.instance; + ObjectMapper objectMapper = new ObjectMapper(); + + this.factory.setObjectMapper(objectMapper); + assertTrue(this.factory.isSingleton()); + assertEquals(ObjectMapper.class, this.factory.getObjectType()); + + Map, JsonDeserializer> deserializers = new HashMap<>(); + deserializers.put(Date.class, new DateDeserializer()); + + JsonSerializer> serializer1 = new ClassSerializer(); + JsonSerializer serializer2 = new NumberSerializer(Integer.class); + + // Disable well-known modules detection + this.factory.setModules(new ArrayList<>()); + this.factory.setSerializers(serializer1); + this.factory.setSerializersByType(Collections.singletonMap(Boolean.class, serializer2)); + this.factory.setDeserializersByType(deserializers); + this.factory.setAnnotationIntrospector(annotationIntrospector); + + this.factory.setFeaturesToEnable(SerializationFeature.FAIL_ON_EMPTY_BEANS, + DeserializationFeature.UNWRAP_ROOT_VALUE, + JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, + JsonGenerator.Feature.WRITE_NUMBERS_AS_STRINGS); + + this.factory.setFeaturesToDisable(MapperFeature.AUTO_DETECT_GETTERS, + MapperFeature.AUTO_DETECT_FIELDS, + JsonParser.Feature.AUTO_CLOSE_SOURCE, + JsonGenerator.Feature.QUOTE_FIELD_NAMES); + + assertFalse(getSerializerFactoryConfig(objectMapper).hasSerializers()); + assertFalse(getDeserializerFactoryConfig(objectMapper).hasDeserializers()); + + this.factory.setSerializationInclusion(Include.NON_NULL); + this.factory.afterPropertiesSet(); + + assertSame(objectMapper, this.factory.getObject()); + assertTrue(getSerializerFactoryConfig(objectMapper).hasSerializers()); + assertTrue(getDeserializerFactoryConfig(objectMapper).hasDeserializers()); + + Serializers serializers = getSerializerFactoryConfig(objectMapper).serializers().iterator().next(); + assertSame(serializer1, serializers.findSerializer(null, SimpleType.construct(Class.class), null)); + assertSame(serializer2, serializers.findSerializer(null, SimpleType.construct(Boolean.class), null)); + assertNull(serializers.findSerializer(null, SimpleType.construct(Number.class), null)); + + assertSame(annotationIntrospector, objectMapper.getSerializationConfig().getAnnotationIntrospector()); + assertSame(annotationIntrospector, objectMapper.getDeserializationConfig().getAnnotationIntrospector()); + + assertTrue(objectMapper.getSerializationConfig().isEnabled(SerializationFeature.FAIL_ON_EMPTY_BEANS)); + assertTrue(objectMapper.getDeserializationConfig().isEnabled(DeserializationFeature.UNWRAP_ROOT_VALUE)); + assertTrue(objectMapper.getFactory().isEnabled(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER)); + assertTrue(objectMapper.getFactory().isEnabled(JsonGenerator.Feature.WRITE_NUMBERS_AS_STRINGS)); + + assertFalse(objectMapper.getSerializationConfig().isEnabled(MapperFeature.AUTO_DETECT_GETTERS)); + assertFalse(objectMapper.getDeserializationConfig().isEnabled(MapperFeature.DEFAULT_VIEW_INCLUSION)); + assertFalse(objectMapper.getDeserializationConfig().isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)); + assertFalse(objectMapper.getDeserializationConfig().isEnabled(MapperFeature.AUTO_DETECT_FIELDS)); + assertFalse(objectMapper.getFactory().isEnabled(JsonParser.Feature.AUTO_CLOSE_SOURCE)); + assertFalse(objectMapper.getFactory().isEnabled(JsonGenerator.Feature.QUOTE_FIELD_NAMES)); + assertSame(Include.NON_NULL, objectMapper.getSerializationConfig().getSerializationInclusion()); + } + + @Test + public void setObjectMapper() { + this.factory.setObjectMapper(new XmlMapper()); + this.factory.afterPropertiesSet(); + + assertNotNull(this.factory.getObject()); + assertTrue(this.factory.isSingleton()); + assertEquals(XmlMapper.class, this.factory.getObjectType()); + } + + @Test + public void setCreateXmlMapper() { + this.factory.setCreateXmlMapper(true); + this.factory.afterPropertiesSet(); + + assertNotNull(this.factory.getObject()); + assertTrue(this.factory.isSingleton()); + assertEquals(XmlMapper.class, this.factory.getObjectType()); + } + + @Test // SPR-14435 + public void setFactory() { + this.factory.setFactory(new SmileFactory()); + this.factory.afterPropertiesSet(); + + assertNotNull(this.factory.getObject()); + assertTrue(this.factory.isSingleton()); + assertEquals(SmileFactory.class, this.factory.getObject().getFactory().getClass()); + } + + + public static class CustomIntegerModule extends Module { + + @Override + public String getModuleName() { + return this.getClass().getSimpleName(); + } + + @Override + public Version version() { + return Version.unknownVersion(); + } + + @Override + public void setupModule(SetupContext context) { + SimpleSerializers serializers = new SimpleSerializers(); + serializers.addSerializer(Integer.class, new CustomIntegerSerializer()); + context.addSerializers(serializers); + } + } + + + public static class CustomIntegerSerializer extends JsonSerializer { + + @Override + public void serialize(Integer value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + gen.writeStartObject(); + gen.writeNumberField("customid", value); + gen.writeEndObject(); + } + } + + + @JsonFilter("myJacksonFilter") + public static class JacksonFilteredBean { + + private String property1; + private String property2; + + + public JacksonFilteredBean(String property1, String property2) { + this.property1 = property1; + this.property2 = property2; + } + + public String getProperty1() { + return property1; + } + + public void setProperty1(String property1) { + this.property1 = property1; + } + + public String getProperty2() { + return property2; + } + + public void setProperty2(String property2) { + this.property2 = property2; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/json/JsonbHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/json/JsonbHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0016895538c3075549170e3ec36532a2f9201b1b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/json/JsonbHttpMessageConverterTests.java @@ -0,0 +1,346 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Type; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.skyscreamer.jsonassert.JSONAssert; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; + +import static org.junit.Assert.*; + +/** + * Integration tests for the JSON Binding API, running against Apache Johnzon. + * + * @author Juergen Hoeller + * @since 5.0 + */ +public class JsonbHttpMessageConverterTests { + + private final JsonbHttpMessageConverter converter = new JsonbHttpMessageConverter(); + + + @Test + public void canRead() { + assertTrue(this.converter.canRead(MyBean.class, new MediaType("application", "json"))); + assertTrue(this.converter.canRead(Map.class, new MediaType("application", "json"))); + } + + @Test + public void canWrite() { + assertTrue(this.converter.canWrite(MyBean.class, new MediaType("application", "json"))); + assertTrue(this.converter.canWrite(Map.class, new MediaType("application", "json"))); + } + + @Test + public void canReadAndWriteMicroformats() { + assertTrue(this.converter.canRead(MyBean.class, new MediaType("application", "vnd.test-micro-type+json"))); + assertTrue(this.converter.canWrite(MyBean.class, new MediaType("application", "vnd.test-micro-type+json"))); + } + + @Test + public void readTyped() throws IOException { + String body = "{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + MyBean result = (MyBean) this.converter.read(MyBean.class, inputMessage); + + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + } + + @Test + @SuppressWarnings("unchecked") + public void readUntyped() throws IOException { + String body = "{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + HashMap result = (HashMap) this.converter.read(HashMap.class, inputMessage); + assertEquals("Foo", result.get("string")); + Number n = (Number) result.get("number"); + assertEquals(42, n.longValue()); + n = (Number) result.get("fraction"); + assertEquals(42D, n.doubleValue(), 0D); + List array = new ArrayList<>(); + array.add("Foo"); + array.add("Bar"); + assertEquals(array, result.get("array")); + assertEquals(Boolean.TRUE, result.get("bool")); + byte[] bytes = new byte[2]; + List resultBytes = (ArrayList)result.get("bytes"); + for (int i = 0; i < 2; i++) { + bytes[i] = resultBytes.get(i).byteValue(); + } + assertArrayEquals(new byte[] {0x1, 0x2}, bytes); + } + + @Test + public void write() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[] {"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[] {0x1, 0x2}); + this.converter.write(body, null, outputMessage); + Charset utf8 = StandardCharsets.UTF_8; + String result = outputMessage.getBodyAsString(utf8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + assertTrue(result.contains("fraction\":42.0")); + assertTrue(result.contains("\"array\":[\"Foo\",\"Bar\"]")); + assertTrue(result.contains("\"bool\":true")); + assertTrue(result.contains("\"bytes\":[1,2]")); + assertEquals("Invalid content-type", new MediaType("application", "json", utf8), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeWithBaseType() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[] {"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[] {0x1, 0x2}); + this.converter.write(body, MyBase.class, null, outputMessage); + Charset utf8 = StandardCharsets.UTF_8; + String result = outputMessage.getBodyAsString(utf8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + assertTrue(result.contains("fraction\":42.0")); + assertTrue(result.contains("\"array\":[\"Foo\",\"Bar\"]")); + assertTrue(result.contains("\"bool\":true")); + assertTrue(result.contains("\"bytes\":[1,2]")); + assertEquals("Invalid content-type", new MediaType("application", "json", utf8), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeUTF16() throws IOException { + MediaType contentType = new MediaType("application", "json", StandardCharsets.UTF_16BE); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + String body = "H\u00e9llo W\u00f6rld"; + this.converter.write(body, contentType, outputMessage); + assertEquals("Invalid result", body, outputMessage.getBodyAsString(StandardCharsets.UTF_16BE)); + assertEquals("Invalid content-type", contentType, outputMessage.getHeaders().getContentType()); + } + + @Test(expected = HttpMessageNotReadableException.class) + public void readInvalidJson() throws IOException { + String body = "FooBar"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + this.converter.read(MyBean.class, inputMessage); + } + + @Test + @SuppressWarnings("unchecked") + public void readAndWriteGenerics() throws Exception { + Field beansList = ListHolder.class.getField("listField"); + + String body = "[{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + Type genericType = beansList.getGenericType(); + List results = (List) converter.read(genericType, MyBeanListHolder.class, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, genericType, new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + @SuppressWarnings("unchecked") + public void readAndWriteParameterizedType() throws Exception { + ParameterizedTypeReference> beansList = new ParameterizedTypeReference>() {}; + + String body = "[{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + List results = (List) converter.read(beansList.getType(), null, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, beansList.getType(), new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + @SuppressWarnings("unchecked") + public void writeParameterizedBaseType() throws Exception { + ParameterizedTypeReference> beansList = new ParameterizedTypeReference>() {}; + ParameterizedTypeReference> baseList = new ParameterizedTypeReference>() {}; + + String body = "[{\"bytes\":[1,2],\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42,\"string\":\"Foo\",\"bool\":true,\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + List results = (List) converter.read(beansList.getType(), null, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, baseList.getType(), new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + public void prefixJson() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.setPrefixJson(true); + this.converter.writeInternal("foo", null, outputMessage); + assertEquals(")]}', foo", outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + } + + @Test + public void prefixJsonCustom() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.setJsonPrefix(")))"); + this.converter.writeInternal("foo", null, outputMessage); + assertEquals(")))foo", outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + } + + + public static class MyBase { + + private String string; + + public String getString() { + return string; + } + + public void setString(String string) { + this.string = string; + } + } + + + public static class MyBean extends MyBase { + + private int number; + + private float fraction; + + private String[] array; + + private boolean bool; + + private byte[] bytes; + + public int getNumber() { + return number; + } + + public void setNumber(int number) { + this.number = number; + } + + public float getFraction() { + return fraction; + } + + public void setFraction(float fraction) { + this.fraction = fraction; + } + + public String[] getArray() { + return array; + } + + public void setArray(String[] array) { + this.array = array; + } + + public boolean isBool() { + return bool; + } + + public void setBool(boolean bool) { + this.bool = bool; + } + + public byte[] getBytes() { + return bytes; + } + + public void setBytes(byte[] bytes) { + this.bytes = bytes; + } + } + + + public static class ListHolder { + + public List listField; + } + + + public static class MyBeanListHolder extends ListHolder { + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/json/MappingJackson2HttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/json/MappingJackson2HttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a2f9308a3fca5a964a41f58102dbdb67dc7182f4 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/json/MappingJackson2HttpMessageConverterTests.java @@ -0,0 +1,686 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.lang.reflect.Type; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonFilter; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ser.FilterProvider; +import com.fasterxml.jackson.databind.ser.impl.SimpleBeanPropertyFilter; +import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider; +import org.junit.Test; +import org.skyscreamer.jsonassert.JSONAssert; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageConversionException; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.lang.Nullable; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.not; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Jackson 2.x converter tests. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @author Juergen Hoeller + */ +public class MappingJackson2HttpMessageConverterTests { + + protected static final String NEWLINE_SYSTEM_PROPERTY = System.getProperty("line.separator"); + + private final MappingJackson2HttpMessageConverter converter = new MappingJackson2HttpMessageConverter(); + + + @Test + public void canRead() { + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "json"))); + assertTrue(converter.canRead(Map.class, new MediaType("application", "json"))); + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "json", StandardCharsets.UTF_8))); + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "json", StandardCharsets.US_ASCII))); + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "json", StandardCharsets.ISO_8859_1))); + } + + @Test + public void canWrite() { + assertTrue(converter.canWrite(MyBean.class, new MediaType("application", "json"))); + assertTrue(converter.canWrite(Map.class, new MediaType("application", "json"))); + assertTrue(converter.canWrite(MyBean.class, new MediaType("application", "json", StandardCharsets.UTF_8))); + assertTrue(converter.canWrite(MyBean.class, new MediaType("application", "json", StandardCharsets.US_ASCII))); + assertFalse(converter.canWrite(MyBean.class, new MediaType("application", "json", StandardCharsets.ISO_8859_1))); + } + + @Test // SPR-7905 + public void canReadAndWriteMicroformats() { + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "vnd.test-micro-type+json"))); + assertTrue(converter.canWrite(MyBean.class, new MediaType("application", "vnd.test-micro-type+json"))); + } + + @Test + public void readTyped() throws IOException { + String body = "{" + + "\"bytes\":\"AQI=\"," + + "\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42," + + "\"string\":\"Foo\"," + + "\"bool\":true," + + "\"fraction\":42.0}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + MyBean result = (MyBean) converter.read(MyBean.class, inputMessage); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + } + + @Test + @SuppressWarnings("unchecked") + public void readUntyped() throws IOException { + String body = "{" + + "\"bytes\":\"AQI=\"," + + "\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42," + + "\"string\":\"Foo\"," + + "\"bool\":true," + + "\"fraction\":42.0}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + HashMap result = (HashMap) converter.read(HashMap.class, inputMessage); + assertEquals("Foo", result.get("string")); + assertEquals(42, result.get("number")); + assertEquals(42D, (Double) result.get("fraction"), 0D); + List array = new ArrayList<>(); + array.add("Foo"); + array.add("Bar"); + assertEquals(array, result.get("array")); + assertEquals(Boolean.TRUE, result.get("bool")); + assertEquals("AQI=", result.get("bytes")); + } + + @Test + public void write() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[] {"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[] {0x1, 0x2}); + converter.write(body, null, outputMessage); + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + assertTrue(result.contains("fraction\":42.0")); + assertTrue(result.contains("\"array\":[\"Foo\",\"Bar\"]")); + assertTrue(result.contains("\"bool\":true")); + assertTrue(result.contains("\"bytes\":\"AQI=\"")); + assertEquals("Invalid content-type", new MediaType("application", "json", StandardCharsets.UTF_8), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeWithBaseType() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[] {"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[] {0x1, 0x2}); + converter.write(body, MyBase.class, null, outputMessage); + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + assertTrue(result.contains("fraction\":42.0")); + assertTrue(result.contains("\"array\":[\"Foo\",\"Bar\"]")); + assertTrue(result.contains("\"bool\":true")); + assertTrue(result.contains("\"bytes\":\"AQI=\"")); + assertEquals("Invalid content-type", new MediaType("application", "json", StandardCharsets.UTF_8), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeUTF16() throws IOException { + MediaType contentType = new MediaType("application", "json", StandardCharsets.UTF_16BE); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + String body = "H\u00e9llo W\u00f6rld"; + converter.write(body, contentType, outputMessage); + assertEquals("Invalid result", "\"" + body + "\"", outputMessage.getBodyAsString(StandardCharsets.UTF_16BE)); + assertEquals("Invalid content-type", contentType, outputMessage.getHeaders().getContentType()); + } + + @Test(expected = HttpMessageNotReadableException.class) + public void readInvalidJson() throws IOException { + String body = "FooBar"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + converter.read(MyBean.class, inputMessage); + } + + @Test + public void readValidJsonWithUnknownProperty() throws IOException { + String body = "{\"string\":\"string\",\"unknownProperty\":\"value\"}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + converter.read(MyBean.class, inputMessage); + // Assert no HttpMessageNotReadableException is thrown + } + + @Test + @SuppressWarnings("unchecked") + public void readAndWriteGenerics() throws Exception { + MappingJackson2HttpMessageConverter converter = new MappingJackson2HttpMessageConverter() { + @Override + protected JavaType getJavaType(Type type, @Nullable Class contextClass) { + if (type instanceof Class && List.class.isAssignableFrom((Class)type)) { + return new ObjectMapper().getTypeFactory().constructCollectionType(ArrayList.class, MyBean.class); + } + else { + return super.getJavaType(type, contextClass); + } + } + }; + String body = "[{" + + "\"bytes\":\"AQI=\"," + + "\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42," + + "\"string\":\"Foo\"," + + "\"bool\":true," + + "\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + List results = (List) converter.read(List.class, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + @SuppressWarnings("unchecked") + public void readAndWriteParameterizedType() throws Exception { + ParameterizedTypeReference> beansList = new ParameterizedTypeReference>() {}; + + String body = "[{" + + "\"bytes\":\"AQI=\"," + + "\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42," + + "\"string\":\"Foo\"," + + "\"bool\":true," + + "\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + MappingJackson2HttpMessageConverter converter = new MappingJackson2HttpMessageConverter(); + List results = (List) converter.read(beansList.getType(), null, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, beansList.getType(), new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + @SuppressWarnings("unchecked") + public void writeParameterizedBaseType() throws Exception { + ParameterizedTypeReference> beansList = new ParameterizedTypeReference>() {}; + ParameterizedTypeReference> baseList = new ParameterizedTypeReference>() {}; + + String body = "[{" + + "\"bytes\":\"AQI=\"," + + "\"array\":[\"Foo\",\"Bar\"]," + + "\"number\":42," + + "\"string\":\"Foo\"," + + "\"bool\":true," + + "\"fraction\":42.0}]"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + + MappingJackson2HttpMessageConverter converter = new MappingJackson2HttpMessageConverter(); + List results = (List) converter.read(beansList.getType(), null, inputMessage); + assertEquals(1, results.size()); + MyBean result = results.get(0); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[] {"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[] {0x1, 0x2}, result.getBytes()); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(results, baseList.getType(), new MediaType("application", "json"), outputMessage); + JSONAssert.assertEquals(body, outputMessage.getBodyAsString(StandardCharsets.UTF_8), true); + } + + @Test + public void prettyPrint() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + PrettyPrintBean bean = new PrettyPrintBean(); + bean.setName("Jason"); + + this.converter.setPrettyPrint(true); + this.converter.writeInternal(bean, null, outputMessage); + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + + assertEquals("{" + NEWLINE_SYSTEM_PROPERTY + + " \"name\" : \"Jason\"" + NEWLINE_SYSTEM_PROPERTY + "}", result); + } + + @Test + public void prettyPrintWithSse() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + outputMessage.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM); + PrettyPrintBean bean = new PrettyPrintBean(); + bean.setName("Jason"); + + this.converter.setPrettyPrint(true); + this.converter.writeInternal(bean, null, outputMessage); + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + + assertEquals("{\ndata: \"name\" : \"Jason\"\ndata:}", result); + } + + @Test + public void prefixJson() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.setPrefixJson(true); + this.converter.writeInternal("foo", null, outputMessage); + + assertEquals(")]}', \"foo\"", outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + } + + @Test + public void prefixJsonCustom() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.converter.setJsonPrefix(")))"); + this.converter.writeInternal("foo", null, outputMessage); + + assertEquals(")))\"foo\"", outputMessage.getBodyAsString(StandardCharsets.UTF_8)); + } + + @Test + public void fieldLevelJsonView() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + JacksonViewBean bean = new JacksonViewBean(); + bean.setWithView1("with"); + bean.setWithView2("with"); + bean.setWithoutView("without"); + + MappingJacksonValue jacksonValue = new MappingJacksonValue(bean); + jacksonValue.setSerializationView(MyJacksonView1.class); + this.converter.writeInternal(jacksonValue, null, outputMessage); + + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertThat(result, containsString("\"withView1\":\"with\"")); + assertThat(result, not(containsString("\"withView2\":\"with\""))); + assertThat(result, not(containsString("\"withoutView\":\"without\""))); + } + + @Test + public void classLevelJsonView() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + JacksonViewBean bean = new JacksonViewBean(); + bean.setWithView1("with"); + bean.setWithView2("with"); + bean.setWithoutView("without"); + + MappingJacksonValue jacksonValue = new MappingJacksonValue(bean); + jacksonValue.setSerializationView(MyJacksonView3.class); + this.converter.writeInternal(jacksonValue, null, outputMessage); + + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertThat(result, not(containsString("\"withView1\":\"with\""))); + assertThat(result, not(containsString("\"withView2\":\"with\""))); + assertThat(result, containsString("\"withoutView\":\"without\"")); + } + + @Test + public void filters() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + JacksonFilteredBean bean = new JacksonFilteredBean(); + bean.setProperty1("value"); + bean.setProperty2("value"); + + MappingJacksonValue jacksonValue = new MappingJacksonValue(bean); + FilterProvider filters = new SimpleFilterProvider().addFilter("myJacksonFilter", + SimpleBeanPropertyFilter.serializeAllExcept("property2")); + jacksonValue.setFilters(filters); + this.converter.writeInternal(jacksonValue, null, outputMessage); + + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertThat(result, containsString("\"property1\":\"value\"")); + assertThat(result, not(containsString("\"property2\":\"value\""))); + } + + @Test // SPR-13318 + public void writeSubType() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean bean = new MyBean(); + bean.setString("Foo"); + bean.setNumber(42); + + this.converter.writeInternal(bean, MyInterface.class, outputMessage); + + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + } + + @Test // SPR-13318 + public void writeSubTypeList() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + List beans = new ArrayList<>(); + MyBean foo = new MyBean(); + foo.setString("Foo"); + foo.setNumber(42); + beans.add(foo); + MyBean bar = new MyBean(); + bar.setString("Bar"); + bar.setNumber(123); + beans.add(bar); + ParameterizedTypeReference> typeReference = + new ParameterizedTypeReference>() {}; + + this.converter.writeInternal(beans, typeReference.getType(), outputMessage); + + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertTrue(result.contains("\"string\":\"Foo\"")); + assertTrue(result.contains("\"number\":42")); + assertTrue(result.contains("\"string\":\"Bar\"")); + assertTrue(result.contains("\"number\":123")); + } + + @Test + public void readWithNoDefaultConstructor() throws Exception { + String body = "{\"property1\":\"foo\",\"property2\":\"bar\"}"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(StandardCharsets.UTF_8)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json")); + try { + converter.read(BeanWithNoDefaultConstructor.class, inputMessage); + } + catch (HttpMessageConversionException ex) { + assertTrue(ex.getMessage(), ex.getMessage().startsWith("Type definition error:")); + return; + } + fail(); + } + + @Test + @SuppressWarnings("unchecked") + public void readNonUnicode() throws Exception { + String body = "{\"føø\":\"bår\"}"; + Charset charset = StandardCharsets.ISO_8859_1; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(charset)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json", charset)); + HashMap result = (HashMap) this.converter.read(HashMap.class, inputMessage); + + assertEquals(1, result.size()); + assertEquals("bår", result.get("føø")); + } + + @Test + @SuppressWarnings("unchecked") + public void readAscii() throws Exception { + String body = "{\"foo\":\"bar\"}"; + Charset charset = StandardCharsets.US_ASCII; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(charset)); + inputMessage.getHeaders().setContentType(new MediaType("application", "json", charset)); + HashMap result = (HashMap) this.converter.read(HashMap.class, inputMessage); + + assertEquals(1, result.size()); + assertEquals("bar", result.get("foo")); + } + + @Test + @SuppressWarnings("unchecked") + public void writeAscii() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Map body = new HashMap<>(); + body.put("foo", "bar"); + Charset charset = StandardCharsets.US_ASCII; + MediaType contentType = new MediaType("application", "json", charset); + converter.write(body, contentType, outputMessage); + + String result = outputMessage.getBodyAsString(charset); + assertEquals("{\"foo\":\"bar\"}", result); + assertEquals(contentType, outputMessage.getHeaders().getContentType()); + } + + + interface MyInterface { + + String getString(); + + void setString(String string); + } + + + public static class MyBase implements MyInterface{ + + private String string; + + public String getString() { + return string; + } + + public void setString(String string) { + this.string = string; + } + } + + + public static class MyBean extends MyBase { + + private int number; + + private float fraction; + + private String[] array; + + private boolean bool; + + private byte[] bytes; + + public int getNumber() { + return number; + } + + public void setNumber(int number) { + this.number = number; + } + + public float getFraction() { + return fraction; + } + + public void setFraction(float fraction) { + this.fraction = fraction; + } + + public String[] getArray() { + return array; + } + + public void setArray(String[] array) { + this.array = array; + } + + public boolean isBool() { + return bool; + } + + public void setBool(boolean bool) { + this.bool = bool; + } + + public byte[] getBytes() { + return bytes; + } + + public void setBytes(byte[] bytes) { + this.bytes = bytes; + } + } + + + public static class PrettyPrintBean { + + private String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + } + + + private interface MyJacksonView1 {} + + private interface MyJacksonView2 {} + + private interface MyJacksonView3 {} + + + @SuppressWarnings("unused") + @JsonView(MyJacksonView3.class) + private static class JacksonViewBean { + + @JsonView(MyJacksonView1.class) + private String withView1; + + @JsonView(MyJacksonView2.class) + private String withView2; + + private String withoutView; + + public String getWithView1() { + return withView1; + } + + public void setWithView1(String withView1) { + this.withView1 = withView1; + } + + public String getWithView2() { + return withView2; + } + + public void setWithView2(String withView2) { + this.withView2 = withView2; + } + + public String getWithoutView() { + return withoutView; + } + + public void setWithoutView(String withoutView) { + this.withoutView = withoutView; + } + } + + + @JsonFilter("myJacksonFilter") + @SuppressWarnings("unused") + private static class JacksonFilteredBean { + + private String property1; + + private String property2; + + public String getProperty1() { + return property1; + } + + public void setProperty1(String property1) { + this.property1 = property1; + } + + public String getProperty2() { + return property2; + } + + public void setProperty2(String property2) { + this.property2 = property2; + } + } + + + private static class BeanWithNoDefaultConstructor { + + private final String property1; + + private final String property2; + + public BeanWithNoDefaultConstructor(String property1, String property2) { + this.property1 = property1; + this.property2 = property2; + } + + public String getProperty1() { + return property1; + } + + public String getProperty2() { + return property2; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/json/SpringHandlerInstantiatorTests.java b/spring-web/src/test/java/org/springframework/http/converter/json/SpringHandlerInstantiatorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8cadcdd09fb03b688be815a1cbc4cd24b770926d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/json/SpringHandlerInstantiatorTests.java @@ -0,0 +1,281 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.json; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.ObjectCodec; +import com.fasterxml.jackson.databind.DatabindContext; +import com.fasterxml.jackson.databind.DeserializationConfig; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.KeyDeserializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationConfig; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver; +import com.fasterxml.jackson.databind.annotation.JsonTypeResolver; +import com.fasterxml.jackson.databind.jsontype.NamedType; +import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; +import com.fasterxml.jackson.databind.jsontype.TypeIdResolver; +import com.fasterxml.jackson.databind.jsontype.TypeSerializer; +import com.fasterxml.jackson.databind.jsontype.impl.StdTypeResolverBuilder; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; + +import static org.junit.Assert.*; + +/** + * Test class for {@link SpringHandlerInstantiatorTests}. + * + * @author Sebastien Deleuze + */ +public class SpringHandlerInstantiatorTests { + + private SpringHandlerInstantiator instantiator; + + private ObjectMapper objectMapper; + + + @Before + public void setup() { + DefaultListableBeanFactory bf = new DefaultListableBeanFactory(); + AutowiredAnnotationBeanPostProcessor bpp = new AutowiredAnnotationBeanPostProcessor(); + bpp.setBeanFactory(bf); + bf.addBeanPostProcessor(bpp); + bf.registerBeanDefinition("capitalizer", new RootBeanDefinition(Capitalizer.class)); + instantiator = new SpringHandlerInstantiator(bf); + objectMapper = Jackson2ObjectMapperBuilder.json().handlerInstantiator(instantiator).build(); + } + + + @Test + public void autowiredSerializer() throws JsonProcessingException { + User user = new User("bob"); + String json = this.objectMapper.writeValueAsString(user); + assertEquals("{\"username\":\"BOB\"}", json); + } + + @Test + public void autowiredDeserializer() throws IOException { + String json = "{\"username\":\"bob\"}"; + User user = this.objectMapper.readValue(json, User.class); + assertEquals("BOB", user.getUsername()); + } + + @Test + public void autowiredKeyDeserializer() throws IOException { + String json = "{\"credentials\":{\"bob\":\"admin\"}}"; + SecurityRegistry registry = this.objectMapper.readValue(json, SecurityRegistry.class); + assertTrue(registry.getCredentials().keySet().contains("BOB")); + assertFalse(registry.getCredentials().keySet().contains("bob")); + } + + @Test + public void applicationContextAwaretypeResolverBuilder() throws JsonProcessingException { + this.objectMapper.writeValueAsString(new Group()); + assertTrue(CustomTypeResolverBuilder.isAutowiredFiledInitialized); + } + + @Test + public void applicationContextAwareTypeIdResolver() throws JsonProcessingException { + this.objectMapper.writeValueAsString(new Group()); + assertTrue(CustomTypeIdResolver.isAutowiredFiledInitialized); + } + + + public static class UserDeserializer extends JsonDeserializer { + + @Autowired + private Capitalizer capitalizer; + + @Override + public User deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { + ObjectCodec oc = jsonParser.getCodec(); + JsonNode node = oc.readTree(jsonParser); + return new User(this.capitalizer.capitalize(node.get("username").asText())); + } + } + + + public static class UserSerializer extends JsonSerializer { + + @Autowired + private Capitalizer capitalizer; + + @Override + public void serialize(User user, JsonGenerator jsonGenerator, + SerializerProvider serializerProvider) throws IOException { + + jsonGenerator.writeStartObject(); + jsonGenerator.writeStringField("username", this.capitalizer.capitalize(user.getUsername())); + jsonGenerator.writeEndObject(); + } + } + + + public static class UpperCaseKeyDeserializer extends KeyDeserializer { + + @Autowired + private Capitalizer capitalizer; + + @Override + public Object deserializeKey(String key, DeserializationContext context) throws IOException { + return this.capitalizer.capitalize(key); + } + } + + + public static class CustomTypeResolverBuilder extends StdTypeResolverBuilder { + + @Autowired + private Capitalizer capitalizer; + + public static boolean isAutowiredFiledInitialized = false; + + @Override + public TypeSerializer buildTypeSerializer(SerializationConfig config, JavaType baseType, + Collection subtypes) { + + isAutowiredFiledInitialized = (this.capitalizer != null); + return super.buildTypeSerializer(config, baseType, subtypes); + } + + @Override + public TypeDeserializer buildTypeDeserializer(DeserializationConfig config, + JavaType baseType, Collection subtypes) { + + return super.buildTypeDeserializer(config, baseType, subtypes); + } + } + + + public static class CustomTypeIdResolver implements TypeIdResolver { + + @Autowired + private Capitalizer capitalizer; + + public static boolean isAutowiredFiledInitialized = false; + + public CustomTypeIdResolver() { + } + + @Override + public String idFromValueAndType(Object o, Class type) { + return type.getClass().getName(); + } + + @Override + public JsonTypeInfo.Id getMechanism() { + return JsonTypeInfo.Id.CUSTOM; + } + + @Override + public String idFromValue(Object value) { + isAutowiredFiledInitialized = (this.capitalizer != null); + return value.getClass().getName(); + } + + @Override + public void init(JavaType type) { + } + + @Override + public String idFromBaseType() { + return null; + } + + @Override + public JavaType typeFromId(DatabindContext context, String id) { + return null; + } + + @Override + public String getDescForKnownTypeIds() { + return null; + } + } + + + @JsonDeserialize(using = UserDeserializer.class) + @JsonSerialize(using = UserSerializer.class) + public static class User { + + private String username; + + public User() { + } + + public User(String username) { + this.username = username; + } + + public String getUsername() { return this.username; } + } + + + public static class SecurityRegistry { + + @JsonDeserialize(keyUsing = UpperCaseKeyDeserializer.class) + private Map credentials = new HashMap<>(); + + public void addCredential(String username, String credential) { + this.credentials.put(username, credential); + } + + public Map getCredentials() { + return credentials; + } + } + + + @JsonTypeInfo(use = JsonTypeInfo.Id.CUSTOM, property = "type") + @JsonTypeResolver(CustomTypeResolverBuilder.class) + @JsonTypeIdResolver(CustomTypeIdResolver.class) + public static class Group { + + public String getType() { + return Group.class.getName(); + } + } + + + public static class Capitalizer { + + public String capitalize(String text) { + return text.toUpperCase(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e9671a06cfd664765c9c9054d946bd4cd3d968e2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/protobuf/ProtobufHttpMessageConverterTests.java @@ -0,0 +1,201 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.protobuf; + +import java.io.IOException; +import java.nio.charset.Charset; + +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.Message; +import com.google.protobuf.util.JsonFormat; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.protobuf.Msg; +import org.springframework.protobuf.SecondMsg; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Test suite for {@link ProtobufHttpMessageConverter}. + * + * @author Alex Antonov + * @author Juergen Hoeller + * @author Andreas Ahlenstorf + * @author Sebastien Deleuze + */ +@SuppressWarnings("deprecation") +public class ProtobufHttpMessageConverterTests { + + private ProtobufHttpMessageConverter converter; + + private ExtensionRegistry extensionRegistry; + + private ExtensionRegistryInitializer registryInitializer; + + private Msg testMsg; + + + @Before + public void setup() { + this.registryInitializer = mock(ExtensionRegistryInitializer.class); + this.extensionRegistry = mock(ExtensionRegistry.class); + this.converter = new ProtobufHttpMessageConverter(this.registryInitializer); + this.testMsg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build(); + } + + + @Test + public void extensionRegistryInitialized() { + verify(this.registryInitializer, times(1)).initializeExtensionRegistry(any()); + } + + @Test + public void extensionRegistryInitializerNull() { + ProtobufHttpMessageConverter converter = new ProtobufHttpMessageConverter((ExtensionRegistryInitializer)null); + assertNotNull(converter.extensionRegistry); + } + + @Test + public void extensionRegistryNull() { + ProtobufHttpMessageConverter converter = new ProtobufHttpMessageConverter((ExtensionRegistry)null); + assertNotNull(converter.extensionRegistry); + } + + @Test + public void canRead() { + assertTrue(this.converter.canRead(Msg.class, null)); + assertTrue(this.converter.canRead(Msg.class, ProtobufHttpMessageConverter.PROTOBUF)); + assertTrue(this.converter.canRead(Msg.class, MediaType.APPLICATION_JSON)); + assertTrue(this.converter.canRead(Msg.class, MediaType.APPLICATION_XML)); + assertTrue(this.converter.canRead(Msg.class, MediaType.TEXT_PLAIN)); + + // only supported as an output format + assertFalse(this.converter.canRead(Msg.class, MediaType.TEXT_HTML)); + } + + @Test + public void canWrite() { + assertTrue(this.converter.canWrite(Msg.class, null)); + assertTrue(this.converter.canWrite(Msg.class, ProtobufHttpMessageConverter.PROTOBUF)); + assertTrue(this.converter.canWrite(Msg.class, MediaType.APPLICATION_JSON)); + assertTrue(this.converter.canWrite(Msg.class, MediaType.APPLICATION_XML)); + assertTrue(this.converter.canWrite(Msg.class, MediaType.TEXT_PLAIN)); + assertTrue(this.converter.canWrite(Msg.class, MediaType.TEXT_HTML)); + } + + @Test + public void read() throws IOException { + byte[] body = this.testMsg.toByteArray(); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + inputMessage.getHeaders().setContentType(ProtobufHttpMessageConverter.PROTOBUF); + Message result = this.converter.read(Msg.class, inputMessage); + assertEquals(this.testMsg, result); + } + + @Test + public void readNoContentType() throws IOException { + byte[] body = this.testMsg.toByteArray(); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + Message result = this.converter.read(Msg.class, inputMessage); + assertEquals(this.testMsg, result); + } + + @Test + public void writeProtobuf() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MediaType contentType = ProtobufHttpMessageConverter.PROTOBUF; + this.converter.write(this.testMsg, contentType, outputMessage); + assertEquals(contentType, outputMessage.getHeaders().getContentType()); + assertTrue(outputMessage.getBodyAsBytes().length > 0); + Message result = Msg.parseFrom(outputMessage.getBodyAsBytes()); + assertEquals(this.testMsg, result); + + String messageHeader = + outputMessage.getHeaders().getFirst(ProtobufHttpMessageConverter.X_PROTOBUF_MESSAGE_HEADER); + assertEquals("Msg", messageHeader); + String schemaHeader = + outputMessage.getHeaders().getFirst(ProtobufHttpMessageConverter.X_PROTOBUF_SCHEMA_HEADER); + assertEquals("sample.proto", schemaHeader); + } + + @Test + public void writeJsonWithGoogleProtobuf() throws IOException { + this.converter = new ProtobufHttpMessageConverter( + new ProtobufHttpMessageConverter.ProtobufJavaUtilSupport(null, null), + this.extensionRegistry); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MediaType contentType = MediaType.APPLICATION_JSON_UTF8; + this.converter.write(this.testMsg, contentType, outputMessage); + + assertEquals(contentType, outputMessage.getHeaders().getContentType()); + + final String body = outputMessage.getBodyAsString(Charset.forName("UTF-8")); + assertFalse("body is empty", body.isEmpty()); + + Msg.Builder builder = Msg.newBuilder(); + JsonFormat.parser().merge(body, builder); + assertEquals(this.testMsg, builder.build()); + + assertNull(outputMessage.getHeaders().getFirst( + ProtobufHttpMessageConverter.X_PROTOBUF_MESSAGE_HEADER)); + assertNull(outputMessage.getHeaders().getFirst( + ProtobufHttpMessageConverter.X_PROTOBUF_SCHEMA_HEADER)); + } + + @Test + public void writeJsonWithJavaFormat() throws IOException { + this.converter = new ProtobufHttpMessageConverter( + new ProtobufHttpMessageConverter.ProtobufJavaFormatSupport(), + this.extensionRegistry); + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MediaType contentType = MediaType.APPLICATION_JSON_UTF8; + this.converter.write(this.testMsg, contentType, outputMessage); + + assertEquals(contentType, outputMessage.getHeaders().getContentType()); + + final String body = outputMessage.getBodyAsString(Charset.forName("UTF-8")); + assertFalse("body is empty", body.isEmpty()); + + Msg.Builder builder = Msg.newBuilder(); + JsonFormat.parser().merge(body, builder); + assertEquals(this.testMsg, builder.build()); + + assertNull(outputMessage.getHeaders().getFirst( + ProtobufHttpMessageConverter.X_PROTOBUF_MESSAGE_HEADER)); + assertNull(outputMessage.getHeaders().getFirst( + ProtobufHttpMessageConverter.X_PROTOBUF_SCHEMA_HEADER)); + } + + @Test + public void defaultContentType() throws Exception { + assertEquals(ProtobufHttpMessageConverter.PROTOBUF, this.converter.getDefaultContentType(this.testMsg)); + } + + @Test + public void getContentLength() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MediaType contentType = ProtobufHttpMessageConverter.PROTOBUF; + this.converter.write(this.testMsg, contentType, outputMessage); + assertEquals(-1, outputMessage.getHeaders().getContentLength()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/protobuf/ProtobufJsonFormatHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/protobuf/ProtobufJsonFormatHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9272b97fa5453bdb2a1e1de4166e3888654adf54 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/protobuf/ProtobufJsonFormatHttpMessageConverterTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.protobuf; + +import java.io.IOException; + +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.Message; +import com.google.protobuf.util.JsonFormat; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.protobuf.Msg; +import org.springframework.protobuf.SecondMsg; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Test suite for {@link ProtobufJsonFormatHttpMessageConverter}. + * + * @author Juergen Hoeller + * @author Sebastien Deleuze + */ +@SuppressWarnings("deprecation") +public class ProtobufJsonFormatHttpMessageConverterTests { + + private ProtobufHttpMessageConverter converter; + + private ExtensionRegistry extensionRegistry; + + private ExtensionRegistryInitializer registryInitializer; + + private Msg testMsg; + + + @Before + public void setup() { + this.registryInitializer = mock(ExtensionRegistryInitializer.class); + this.extensionRegistry = mock(ExtensionRegistry.class); + this.converter = new ProtobufJsonFormatHttpMessageConverter( + JsonFormat.parser(), JsonFormat.printer(), this.registryInitializer); + this.testMsg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build(); + } + + + @Test + public void extensionRegistryInitialized() { + verify(this.registryInitializer, times(1)).initializeExtensionRegistry(any()); + } + + @Test + public void extensionRegistryInitializerNull() { + ProtobufHttpMessageConverter converter = new ProtobufHttpMessageConverter((ExtensionRegistryInitializer)null); + assertNotNull(converter); + } + + @Test + public void extensionRegistryInitializer() { + ProtobufHttpMessageConverter converter = new ProtobufHttpMessageConverter((ExtensionRegistry)null); + assertNotNull(converter); + } + + @Test + public void canRead() { + assertTrue(this.converter.canRead(Msg.class, null)); + assertTrue(this.converter.canRead(Msg.class, ProtobufHttpMessageConverter.PROTOBUF)); + assertTrue(this.converter.canRead(Msg.class, MediaType.APPLICATION_JSON)); + assertTrue(this.converter.canRead(Msg.class, MediaType.TEXT_PLAIN)); + } + + @Test + public void canWrite() { + assertTrue(this.converter.canWrite(Msg.class, null)); + assertTrue(this.converter.canWrite(Msg.class, ProtobufHttpMessageConverter.PROTOBUF)); + assertTrue(this.converter.canWrite(Msg.class, MediaType.APPLICATION_JSON)); + assertTrue(this.converter.canWrite(Msg.class, MediaType.TEXT_PLAIN)); + } + + @Test + public void read() throws IOException { + byte[] body = this.testMsg.toByteArray(); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + inputMessage.getHeaders().setContentType(ProtobufHttpMessageConverter.PROTOBUF); + Message result = this.converter.read(Msg.class, inputMessage); + assertEquals(this.testMsg, result); + } + + @Test + public void readNoContentType() throws IOException { + byte[] body = this.testMsg.toByteArray(); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + Message result = this.converter.read(Msg.class, inputMessage); + assertEquals(this.testMsg, result); + } + + @Test + public void write() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MediaType contentType = ProtobufHttpMessageConverter.PROTOBUF; + this.converter.write(this.testMsg, contentType, outputMessage); + assertEquals(contentType, outputMessage.getHeaders().getContentType()); + assertTrue(outputMessage.getBodyAsBytes().length > 0); + Message result = Msg.parseFrom(outputMessage.getBodyAsBytes()); + assertEquals(this.testMsg, result); + + String messageHeader = + outputMessage.getHeaders().getFirst(ProtobufHttpMessageConverter.X_PROTOBUF_MESSAGE_HEADER); + assertEquals("Msg", messageHeader); + String schemaHeader = + outputMessage.getHeaders().getFirst(ProtobufHttpMessageConverter.X_PROTOBUF_SCHEMA_HEADER); + assertEquals("sample.proto", schemaHeader); + } + + @Test + public void defaultContentType() throws Exception { + assertEquals(ProtobufHttpMessageConverter.PROTOBUF, this.converter.getDefaultContentType(this.testMsg)); + } + + @Test + public void getContentLength() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MediaType contentType = ProtobufHttpMessageConverter.PROTOBUF; + this.converter.write(this.testMsg, contentType, outputMessage); + assertEquals(-1, outputMessage.getHeaders().getContentLength()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/smile/MappingJackson2SmileHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/smile/MappingJackson2SmileHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..9860ae36ac9715ab6214b152d0cf1757bb8d4e77 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/smile/MappingJackson2SmileHttpMessageConverterTests.java @@ -0,0 +1,161 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.smile; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import static org.junit.Assert.*; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; + +/** + * Jackson 2.x Smile converter tests. + * + * @author Sebastien Deleuze + */ +public class MappingJackson2SmileHttpMessageConverterTests { + + private final MappingJackson2SmileHttpMessageConverter converter = new MappingJackson2SmileHttpMessageConverter(); + private final ObjectMapper mapper = new ObjectMapper(new SmileFactory()); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Test + public void canRead() { + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "x-jackson-smile"))); + assertFalse(converter.canRead(MyBean.class, new MediaType("application", "json"))); + assertFalse(converter.canRead(MyBean.class, new MediaType("application", "xml"))); + } + + @Test + public void canWrite() { + assertTrue(converter.canWrite(MyBean.class, new MediaType("application", "x-jackson-smile"))); + assertFalse(converter.canWrite(MyBean.class, new MediaType("application", "json"))); + assertFalse(converter.canWrite(MyBean.class, new MediaType("application", "xml"))); + } + + @Test + public void read() throws IOException { + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[]{"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[]{0x1, 0x2}); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(mapper.writeValueAsBytes(body)); + inputMessage.getHeaders().setContentType(new MediaType("application", "x-jackson-smile")); + MyBean result = (MyBean) converter.read(MyBean.class, inputMessage); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[]{"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[]{0x1, 0x2}, result.getBytes()); + } + + @Test + public void write() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[]{"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[]{0x1, 0x2}); + converter.write(body, null, outputMessage); + assertArrayEquals(mapper.writeValueAsBytes(body), outputMessage.getBodyAsBytes()); + assertEquals("Invalid content-type", new MediaType("application", "x-jackson-smile", StandardCharsets.UTF_8), + outputMessage.getHeaders().getContentType()); + } + + + public static class MyBean { + + private String string; + + private int number; + + private float fraction; + + private String[] array; + + private boolean bool; + + private byte[] bytes; + + public byte[] getBytes() { + return bytes; + } + + public void setBytes(byte[] bytes) { + this.bytes = bytes; + } + + public boolean isBool() { + return bool; + } + + public void setBool(boolean bool) { + this.bool = bool; + } + + public String getString() { + return string; + } + + public void setString(String string) { + this.string = string; + } + + public int getNumber() { + return number; + } + + public void setNumber(int number) { + this.number = number; + } + + public float getFraction() { + return fraction; + } + + public void setFraction(float fraction) { + this.fraction = fraction; + } + + public String[] getArray() { + return array; + } + + public void setArray(String[] array) { + this.array = array; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2CollectionHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2CollectionHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6646ad85dc39b12292658b4c9efa9f30355bb19f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2CollectionHttpMessageConverterTests.java @@ -0,0 +1,275 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.lang.reflect.Type; +import java.util.Collection; +import java.util.List; +import java.util.Set; + +import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlType; +import javax.xml.stream.XMLInputFactory; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; + +import static org.junit.Assert.*; + +/** + * Test fixture for {@link Jaxb2CollectionHttpMessageConverter}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class Jaxb2CollectionHttpMessageConverterTests { + + private Jaxb2CollectionHttpMessageConverter converter; + + private Type rootElementListType; + + private Type rootElementSetType; + + private Type typeListType; + + private Type typeSetType; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setup() { + converter = new Jaxb2CollectionHttpMessageConverter>(); + rootElementListType = new ParameterizedTypeReference>() {}.getType(); + rootElementSetType = new ParameterizedTypeReference>() {}.getType(); + typeListType = new ParameterizedTypeReference>() {}.getType(); + typeSetType = new ParameterizedTypeReference>() {}.getType(); + } + + + @Test + public void canRead() { + assertTrue(converter.canRead(rootElementListType, null, null)); + assertTrue(converter.canRead(rootElementSetType, null, null)); + assertTrue(converter.canRead(typeSetType, null, null)); + } + + @Test + @SuppressWarnings("unchecked") + public void readXmlRootElementList() throws Exception { + String content = ""; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + List result = (List) converter.read(rootElementListType, null, inputMessage); + + assertEquals("Invalid result", 2, result.size()); + assertEquals("Invalid result", "1", result.get(0).type.s); + assertEquals("Invalid result", "2", result.get(1).type.s); + } + + @Test + @SuppressWarnings("unchecked") + public void readXmlRootElementSet() throws Exception { + String content = ""; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + Set result = (Set) converter.read(rootElementSetType, null, inputMessage); + + assertEquals("Invalid result", 2, result.size()); + assertTrue("Invalid result", result.contains(new RootElement("1"))); + assertTrue("Invalid result", result.contains(new RootElement("2"))); + } + + @Test + @SuppressWarnings("unchecked") + public void readXmlTypeList() throws Exception { + String content = ""; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + List result = (List) converter.read(typeListType, null, inputMessage); + + assertEquals("Invalid result", 2, result.size()); + assertEquals("Invalid result", "1", result.get(0).s); + assertEquals("Invalid result", "2", result.get(1).s); + } + + @Test + @SuppressWarnings("unchecked") + public void readXmlTypeSet() throws Exception { + String content = ""; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + Set result = (Set) converter.read(typeSetType, null, inputMessage); + + assertEquals("Invalid result", 2, result.size()); + assertTrue("Invalid result", result.contains(new TestType("1"))); + assertTrue("Invalid result", result.contains(new TestType("2"))); + } + + @Test + @SuppressWarnings("unchecked") + public void readXmlRootElementExternalEntityDisabled() throws Exception { + Resource external = new ClassPathResource("external.txt", getClass()); + String content = "\n" + + " ]>" + + " &ext;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + + converter = new Jaxb2CollectionHttpMessageConverter>() { + @Override + protected XMLInputFactory createXmlInputFactory() { + XMLInputFactory inputFactory = super.createXmlInputFactory(); + inputFactory.setProperty(XMLInputFactory.SUPPORT_DTD, true); + return inputFactory; + } + }; + + try { + Collection result = converter.read(rootElementListType, null, inputMessage); + assertEquals(1, result.size()); + assertEquals("", result.iterator().next().external); + } + catch (HttpMessageNotReadableException ex) { + // Some parsers raise an exception + } + } + + @Test + @SuppressWarnings("unchecked") + public void readXmlRootElementExternalEntityEnabled() throws Exception { + Resource external = new ClassPathResource("external.txt", getClass()); + String content = "\n" + + " ]>" + + " &ext;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + + Jaxb2CollectionHttpMessageConverter c = new Jaxb2CollectionHttpMessageConverter>() { + @Override + protected XMLInputFactory createXmlInputFactory() { + XMLInputFactory inputFactory = XMLInputFactory.newInstance(); + inputFactory.setProperty(XMLInputFactory.IS_REPLACING_ENTITY_REFERENCES, true); + return inputFactory; + } + }; + + Collection result = c.read(rootElementListType, null, inputMessage); + assertEquals(1, result.size()); + assertEquals("Foo Bar", result.iterator().next().external); + } + + @Test + public void testXmlBomb() throws Exception { + // https://en.wikipedia.org/wiki/Billion_laughs + // https://msdn.microsoft.com/en-us/magazine/ee335713.aspx + String content = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + "]>\n" + + "&lol9;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + this.thrown.expect(HttpMessageNotReadableException.class); + this.thrown.expectMessage("\"lol9\""); + this.converter.read(this.rootElementListType, null, inputMessage); + } + + + @XmlRootElement + public static class RootElement { + + public RootElement() { + } + + public RootElement(String s) { + this.type = new TestType(s); + } + + @XmlElement + public TestType type = new TestType(); + + @XmlElement(required=false) + public String external; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof RootElement) { + RootElement other = (RootElement) o; + return this.type.equals(other.type); + } + return false; + } + + @Override + public int hashCode() { + return type.hashCode(); + } + } + + + @XmlType + public static class TestType { + + public TestType() { + } + + public TestType(String s) { + this.s = s; + } + + @XmlAttribute + public String s = "Hello World"; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof TestType) { + TestType other = (TestType) o; + return this.s.equals(other.s); + } + return false; + } + + @Override + public int hashCode() { + return s.hashCode(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..afb3abdcc76015c11f2d6be0bb3cf8cd45fc7e5e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java @@ -0,0 +1,341 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.nio.charset.StandardCharsets; + +import javax.xml.bind.Marshaller; +import javax.xml.bind.Unmarshaller; +import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlType; +import javax.xml.bind.annotation.adapters.XmlAdapter; +import javax.xml.bind.annotation.adapters.XmlJavaTypeAdapter; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.xmlunit.diff.DifferenceEvaluator; + +import org.springframework.aop.framework.AdvisedSupport; +import org.springframework.aop.framework.AopProxy; +import org.springframework.aop.framework.DefaultAopProxyFactory; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; + +import static org.junit.Assert.*; +import static org.xmlunit.diff.ComparisonType.*; +import static org.xmlunit.diff.DifferenceEvaluators.*; +import static org.xmlunit.matchers.CompareMatcher.*; + +/** + * Tests for {@link Jaxb2RootElementHttpMessageConverter}. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class Jaxb2RootElementHttpMessageConverterTests { + + private Jaxb2RootElementHttpMessageConverter converter; + + private RootElement rootElement; + + private RootElement rootElementCglib; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setup() { + converter = new Jaxb2RootElementHttpMessageConverter(); + rootElement = new RootElement(); + DefaultAopProxyFactory proxyFactory = new DefaultAopProxyFactory(); + AdvisedSupport advisedSupport = new AdvisedSupport(); + advisedSupport.setTarget(rootElement); + advisedSupport.setProxyTargetClass(true); + AopProxy proxy = proxyFactory.createAopProxy(advisedSupport); + rootElementCglib = (RootElement) proxy.getProxy(); + } + + + @Test + public void canRead() { + assertTrue("Converter does not support reading @XmlRootElement", + converter.canRead(RootElement.class, null)); + assertTrue("Converter does not support reading @XmlType", + converter.canRead(Type.class, null)); + } + + @Test + public void canWrite() { + assertTrue("Converter does not support writing @XmlRootElement", + converter.canWrite(RootElement.class, null)); + assertTrue("Converter does not support writing @XmlRootElement subclass", + converter.canWrite(RootElementSubclass.class, null)); + assertTrue("Converter does not support writing @XmlRootElement subclass", + converter.canWrite(rootElementCglib.getClass(), null)); + assertFalse("Converter supports writing @XmlType", converter.canWrite(Type.class, null)); + } + + @Test + public void readXmlRootElement() throws Exception { + byte[] body = "".getBytes("UTF-8"); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + RootElement result = (RootElement) converter.read(RootElement.class, inputMessage); + assertEquals("Invalid result", "Hello World", result.type.s); + } + + @Test + public void readXmlRootElementSubclass() throws Exception { + byte[] body = "".getBytes("UTF-8"); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + RootElementSubclass result = (RootElementSubclass) converter.read(RootElementSubclass.class, inputMessage); + assertEquals("Invalid result", "Hello World", result.getType().s); + } + + @Test + public void readXmlType() throws Exception { + byte[] body = "".getBytes("UTF-8"); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + Type result = (Type) converter.read(Type.class, inputMessage); + assertEquals("Invalid result", "Hello World", result.s); + } + + @Test + public void readXmlRootElementExternalEntityDisabled() throws Exception { + Resource external = new ClassPathResource("external.txt", getClass()); + String content = "\n" + + " ]>" + + " &ext;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + converter.setSupportDtd(true); + RootElement rootElement = (RootElement) converter.read(RootElement.class, inputMessage); + + assertEquals("", rootElement.external); + } + + @Test + public void readXmlRootElementExternalEntityEnabled() throws Exception { + Resource external = new ClassPathResource("external.txt", getClass()); + String content = "\n" + + " ]>" + + " &ext;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + this.converter.setProcessExternalEntities(true); + RootElement rootElement = (RootElement) converter.read(RootElement.class, inputMessage); + + assertEquals("Foo Bar", rootElement.external); + } + + @Test + public void testXmlBomb() throws Exception { + // https://en.wikipedia.org/wiki/Billion_laughs + // https://msdn.microsoft.com/en-us/magazine/ee335713.aspx + String content = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + "]>\n" + + "&lol9;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + this.thrown.expect(HttpMessageNotReadableException.class); + this.thrown.expectMessage("DOCTYPE"); + this.converter.read(RootElement.class, inputMessage); + } + + @Test + public void writeXmlRootElement() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(rootElement, null, outputMessage); + assertEquals("Invalid content-type", new MediaType("application", "xml"), + outputMessage.getHeaders().getContentType()); + DifferenceEvaluator ev = chain(Default, downgradeDifferencesToEqual(XML_STANDALONE)); + assertThat("Invalid result", outputMessage.getBodyAsString(StandardCharsets.UTF_8), + isSimilarTo("").withDifferenceEvaluator(ev)); + } + + @Test + public void writeXmlRootElementSubclass() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(rootElementCglib, null, outputMessage); + assertEquals("Invalid content-type", new MediaType("application", "xml"), + outputMessage.getHeaders().getContentType()); + DifferenceEvaluator ev = chain(Default, downgradeDifferencesToEqual(XML_STANDALONE)); + assertThat("Invalid result", outputMessage.getBodyAsString(StandardCharsets.UTF_8), + isSimilarTo("").withDifferenceEvaluator(ev)); + } + + // SPR-11488 + + @Test + public void customizeMarshaller() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyJaxb2RootElementHttpMessageConverter myConverter = new MyJaxb2RootElementHttpMessageConverter(); + myConverter.write(new MyRootElement(new MyCustomElement("a", "b")), null, outputMessage); + DifferenceEvaluator ev = chain(Default, downgradeDifferencesToEqual(XML_STANDALONE)); + assertThat("Invalid result", outputMessage.getBodyAsString(StandardCharsets.UTF_8), + isSimilarTo("a|||b").withDifferenceEvaluator(ev)); + } + + @Test + public void customizeUnmarshaller() throws Exception { + byte[] body = "a|||b".getBytes("UTF-8"); + MyJaxb2RootElementHttpMessageConverter myConverter = new MyJaxb2RootElementHttpMessageConverter(); + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body); + MyRootElement result = (MyRootElement) myConverter.read(MyRootElement.class, inputMessage); + assertEquals("a", result.getElement().getField1()); + assertEquals("b", result.getElement().getField2()); + } + + + @XmlRootElement + public static class RootElement { + + private Type type = new Type(); + + @XmlElement(required=false) + public String external; + + public Type getType() { + return this.type; + } + + @XmlElement + public void setType(Type type) { + this.type = type; + } + } + + + @XmlType + public static class Type { + + @XmlAttribute + public String s = "Hello World"; + + } + + + public static class RootElementSubclass extends RootElement { + } + + + public static class MyJaxb2RootElementHttpMessageConverter extends Jaxb2RootElementHttpMessageConverter { + + @Override + protected void customizeMarshaller(Marshaller marshaller) { + marshaller.setAdapter(new MyCustomElementAdapter()); + } + + @Override + protected void customizeUnmarshaller(Unmarshaller unmarshaller) { + unmarshaller.setAdapter(new MyCustomElementAdapter()); + } + } + + + public static class MyCustomElement { + + private String field1; + + private String field2; + + public MyCustomElement() { + } + + public MyCustomElement(String field1, String field2) { + this.field1 = field1; + this.field2 = field2; + } + + public String getField1() { + return field1; + } + + public void setField1(String field1) { + this.field1 = field1; + } + + public String getField2() { + return field2; + } + + public void setField2(String field2) { + this.field2 = field2; + } + } + + + @XmlRootElement + public static class MyRootElement { + + private MyCustomElement element; + + public MyRootElement() { + + } + + public MyRootElement(MyCustomElement element) { + this.element = element; + } + + @XmlJavaTypeAdapter(MyCustomElementAdapter.class) + public MyCustomElement getElement() { + return element; + } + + public void setElement(MyCustomElement element) { + this.element = element; + } + } + + + public static class MyCustomElementAdapter extends XmlAdapter { + + @Override + public String marshal(MyCustomElement c) throws Exception { + return c.getField1() + "|||" + c.getField2(); + } + + @Override + public MyCustomElement unmarshal(String c) throws Exception { + String[] t = c.split("\\|\\|\\|"); + return new MyCustomElement(t[0], t[1]); + } + + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/MappingJackson2XmlHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/MappingJackson2XmlHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e374a0555db53b7501d5e9a13265d7fe13fa1585 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/xml/MappingJackson2XmlHttpMessageConverterTests.java @@ -0,0 +1,324 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.json.MappingJacksonValue; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Jackson 2.x XML converter tests. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class MappingJackson2XmlHttpMessageConverterTests { + + private final MappingJackson2XmlHttpMessageConverter converter = new MappingJackson2XmlHttpMessageConverter(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Test + public void canRead() { + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "xml"))); + assertTrue(converter.canRead(MyBean.class, new MediaType("text", "xml"))); + assertTrue(converter.canRead(MyBean.class, new MediaType("application", "soap+xml"))); + assertTrue(converter.canRead(MyBean.class, new MediaType("text", "xml", StandardCharsets.UTF_8))); + assertTrue(converter.canRead(MyBean.class, new MediaType("text", "xml", StandardCharsets.ISO_8859_1))); + } + + @Test + public void canWrite() { + assertTrue(converter.canWrite(MyBean.class, new MediaType("application", "xml"))); + assertTrue(converter.canWrite(MyBean.class, new MediaType("text", "xml"))); + assertTrue(converter.canWrite(MyBean.class, new MediaType("application", "soap+xml"))); + assertTrue(converter.canWrite(MyBean.class, new MediaType("text", "xml", StandardCharsets.UTF_8))); + assertFalse(converter.canWrite(MyBean.class, new MediaType("text", "xml", StandardCharsets.ISO_8859_1))); + } + + @Test + public void read() throws IOException { + String body = "" + + "Foo" + + "42" + + "42.0" + + "Foo" + + "Bar" + + "true" + + "AQI="; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + MyBean result = (MyBean) converter.read(MyBean.class, inputMessage); + assertEquals("Foo", result.getString()); + assertEquals(42, result.getNumber()); + assertEquals(42F, result.getFraction(), 0F); + assertArrayEquals(new String[]{"Foo", "Bar"}, result.getArray()); + assertTrue(result.isBool()); + assertArrayEquals(new byte[]{0x1, 0x2}, result.getBytes()); + } + + @Test + public void write() throws IOException { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MyBean body = new MyBean(); + body.setString("Foo"); + body.setNumber(42); + body.setFraction(42F); + body.setArray(new String[]{"Foo", "Bar"}); + body.setBool(true); + body.setBytes(new byte[]{0x1, 0x2}); + converter.write(body, null, outputMessage); + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertTrue(result.contains("Foo")); + assertTrue(result.contains("42")); + assertTrue(result.contains("42.0")); + assertTrue(result.contains("FooBar")); + assertTrue(result.contains("true")); + assertTrue(result.contains("AQI=")); + assertEquals("Invalid content-type", new MediaType("application", "xml", StandardCharsets.UTF_8), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void readInvalidXml() throws IOException { + String body = "FooBar"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + this.thrown.expect(HttpMessageNotReadableException.class); + converter.read(MyBean.class, inputMessage); + } + + @Test + public void readValidXmlWithUnknownProperty() throws IOException { + String body = "stringvalue"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + converter.read(MyBean.class, inputMessage); + // Assert no HttpMessageNotReadableException is thrown + } + + @Test + public void jsonView() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + JacksonViewBean bean = new JacksonViewBean(); + bean.setWithView1("with"); + bean.setWithView2("with"); + bean.setWithoutView("without"); + + MappingJacksonValue jacksonValue = new MappingJacksonValue(bean); + jacksonValue.setSerializationView(MyJacksonView1.class); + this.converter.write(jacksonValue, null, outputMessage); + + String result = outputMessage.getBodyAsString(StandardCharsets.UTF_8); + assertThat(result, containsString("with")); + assertThat(result, not(containsString("with"))); + assertThat(result, not(containsString("without"))); + } + + @Test + public void customXmlMapper() { + new MappingJackson2XmlHttpMessageConverter(new MyXmlMapper()); + // Assert no exception is thrown + } + + @Test + public void readWithExternalReference() throws IOException { + String body = "\n" + + " ]>&ext;"; + + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + + this.thrown.expect(HttpMessageNotReadableException.class); + this.converter.read(MyBean.class, inputMessage); + } + + @Test + public void readWithXmlBomb() throws IOException { + // https://en.wikipedia.org/wiki/Billion_laughs + // https://msdn.microsoft.com/en-us/magazine/ee335713.aspx + String body = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + "]>\n" + + "&lol9;"; + + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + + this.thrown.expect(HttpMessageNotReadableException.class); + this.converter.read(MyBean.class, inputMessage); + } + + @Test + @SuppressWarnings("unchecked") + public void readNonUnicode() throws Exception { + String body = "" + + "føø bår" + + ""; + + Charset charset = StandardCharsets.ISO_8859_1; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes(charset)); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml", charset)); + MyBean result = (MyBean) converter.read(MyBean.class, inputMessage); + assertEquals("føø bår", result.getString()); + } + + + + public static class MyBean { + + private String string; + + private int number; + + private float fraction; + + private String[] array; + + private boolean bool; + + private byte[] bytes; + + public byte[] getBytes() { + return bytes; + } + + public void setBytes(byte[] bytes) { + this.bytes = bytes; + } + + public boolean isBool() { + return bool; + } + + public void setBool(boolean bool) { + this.bool = bool; + } + + public String getString() { + return string; + } + + public void setString(String string) { + this.string = string; + } + + public int getNumber() { + return number; + } + + public void setNumber(int number) { + this.number = number; + } + + public float getFraction() { + return fraction; + } + + public void setFraction(float fraction) { + this.fraction = fraction; + } + + public String[] getArray() { + return array; + } + + public void setArray(String[] array) { + this.array = array; + } + } + + + private interface MyJacksonView1 {}; + + private interface MyJacksonView2 {}; + + + @SuppressWarnings("unused") + private static class JacksonViewBean { + + @JsonView(MyJacksonView1.class) + private String withView1; + + @JsonView(MyJacksonView2.class) + private String withView2; + + private String withoutView; + + public String getWithView1() { + return withView1; + } + + public void setWithView1(String withView1) { + this.withView1 = withView1; + } + + public String getWithView2() { + return withView2; + } + + public void setWithView2(String withView2) { + this.withView2 = withView2; + } + + public String getWithoutView() { + return withoutView; + } + + public void setWithoutView(String withoutView) { + this.withoutView = withoutView; + } + } + + + @SuppressWarnings("serial") + private static class MyXmlMapper extends XmlMapper { + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e9b3e592555914d8db12665ef662a3125091a7a5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import javax.xml.transform.Result; +import javax.xml.transform.stream.StreamSource; + +import org.junit.Test; + +import org.springframework.beans.TypeMismatchException; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.oxm.Marshaller; +import org.springframework.oxm.MarshallingFailureException; +import org.springframework.oxm.Unmarshaller; +import org.springframework.oxm.UnmarshallingFailureException; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Tests for {@link MarshallingHttpMessageConverter}. + * + * @author Arjen Poutsma + */ +public class MarshallingHttpMessageConverterTests { + + @Test + public void canRead() { + Unmarshaller unmarshaller = mock(Unmarshaller.class); + + given(unmarshaller.supports(Integer.class)).willReturn(false); + given(unmarshaller.supports(String.class)).willReturn(true); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); + converter.setUnmarshaller(unmarshaller); + + assertFalse(converter.canRead(Boolean.class, MediaType.TEXT_PLAIN)); + assertFalse(converter.canRead(Integer.class, MediaType.TEXT_XML)); + assertTrue(converter.canRead(String.class, MediaType.TEXT_XML)); + } + + @Test + public void canWrite() { + Marshaller marshaller = mock(Marshaller.class); + + given(marshaller.supports(Integer.class)).willReturn(false); + given(marshaller.supports(String.class)).willReturn(true); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); + converter.setMarshaller(marshaller); + + assertFalse(converter.canWrite(Boolean.class, MediaType.TEXT_PLAIN)); + assertFalse(converter.canWrite(Integer.class, MediaType.TEXT_XML)); + assertTrue(converter.canWrite(String.class, MediaType.TEXT_XML)); + } + + @Test + public void read() throws Exception { + String body = "Hello World"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); + + Unmarshaller unmarshaller = mock(Unmarshaller.class); + given(unmarshaller.unmarshal(isA(StreamSource.class))).willReturn(body); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); + converter.setUnmarshaller(unmarshaller); + + String result = (String) converter.read(Object.class, inputMessage); + assertEquals("Invalid result", body, result); + } + + @Test + public void readWithTypeMismatchException() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(new byte[0]); + + Marshaller marshaller = mock(Marshaller.class); + Unmarshaller unmarshaller = mock(Unmarshaller.class); + given(unmarshaller.unmarshal(isA(StreamSource.class))).willReturn(Integer.valueOf(3)); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller, unmarshaller); + try { + converter.read(String.class, inputMessage); + fail("Should have thrown HttpMessageNotReadableException"); + } + catch (HttpMessageNotReadableException ex) { + assertTrue(ex.getCause() instanceof TypeMismatchException); + } + } + + @Test + public void readWithMarshallingFailureException() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(new byte[0]); + UnmarshallingFailureException ex = new UnmarshallingFailureException("forced"); + + Unmarshaller unmarshaller = mock(Unmarshaller.class); + given(unmarshaller.unmarshal(isA(StreamSource.class))).willThrow(ex); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); + converter.setUnmarshaller(unmarshaller); + + try { + converter.read(Object.class, inputMessage); + fail("HttpMessageNotReadableException should be thrown"); + } + catch (HttpMessageNotReadableException e) { + assertTrue("Invalid exception hierarchy", e.getCause() == ex); + } + } + + @Test + public void write() throws Exception { + String body = "Hello World"; + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + + Marshaller marshaller = mock(Marshaller.class); + willDoNothing().given(marshaller).marshal(eq(body), isA(Result.class)); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller); + converter.write(body, null, outputMessage); + + assertEquals("Invalid content-type", new MediaType("application", "xml"), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeWithMarshallingFailureException() throws Exception { + String body = "Hello World"; + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MarshallingFailureException ex = new MarshallingFailureException("forced"); + + Marshaller marshaller = mock(Marshaller.class); + willThrow(ex).given(marshaller).marshal(eq(body), isA(Result.class)); + + try { + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller); + converter.write(body, null, outputMessage); + fail("HttpMessageNotWritableException should be thrown"); + } + catch (HttpMessageNotWritableException e) { + assertTrue("Invalid exception hierarchy", e.getCause() == ex); + } + } + + @Test(expected = UnsupportedOperationException.class) + public void supports() throws Exception { + new MarshallingHttpMessageConverter().supports(Object.class); + } +} diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/SourceHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/SourceHttpMessageConverterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b9a991f2abb83dd540f81ef55e51cfe43463b701 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/converter/xml/SourceHttpMessageConverterTests.java @@ -0,0 +1,333 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter.xml; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.StringReader; +import java.nio.charset.StandardCharsets; + +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Source; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamSource; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.xml.sax.InputSource; +import org.xml.sax.SAXException; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.DefaultHandler; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.MockHttpInputMessage; +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.util.FileCopyUtils; + +import static org.junit.Assert.*; +import static org.xmlunit.matchers.CompareMatcher.*; + +/** + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class SourceHttpMessageConverterTests { + + private static final String BODY = "Hello World"; + + private SourceHttpMessageConverter converter; + + private String bodyExternal; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setup() throws IOException { + converter = new SourceHttpMessageConverter<>(); + Resource external = new ClassPathResource("external.txt", getClass()); + + bodyExternal = "\n" + + " ]>&ext;"; + } + + + @Test + public void canRead() { + assertTrue(converter.canRead(Source.class, new MediaType("application", "xml"))); + assertTrue(converter.canRead(Source.class, new MediaType("application", "soap+xml"))); + } + + @Test + public void canWrite() { + assertTrue(converter.canWrite(Source.class, new MediaType("application", "xml"))); + assertTrue(converter.canWrite(Source.class, new MediaType("application", "soap+xml"))); + assertTrue(converter.canWrite(Source.class, MediaType.ALL)); + } + + @Test + public void readDOMSource() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(BODY.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + DOMSource result = (DOMSource) converter.read(DOMSource.class, inputMessage); + Document document = (Document) result.getNode(); + assertEquals("Invalid result", "root", document.getDocumentElement().getLocalName()); + } + + @Test + public void readDOMSourceExternal() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(bodyExternal.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + converter.setSupportDtd(true); + DOMSource result = (DOMSource) converter.read(DOMSource.class, inputMessage); + Document document = (Document) result.getNode(); + assertEquals("Invalid result", "root", document.getDocumentElement().getLocalName()); + assertNotEquals("Invalid result", "Foo Bar", document.getDocumentElement().getTextContent()); + } + + @Test + public void readDomSourceWithXmlBomb() throws Exception { + // https://en.wikipedia.org/wiki/Billion_laughs + // https://msdn.microsoft.com/en-us/magazine/ee335713.aspx + String content = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + "]>\n" + + "&lol9;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + + this.thrown.expect(HttpMessageNotReadableException.class); + this.thrown.expectMessage("DOCTYPE"); + + this.converter.read(DOMSource.class, inputMessage); + } + + @Test + public void readSAXSource() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(BODY.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + SAXSource result = (SAXSource) converter.read(SAXSource.class, inputMessage); + InputSource inputSource = result.getInputSource(); + String s = FileCopyUtils.copyToString(new InputStreamReader(inputSource.getByteStream())); + assertThat("Invalid result", s, isSimilarTo(BODY)); + } + + @Test + public void readSAXSourceExternal() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(bodyExternal.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + converter.setSupportDtd(true); + SAXSource result = (SAXSource) converter.read(SAXSource.class, inputMessage); + InputSource inputSource = result.getInputSource(); + XMLReader reader = result.getXMLReader(); + reader.setContentHandler(new DefaultHandler() { + @Override + public void characters(char[] ch, int start, int length) { + String s = new String(ch, start, length); + assertNotEquals("Invalid result", "Foo Bar", s); + } + }); + reader.parse(inputSource); + } + + @Test + public void readSAXSourceWithXmlBomb() throws Exception { + // https://en.wikipedia.org/wiki/Billion_laughs + // https://msdn.microsoft.com/en-us/magazine/ee335713.aspx + String content = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + "]>\n" + + "&lol9;"; + + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + SAXSource result = (SAXSource) this.converter.read(SAXSource.class, inputMessage); + + this.thrown.expect(SAXException.class); + this.thrown.expectMessage("DOCTYPE"); + + InputSource inputSource = result.getInputSource(); + XMLReader reader = result.getXMLReader(); + reader.parse(inputSource); + } + + @Test + public void readStAXSource() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(BODY.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + StAXSource result = (StAXSource) converter.read(StAXSource.class, inputMessage); + XMLStreamReader streamReader = result.getXMLStreamReader(); + assertTrue(streamReader.hasNext()); + streamReader.nextTag(); + String s = streamReader.getLocalName(); + assertEquals("root", s); + s = streamReader.getElementText(); + assertEquals("Hello World", s); + streamReader.close(); + } + + @Test + public void readStAXSourceExternal() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(bodyExternal.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + converter.setSupportDtd(true); + StAXSource result = (StAXSource) converter.read(StAXSource.class, inputMessage); + XMLStreamReader streamReader = result.getXMLStreamReader(); + assertTrue(streamReader.hasNext()); + streamReader.next(); + streamReader.next(); + String s = streamReader.getLocalName(); + assertEquals("root", s); + try { + s = streamReader.getElementText(); + assertNotEquals("Foo Bar", s); + } + catch (XMLStreamException ex) { + // Some parsers raise a parse exception + } + streamReader.close(); + } + + @Test + public void readStAXSourceWithXmlBomb() throws Exception { + // https://en.wikipedia.org/wiki/Billion_laughs + // https://msdn.microsoft.com/en-us/magazine/ee335713.aspx + String content = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + "]>\n" + + "&lol9;"; + MockHttpInputMessage inputMessage = new MockHttpInputMessage(content.getBytes("UTF-8")); + StAXSource result = (StAXSource) this.converter.read(StAXSource.class, inputMessage); + + XMLStreamReader streamReader = result.getXMLStreamReader(); + assertTrue(streamReader.hasNext()); + streamReader.next(); + streamReader.next(); + String s = streamReader.getLocalName(); + assertEquals("root", s); + + this.thrown.expectMessage("\"lol9\""); + s = streamReader.getElementText(); + } + + @Test + public void readStreamSource() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(BODY.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + StreamSource result = (StreamSource) converter.read(StreamSource.class, inputMessage); + String s = FileCopyUtils.copyToString(new InputStreamReader(result.getInputStream())); + assertThat("Invalid result", s, isSimilarTo(BODY)); + } + + @Test + public void readSource() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(BODY.getBytes("UTF-8")); + inputMessage.getHeaders().setContentType(new MediaType("application", "xml")); + converter.read(Source.class, inputMessage); + } + + @Test + public void writeDOMSource() throws Exception { + DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); + documentBuilderFactory.setNamespaceAware(true); + Document document = documentBuilderFactory.newDocumentBuilder().newDocument(); + Element rootElement = document.createElement("root"); + document.appendChild(rootElement); + rootElement.setTextContent("Hello World"); + DOMSource domSource = new DOMSource(document); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(domSource, null, outputMessage); + assertThat("Invalid result", outputMessage.getBodyAsString(StandardCharsets.UTF_8), + isSimilarTo("Hello World")); + assertEquals("Invalid content-type", new MediaType("application", "xml"), + outputMessage.getHeaders().getContentType()); + assertEquals("Invalid content-length", outputMessage.getBodyAsBytes().length, + outputMessage.getHeaders().getContentLength()); + } + + @Test + public void writeSAXSource() throws Exception { + String xml = "Hello World"; + SAXSource saxSource = new SAXSource(new InputSource(new StringReader(xml))); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(saxSource, null, outputMessage); + assertThat("Invalid result", outputMessage.getBodyAsString(StandardCharsets.UTF_8), + isSimilarTo("Hello World")); + assertEquals("Invalid content-type", new MediaType("application", "xml"), + outputMessage.getHeaders().getContentType()); + } + + @Test + public void writeStreamSource() throws Exception { + String xml = "Hello World"; + StreamSource streamSource = new StreamSource(new StringReader(xml)); + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + converter.write(streamSource, null, outputMessage); + assertThat("Invalid result", outputMessage.getBodyAsString(StandardCharsets.UTF_8), + isSimilarTo("Hello World")); + assertEquals("Invalid content-type", new MediaType("application", "xml"), + outputMessage.getHeaders().getContentType()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/DefaultPathContainerTests.java b/spring-web/src/test/java/org/springframework/http/server/DefaultPathContainerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d039458be2a7824026518ebdc1e9d0c500682990 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/DefaultPathContainerTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.Test; + +import org.springframework.http.server.PathContainer.PathSegment; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; + +/** + * Unit tests for {@link DefaultPathContainer}. + * @author Rossen Stoyanchev + */ +public class DefaultPathContainerTests { + + @Test + public void pathSegment() throws Exception { + // basic + testPathSegment("cars", "cars", new LinkedMultiValueMap<>()); + + // empty + testPathSegment("", "", new LinkedMultiValueMap<>()); + + // spaces + testPathSegment("%20%20", " ", new LinkedMultiValueMap<>()); + testPathSegment("%20a%20", " a ", new LinkedMultiValueMap<>()); + } + + @Test + public void pathSegmentParams() throws Exception { + // basic + LinkedMultiValueMap params = new LinkedMultiValueMap<>(); + params.add("colors", "red"); + params.add("colors", "blue"); + params.add("colors", "green"); + params.add("year", "2012"); + testPathSegment("cars;colors=red,blue,green;year=2012", "cars", params); + + // trailing semicolon + params = new LinkedMultiValueMap<>(); + params.add("p", "1"); + testPathSegment("path;p=1;", "path", params); + + // params with spaces + params = new LinkedMultiValueMap<>(); + params.add("param name", "param value"); + testPathSegment("path;param%20name=param%20value;%20", "path", params); + + // empty params + params = new LinkedMultiValueMap<>(); + params.add("p", "1"); + testPathSegment("path;;;%20;%20;p=1;%20", "path", params); + } + + private void testPathSegment(String rawValue, String valueToMatch, MultiValueMap params) { + + PathContainer container = PathContainer.parsePath(rawValue); + + if ("".equals(rawValue)) { + assertEquals(0, container.elements().size()); + return; + } + + assertEquals(1, container.elements().size()); + PathSegment segment = (PathSegment) container.elements().get(0); + + assertEquals("value: '" + rawValue + "'", rawValue, segment.value()); + assertEquals("valueToMatch: '" + rawValue + "'", valueToMatch, segment.valueToMatch()); + assertEquals("params: '" + rawValue + "'", params, segment.parameters()); + } + + @Test + public void path() throws Exception { + // basic + testPath("/a/b/c", "/a/b/c", Arrays.asList("/", "a", "/", "b", "/", "c")); + + // root path + testPath("/", "/", Collections.singletonList("/")); + + // empty path + testPath("", "", Collections.emptyList()); + testPath("%20%20", "%20%20", Collections.singletonList("%20%20")); + + // trailing slash + testPath("/a/b/", "/a/b/", Arrays.asList("/", "a", "/", "b", "/")); + testPath("/a/b//", "/a/b//", Arrays.asList("/", "a", "/", "b", "/", "/")); + + // extra slashes and spaces + testPath("/%20", "/%20", Arrays.asList("/", "%20")); + testPath("//%20/%20", "//%20/%20", Arrays.asList("/", "/", "%20", "/", "%20")); + } + + private void testPath(String input, String value, List expectedElements) { + + PathContainer path = PathContainer.parsePath(input); + + assertEquals("value: '" + input + "'", value, path.value()); + assertEquals("elements: " + input, expectedElements, path.elements().stream() + .map(PathContainer.Element::value).collect(Collectors.toList())); + } + + @Test + public void subPath() throws Exception { + // basic + PathContainer path = PathContainer.parsePath("/a/b/c"); + assertSame(path, path.subPath(0)); + assertEquals("/b/c", path.subPath(2).value()); + assertEquals("/c", path.subPath(4).value()); + + // root path + path = PathContainer.parsePath("/"); + assertEquals("/", path.subPath(0).value()); + + // trailing slash + path = PathContainer.parsePath("/a/b/"); + assertEquals("/b/", path.subPath(2).value()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/DefaultRequestPathTests.java b/spring-web/src/test/java/org/springframework/http/server/DefaultRequestPathTests.java new file mode 100644 index 0000000000000000000000000000000000000000..30a88935b621ae3e76698957a969ad6eea9a112f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/DefaultRequestPathTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server; + +import java.net.URI; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for {@link DefaultRequestPath}. + * @author Rossen Stoyanchev + */ +public class DefaultRequestPathTests { + + @Test + public void requestPath() throws Exception { + // basic + testRequestPath("/app/a/b/c", "/app", "/a/b/c"); + + // no context path + testRequestPath("/a/b/c", "", "/a/b/c"); + + // context path only + testRequestPath("/a/b", "/a/b", ""); + + // root path + testRequestPath("/", "", "/"); + + // empty path + testRequestPath("", "", ""); + testRequestPath("", "/", ""); + + // trailing slash + testRequestPath("/app/a/", "/app", "/a/"); + testRequestPath("/app/a//", "/app", "/a//"); + } + + private void testRequestPath(String fullPath, String contextPath, String pathWithinApplication) { + + URI uri = URI.create("http://localhost:8080" + fullPath); + RequestPath requestPath = RequestPath.parse(uri, contextPath); + + assertEquals(contextPath.equals("/") ? "" : contextPath, requestPath.contextPath().value()); + assertEquals(pathWithinApplication, requestPath.pathWithinApplication().value()); + } + + @Test + public void updateRequestPath() throws Exception { + + URI uri = URI.create("http://localhost:8080/aA/bB/cC"); + RequestPath requestPath = RequestPath.parse(uri, null); + + assertEquals("", requestPath.contextPath().value()); + assertEquals("/aA/bB/cC", requestPath.pathWithinApplication().value()); + + requestPath = requestPath.modifyContextPath("/aA"); + + assertEquals("/aA", requestPath.contextPath().value()); + assertEquals("/bB/cC", requestPath.pathWithinApplication().value()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2c1aa4151c75cb5ce90c27dc0c2d340515dd75d3 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java @@ -0,0 +1,174 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.util.FileCopyUtils; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Juergen Hoeller + */ +public class ServletServerHttpRequestTests { + + private ServletServerHttpRequest request; + + private MockHttpServletRequest mockRequest; + + + @Before + public void create() { + mockRequest = new MockHttpServletRequest(); + request = new ServletServerHttpRequest(mockRequest); + } + + + @Test + public void getMethod() { + mockRequest.setMethod("POST"); + assertEquals("Invalid method", HttpMethod.POST, request.getMethod()); + } + + @Test + public void getUriForSimplePath() throws URISyntaxException { + URI uri = new URI("https://example.com/path"); + mockRequest.setScheme(uri.getScheme()); + mockRequest.setServerName(uri.getHost()); + mockRequest.setServerPort(uri.getPort()); + mockRequest.setRequestURI(uri.getPath()); + mockRequest.setQueryString(uri.getQuery()); + assertEquals(uri, request.getURI()); + } + + @Test + public void getUriWithQueryString() throws URISyntaxException { + URI uri = new URI("https://example.com/path?query"); + mockRequest.setScheme(uri.getScheme()); + mockRequest.setServerName(uri.getHost()); + mockRequest.setServerPort(uri.getPort()); + mockRequest.setRequestURI(uri.getPath()); + mockRequest.setQueryString(uri.getQuery()); + assertEquals(uri, request.getURI()); + } + + @Test // SPR-16414 + public void getUriWithQueryParam() throws URISyntaxException { + mockRequest.setServerName("example.com"); + mockRequest.setRequestURI("/path"); + mockRequest.setQueryString("query=foo"); + assertEquals(new URI("http://example.com/path?query=foo"), request.getURI()); + } + + @Test // SPR-16414 + public void getUriWithMalformedQueryParam() throws URISyntaxException { + mockRequest.setServerName("example.com"); + mockRequest.setRequestURI("/path"); + mockRequest.setQueryString("query=foo%%x"); + assertEquals(new URI("http://example.com/path"), request.getURI()); + } + + @Test // SPR-13876 + public void getUriWithEncoding() throws URISyntaxException { + URI uri = new URI("https://example.com/%E4%B8%AD%E6%96%87" + + "?redirect=https%3A%2F%2Fgithub.com%2Fspring-projects%2Fspring-framework"); + mockRequest.setScheme(uri.getScheme()); + mockRequest.setServerName(uri.getHost()); + mockRequest.setServerPort(uri.getPort()); + mockRequest.setRequestURI(uri.getRawPath()); + mockRequest.setQueryString(uri.getRawQuery()); + assertEquals(uri, request.getURI()); + } + + @Test + public void getHeaders() { + String headerName = "MyHeader"; + String headerValue1 = "value1"; + String headerValue2 = "value2"; + mockRequest.addHeader(headerName, headerValue1); + mockRequest.addHeader(headerName, headerValue2); + mockRequest.setContentType("text/plain"); + mockRequest.setCharacterEncoding("UTF-8"); + + HttpHeaders headers = request.getHeaders(); + assertNotNull("No HttpHeaders returned", headers); + assertTrue("Invalid headers returned", headers.containsKey(headerName)); + List headerValues = headers.get(headerName); + assertEquals("Invalid header values returned", 2, headerValues.size()); + assertTrue("Invalid header values returned", headerValues.contains(headerValue1)); + assertTrue("Invalid header values returned", headerValues.contains(headerValue2)); + assertEquals("Invalid Content-Type", new MediaType("text", "plain", StandardCharsets.UTF_8), + headers.getContentType()); + } + + @Test + public void getHeadersWithEmptyContentTypeAndEncoding() { + String headerName = "MyHeader"; + String headerValue1 = "value1"; + String headerValue2 = "value2"; + mockRequest.addHeader(headerName, headerValue1); + mockRequest.addHeader(headerName, headerValue2); + mockRequest.setContentType(""); + mockRequest.setCharacterEncoding(""); + + HttpHeaders headers = request.getHeaders(); + assertNotNull("No HttpHeaders returned", headers); + assertTrue("Invalid headers returned", headers.containsKey(headerName)); + List headerValues = headers.get(headerName); + assertEquals("Invalid header values returned", 2, headerValues.size()); + assertTrue("Invalid header values returned", headerValues.contains(headerValue1)); + assertTrue("Invalid header values returned", headerValues.contains(headerValue2)); + assertNull(headers.getContentType()); + } + + @Test + public void getBody() throws IOException { + byte[] content = "Hello World".getBytes("UTF-8"); + mockRequest.setContent(content); + + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertArrayEquals("Invalid content returned", content, result); + } + + @Test + public void getFormBody() throws IOException { + // Charset (SPR-8676) + mockRequest.setContentType("application/x-www-form-urlencoded; charset=UTF-8"); + mockRequest.setMethod("POST"); + mockRequest.addParameter("name 1", "value 1"); + mockRequest.addParameter("name 2", new String[] {"value 2+1", "value 2+2"}); + mockRequest.addParameter("name 3", (String) null); + + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + byte[] content = "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3".getBytes("UTF-8"); + assertArrayEquals("Invalid content returned", content, result); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..76f48b18de11cd555ada64430c07f6c73c0fb2a2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.FileCopyUtils; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class ServletServerHttpResponseTests { + + private ServletServerHttpResponse response; + + private MockHttpServletResponse mockResponse; + + + @Before + public void create() throws Exception { + mockResponse = new MockHttpServletResponse(); + response = new ServletServerHttpResponse(mockResponse); + } + + + @Test + public void setStatusCode() throws Exception { + response.setStatusCode(HttpStatus.NOT_FOUND); + assertEquals("Invalid status code", 404, mockResponse.getStatus()); + } + + @Test + public void getHeaders() throws Exception { + HttpHeaders headers = response.getHeaders(); + String headerName = "MyHeader"; + String headerValue1 = "value1"; + headers.add(headerName, headerValue1); + String headerValue2 = "value2"; + headers.add(headerName, headerValue2); + headers.setContentType(new MediaType("text", "plain", StandardCharsets.UTF_8)); + + response.close(); + assertTrue("Header not set", mockResponse.getHeaderNames().contains(headerName)); + List headerValues = mockResponse.getHeaders(headerName); + assertTrue("Header not set", headerValues.contains(headerValue1)); + assertTrue("Header not set", headerValues.contains(headerValue2)); + assertEquals("Invalid Content-Type", "text/plain;charset=UTF-8", mockResponse.getHeader("Content-Type")); + assertEquals("Invalid Content-Type", "text/plain;charset=UTF-8", mockResponse.getContentType()); + assertEquals("Invalid Content-Type", "UTF-8", mockResponse.getCharacterEncoding()); + } + + @Test + public void preExistingHeadersFromHttpServletResponse() { + String headerName = "Access-Control-Allow-Origin"; + String headerValue = "localhost:8080"; + + this.mockResponse.addHeader(headerName, headerValue); + this.response = new ServletServerHttpResponse(this.mockResponse); + + assertEquals(headerValue, this.response.getHeaders().getFirst(headerName)); + assertEquals(Collections.singletonList(headerValue), this.response.getHeaders().get(headerName)); + assertTrue(this.response.getHeaders().containsKey(headerName)); + assertEquals(headerValue, this.response.getHeaders().getFirst(headerName)); + assertEquals(headerValue, this.response.getHeaders().getAccessControlAllowOrigin()); + } + + @Test + public void getBody() throws Exception { + byte[] content = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(content, response.getBody()); + + assertArrayEquals("Invalid content written", content, mockResponse.getContentAsByteArray()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/AbstractHttpHandlerIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/AbstractHttpHandlerIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1f6c6d658d4fca895a42d3886d5350abd80d16ba --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/AbstractHttpHandlerIntegrationTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.File; +import java.time.Duration; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.After; +import org.junit.Before; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import reactor.core.publisher.Flux; + +import org.springframework.http.server.reactive.bootstrap.HttpServer; +import org.springframework.http.server.reactive.bootstrap.JettyHttpServer; +import org.springframework.http.server.reactive.bootstrap.ReactorHttpServer; +import org.springframework.http.server.reactive.bootstrap.TomcatHttpServer; +import org.springframework.http.server.reactive.bootstrap.UndertowHttpServer; + +@RunWith(Parameterized.class) +public abstract class AbstractHttpHandlerIntegrationTests { + + protected Log logger = LogFactory.getLog(getClass()); + + protected int port; + + @Parameterized.Parameter(0) + public HttpServer server; + + + @Parameterized.Parameters(name = "server [{0}]") + public static Object[][] arguments() { + File base = new File(System.getProperty("java.io.tmpdir")); + return new Object[][] { + {new JettyHttpServer()}, + {new ReactorHttpServer()}, + {new TomcatHttpServer(base.getAbsolutePath())}, + {new UndertowHttpServer()} + }; + } + + + @Before + public void setup() throws Exception { + this.server.setHandler(createHttpHandler()); + this.server.afterPropertiesSet(); + this.server.start(); + + // Set dynamically chosen port + this.port = this.server.getPort(); + } + + @After + public void tearDown() throws Exception { + this.server.stop(); + this.port = 0; + } + + + protected abstract HttpHandler createHttpHandler(); + + + /** + * Return an interval stream of N number of ticks and buffer the emissions + * to avoid back pressure failures (e.g. on slow CI server). + * + *

Use this method as follows: + *

    + *
  • Tests that verify N number of items followed by verifyOnComplete() + * should set the number of emissions to N. + *
  • Tests that verify N number of items followed by thenCancel() should + * set the number of buffered to an arbitrary number greater than N. + *
+ */ + public static Flux testInterval(Duration period, int count) { + return Flux.interval(period).take(count).onBackpressureBuffer(count); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/AsyncIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/AsyncIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4d22fb51a71f20af51dd301116a17150b48e5fae --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/AsyncIntegrationTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.time.Duration; + +import org.hamcrest.Matchers; +import org.junit.Ignore; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +/** + * @author Stephane Maldini + * @since 5.0 + */ +public class AsyncIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private final Scheduler asyncGroup = Schedulers.parallel(); + + private final DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(); + + + @Override + protected AsyncHandler createHttpHandler() { + return new AsyncHandler(); + } + + @Test + @Ignore // TODO: fragile due to socket failures + public void basicTest() throws Exception { + URI url = new URI("http://localhost:" + port); + ResponseEntity response = new RestTemplate().exchange( + RequestEntity.get(url).build(), String.class); + + assertThat(response.getBody(), Matchers.equalTo("hello")); + } + + + private class AsyncHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + return response.writeWith(Flux.just("h", "e", "l", "l", "o") + .delayElements(Duration.ofMillis(100)) + .publishOn(asyncGroup) + .collect(dataBufferFactory::allocateBuffer, (buffer, str) -> buffer.write(str.getBytes()))); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b0cfa5131553dd3796c41e78789ba8b6a2d9fd7f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java @@ -0,0 +1,305 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.LeakAwareDataBufferFactory; + +import static org.junit.Assert.*; + +/** + * @author Rossen Stoyanchev + * @author Stephane Maldini + */ +public class ChannelSendOperatorTests { + + private OneByOneAsyncWriter writer; + + + @Before + public void setUp() throws Exception { + this.writer = new OneByOneAsyncWriter(); + } + + + @Test + public void errorBeforeFirstItem() throws Exception { + IllegalStateException error = new IllegalStateException("boo"); + Mono completion = Mono.error(error).as(this::sendOperator); + Signal signal = completion.materialize().block(); + + assertNotNull(signal); + assertSame("Unexpected signal: " + signal, error, signal.getThrowable()); + } + + @Test + public void completionBeforeFirstItem() throws Exception { + Mono completion = Flux.empty().as(this::sendOperator); + Signal signal = completion.materialize().block(); + + assertNotNull(signal); + assertTrue("Unexpected signal: " + signal, signal.isOnComplete()); + + assertEquals(0, this.writer.items.size()); + assertTrue(this.writer.completed); + } + + @Test + public void writeOneItem() throws Exception { + Mono completion = Flux.just("one").as(this::sendOperator); + Signal signal = completion.materialize().block(); + + assertNotNull(signal); + assertTrue("Unexpected signal: " + signal, signal.isOnComplete()); + + assertEquals(1, this.writer.items.size()); + assertEquals("one", this.writer.items.get(0)); + assertTrue(this.writer.completed); + } + + + @Test + public void writeMultipleItems() { + List items = Arrays.asList("one", "two", "three"); + Mono completion = Flux.fromIterable(items).as(this::sendOperator); + Signal signal = completion.materialize().block(); + + assertNotNull(signal); + assertTrue("Unexpected signal: " + signal, signal.isOnComplete()); + + assertEquals(3, this.writer.items.size()); + assertEquals("one", this.writer.items.get(0)); + assertEquals("two", this.writer.items.get(1)); + assertEquals("three", this.writer.items.get(2)); + assertTrue(this.writer.completed); + } + + @Test + public void errorAfterMultipleItems() { + IllegalStateException error = new IllegalStateException("boo"); + Flux publisher = Flux.generate(() -> 0, (idx , subscriber) -> { + int i = ++idx; + subscriber.next(String.valueOf(i)); + if (i == 3) { + subscriber.error(error); + } + return i; + }); + Mono completion = publisher.as(this::sendOperator); + Signal signal = completion.materialize().block(); + + assertNotNull(signal); + assertSame("Unexpected signal: " + signal, error, signal.getThrowable()); + + assertEquals(3, this.writer.items.size()); + assertEquals("1", this.writer.items.get(0)); + assertEquals("2", this.writer.items.get(1)); + assertEquals("3", this.writer.items.get(2)); + assertSame(error, this.writer.error); + } + + @Test // gh-22720 + public void cancelWhileItemCached() { + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); + + ChannelSendOperator operator = new ChannelSendOperator<>( + Mono.fromCallable(() -> { + DataBuffer dataBuffer = bufferFactory.allocateBuffer(); + dataBuffer.write("foo", StandardCharsets.UTF_8); + return dataBuffer; + }), + publisher -> { + ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber(); + publisher.subscribe(subscriber); + return Mono.never(); + }); + + BaseSubscriber subscriber = new BaseSubscriber() {}; + operator.subscribe(subscriber); + subscriber.cancel(); + + bufferFactory.checkForLeaks(); + } + + @Test // gh-22720 + public void errorFromWriteSourceWhileItemCached() { + + // 1. First item received + // 2. writeFunction applied and writeCompletionBarrier subscribed to it + // 3. Write Publisher fails right after that and before request(n) from server + + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); + ZeroDemandSubscriber writeSubscriber = new ZeroDemandSubscriber(); + + ChannelSendOperator operator = new ChannelSendOperator<>( + Flux.create(sink -> { + DataBuffer dataBuffer = bufferFactory.allocateBuffer(); + dataBuffer.write("foo", StandardCharsets.UTF_8); + sink.next(dataBuffer); + sink.error(new IllegalStateException("err")); + }), + publisher -> { + publisher.subscribe(writeSubscriber); + return Mono.never(); + }); + + + operator.subscribe(new BaseSubscriber() {}); + try { + writeSubscriber.signalDemand(1); // Let cached signals ("foo" and error) be published.. + } + catch (Throwable ex) { + assertNotNull(ex.getCause()); + assertEquals("err", ex.getCause().getMessage()); + } + + bufferFactory.checkForLeaks(); + } + + @Test // gh-22720 + public void errorFromWriteFunctionWhileItemCached() { + + // 1. First item received + // 2. writeFunction applied and writeCompletionBarrier subscribed to it + // 3. writeFunction fails, e.g. to flush status and headers, before request(n) from server + + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); + + ChannelSendOperator operator = new ChannelSendOperator<>( + Flux.create(sink -> { + DataBuffer dataBuffer = bufferFactory.allocateBuffer(); + dataBuffer.write("foo", StandardCharsets.UTF_8); + sink.next(dataBuffer); + }), + publisher -> { + publisher.subscribe(new ZeroDemandSubscriber()); + return Mono.error(new IllegalStateException("err")); + }); + + StepVerifier.create(operator).expectErrorMessage("err").verify(Duration.ofSeconds(5)); + bufferFactory.checkForLeaks(); + } + + @Test // gh-23175 + public void errorInWriteFunction() { + + StepVerifier + .create(new ChannelSendOperator<>(Mono.just("one"), p -> { + throw new IllegalStateException("boo"); + })) + .expectErrorMessage("boo") + .verify(Duration.ofMillis(5000)); + + StepVerifier + .create(new ChannelSendOperator<>(Mono.empty(), p -> { + throw new IllegalStateException("boo"); + })) + .expectErrorMessage("boo") + .verify(Duration.ofMillis(5000)); + } + + + private Mono sendOperator(Publisher source){ + return new ChannelSendOperator<>(source, writer::send); + } + + + private static class OneByOneAsyncWriter { + + private List items = new ArrayList<>(); + + private boolean completed = false; + + private Throwable error; + + + public Publisher send(Publisher publisher) { + return subscriber -> Executors.newSingleThreadScheduledExecutor().schedule(() -> + publisher.subscribe(new WriteSubscriber(subscriber)),50, TimeUnit.MILLISECONDS); + } + + + private class WriteSubscriber implements Subscriber { + + private Subscription subscription; + + private final Subscriber subscriber; + + public WriteSubscriber(Subscriber subscriber) { + this.subscriber = subscriber; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + this.subscription.request(1); + } + + @Override + public void onNext(String item) { + items.add(item); + this.subscription.request(1); + } + + @Override + public void onError(Throwable ex) { + error = ex; + this.subscriber.onError(ex); + } + + @Override + public void onComplete() { + completed = true; + this.subscriber.onComplete(); + } + } + } + + + private static class ZeroDemandSubscriber extends BaseSubscriber { + + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } + + public void signalDemand(long demand) { + upstream().request(demand); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ContextPathCompositeHandlerTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ContextPathCompositeHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8c44de36e07354ab3232d59fdee8e80bb98dea32 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ContextPathCompositeHandlerTests.java @@ -0,0 +1,188 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; + +import static junit.framework.TestCase.assertFalse; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Unit tests for {@link ContextPathCompositeHandler}. + * + * @author Rossen Stoyanchev + */ +public class ContextPathCompositeHandlerTests { + + + @Test + public void invalidContextPath() { + testInvalid(" ", "Context path must not be empty"); + testInvalid("path", "Context path must begin with '/'"); + testInvalid("/path/", "Context path must not end with '/'"); + } + + private void testInvalid(String contextPath, String expectedError) { + try { + new ContextPathCompositeHandler(Collections.singletonMap(contextPath, new TestHttpHandler())); + fail(); + } + catch (IllegalArgumentException ex) { + assertEquals(expectedError, ex.getMessage()); + } + } + + @Test + public void match() { + TestHttpHandler handler1 = new TestHttpHandler(); + TestHttpHandler handler2 = new TestHttpHandler(); + TestHttpHandler handler3 = new TestHttpHandler(); + + Map map = new HashMap<>(); + map.put("/path", handler1); + map.put("/another/path", handler2); + map.put("/yet/another/path", handler3); + + testHandle("/another/path/and/more", map); + + assertInvoked(handler2, "/another/path"); + assertNotInvoked(handler1, handler3); + } + + @Test + public void matchWithContextPathEqualToPath() { + TestHttpHandler handler1 = new TestHttpHandler(); + TestHttpHandler handler2 = new TestHttpHandler(); + TestHttpHandler handler3 = new TestHttpHandler(); + + Map map = new HashMap<>(); + map.put("/path", handler1); + map.put("/another/path", handler2); + map.put("/yet/another/path", handler3); + + testHandle("/path", map); + + assertInvoked(handler1, "/path"); + assertNotInvoked(handler2, handler3); + } + + @Test + public void matchWithNativeContextPath() { + MockServerHttpRequest request = MockServerHttpRequest + .get("/yet/another/path") + .contextPath("/yet") // contextPath in underlying request + .build(); + + TestHttpHandler handler = new TestHttpHandler(); + Map map = Collections.singletonMap("/another/path", handler); + + new ContextPathCompositeHandler(map).handle(request, new MockServerHttpResponse()); + + assertTrue(handler.wasInvoked()); + assertEquals("/yet/another/path", handler.getRequest().getPath().contextPath().value()); + } + + @Test + public void notFound() { + TestHttpHandler handler1 = new TestHttpHandler(); + TestHttpHandler handler2 = new TestHttpHandler(); + + Map map = new HashMap<>(); + map.put("/path", handler1); + map.put("/another/path", handler2); + + ServerHttpResponse response = testHandle("/yet/another/path", map); + + assertNotInvoked(handler1, handler2); + assertEquals(HttpStatus.NOT_FOUND, response.getStatusCode()); + } + + @Test // SPR-17144 + public void notFoundWithCommitAction() { + + AtomicBoolean commitInvoked = new AtomicBoolean(false); + + ServerHttpRequest request = MockServerHttpRequest.get("/unknown/path").build(); + ServerHttpResponse response = new MockServerHttpResponse(); + response.beforeCommit(() -> { + commitInvoked.set(true); + return Mono.empty(); + }); + + Map map = new HashMap<>(); + TestHttpHandler handler = new TestHttpHandler(); + map.put("/path", handler); + new ContextPathCompositeHandler(map).handle(request, response).block(Duration.ofSeconds(5)); + + assertNotInvoked(handler); + assertEquals(HttpStatus.NOT_FOUND, response.getStatusCode()); + assertTrue(commitInvoked.get()); + } + + + private ServerHttpResponse testHandle(String pathToHandle, Map handlerMap) { + ServerHttpRequest request = MockServerHttpRequest.get(pathToHandle).build(); + ServerHttpResponse response = new MockServerHttpResponse(); + new ContextPathCompositeHandler(handlerMap).handle(request, response).block(Duration.ofSeconds(5)); + return response; + } + + private void assertInvoked(TestHttpHandler handler, String contextPath) { + assertTrue(handler.wasInvoked()); + assertEquals(contextPath, handler.getRequest().getPath().contextPath().value()); + } + + private void assertNotInvoked(TestHttpHandler... handlers) { + Arrays.stream(handlers).forEach(handler -> assertFalse(handler.wasInvoked())); + } + + + @SuppressWarnings("WeakerAccess") + private static class TestHttpHandler implements HttpHandler { + + private ServerHttpRequest request; + + public boolean wasInvoked() { + return (this.request != null); + } + + public ServerHttpRequest getRequest() { + return this.request; + } + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + this.request = request; + return Mono.empty(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/CookieIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/CookieIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..78de21cfe46bff629eb442be92f0d2bbb2cf1dbe --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/CookieIntegrationTests.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpCookie; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseCookie; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + +/** + * @author Rossen Stoyanchev + */ +@RunWith(Parameterized.class) +public class CookieIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private final CookieHandler cookieHandler = new CookieHandler(); + + + @Override + protected HttpHandler createHttpHandler() { + return this.cookieHandler; + } + + + @SuppressWarnings("unchecked") + @Test + public void basicTest() throws Exception { + URI url = new URI("http://localhost:" + port); + String header = "SID=31d4d96e407aad42; lang=en-US"; + ResponseEntity response = new RestTemplate().exchange( + RequestEntity.get(url).header("Cookie", header).build(), Void.class); + + Map> requestCookies = this.cookieHandler.requestCookies; + assertEquals(2, requestCookies.size()); + + List list = requestCookies.get("SID"); + assertEquals(1, list.size()); + assertEquals("31d4d96e407aad42", list.iterator().next().getValue()); + + list = requestCookies.get("lang"); + assertEquals(1, list.size()); + assertEquals("en-US", list.iterator().next().getValue()); + + List headerValues = response.getHeaders().get("Set-Cookie"); + assertEquals(2, headerValues.size()); + + assertThat(splitCookie(headerValues.get(0)), containsInAnyOrder(equalTo("SID=31d4d96e407aad42"), + equalToIgnoringCase("Path=/"), equalToIgnoringCase("Secure"), equalToIgnoringCase("HttpOnly"))); + + assertThat(splitCookie(headerValues.get(1)), containsInAnyOrder(equalTo("lang=en-US"), + equalToIgnoringCase("Path=/"), equalToIgnoringCase("Domain=example.com"))); + } + + // No client side HttpCookie support yet + private List splitCookie(String value) { + List list = new ArrayList<>(); + for (String s : value.split(";")){ + list.add(s.trim()); + } + return list; + } + + + private class CookieHandler implements HttpHandler { + + private Map> requestCookies; + + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + + this.requestCookies = request.getCookies(); + this.requestCookies.size(); // Cause lazy loading + + response.getCookies().add("SID", ResponseCookie.from("SID", "31d4d96e407aad42") + .path("/").secure(true).httpOnly(true).build()); + response.getCookies().add("lang", ResponseCookie.from("lang", "en-US") + .domain("example.com").path("/").build()); + + return response.setComplete(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/EchoHandlerIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/EchoHandlerIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7f71865b21e5dca9e2334ea60a09874e4cab855e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/EchoHandlerIntegrationTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.util.Random; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class EchoHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private static final int REQUEST_SIZE = 4096 * 3; + + private final Random rnd = new Random(); + + + @Override + protected EchoHandler createHttpHandler() { + return new EchoHandler(); + } + + + @Test + public void echo() throws Exception { + RestTemplate restTemplate = new RestTemplate(); + + byte[] body = randomBytes(); + RequestEntity request = RequestEntity.post(new URI("http://localhost:" + port)).body(body); + ResponseEntity response = restTemplate.exchange(request, byte[].class); + + assertArrayEquals(body, response.getBody()); + } + + + private byte[] randomBytes() { + byte[] buffer = new byte[REQUEST_SIZE]; + rnd.nextBytes(buffer); + return buffer; + } + + /** + * @author Arjen Poutsma + */ + public static class EchoHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + return response.writeWith(request.getBody()); + } + } +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ErrorHandlerIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ErrorHandlerIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..63b2af4af707860a9f4d284a713657c6f91df058 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ErrorHandlerIntegrationTests.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class ErrorHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private final ErrorHandler handler = new ErrorHandler(); + + + @Override + protected HttpHandler createHttpHandler() { + return handler; + } + + + @Test + public void responseBodyError() throws Exception { + RestTemplate restTemplate = new RestTemplate(); + restTemplate.setErrorHandler(NO_OP_ERROR_HANDLER); + + URI url = new URI("http://localhost:" + port + "/response-body-error"); + ResponseEntity response = restTemplate.getForEntity(url, String.class); + + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode()); + } + + @Test + public void handlingError() throws Exception { + RestTemplate restTemplate = new RestTemplate(); + restTemplate.setErrorHandler(NO_OP_ERROR_HANDLER); + + URI url = new URI("http://localhost:" + port + "/handling-error"); + ResponseEntity response = restTemplate.getForEntity(url, String.class); + + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode()); + } + + @Test // SPR-15560 + public void emptyPathSegments() throws Exception { + + RestTemplate restTemplate = new RestTemplate(); + restTemplate.setErrorHandler(NO_OP_ERROR_HANDLER); + + URI url = new URI("http://localhost:" + port + "//"); + ResponseEntity response = restTemplate.getForEntity(url, String.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + } + + + private static class ErrorHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + Exception error = new UnsupportedOperationException(); + String path = request.getURI().getPath(); + if (path.endsWith("response-body-error")) { + return response.writeWith(Mono.error(error)); + } + else if (path.endsWith("handling-error")) { + return Mono.error(error); + } + else { + return Mono.empty(); + } + } + } + + + private static final ResponseErrorHandler NO_OP_ERROR_HANDLER = new ResponseErrorHandler() { + + @Override + public boolean hasError(ClientHttpResponse response) { + return false; + } + + @Override + public void handleError(ClientHttpResponse response) { + } + }; + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/HeadersAdaptersTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/HeadersAdaptersTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e97182745c693d10f1f89405863abdf36786220c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/HeadersAdaptersTests.java @@ -0,0 +1,114 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.Arrays; +import java.util.Locale; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.undertow.util.HeaderMap; +import org.apache.tomcat.util.http.MimeHeaders; +import org.eclipse.jetty.http.HttpFields; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@code HeadersAdapters} {@code MultiValueMap} implementations. + * + * @author Brian Clozel + */ +@RunWith(Parameterized.class) +public class HeadersAdaptersTests { + + @Parameterized.Parameter(0) + public MultiValueMap headers; + + @Parameterized.Parameters(name = "headers [{0}]") + public static Object[][] arguments() { + return new Object[][] { + {CollectionUtils.toMultiValueMap( + new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH))}, + {new NettyHeadersAdapter(new DefaultHttpHeaders())}, + {new TomcatHeadersAdapter(new MimeHeaders())}, + {new UndertowHeadersAdapter(new HeaderMap())}, + {new JettyHeadersAdapter(new HttpFields())} + }; + } + + @After + public void tearDown() { + this.headers.clear(); + } + + @Test + public void getWithUnknownHeaderShouldReturnNull() { + assertNull(this.headers.get("Unknown")); + } + + @Test + public void getFirstWithUnknownHeaderShouldReturnNull() { + assertNull(this.headers.getFirst("Unknown")); + } + + @Test + public void sizeWithMultipleValuesForHeaderShouldCountHeaders() { + this.headers.add("TestHeader", "first"); + this.headers.add("TestHeader", "second"); + assertEquals(1, this.headers.size()); + } + + @Test + public void keySetShouldNotDuplicateHeaderNames() { + this.headers.add("TestHeader", "first"); + this.headers.add("OtherHeader", "test"); + this.headers.add("TestHeader", "second"); + assertEquals(2, this.headers.keySet().size()); + } + + @Test + public void containsKeyShouldBeCaseInsensitive() { + this.headers.add("TestHeader", "first"); + assertTrue(this.headers.containsKey("testheader")); + } + + @Test + public void addShouldKeepOrdering() { + this.headers.add("TestHeader", "first"); + this.headers.add("TestHeader", "second"); + assertEquals("first", this.headers.getFirst("TestHeader")); + assertEquals("first", this.headers.get("TestHeader").get(0)); + } + + @Test + public void putShouldOverrideExisting() { + this.headers.add("TestHeader", "first"); + this.headers.put("TestHeader", Arrays.asList("override")); + assertEquals("override", this.headers.getFirst("TestHeader")); + assertEquals(1, this.headers.get("TestHeader").size()); + } + +} \ No newline at end of file diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/HttpHeadResponseDecoratorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/HttpHeadResponseDecoratorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8231bc4789cc89bbde10185561af8fb8df3f49c8 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/HttpHeadResponseDecoratorTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server.reactive; + +import java.nio.charset.StandardCharsets; + +import io.netty.buffer.PooledByteBufAllocator; +import org.junit.After; +import org.junit.Test; +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.LeakAwareDataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; + +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for {@link HttpHeadResponseDecorator}. + * @author Rossen Stoyanchev + */ +public class HttpHeadResponseDecoratorTests { + + private final LeakAwareDataBufferFactory bufferFactory = + new LeakAwareDataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)); + + private final ServerHttpResponse response = + new HttpHeadResponseDecorator(new MockServerHttpResponse(this.bufferFactory)); + + + @After + public void tearDown() { + this.bufferFactory.checkForLeaks(); + } + + + @Test + public void write() { + Flux body = Flux.just(toDataBuffer("data1"), toDataBuffer("data2")); + response.writeWith(body).block(); + assertEquals(10, response.getHeaders().getContentLength()); + } + + @Test // gh-23484 + public void writeWithGivenContentLength() { + int length = 15; + this.response.getHeaders().setContentLength(length); + this.response.writeWith(Flux.empty()).block(); + assertEquals(length, this.response.getHeaders().getContentLength()); + } + + + private DataBuffer toDataBuffer(String s) { + DataBuffer buffer = this.bufferFactory.allocateBuffer(); + buffer.write(s.getBytes(StandardCharsets.UTF_8)); + return buffer; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java new file mode 100644 index 0000000000000000000000000000000000000000..78f0936c1c5a28f1edb919d94c17377323f00e93 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java @@ -0,0 +1,160 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBuffer; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link AbstractListenerReadPublisher}. + * + * @author Violeta Georgieva + * @author Rossen Stoyanchev + */ +public class ListenerReadPublisherTests { + + private final TestListenerReadPublisher publisher = new TestListenerReadPublisher(); + + private final TestSubscriber subscriber = new TestSubscriber(); + + + @Before + public void setup() { + this.publisher.subscribe(this.subscriber); + } + + + @Test + public void twoReads() { + + this.subscriber.getSubscription().request(2); + this.publisher.onDataAvailable(); + + assertEquals(2, this.publisher.getReadCalls()); + } + + @Test // SPR-17410 + public void discardDataOnError() { + + this.subscriber.getSubscription().request(2); + this.publisher.onDataAvailable(); + this.publisher.onError(new IllegalStateException()); + + assertEquals(2, this.publisher.getReadCalls()); + assertEquals(1, this.publisher.getDiscardCalls()); + } + + @Test // SPR-17410 + public void discardDataOnCancel() { + + this.subscriber.getSubscription().request(2); + this.subscriber.setCancelOnNext(true); + this.publisher.onDataAvailable(); + + assertEquals(1, this.publisher.getReadCalls()); + assertEquals(1, this.publisher.getDiscardCalls()); + } + + + private static final class TestListenerReadPublisher extends AbstractListenerReadPublisher { + + private int readCalls = 0; + + private int discardCalls = 0; + + + public TestListenerReadPublisher() { + super(""); + } + + + public int getReadCalls() { + return this.readCalls; + } + + public int getDiscardCalls() { + return this.discardCalls; + } + + @Override + protected void checkOnDataAvailable() { + // no-op + } + + @Override + protected DataBuffer read() { + this.readCalls++; + return mock(DataBuffer.class); + } + + @Override + protected void readingPaused() { + // No-op + } + + @Override + protected void discardData() { + this.discardCalls++; + } + } + + + private static final class TestSubscriber implements Subscriber { + + private Subscription subscription; + + private boolean cancelOnNext; + + + public Subscription getSubscription() { + return this.subscription; + } + + public void setCancelOnNext(boolean cancelOnNext) { + this.cancelOnNext = cancelOnNext; + } + + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + } + + @Override + public void onNext(DataBuffer dataBuffer) { + if (this.cancelOnNext) { + this.subscription.cancel(); + } + } + + @Override + public void onError(Throwable t) { + } + + @Override + public void onComplete() { + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e6bae84d5ccc01dbf2d92a12924f5efa3d6cce6a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java @@ -0,0 +1,206 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBuffer; + +import static junit.framework.TestCase.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link AbstractListenerWriteProcessor}. + * + * @author Rossen Stoyanchev + */ +public class ListenerWriteProcessorTests { + + private final TestListenerWriteProcessor processor = new TestListenerWriteProcessor(); + + private final TestResultSubscriber resultSubscriber = new TestResultSubscriber(); + + private final TestSubscription subscription = new TestSubscription(); + + + @Before + public void setup() { + this.processor.subscribe(this.resultSubscriber); + this.processor.onSubscribe(this.subscription); + assertEquals(1, subscription.getDemand()); + } + + + @Test // SPR-17410 + public void writePublisherError() { + + // Turn off writing so next item will be cached + this.processor.setWritePossible(false); + DataBuffer buffer = mock(DataBuffer.class); + this.processor.onNext(buffer); + + // Send error while item cached + this.processor.onError(new IllegalStateException()); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(1, this.processor.getDiscardedBuffers().size()); + assertSame(buffer, this.processor.getDiscardedBuffers().get(0)); + } + + @Test // SPR-17410 + public void ioExceptionDuringWrite() { + + // Fail on next write + this.processor.setWritePossible(true); + this.processor.setFailOnWrite(true); + + // Write + DataBuffer buffer = mock(DataBuffer.class); + this.processor.onNext(buffer); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(1, this.processor.getDiscardedBuffers().size()); + assertSame(buffer, this.processor.getDiscardedBuffers().get(0)); + } + + @Test // SPR-17410 + public void onNextWithoutDemand() { + + // Disable writing: next item will be cached.. + this.processor.setWritePossible(false); + DataBuffer buffer1 = mock(DataBuffer.class); + this.processor.onNext(buffer1); + + // Send more data illegally + DataBuffer buffer2 = mock(DataBuffer.class); + this.processor.onNext(buffer2); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(2, this.processor.getDiscardedBuffers().size()); + assertSame(buffer2, this.processor.getDiscardedBuffers().get(0)); + assertSame(buffer1, this.processor.getDiscardedBuffers().get(1)); + } + + + private static final class TestListenerWriteProcessor extends AbstractListenerWriteProcessor { + + private final List discardedBuffers = new ArrayList<>(); + + private boolean writePossible; + + private boolean failOnWrite; + + + public List getDiscardedBuffers() { + return this.discardedBuffers; + } + + public void setWritePossible(boolean writePossible) { + this.writePossible = writePossible; + } + + public void setFailOnWrite(boolean failOnWrite) { + this.failOnWrite = failOnWrite; + } + + + @Override + protected boolean isDataEmpty(DataBuffer dataBuffer) { + return false; + } + + @Override + protected boolean isWritePossible() { + return this.writePossible; + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (this.failOnWrite) { + throw new IOException("write failed"); + } + return true; + } + + @Override + protected void writingFailed(Throwable ex) { + cancel(); + onError(ex); + } + + @Override + protected void discardData(DataBuffer dataBuffer) { + this.discardedBuffers.add(dataBuffer); + } + } + + + private static final class TestSubscription implements Subscription { + + private long demand; + + + public long getDemand() { + return this.demand; + } + + + @Override + public void request(long n) { + this.demand = (n == Long.MAX_VALUE ? n : this.demand + n); + } + + @Override + public void cancel() { + } + } + + private static final class TestResultSubscriber implements Subscriber { + + private Throwable error; + + + public Throwable getError() { + return this.error; + } + + + @Override + public void onSubscribe(Subscription subscription) { + } + + @Override + public void onNext(Void aVoid) { + } + + @Override + public void onError(Throwable ex) { + this.error = ex; + } + + @Override + public void onComplete() { + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2e0c0dff644788e849ff80100c69ea6ee1bb9e6d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java @@ -0,0 +1,124 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; + +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.FormFieldPart; +import org.springframework.http.codec.multipart.Part; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.adapter.HttpWebHandlerAdapter; + +import static org.junit.Assert.*; + +/** + * @author Sebastien Deleuze + */ +public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + @Override + protected HttpHandler createHttpHandler() { + return new HttpWebHandlerAdapter(new CheckRequestHandler()); + } + + @Test + public void getFormParts() throws Exception { + RestTemplate restTemplate = new RestTemplate(); + RequestEntity> request = RequestEntity + .post(new URI("http://localhost:" + port + "/form-parts")) + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(generateBody()); + ResponseEntity response = restTemplate.exchange(request, Void.class); + assertEquals(HttpStatus.OK, response.getStatusCode()); + } + + private MultiValueMap generateBody() { + HttpHeaders fooHeaders = new HttpHeaders(); + fooHeaders.setContentType(MediaType.TEXT_PLAIN); + ClassPathResource fooResource = new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"); + HttpEntity fooPart = new HttpEntity<>(fooResource, fooHeaders); + HttpEntity barPart = new HttpEntity<>("bar"); + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("fooPart", fooPart); + parts.add("barPart", barPart); + return parts; + } + + + public static class CheckRequestHandler implements WebHandler { + + @Override + public Mono handle(ServerWebExchange exchange) { + if (exchange.getRequest().getURI().getPath().equals("/form-parts")) { + return assertGetFormParts(exchange); + } + return Mono.error(new AssertionError()); + } + + private Mono assertGetFormParts(ServerWebExchange exchange) { + return exchange + .getMultipartData() + .doOnNext(parts -> { + assertEquals(2, parts.size()); + assertTrue(parts.containsKey("fooPart")); + assertFooPart(parts.getFirst("fooPart")); + assertTrue(parts.containsKey("barPart")); + assertBarPart(parts.getFirst("barPart")); + }) + .then(); + } + + private void assertFooPart(Part part) { + assertEquals("fooPart", part.name()); + assertTrue(part instanceof FilePart); + assertEquals("foo.txt", ((FilePart) part).filename()); + + StepVerifier.create(DataBufferUtils.join(part.content())) + .consumeNextWith(buffer -> { + assertEquals(12, buffer.readableByteCount()); + byte[] byteContent = new byte[12]; + buffer.read(byteContent); + assertEquals("Lorem Ipsum.", new String(byteContent)); + }) + .verifyComplete(); + } + + private void assertBarPart(Part part) { + assertEquals("barPart", part.name()); + assertTrue(part instanceof FormFieldPart); + assertEquals("bar", ((FormFieldPart) part).value()); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/RandomHandlerIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/RandomHandlerIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d4146061837587e2d5b5253b657877032f7ae78d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/RandomHandlerIntegrationTests.java @@ -0,0 +1,114 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.util.Random; + +import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + public static final int REQUEST_SIZE = 4096 * 3; + + public static final int RESPONSE_SIZE = 1024 * 4; + + private final Random rnd = new Random(); + + private final RandomHandler handler = new RandomHandler(); + + private final DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(); + + + @Override + protected RandomHandler createHttpHandler() { + return handler; + } + + + @Test + public void random() throws Throwable { + // TODO: fix Reactor support + + RestTemplate restTemplate = new RestTemplate(); + + byte[] body = randomBytes(); + RequestEntity request = RequestEntity.post(new URI("http://localhost:" + port)).body(body); + ResponseEntity response = restTemplate.exchange(request, byte[].class); + + assertNotNull(response.getBody()); + assertEquals(RESPONSE_SIZE, + response.getHeaders().getContentLength()); + assertEquals(RESPONSE_SIZE, response.getBody().length); + } + + + private byte[] randomBytes() { + byte[] buffer = new byte[REQUEST_SIZE]; + rnd.nextBytes(buffer); + return buffer; + } + + private class RandomHandler implements HttpHandler { + + public static final int CHUNKS = 16; + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + Mono requestSizeMono = request.getBody(). + reduce(0, (integer, dataBuffer) -> integer + + dataBuffer.readableByteCount()). + doOnSuccessOrError((size, throwable) -> { + assertNull(throwable); + assertEquals(REQUEST_SIZE, (long) size); + }); + + response.getHeaders().setContentLength(RESPONSE_SIZE); + + return requestSizeMono.then(response.writeWith(multipleChunks())); + } + + private Publisher multipleChunks() { + int chunkSize = RESPONSE_SIZE / CHUNKS; + return Flux.range(1, CHUNKS).map(integer -> randomBuffer(chunkSize)); + } + + private DataBuffer randomBuffer(int size) { + byte[] bytes = new byte[size]; + rnd.nextBytes(bytes); + DataBuffer buffer = dataBufferFactory.allocateBuffer(size); + buffer.write(bytes); + return buffer; + } + + } +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..27a05830a65d3cbaa7a280152854b0777cbd804e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestIntegrationTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +/** + * @author Sebastien Deleuze + */ +public class ServerHttpRequestIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + @Override + protected CheckRequestHandler createHttpHandler() { + return new CheckRequestHandler(); + } + + @Test + public void checkUri() throws Exception { + URI url = new URI("http://localhost:" + port + "/foo?param=bar"); + RequestEntity request = RequestEntity.post(url).build(); + ResponseEntity response = new RestTemplate().exchange(request, Void.class); + assertEquals(HttpStatus.OK, response.getStatusCode()); + } + + + public static class CheckRequestHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + URI uri = request.getURI(); + assertEquals("http", uri.getScheme()); + assertNotNull(uri.getHost()); + assertNotEquals(-1, uri.getPort()); + assertNotNull(request.getRemoteAddress()); + assertEquals("/foo", uri.getPath()); + assertEquals("param=bar", uri.getQuery()); + return Mono.empty(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c378cfe0fd7bd702d2065968e5df742c0fb64514 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java @@ -0,0 +1,212 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.ByteArrayInputStream; +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; + +import javax.servlet.AsyncContext; +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; + +import org.junit.Test; + +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.test.DelegatingServletInputStream; +import org.springframework.mock.web.test.MockAsyncContext; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link AbstractServerHttpRequest}. + * + * @author Rossen Stoyanchev + * @author Sam Brannen + */ +public class ServerHttpRequestTests { + + @Test + public void queryParamsNone() throws Exception { + MultiValueMap params = createHttpRequest("/path").getQueryParams(); + assertEquals(0, params.size()); + } + + @Test + public void queryParams() throws Exception { + MultiValueMap params = createHttpRequest("/path?a=A&b=B").getQueryParams(); + assertEquals(2, params.size()); + assertEquals(Collections.singletonList("A"), params.get("a")); + assertEquals(Collections.singletonList("B"), params.get("b")); + } + + @Test + public void queryParamsWithMultipleValues() throws Exception { + MultiValueMap params = createHttpRequest("/path?a=1&a=2").getQueryParams(); + assertEquals(1, params.size()); + assertEquals(Arrays.asList("1", "2"), params.get("a")); + } + + @Test // SPR-15140 + public void queryParamsWithEncodedValue() throws Exception { + MultiValueMap params = createHttpRequest("/path?a=%20%2B+%C3%A0").getQueryParams(); + assertEquals(1, params.size()); + assertEquals(Collections.singletonList(" + \u00e0"), params.get("a")); + } + + @Test + public void queryParamsWithEmptyValue() throws Exception { + MultiValueMap params = createHttpRequest("/path?a=").getQueryParams(); + assertEquals(1, params.size()); + assertEquals(Collections.singletonList(""), params.get("a")); + } + + @Test + public void queryParamsWithNoValue() throws Exception { + MultiValueMap params = createHttpRequest("/path?a").getQueryParams(); + assertEquals(1, params.size()); + assertEquals(Collections.singletonList(null), params.get("a")); + } + + @Test + public void mutateRequest() throws Exception { + SslInfo sslInfo = mock(SslInfo.class); + ServerHttpRequest request = createHttpRequest("/").mutate().sslInfo(sslInfo).build(); + assertSame(sslInfo, request.getSslInfo()); + + request = createHttpRequest("/").mutate().method(HttpMethod.DELETE).build(); + assertEquals(HttpMethod.DELETE, request.getMethod()); + + String baseUri = "http://www.aaa.org/articles/"; + + request = createHttpRequest(baseUri).mutate().uri(URI.create("http://bbb.org:9090/b")).build(); + assertEquals("http://bbb.org:9090/b", request.getURI().toString()); + + request = createHttpRequest(baseUri).mutate().path("/b/c/d").build(); + assertEquals("http://www.aaa.org/b/c/d", request.getURI().toString()); + + request = createHttpRequest(baseUri).mutate().path("/app/b/c/d").contextPath("/app").build(); + assertEquals("http://www.aaa.org/app/b/c/d", request.getURI().toString()); + assertEquals("/app", request.getPath().contextPath().value()); + } + + @Test(expected = IllegalArgumentException.class) + public void mutateWithInvalidPath() throws Exception { + createHttpRequest("/").mutate().path("foo-bar"); + } + + @Test // SPR-16434 + public void mutatePathWithEncodedQueryParams() throws Exception { + ServerHttpRequest request = createHttpRequest("/path?name=%E6%89%8E%E6%A0%B9"); + request = request.mutate().path("/mutatedPath").build(); + + assertEquals("/mutatedPath", request.getURI().getRawPath()); + assertEquals("name=%E6%89%8E%E6%A0%B9", request.getURI().getRawQuery()); + } + + @Test + @SuppressWarnings("deprecation") + public void mutateHeaderByAddingHeaderValues() throws Exception { + String headerName = "key"; + String headerValue1 = "value1"; + String headerValue2 = "value2"; + + ServerHttpRequest request = createHttpRequest("/path"); + assertNull(request.getHeaders().get(headerName)); + + request = request.mutate().header(headerName, headerValue1).build(); + + assertNotNull(request.getHeaders().get(headerName)); + assertEquals(1, request.getHeaders().get(headerName).size()); + assertEquals(headerValue1, request.getHeaders().get(headerName).get(0)); + + request = request.mutate().header(headerName, headerValue2).build(); + + assertNotNull(request.getHeaders().get(headerName)); + assertEquals(2, request.getHeaders().get(headerName).size()); + assertEquals(headerValue1, request.getHeaders().get(headerName).get(0)); + assertEquals(headerValue2, request.getHeaders().get(headerName).get(1)); + } + + @Test + public void mutateHeaderBySettingHeaderValues() throws Exception { + String headerName = "key"; + String headerValue1 = "value1"; + String headerValue2 = "value2"; + String headerValue3 = "value3"; + + ServerHttpRequest request = createHttpRequest("/path"); + assertNull(request.getHeaders().get(headerName)); + + request = request.mutate().header(headerName, headerValue1, headerValue2).build(); + + assertNotNull(request.getHeaders().get(headerName)); + assertEquals(2, request.getHeaders().get(headerName).size()); + assertEquals(headerValue1, request.getHeaders().get(headerName).get(0)); + assertEquals(headerValue2, request.getHeaders().get(headerName).get(1)); + + request = request.mutate().header(headerName, new String[] { headerValue3 }).build(); + + assertNotNull(request.getHeaders().get(headerName)); + assertEquals(1, request.getHeaders().get(headerName).size()); + assertEquals(headerValue3, request.getHeaders().get(headerName).get(0)); + } + + private ServerHttpRequest createHttpRequest(String uriString) throws Exception { + URI uri = URI.create(uriString); + MockHttpServletRequest request = new TestHttpServletRequest(uri); + AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse()); + return new ServletServerHttpRequest(request, asyncContext, "", new DefaultDataBufferFactory(), 1024); + } + + + private static class TestHttpServletRequest extends MockHttpServletRequest { + + TestHttpServletRequest(URI uri) { + super("GET", uri.getRawPath()); + if (uri.getScheme() != null) { + setScheme(uri.getScheme()); + } + if (uri.getHost() != null) { + setServerName(uri.getHost()); + } + if (uri.getPort() != -1) { + setServerPort(uri.getPort()); + } + if (uri.getRawQuery() != null) { + setQueryString(uri.getRawQuery()); + } + } + + @Override + public ServletInputStream getInputStream() { + return new DelegatingServletInputStream(new ByteArrayInputStream(new byte[0])) { + @Override + public void setReadListener(ReadListener readListener) { + // Ignore + } + }; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java new file mode 100644 index 0000000000000000000000000000000000000000..59252e13f50d96133c07d88f65f045b52ba22aac --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java @@ -0,0 +1,200 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseCookie; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; + +/** + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + */ +public class ServerHttpResponseTests { + + + @Test + public void writeWith() throws Exception { + TestServerHttpResponse response = new TestServerHttpResponse(); + response.writeWith(Flux.just(wrap("a"), wrap("b"), wrap("c"))).block(); + + assertTrue(response.statusCodeWritten); + assertTrue(response.headersWritten); + assertTrue(response.cookiesWritten); + + assertEquals(3, response.body.size()); + assertEquals("a", new String(response.body.get(0).asByteBuffer().array(), StandardCharsets.UTF_8)); + assertEquals("b", new String(response.body.get(1).asByteBuffer().array(), StandardCharsets.UTF_8)); + assertEquals("c", new String(response.body.get(2).asByteBuffer().array(), StandardCharsets.UTF_8)); + } + + @Test // SPR-14952 + public void writeAndFlushWithFluxOfDefaultDataBuffer() throws Exception { + TestServerHttpResponse response = new TestServerHttpResponse(); + Flux> flux = Flux.just(Flux.just(wrap("foo"))); + response.writeAndFlushWith(flux).block(); + + assertTrue(response.statusCodeWritten); + assertTrue(response.headersWritten); + assertTrue(response.cookiesWritten); + + assertEquals(1, response.body.size()); + assertEquals("foo", new String(response.body.get(0).asByteBuffer().array(), StandardCharsets.UTF_8)); + } + + @Test + public void writeWithError() throws Exception { + TestServerHttpResponse response = new TestServerHttpResponse(); + response.getHeaders().setContentLength(12); + IllegalStateException error = new IllegalStateException("boo"); + response.writeWith(Flux.error(error)).onErrorResume(ex -> Mono.empty()).block(); + + assertFalse(response.statusCodeWritten); + assertFalse(response.headersWritten); + assertFalse(response.cookiesWritten); + assertFalse(response.getHeaders().containsKey(HttpHeaders.CONTENT_LENGTH)); + assertTrue(response.body.isEmpty()); + } + + @Test + public void setComplete() throws Exception { + TestServerHttpResponse response = new TestServerHttpResponse(); + response.setComplete().block(); + + assertTrue(response.statusCodeWritten); + assertTrue(response.headersWritten); + assertTrue(response.cookiesWritten); + assertTrue(response.body.isEmpty()); + } + + @Test + public void beforeCommitWithComplete() throws Exception { + ResponseCookie cookie = ResponseCookie.from("ID", "123").build(); + TestServerHttpResponse response = new TestServerHttpResponse(); + response.beforeCommit(() -> Mono.fromRunnable(() -> response.getCookies().add(cookie.getName(), cookie))); + response.writeWith(Flux.just(wrap("a"), wrap("b"), wrap("c"))).block(); + + assertTrue(response.statusCodeWritten); + assertTrue(response.headersWritten); + assertTrue(response.cookiesWritten); + assertSame(cookie, response.getCookies().getFirst("ID")); + + assertEquals(3, response.body.size()); + assertEquals("a", new String(response.body.get(0).asByteBuffer().array(), StandardCharsets.UTF_8)); + assertEquals("b", new String(response.body.get(1).asByteBuffer().array(), StandardCharsets.UTF_8)); + assertEquals("c", new String(response.body.get(2).asByteBuffer().array(), StandardCharsets.UTF_8)); + } + + @Test + public void beforeCommitActionWithSetComplete() throws Exception { + ResponseCookie cookie = ResponseCookie.from("ID", "123").build(); + TestServerHttpResponse response = new TestServerHttpResponse(); + response.beforeCommit(() -> { + response.getCookies().add(cookie.getName(), cookie); + return Mono.empty(); + }); + response.setComplete().block(); + + assertTrue(response.statusCodeWritten); + assertTrue(response.headersWritten); + assertTrue(response.cookiesWritten); + assertTrue(response.body.isEmpty()); + assertSame(cookie, response.getCookies().getFirst("ID")); + } + + + + private DefaultDataBuffer wrap(String a) { + return new DefaultDataBufferFactory().wrap(ByteBuffer.wrap(a.getBytes(StandardCharsets.UTF_8))); + } + + + private static class TestServerHttpResponse extends AbstractServerHttpResponse { + + private boolean statusCodeWritten; + + private boolean headersWritten; + + private boolean cookiesWritten; + + private final List body = new ArrayList<>(); + + public TestServerHttpResponse() { + super(new DefaultDataBufferFactory()); + } + + @Override + public T getNativeResponse() { + throw new IllegalStateException("This is a mock. No running server, no native response."); + } + + @Override + public void applyStatusCode() { + assertFalse(this.statusCodeWritten); + this.statusCodeWritten = true; + } + + @Override + protected void applyHeaders() { + assertFalse(this.headersWritten); + this.headersWritten = true; + } + + @Override + protected void applyCookies() { + assertFalse(this.cookiesWritten); + this.cookiesWritten = true; + } + + @Override + protected Mono writeWithInternal(Publisher body) { + return Flux.from(body).map(b -> { + this.body.add(b); + return b; + }).then(); + } + + @Override + protected Mono writeAndFlushWithInternal( + Publisher> bodyWithFlush) { + return Flux.from(bodyWithFlush).flatMap(body -> + Flux.from(body).map(b -> { + this.body.add(b); + return b; + }) + ).then(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpsRequestIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpsRequestIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d409923eb54648b20713f1164d0cbc46657ea920 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpsRequestIntegrationTests.java @@ -0,0 +1,114 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; + +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.conn.ssl.TrustSelfSignedStrategy; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.ssl.SSLContextBuilder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.server.reactive.bootstrap.HttpServer; +import org.springframework.http.server.reactive.bootstrap.ReactorHttpsServer; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +/** + * HTTPS-specific integration test for {@link ServerHttpRequest}. + * @author Arjen Poutsma + */ +@RunWith(Parameterized.class) +public class ServerHttpsRequestIntegrationTests { + + private int port; + + @Parameterized.Parameter(0) + public HttpServer server; + + private RestTemplate restTemplate; + + @Parameterized.Parameters(name = "server [{0}]") + public static Object[][] arguments() { + return new Object[][]{ + {new ReactorHttpsServer()}, + }; + } + + @Before + public void setup() throws Exception { + this.server.setHandler(new CheckRequestHandler()); + this.server.afterPropertiesSet(); + this.server.start(); + + // Set dynamically chosen port + this.port = this.server.getPort(); + + SSLContextBuilder builder = new SSLContextBuilder(); + builder.loadTrustMaterial(new TrustSelfSignedStrategy()); + SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( + builder.build(), NoopHostnameVerifier.INSTANCE); + CloseableHttpClient httpclient = HttpClients.custom().setSSLSocketFactory( + socketFactory).build(); + HttpComponentsClientHttpRequestFactory requestFactory = + new HttpComponentsClientHttpRequestFactory(httpclient); + this.restTemplate = new RestTemplate(requestFactory); + } + + @After + public void tearDown() throws Exception { + this.server.stop(); + this.port = 0; + } + + @Test + public void checkUri() throws Exception { + URI url = new URI("https://localhost:" + port + "/foo?param=bar"); + RequestEntity request = RequestEntity.post(url).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + assertEquals(HttpStatus.OK, response.getStatusCode()); + } + + public static class CheckRequestHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + URI uri = request.getURI(); + assertEquals("https", uri.getScheme()); + assertNotNull(uri.getHost()); + assertNotEquals(-1, uri.getPort()); + assertNotNull(request.getRemoteAddress()); + assertEquals("/foo", uri.getPath()); + assertEquals("param=bar", uri.getQuery()); + return Mono.empty(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/WriteOnlyHandlerIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/WriteOnlyHandlerIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6ff62dba5a89bd5c032492def3f694f3f56b5ac0 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/WriteOnlyHandlerIntegrationTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Random; + +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +/** + * @author Violeta Georgieva + * @since 5.0 + */ +public class WriteOnlyHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private static final int REQUEST_SIZE = 4096 * 3; + + private Random rnd = new Random(); + + private byte[] body; + + + @Override + protected WriteOnlyHandler createHttpHandler() { + return new WriteOnlyHandler(); + } + + @Test + public void writeOnly() throws Exception { + RestTemplate restTemplate = new RestTemplate(); + + this.body = randomBytes(); + RequestEntity request = RequestEntity.post( + new URI("http://localhost:" + port)).body( + "".getBytes(StandardCharsets.UTF_8)); + ResponseEntity response = restTemplate.exchange(request, byte[].class); + + assertArrayEquals(body, response.getBody()); + } + + private byte[] randomBytes() { + byte[] buffer = new byte[REQUEST_SIZE]; + rnd.nextBytes(buffer); + return buffer; + } + + + public class WriteOnlyHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + DataBuffer buffer = response.bufferFactory().allocateBuffer(body.length); + buffer.write(body); + return response.writeAndFlushWith(Flux.just(Flux.just(buffer))); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ZeroCopyIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ZeroCopyIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..295e21ff4d23d5f676290220ecdcc4f6051f2cf7 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ZeroCopyIntegrationTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.File; +import java.net.URI; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.ZeroCopyHttpOutputMessage; +import org.springframework.http.server.reactive.bootstrap.ReactorHttpServer; +import org.springframework.http.server.reactive.bootstrap.UndertowHttpServer; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; +import static org.junit.Assume.*; + +/** + * @author Arjen Poutsma + */ +public class ZeroCopyIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private final ZeroCopyHandler handler = new ZeroCopyHandler(); + + + @Override + protected HttpHandler createHttpHandler() { + return this.handler; + } + + + @Test + public void zeroCopy() throws Exception { + // Zero-copy only does not support servlet + assumeTrue(server instanceof ReactorHttpServer || server instanceof UndertowHttpServer); + + URI url = new URI("http://localhost:" + port); + RequestEntity request = RequestEntity.get(url).build(); + ResponseEntity response = new RestTemplate().exchange(request, byte[].class); + + Resource logo = new ClassPathResource("spring.png", ZeroCopyIntegrationTests.class); + + assertTrue(response.hasBody()); + assertEquals(logo.contentLength(), response.getHeaders().getContentLength()); + assertEquals(logo.contentLength(), response.getBody().length); + assertEquals(MediaType.IMAGE_PNG, response.getHeaders().getContentType()); + } + + + private static class ZeroCopyHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + try { + ZeroCopyHttpOutputMessage zeroCopyResponse = (ZeroCopyHttpOutputMessage) response; + Resource logo = new ClassPathResource("spring.png", ZeroCopyIntegrationTests.class); + File logoFile = logo.getFile(); + zeroCopyResponse.getHeaders().setContentType(MediaType.IMAGE_PNG); + zeroCopyResponse.getHeaders().setContentLength(logoFile.length()); + return zeroCopyResponse.writeWith(logoFile, 0, logoFile.length()); + } + catch (Throwable ex) { + return Mono.error(ex); + } + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/AbstractHttpServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/AbstractHttpServer.java new file mode 100644 index 0000000000000000000000000000000000000000..8aca3b152e5ef3b9e5083f29daa220ce39482e8f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/AbstractHttpServer.java @@ -0,0 +1,183 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive.bootstrap; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.server.reactive.ContextPathCompositeHandler; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.util.Assert; +import org.springframework.util.StopWatch; + +/** + * @author Rossen Stoyanchev + */ +public abstract class AbstractHttpServer implements HttpServer { + + protected Log logger = LogFactory.getLog(getClass().getName()); + + private String host = "0.0.0.0"; + + private int port = 0; + + private HttpHandler httpHandler; + + private Map handlerMap; + + private volatile boolean running; + + private final Object lifecycleMonitor = new Object(); + + + @Override + public void setHost(String host) { + this.host = host; + } + + public String getHost() { + return host; + } + + @Override + public void setPort(int port) { + this.port = port; + } + + @Override + public int getPort() { + return this.port; + } + + @Override + public void setHandler(HttpHandler handler) { + this.httpHandler = handler; + } + + public HttpHandler getHttpHandler() { + return this.httpHandler; + } + + public void registerHttpHandler(String contextPath, HttpHandler handler) { + if (this.handlerMap == null) { + this.handlerMap = new LinkedHashMap<>(); + } + this.handlerMap.put(contextPath, handler); + } + + public Map getHttpHandlerMap() { + return this.handlerMap; + } + + protected HttpHandler resolveHttpHandler() { + return (getHttpHandlerMap() != null ? + new ContextPathCompositeHandler(getHttpHandlerMap()) : getHttpHandler()); + } + + + // InitializingBean + + @Override + public final void afterPropertiesSet() throws Exception { + Assert.notNull(this.host, "Host must not be null"); + Assert.isTrue(this.port >= 0, "Port must not be a negative number"); + Assert.isTrue(this.httpHandler != null || this.handlerMap != null, "No HttpHandler configured"); + Assert.state(!this.running, "Cannot reconfigure while running"); + + synchronized (this.lifecycleMonitor) { + initServer(); + } + } + + protected abstract void initServer() throws Exception; + + + // Lifecycle + + @Override + public final void start() { + synchronized (this.lifecycleMonitor) { + if (!isRunning()) { + String serverName = getClass().getSimpleName(); + if (logger.isDebugEnabled()) { + logger.debug("Starting " + serverName + "..."); + } + this.running = true; + try { + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + startInternal(); + long millis = stopWatch.getTotalTimeMillis(); + if (logger.isDebugEnabled()) { + logger.debug("Server started on port " + getPort() + "(" + millis + " millis)."); + } + } + catch (Throwable ex) { + throw new IllegalStateException(ex); + } + } + } + + } + + protected abstract void startInternal() throws Exception; + + @Override + public final void stop() { + synchronized (this.lifecycleMonitor) { + if (isRunning()) { + String serverName = getClass().getSimpleName(); + logger.debug("Stopping " + serverName + "..."); + this.running = false; + try { + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + stopInternal(); + logger.debug("Server stopped (" + stopWatch.getTotalTimeMillis() + " millis)."); + } + catch (Throwable ex) { + throw new IllegalStateException(ex); + } + finally { + reset(); + } + } + } + } + + protected abstract void stopInternal() throws Exception; + + @Override + public boolean isRunning() { + return this.running; + } + + + private void reset() { + this.host = "0.0.0.0"; + this.port = 0; + this.httpHandler = null; + this.handlerMap = null; + resetInternal(); + } + + protected abstract void resetInternal(); + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/HttpServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/HttpServer.java new file mode 100644 index 0000000000000000000000000000000000000000..9abb1fe14c15dd955cffa3a4821b08dc2935d7e6 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/HttpServer.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive.bootstrap; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.Lifecycle; +import org.springframework.http.server.reactive.HttpHandler; + +/** + * @author Rossen Stoyanchev + */ +public interface HttpServer extends InitializingBean, Lifecycle { + + void setHost(String host); + + void setPort(int port); + + int getPort(); + + void setHandler(HttpHandler handler); + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/JettyHttpServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/JettyHttpServer.java new file mode 100644 index 0000000000000000000000000000000000000000..4d63e75202bdf9b5fbb79126f6f015338fb67a82 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/JettyHttpServer.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive.bootstrap; + +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; + +import org.springframework.http.server.reactive.JettyHttpHandlerAdapter; +import org.springframework.http.server.reactive.ServletHttpHandlerAdapter; + +/** + * @author Rossen Stoyanchev + */ +public class JettyHttpServer extends AbstractHttpServer { + + private Server jettyServer; + + private ServletContextHandler contextHandler; + + + @Override + protected void initServer() throws Exception { + + this.jettyServer = new Server(); + + ServletHttpHandlerAdapter servlet = createServletAdapter(); + ServletHolder servletHolder = new ServletHolder(servlet); + servletHolder.setAsyncSupported(true); + + this.contextHandler = new ServletContextHandler(this.jettyServer, "", false, false); + this.contextHandler.addServlet(servletHolder, "/"); + this.contextHandler.start(); + + ServerConnector connector = new ServerConnector(this.jettyServer); + connector.setHost(getHost()); + connector.setPort(getPort()); + this.jettyServer.addConnector(connector); + } + + private ServletHttpHandlerAdapter createServletAdapter() { + return new JettyHttpHandlerAdapter(resolveHttpHandler()); + } + + @Override + protected void startInternal() throws Exception { + this.jettyServer.start(); + setPort(((ServerConnector) this.jettyServer.getConnectors()[0]).getLocalPort()); + } + + @Override + protected void stopInternal() throws Exception { + try { + if (this.contextHandler.isRunning()) { + this.contextHandler.stop(); + } + } + finally { + try { + if (this.jettyServer.isRunning()) { + this.jettyServer.setStopTimeout(5000); + this.jettyServer.stop(); + this.jettyServer.destroy(); + } + } + catch (Exception ex) { + // ignore + } + } + } + + @Override + protected void resetInternal() { + try { + if (this.jettyServer.isRunning()) { + this.jettyServer.setStopTimeout(5000); + this.jettyServer.stop(); + this.jettyServer.destroy(); + } + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + finally { + this.jettyServer = null; + this.contextHandler = null; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/ReactorHttpServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/ReactorHttpServer.java new file mode 100644 index 0000000000000000000000000000000000000000..25e9b2ac3c7d000a2a966dd51bd84721d5b5111c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/ReactorHttpServer.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive.bootstrap; + +import java.util.concurrent.atomic.AtomicReference; + +import reactor.netty.DisposableServer; + +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; + +/** + * @author Stephane Maldini + */ +public class ReactorHttpServer extends AbstractHttpServer { + + private ReactorHttpHandlerAdapter reactorHandler; + + private reactor.netty.http.server.HttpServer reactorServer; + + private AtomicReference serverRef = new AtomicReference<>(); + + + @Override + protected void initServer() { + this.reactorHandler = createHttpHandlerAdapter(); + this.reactorServer = reactor.netty.http.server.HttpServer.create() + .tcpConfiguration(server -> server.host(getHost())) + .port(getPort()); + } + + private ReactorHttpHandlerAdapter createHttpHandlerAdapter() { + return new ReactorHttpHandlerAdapter(resolveHttpHandler()); + } + + @Override + protected void startInternal() { + DisposableServer server = this.reactorServer.handle(this.reactorHandler).bind().block(); + setPort(server.address().getPort()); + this.serverRef.set(server); + } + + @Override + protected void stopInternal() { + this.serverRef.get().dispose(); + } + + @Override + protected void resetInternal() { + this.reactorServer = null; + this.reactorHandler = null; + this.serverRef.set(null); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/ReactorHttpsServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/ReactorHttpsServer.java new file mode 100644 index 0000000000000000000000000000000000000000..1a7636d93516bb85f58291e44d0a8a4802b5d4df --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/ReactorHttpsServer.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive.bootstrap; + +import java.util.concurrent.atomic.AtomicReference; + +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import reactor.netty.DisposableServer; +import reactor.netty.tcp.SslProvider.DefaultConfigurationType; + +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; + +/** + * @author Stephane Maldini + */ +public class ReactorHttpsServer extends AbstractHttpServer { + + private ReactorHttpHandlerAdapter reactorHandler; + + private reactor.netty.http.server.HttpServer reactorServer; + + private AtomicReference serverRef = new AtomicReference<>(); + + + @Override + protected void initServer() throws Exception { + + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forServer(cert.certificate(), cert.privateKey()); + + this.reactorHandler = createHttpHandlerAdapter(); + this.reactorServer = reactor.netty.http.server.HttpServer.create() + .host(getHost()) + .port(getPort()) + .secure(spec -> spec.sslContext(builder).defaultConfiguration(DefaultConfigurationType.TCP)); + } + + private ReactorHttpHandlerAdapter createHttpHandlerAdapter() { + return new ReactorHttpHandlerAdapter(resolveHttpHandler()); + } + + @Override + protected void startInternal() { + DisposableServer server = this.reactorServer.handle(this.reactorHandler).bind().block(); + setPort(server.address().getPort()); + this.serverRef.set(server); + } + + @Override + protected void stopInternal() { + this.serverRef.get().dispose(); + } + + @Override + protected void resetInternal() { + this.reactorServer = null; + this.reactorHandler = null; + this.serverRef.set(null); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/TomcatHttpServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/TomcatHttpServer.java new file mode 100644 index 0000000000000000000000000000000000000000..0cbc79c37f007a68f8a7f2ef5bf194bfc30dd68d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/TomcatHttpServer.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive.bootstrap; + +import java.io.File; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; + +import org.springframework.http.server.reactive.ServletHttpHandlerAdapter; +import org.springframework.http.server.reactive.TomcatHttpHandlerAdapter; +import org.springframework.util.Assert; + +/** + * @author Rossen Stoyanchev + */ +public class TomcatHttpServer extends AbstractHttpServer { + + private final String baseDir; + + private final Class wsListener; + + private String contextPath = ""; + + private String servletMapping = "/"; + + private Tomcat tomcatServer; + + + public TomcatHttpServer(String baseDir) { + this(baseDir, null); + } + + public TomcatHttpServer(String baseDir, Class wsListener) { + Assert.notNull(baseDir, "Base dir must not be null"); + this.baseDir = baseDir; + this.wsListener = wsListener; + } + + + public void setContextPath(String contextPath) { + this.contextPath = contextPath; + } + + public void setServletMapping(String servletMapping) { + this.servletMapping = servletMapping; + } + + + @Override + protected void initServer() throws Exception { + this.tomcatServer = new Tomcat(); + this.tomcatServer.setBaseDir(baseDir); + this.tomcatServer.setHostname(getHost()); + this.tomcatServer.setPort(getPort()); + + ServletHttpHandlerAdapter servlet = initServletAdapter(); + + File base = new File(System.getProperty("java.io.tmpdir")); + Context rootContext = tomcatServer.addContext(this.contextPath, base.getAbsolutePath()); + Tomcat.addServlet(rootContext, "httpHandlerServlet", servlet).setAsyncSupported(true); + rootContext.addServletMappingDecoded(this.servletMapping, "httpHandlerServlet"); + if (wsListener != null) { + rootContext.addApplicationListener(wsListener.getName()); + } + } + + private ServletHttpHandlerAdapter initServletAdapter() { + return new TomcatHttpHandlerAdapter(resolveHttpHandler()); + } + + + @Override + protected void startInternal() throws LifecycleException { + this.tomcatServer.start(); + setPort(this.tomcatServer.getConnector().getLocalPort()); + } + + @Override + protected void stopInternal() throws Exception { + this.tomcatServer.stop(); + this.tomcatServer.destroy(); + } + + @Override + protected void resetInternal() { + this.tomcatServer = null; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/UndertowHttpServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/UndertowHttpServer.java new file mode 100644 index 0000000000000000000000000000000000000000..cdfc5711ac66867e6bd0b89174e9a57a5f06f3df --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/UndertowHttpServer.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive.bootstrap; + +import java.net.InetSocketAddress; + +import io.undertow.Undertow; + +import org.springframework.http.server.reactive.UndertowHttpHandlerAdapter; + +/** + * @author Marek Hawrylczak + */ +public class UndertowHttpServer extends AbstractHttpServer { + + private Undertow server; + + + @Override + protected void initServer() throws Exception { + this.server = Undertow.builder().addHttpListener(getPort(), getHost()) + .setHandler(initHttpHandlerAdapter()) + .build(); + } + + private UndertowHttpHandlerAdapter initHttpHandlerAdapter() { + return new UndertowHttpHandlerAdapter(resolveHttpHandler()); + } + + @Override + protected void startInternal() { + this.server.start(); + Undertow.ListenerInfo info = this.server.getListenerInfo().get(0); + setPort(((InetSocketAddress) info.getAddress()).getPort()); + } + + @Override + protected void stopInternal() { + this.server.stop(); + } + + @Override + protected void resetInternal() { + this.server = null; + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/package-info.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..93e5dac54947119729c6818744d93f52871e60f4 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/package-info.java @@ -0,0 +1,5 @@ +/** + * This package contains temporary interfaces and classes for running embedded servers. + * They are expected to be replaced by an upcoming Spring Boot support. + */ +package org.springframework.http.server.reactive.bootstrap; diff --git a/spring-web/src/test/java/org/springframework/mock/http/client/reactive/test/MockClientHttpRequest.java b/spring-web/src/test/java/org/springframework/mock/http/client/reactive/test/MockClientHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..02b9bceed263219cb0ee06fad675c9917a84de95 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/http/client/reactive/test/MockClientHttpRequest.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.http.client.reactive.test; + +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.Optional; +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.reactive.AbstractClientHttpRequest; +import org.springframework.http.client.reactive.ClientHttpRequest; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Mock implementation of {@link ClientHttpRequest}. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class MockClientHttpRequest extends AbstractClientHttpRequest { + + private HttpMethod httpMethod; + + private URI url; + + private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + private Flux body = Flux.error( + new IllegalStateException("The body is not set. " + + "Did handling complete with success? Is a custom \"writeHandler\" configured?")); + + private Function, Mono> writeHandler; + + + public MockClientHttpRequest(HttpMethod httpMethod, String urlTemplate, Object... vars) { + this(httpMethod, UriComponentsBuilder.fromUriString(urlTemplate).buildAndExpand(vars).encode().toUri()); + } + + public MockClientHttpRequest(HttpMethod httpMethod, URI url) { + this.httpMethod = httpMethod; + this.url = url; + this.writeHandler = body -> { + this.body = body.cache(); + return this.body.then(); + }; + } + + + /** + * Configure a custom handler for writing the request body. + * + *

The default write handler consumes and caches the request body so it + * may be accessed subsequently, e.g. in test assertions. Use this property + * when the request body is an infinite stream. + * + * @param writeHandler the write handler to use returning {@code Mono} + * when the body has been "written" (i.e. consumed). + */ + public void setWriteHandler(Function, Mono> writeHandler) { + Assert.notNull(writeHandler, "'writeHandler' is required"); + this.writeHandler = writeHandler; + } + + + @Override + public HttpMethod getMethod() { + return this.httpMethod; + } + + @Override + public URI getURI() { + return this.url; + } + + @Override + public DataBufferFactory bufferFactory() { + return this.bufferFactory; + } + + @Override + protected void applyHeaders() { + } + + @Override + protected void applyCookies() { + getCookies().values().stream().flatMap(Collection::stream) + .forEach(cookie -> getHeaders().add(HttpHeaders.COOKIE, cookie.toString())); + } + + @Override + public Mono writeWith(Publisher body) { + return doCommit(() -> Mono.defer(() -> this.writeHandler.apply(Flux.from(body)))); + } + + @Override + public Mono writeAndFlushWith(Publisher> body) { + return writeWith(Flux.from(body).flatMap(p -> p)); + } + + @Override + public Mono setComplete() { + return writeWith(Flux.empty()); + } + + + /** + * Return the request body, or an error stream if the body was never set + * or when {@link #setWriteHandler} is configured. + */ + public Flux getBody() { + return this.body; + } + + /** + * Aggregate response data and convert to a String using the "Content-Type" + * charset or "UTF-8" by default. + */ + public Mono getBodyAsString() { + + Charset charset = Optional.ofNullable(getHeaders().getContentType()).map(MimeType::getCharset) + .orElse(StandardCharsets.UTF_8); + + return getBody() + .reduce(bufferFactory().allocateBuffer(), (previous, current) -> { + previous.write(current); + DataBufferUtils.release(current); + return previous; + }) + .map(buffer -> bufferToString(buffer, charset)); + } + + private static String bufferToString(DataBuffer buffer, Charset charset) { + Assert.notNull(charset, "'charset' must not be null"); + byte[] bytes = new byte[buffer.readableByteCount()]; + buffer.read(bytes); + return new String(bytes, charset); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/http/client/reactive/test/MockClientHttpResponse.java b/spring-web/src/test/java/org/springframework/mock/http/client/reactive/test/MockClientHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..f7b772a74c8ec21b9f107e0eb9b8afafbb7e01ae --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/http/client/reactive/test/MockClientHttpResponse.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.http.client.reactive.test; + +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collection; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Mock implementation of {@link ClientHttpResponse}. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class MockClientHttpResponse implements ClientHttpResponse { + + private final int status; + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private Flux body = Flux.empty(); + + private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + + public MockClientHttpResponse(HttpStatus status) { + Assert.notNull(status, "HttpStatus is required"); + this.status = status.value(); + } + + public MockClientHttpResponse(int status) { + Assert.isTrue(status >= 100 && status < 600, "Status must be between 1xx and 5xx"); + this.status = status; + } + + + @Override + public HttpStatus getStatusCode() { + return HttpStatus.valueOf(this.status); + } + + @Override + public int getRawStatusCode() { + return this.status; + } + + @Override + public HttpHeaders getHeaders() { + if (!getCookies().isEmpty() && this.headers.get(HttpHeaders.SET_COOKIE) == null) { + getCookies().values().stream().flatMap(Collection::stream) + .forEach(cookie -> getHeaders().add(HttpHeaders.SET_COOKIE, cookie.toString())); + } + return this.headers; + } + + @Override + public MultiValueMap getCookies() { + return this.cookies; + } + + public void setBody(Publisher body) { + this.body = Flux.from(body); + } + + public void setBody(String body) { + setBody(body, StandardCharsets.UTF_8); + } + + public void setBody(String body, Charset charset) { + DataBuffer buffer = toDataBuffer(body, charset); + this.body = Flux.just(buffer); + } + + private DataBuffer toDataBuffer(String body, Charset charset) { + byte[] bytes = body.getBytes(charset); + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + return this.bufferFactory.wrap(byteBuffer); + } + + @Override + public Flux getBody() { + return this.body; + } + + /** + * Return the response body aggregated and converted to a String using the + * charset of the Content-Type response or otherwise as "UTF-8". + */ + public Mono getBodyAsString() { + Charset charset = getCharset(); + return Flux.from(getBody()) + .reduce(this.bufferFactory.allocateBuffer(), (previous, current) -> { + previous.write(current); + DataBufferUtils.release(current); + return previous; + }) + .map(buffer -> dumpString(buffer, charset)); + } + + private static String dumpString(DataBuffer buffer, Charset charset) { + Assert.notNull(charset, "'charset' must not be null"); + byte[] bytes = new byte[buffer.readableByteCount()]; + buffer.read(bytes); + return new String(bytes, charset); + } + + private Charset getCharset() { + Charset charset = null; + MediaType contentType = getHeaders().getContentType(); + if (contentType != null) { + charset = contentType.getCharset(); + } + return (charset != null ? charset : StandardCharsets.UTF_8); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/MockServerHttpRequest.java b/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/MockServerHttpRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..6b09e8830ad4c6a2c1c18e072014105aaa5d0b7c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/MockServerHttpRequest.java @@ -0,0 +1,565 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.http.server.reactive.test; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRange; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.AbstractServerHttpRequest; +import org.springframework.http.server.reactive.SslInfo; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MimeType; +import org.springframework.util.MultiValueMap; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Mock extension of {@link AbstractServerHttpRequest} for use in tests without + * an actual server. Use the static methods to obtain a builder. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public final class MockServerHttpRequest extends AbstractServerHttpRequest { + + private final HttpMethod httpMethod; + + private final MultiValueMap cookies; + + @Nullable + private final InetSocketAddress remoteAddress; + + @Nullable + private final SslInfo sslInfo; + + private final Flux body; + + + private MockServerHttpRequest(HttpMethod httpMethod, URI uri, @Nullable String contextPath, + HttpHeaders headers, MultiValueMap cookies, + @Nullable InetSocketAddress remoteAddress, @Nullable SslInfo sslInfo, + Publisher body) { + + super(uri, contextPath, headers); + this.httpMethod = httpMethod; + this.cookies = cookies; + this.remoteAddress = remoteAddress; + this.sslInfo = sslInfo; + this.body = Flux.from(body); + } + + + @Override + public HttpMethod getMethod() { + return this.httpMethod; + } + + @Override + public String getMethodValue() { + return this.httpMethod.name(); + } + + @Override + @Nullable + public InetSocketAddress getRemoteAddress() { + return this.remoteAddress; + } + + @Override + @Nullable + protected SslInfo initSslInfo() { + return this.sslInfo; + } + + @Override + public Flux getBody() { + return this.body; + } + + @Override + protected MultiValueMap initCookies() { + return this.cookies; + } + + @Override + public T getNativeRequest() { + throw new IllegalStateException("This is a mock. No running server, no native request."); + } + + + // Static builder methods + + /** + * Create a builder with the given HTTP method and a {@link URI}. + * @param method the HTTP method (GET, POST, etc) + * @param url the URL + * @return the created builder + */ + public static BodyBuilder method(HttpMethod method, URI url) { + return new DefaultBodyBuilder(method, url); + } + + /** + * Alternative to {@link #method(HttpMethod, URI)} that accepts a URI template. + * The given URI may contain query parameters, or those may be added later via + * {@link BaseBuilder#queryParam queryParam} builder methods. + * @param method the HTTP method (GET, POST, etc) + * @param urlTemplate the URL template + * @param vars variables to expand into the template + * @return the created builder + */ + public static BodyBuilder method(HttpMethod method, String urlTemplate, Object... vars) { + URI url = UriComponentsBuilder.fromUriString(urlTemplate).buildAndExpand(vars).encode().toUri(); + return new DefaultBodyBuilder(method, url); + } + + /** + * Create an HTTP GET builder with the given URI template. The given URI may + * contain query parameters, or those may be added later via + * {@link BaseBuilder#queryParam queryParam} builder methods. + * @param urlTemplate a URL template; the resulting URL will be encoded + * @param uriVars zero or more URI variables + * @return the created builder + */ + public static BaseBuilder get(String urlTemplate, Object... uriVars) { + return method(HttpMethod.GET, urlTemplate, uriVars); + } + + /** + * HTTP HEAD variant. See {@link #get(String, Object...)} for general info. + * @param urlTemplate a URL template; the resulting URL will be encoded + * @param uriVars zero or more URI variables + * @return the created builder + */ + public static BaseBuilder head(String urlTemplate, Object... uriVars) { + return method(HttpMethod.HEAD, urlTemplate, uriVars); + } + + /** + * HTTP POST variant. See {@link #get(String, Object...)} for general info. + * @param urlTemplate a URL template; the resulting URL will be encoded + * @param uriVars zero or more URI variables + * @return the created builder + */ + public static BodyBuilder post(String urlTemplate, Object... uriVars) { + return method(HttpMethod.POST, urlTemplate, uriVars); + } + + /** + * HTTP PUT variant. See {@link #get(String, Object...)} for general info. + * {@link BaseBuilder#queryParam queryParam} builder methods. + * @param urlTemplate a URL template; the resulting URL will be encoded + * @param uriVars zero or more URI variables + * @return the created builder + */ + public static BodyBuilder put(String urlTemplate, Object... uriVars) { + return method(HttpMethod.PUT, urlTemplate, uriVars); + } + + /** + * HTTP PATCH variant. See {@link #get(String, Object...)} for general info. + * @param urlTemplate a URL template; the resulting URL will be encoded + * @param uriVars zero or more URI variables + * @return the created builder + */ + public static BodyBuilder patch(String urlTemplate, Object... uriVars) { + return method(HttpMethod.PATCH, urlTemplate, uriVars); + } + + /** + * HTTP DELETE variant. See {@link #get(String, Object...)} for general info. + * @param urlTemplate a URL template; the resulting URL will be encoded + * @param uriVars zero or more URI variables + * @return the created builder + */ + public static BaseBuilder delete(String urlTemplate, Object... uriVars) { + return method(HttpMethod.DELETE, urlTemplate, uriVars); + } + + /** + * HTTP OPTIONS variant. See {@link #get(String, Object...)} for general info. + * @param urlTemplate a URL template; the resulting URL will be encoded + * @param uriVars zero or more URI variables + * @return the created builder + */ + public static BaseBuilder options(String urlTemplate, Object... uriVars) { + return method(HttpMethod.OPTIONS, urlTemplate, uriVars); + } + + + /** + * Request builder exposing properties not related to the body. + * @param the builder sub-class + */ + public interface BaseBuilder> { + + /** + * Set the contextPath to return. + */ + B contextPath(String contextPath); + + /** + * Append the given query parameter to the existing query parameters. + * If no values are given, the resulting URI will contain the query + * parameter name only (i.e. {@code ?foo} instead of {@code ?foo=bar}). + *

The provided query name and values will be encoded. + * @param name the query parameter name + * @param values the query parameter values + * @return this UriComponentsBuilder + */ + B queryParam(String name, Object... values); + + /** + * Add the given query parameters and values. The provided query name + * and corresponding values will be encoded. + * @param params the params + * @return this UriComponentsBuilder + */ + B queryParams(MultiValueMap params); + + /** + * Set the remote address to return. + */ + B remoteAddress(InetSocketAddress remoteAddress); + + /** + * Set SSL session information and certificates. + */ + void sslInfo(SslInfo sslInfo); + + /** + * Add one or more cookies. + */ + B cookie(HttpCookie... cookie); + + /** + * Add the given cookies. + * @param cookies the cookies. + */ + B cookies(MultiValueMap cookies); + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @see HttpHeaders#add(String, String) + */ + B header(String headerName, String... headerValues); + + /** + * Add the given header values. + * @param headers the header values + */ + B headers(MultiValueMap headers); + + /** + * Set the list of acceptable {@linkplain MediaType media types}, as + * specified by the {@code Accept} header. + * @param acceptableMediaTypes the acceptable media types + */ + B accept(MediaType... acceptableMediaTypes); + + /** + * Set the list of acceptable {@linkplain Charset charsets}, as specified + * by the {@code Accept-Charset} header. + * @param acceptableCharsets the acceptable charsets + */ + B acceptCharset(Charset... acceptableCharsets); + + /** + * Set the list of acceptable {@linkplain Locale locales}, as specified + * by the {@code Accept-Languages} header. + * @param acceptableLocales the acceptable locales + */ + B acceptLanguageAsLocales(Locale... acceptableLocales); + + /** + * Set the value of the {@code If-Modified-Since} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param ifModifiedSince the new value of the header + */ + B ifModifiedSince(long ifModifiedSince); + + /** + * Set the (new) value of the {@code If-Unmodified-Since} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param ifUnmodifiedSince the new value of the header + * @see HttpHeaders#setIfUnmodifiedSince(long) + */ + B ifUnmodifiedSince(long ifUnmodifiedSince); + + /** + * Set the values of the {@code If-None-Match} header. + * @param ifNoneMatches the new value of the header + */ + B ifNoneMatch(String... ifNoneMatches); + + /** + * Set the (new) value of the Range header. + * @param ranges the HTTP ranges + * @see HttpHeaders#setRange(List) + */ + B range(HttpRange... ranges); + + /** + * Builds the request with no body. + * @return the request + * @see BodyBuilder#body(Publisher) + * @see BodyBuilder#body(String) + */ + MockServerHttpRequest build(); + } + + + /** + * A builder that adds a body to the request. + */ + public interface BodyBuilder extends BaseBuilder { + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + * @param contentLength the content length + * @return this builder + * @see HttpHeaders#setContentLength(long) + */ + BodyBuilder contentLength(long contentLength); + + /** + * Set the {@linkplain MediaType media type} of the body, as specified + * by the {@code Content-Type} header. + * @param contentType the content type + * @return this builder + * @see HttpHeaders#setContentType(MediaType) + */ + BodyBuilder contentType(MediaType contentType); + + /** + * Set the body of the request and build it. + * @param body the body + * @return the built request entity + */ + MockServerHttpRequest body(Publisher body); + + /** + * Set the body of the request and build it. + *

The String is assumed to be UTF-8 encoded unless the request has a + * "content-type" header with a charset attribute. + * @param body the body as text + * @return the built request entity + */ + MockServerHttpRequest body(String body); + } + + + private static class DefaultBodyBuilder implements BodyBuilder { + + private static final DataBufferFactory BUFFER_FACTORY = new DefaultDataBufferFactory(); + + private final HttpMethod method; + + private final URI url; + + @Nullable + private String contextPath; + + private final UriComponentsBuilder queryParamsBuilder = UriComponentsBuilder.newInstance(); + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + @Nullable + private InetSocketAddress remoteAddress; + + @Nullable + private SslInfo sslInfo; + + public DefaultBodyBuilder(HttpMethod method, URI url) { + this.method = method; + this.url = url; + } + + @Override + public BodyBuilder contextPath(String contextPath) { + this.contextPath = contextPath; + return this; + } + + @Override + public BodyBuilder queryParam(String name, Object... values) { + this.queryParamsBuilder.queryParam(name, values); + return this; + } + + @Override + public BodyBuilder queryParams(MultiValueMap params) { + this.queryParamsBuilder.queryParams(params); + return this; + } + + @Override + public BodyBuilder remoteAddress(InetSocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + return this; + } + + @Override + public void sslInfo(SslInfo sslInfo) { + this.sslInfo = sslInfo; + } + + @Override + public BodyBuilder cookie(HttpCookie... cookies) { + Arrays.stream(cookies).forEach(cookie -> this.cookies.add(cookie.getName(), cookie)); + return this; + } + + @Override + public BodyBuilder cookies(MultiValueMap cookies) { + this.cookies.putAll(cookies); + return this; + } + + @Override + public BodyBuilder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public BodyBuilder headers(MultiValueMap headers) { + this.headers.putAll(headers); + return this; + } + + @Override + public BodyBuilder accept(MediaType... acceptableMediaTypes) { + this.headers.setAccept(Arrays.asList(acceptableMediaTypes)); + return this; + } + + @Override + public BodyBuilder acceptCharset(Charset... acceptableCharsets) { + this.headers.setAcceptCharset(Arrays.asList(acceptableCharsets)); + return this; + } + + @Override + public BodyBuilder acceptLanguageAsLocales(Locale... acceptableLocales) { + this.headers.setAcceptLanguageAsLocales(Arrays.asList(acceptableLocales)); + return this; + } + + @Override + public BodyBuilder contentLength(long contentLength) { + this.headers.setContentLength(contentLength); + return this; + } + + @Override + public BodyBuilder contentType(MediaType contentType) { + this.headers.setContentType(contentType); + return this; + } + + @Override + public BodyBuilder ifModifiedSince(long ifModifiedSince) { + this.headers.setIfModifiedSince(ifModifiedSince); + return this; + } + + @Override + public BodyBuilder ifUnmodifiedSince(long ifUnmodifiedSince) { + this.headers.setIfUnmodifiedSince(ifUnmodifiedSince); + return this; + } + + @Override + public BodyBuilder ifNoneMatch(String... ifNoneMatches) { + this.headers.setIfNoneMatch(Arrays.asList(ifNoneMatches)); + return this; + } + + @Override + public BodyBuilder range(HttpRange... ranges) { + this.headers.setRange(Arrays.asList(ranges)); + return this; + } + + @Override + public MockServerHttpRequest build() { + return body(Flux.empty()); + } + + @Override + public MockServerHttpRequest body(String body) { + return body(Flux.just(BUFFER_FACTORY.wrap(body.getBytes(getCharset())))); + } + + private Charset getCharset() { + return Optional.ofNullable(this.headers.getContentType()) + .map(MimeType::getCharset).orElse(StandardCharsets.UTF_8); + } + + @Override + public MockServerHttpRequest body(Publisher body) { + applyCookiesIfNecessary(); + return new MockServerHttpRequest(this.method, getUrlToUse(), this.contextPath, + this.headers, this.cookies, this.remoteAddress, this.sslInfo, body); + } + + private void applyCookiesIfNecessary() { + if (this.headers.get(HttpHeaders.COOKIE) == null) { + this.cookies.values().stream().flatMap(Collection::stream) + .forEach(cookie -> this.headers.add(HttpHeaders.COOKIE, cookie.toString())); + } + } + + private URI getUrlToUse() { + MultiValueMap params = + this.queryParamsBuilder.buildAndExpand().encode().getQueryParams(); + if (!params.isEmpty()) { + return UriComponentsBuilder.fromUri(this.url).queryParams(params).build(true).toUri(); + } + return this.url; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/MockServerHttpResponse.java b/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/MockServerHttpResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..98b92da526d0c32271c510580fd50ae027ff59fa --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/MockServerHttpResponse.java @@ -0,0 +1,162 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.http.server.reactive.test; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseCookie; +import org.springframework.http.server.reactive.AbstractServerHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Mock extension of {@link AbstractServerHttpResponse} for use in tests without + * an actual server. + * + *

By default response content is consumed in full upon writing and cached + * for subsequent access, however it is also possible to set a custom + * {@link #setWriteHandler(Function) writeHandler}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class MockServerHttpResponse extends AbstractServerHttpResponse { + + private Flux body = Flux.error(new IllegalStateException( + "No content was written nor was setComplete() called on this response.")); + + private Function, Mono> writeHandler; + + + public MockServerHttpResponse() { + this(new DefaultDataBufferFactory()); + } + + public MockServerHttpResponse(DataBufferFactory dataBufferFactory) { + super(dataBufferFactory); + this.writeHandler = body -> { + // Avoid .then() which causes data buffers to be released + MonoProcessor completion = MonoProcessor.create(); + this.body = body.doOnComplete(completion::onComplete).doOnError(completion::onError).cache(); + this.body.subscribe(); + return completion; + }; + } + + + /** + * Configure a custom handler to consume the response body. + *

By default, response body content is consumed in full and cached for + * subsequent access in tests. Use this option to take control over how the + * response body is consumed. + * @param writeHandler the write handler to use returning {@code Mono} + * when the body has been "written" (i.e. consumed). + */ + public void setWriteHandler(Function, Mono> writeHandler) { + Assert.notNull(writeHandler, "'writeHandler' is required"); + this.body = Flux.error(new IllegalStateException("Not available with custom write handler.")); + this.writeHandler = writeHandler; + } + + @Override + public T getNativeResponse() { + throw new IllegalStateException("This is a mock. No running server, no native response."); + } + + + @Override + protected void applyStatusCode() { + } + + @Override + protected void applyHeaders() { + } + + @Override + protected void applyCookies() { + for (List cookies : getCookies().values()) { + for (ResponseCookie cookie : cookies) { + getHeaders().add(HttpHeaders.SET_COOKIE, cookie.toString()); + } + } + } + + @Override + protected Mono writeWithInternal(Publisher body) { + return this.writeHandler.apply(Flux.from(body)); + } + + @Override + protected Mono writeAndFlushWithInternal( + Publisher> body) { + + return this.writeHandler.apply(Flux.from(body).concatMap(Flux::from)); + } + + @Override + public Mono setComplete() { + return doCommit(() -> Mono.defer(() -> this.writeHandler.apply(Flux.empty()))); + } + + /** + * Return the response body or an error stream if the body was not set. + */ + public Flux getBody() { + return this.body; + } + + /** + * Aggregate response data and convert to a String using the "Content-Type" + * charset or "UTF-8" by default. + */ + public Mono getBodyAsString() { + + Charset charset = Optional.ofNullable(getHeaders().getContentType()).map(MimeType::getCharset) + .orElse(StandardCharsets.UTF_8); + + return getBody() + .reduce(bufferFactory().allocateBuffer(), (previous, current) -> { + previous.write(current); + DataBufferUtils.release(current); + return previous; + }) + .map(buffer -> bufferToString(buffer, charset)); + } + + private static String bufferToString(DataBuffer buffer, Charset charset) { + Assert.notNull(charset, "'charset' must not be null"); + byte[] bytes = new byte[buffer.readableByteCount()]; + buffer.read(bytes); + DataBufferUtils.release(buffer); + return new String(bytes, charset); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/package-info.java b/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..92e8462a0c8bfd9cbc91dce05bbae476cadcf999 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/http/server/reactive/test/package-info.java @@ -0,0 +1,9 @@ + +// For @NonNull annotations on implementation classes + +@NonNullApi +@NonNullFields +package org.springframework.mock.http.server.reactive.test; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/DelegatingServletInputStream.java b/spring-web/src/test/java/org/springframework/mock/web/test/DelegatingServletInputStream.java new file mode 100644 index 0000000000000000000000000000000000000000..577cfc0abc0005079536706aa15245e25805981c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/DelegatingServletInputStream.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.io.InputStream; + +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; + +import org.springframework.util.Assert; + +/** + * Delegating implementation of {@link javax.servlet.ServletInputStream}. + * + *

Used by {@link MockHttpServletRequest}; typically not directly + * used for testing application controllers. + * + * @author Juergen Hoeller + * @since 1.0.2 + * @see MockHttpServletRequest + */ +public class DelegatingServletInputStream extends ServletInputStream { + + private final InputStream sourceStream; + + private boolean finished = false; + + + /** + * Create a DelegatingServletInputStream for the given source stream. + * @param sourceStream the source stream (never {@code null}) + */ + public DelegatingServletInputStream(InputStream sourceStream) { + Assert.notNull(sourceStream, "Source InputStream must not be null"); + this.sourceStream = sourceStream; + } + + /** + * Return the underlying source stream (never {@code null}). + */ + public final InputStream getSourceStream() { + return this.sourceStream; + } + + + @Override + public int read() throws IOException { + int data = this.sourceStream.read(); + if (data == -1) { + this.finished = true; + } + return data; + } + + @Override + public int available() throws IOException { + return this.sourceStream.available(); + } + + @Override + public void close() throws IOException { + super.close(); + this.sourceStream.close(); + } + + @Override + public boolean isFinished() { + return this.finished; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + throw new UnsupportedOperationException(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/DelegatingServletOutputStream.java b/spring-web/src/test/java/org/springframework/mock/web/test/DelegatingServletOutputStream.java new file mode 100644 index 0000000000000000000000000000000000000000..e70b2984f681eaa542c6ca8aa57d0b12e5fa661b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/DelegatingServletOutputStream.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.io.OutputStream; + +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; + +import org.springframework.util.Assert; + +/** + * Delegating implementation of {@link javax.servlet.ServletOutputStream}. + * + *

Used by {@link MockHttpServletResponse}; typically not directly + * used for testing application controllers. + * + * @author Juergen Hoeller + * @since 1.0.2 + * @see MockHttpServletResponse + */ +public class DelegatingServletOutputStream extends ServletOutputStream { + + private final OutputStream targetStream; + + + /** + * Create a DelegatingServletOutputStream for the given target stream. + * @param targetStream the target stream (never {@code null}) + */ + public DelegatingServletOutputStream(OutputStream targetStream) { + Assert.notNull(targetStream, "Target OutputStream must not be null"); + this.targetStream = targetStream; + } + + /** + * Return the underlying target stream (never {@code null}). + */ + public final OutputStream getTargetStream() { + return this.targetStream; + } + + + @Override + public void write(int b) throws IOException { + this.targetStream.write(b); + } + + @Override + public void flush() throws IOException { + super.flush(); + this.targetStream.flush(); + } + + @Override + public void close() throws IOException { + super.close(); + this.targetStream.close(); + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setWriteListener(WriteListener writeListener) { + throw new UnsupportedOperationException(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/HeaderValueHolder.java b/spring-web/src/test/java/org/springframework/mock/web/test/HeaderValueHolder.java new file mode 100644 index 0000000000000000000000000000000000000000..1259509f22f72712f1cc3c556f07e05fb4259060 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/HeaderValueHolder.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * Internal helper class that serves as value holder for request headers. + * + * @author Juergen Hoeller + * @author Rick Evans + * @since 2.0.1 + */ +class HeaderValueHolder { + + private final List values = new LinkedList<>(); + + + public void setValue(@Nullable Object value) { + this.values.clear(); + if (value != null) { + this.values.add(value); + } + } + + public void addValue(Object value) { + this.values.add(value); + } + + public void addValues(Collection values) { + this.values.addAll(values); + } + + public void addValueArray(Object values) { + CollectionUtils.mergeArrayIntoCollection(values, this.values); + } + + public List getValues() { + return Collections.unmodifiableList(this.values); + } + + public List getStringValues() { + List stringList = new ArrayList<>(this.values.size()); + for (Object value : this.values) { + stringList.add(value.toString()); + } + return Collections.unmodifiableList(stringList); + } + + @Nullable + public Object getValue() { + return (!this.values.isEmpty() ? this.values.get(0) : null); + } + + @Nullable + public String getStringValue() { + return (!this.values.isEmpty() ? String.valueOf(this.values.get(0)) : null); + } + + @Override + public String toString() { + return this.values.toString(); + } + + + /** + * Find a HeaderValueHolder by name, ignoring casing. + * @param headers the Map of header names to HeaderValueHolders + * @param name the name of the desired header + * @return the corresponding HeaderValueHolder, or {@code null} if none found + */ + @Nullable + public static HeaderValueHolder getByName(Map headers, String name) { + Assert.notNull(name, "Header name must not be null"); + for (String headerName : headers.keySet()) { + if (headerName.equalsIgnoreCase(name)) { + return headers.get(headerName); + } + } + return null; + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java new file mode 100644 index 0000000000000000000000000000000000000000..4c60f1e40df010feea94813ff3789b4f3786ed52 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java @@ -0,0 +1,179 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.beans.BeanUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.util.WebUtils; + +/** + * Mock implementation of the {@link AsyncContext} interface. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class MockAsyncContext implements AsyncContext { + + private final HttpServletRequest request; + + @Nullable + private final HttpServletResponse response; + + private final List listeners = new ArrayList<>(); + + @Nullable + private String dispatchedPath; + + private long timeout = 10 * 1000L; // 10 seconds is Tomcat's default + + private final List dispatchHandlers = new ArrayList<>(); + + + public MockAsyncContext(ServletRequest request, @Nullable ServletResponse response) { + this.request = (HttpServletRequest) request; + this.response = (HttpServletResponse) response; + } + + + public void addDispatchHandler(Runnable handler) { + Assert.notNull(handler, "Dispatch handler must not be null"); + synchronized (this) { + if (this.dispatchedPath == null) { + this.dispatchHandlers.add(handler); + } + else { + handler.run(); + } + } + } + + @Override + public ServletRequest getRequest() { + return this.request; + } + + @Override + @Nullable + public ServletResponse getResponse() { + return this.response; + } + + @Override + public boolean hasOriginalRequestAndResponse() { + return (this.request instanceof MockHttpServletRequest && this.response instanceof MockHttpServletResponse); + } + + @Override + public void dispatch() { + dispatch(this.request.getRequestURI()); + } + + @Override + public void dispatch(String path) { + dispatch(null, path); + } + + @Override + public void dispatch(@Nullable ServletContext context, String path) { + synchronized (this) { + this.dispatchedPath = path; + this.dispatchHandlers.forEach(Runnable::run); + } + } + + @Nullable + public String getDispatchedPath() { + return this.dispatchedPath; + } + + @Override + public void complete() { + MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(this.request, MockHttpServletRequest.class); + if (mockRequest != null) { + mockRequest.setAsyncStarted(false); + } + for (AsyncListener listener : this.listeners) { + try { + listener.onComplete(new AsyncEvent(this, this.request, this.response)); + } + catch (IOException ex) { + throw new IllegalStateException("AsyncListener failure", ex); + } + } + } + + @Override + public void start(Runnable runnable) { + runnable.run(); + } + + @Override + public void addListener(AsyncListener listener) { + this.listeners.add(listener); + } + + @Override + public void addListener(AsyncListener listener, ServletRequest request, ServletResponse response) { + this.listeners.add(listener); + } + + public List getListeners() { + return this.listeners; + } + + @Override + public T createListener(Class clazz) throws ServletException { + return BeanUtils.instantiateClass(clazz); + } + + /** + * By default this is set to 10000 (10 seconds) even though the Servlet API + * specifies a default async request timeout of 30 seconds. Keep in mind the + * timeout could further be impacted by global configuration through the MVC + * Java config or the XML namespace, as well as be overridden per request on + * {@link org.springframework.web.context.request.async.DeferredResult DeferredResult} + * or on + * {@link org.springframework.web.servlet.mvc.method.annotation.SseEmitter SseEmitter}. + * @param timeout the timeout value to use. + * @see AsyncContext#setTimeout(long) + */ + @Override + public void setTimeout(long timeout) { + this.timeout = timeout; + } + + @Override + public long getTimeout() { + return this.timeout; + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockBodyContent.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockBodyContent.java new file mode 100644 index 0000000000000000000000000000000000000000..f38bf8a175e260a5aabd728de78fbca479f9ca01 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockBodyContent.java @@ -0,0 +1,226 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.io.Writer; + +import javax.servlet.http.HttpServletResponse; +import javax.servlet.jsp.JspWriter; +import javax.servlet.jsp.tagext.BodyContent; + +import org.springframework.lang.Nullable; + +/** + * Mock implementation of the {@link javax.servlet.jsp.tagext.BodyContent} class. + * Only necessary for testing applications when testing custom JSP tags. + * + * @author Juergen Hoeller + * @since 2.5 + */ +public class MockBodyContent extends BodyContent { + + private final String content; + + + /** + * Create a MockBodyContent for the given response. + * @param content the body content to expose + * @param response the servlet response to wrap + */ + public MockBodyContent(String content, HttpServletResponse response) { + this(content, response, null); + } + + /** + * Create a MockBodyContent for the given response. + * @param content the body content to expose + * @param targetWriter the target Writer to wrap + */ + public MockBodyContent(String content, Writer targetWriter) { + this(content, null, targetWriter); + } + + /** + * Create a MockBodyContent for the given response. + * @param content the body content to expose + * @param response the servlet response to wrap + * @param targetWriter the target Writer to wrap + */ + public MockBodyContent(String content, @Nullable HttpServletResponse response, @Nullable Writer targetWriter) { + super(adaptJspWriter(targetWriter, response)); + this.content = content; + } + + private static JspWriter adaptJspWriter(@Nullable Writer targetWriter, @Nullable HttpServletResponse response) { + if (targetWriter instanceof JspWriter) { + return (JspWriter) targetWriter; + } + else { + return new MockJspWriter(response, targetWriter); + } + } + + + @Override + public Reader getReader() { + return new StringReader(this.content); + } + + @Override + public String getString() { + return this.content; + } + + @Override + public void writeOut(Writer writer) throws IOException { + writer.write(this.content); + } + + + //--------------------------------------------------------------------- + // Delegating implementations of JspWriter's abstract methods + //--------------------------------------------------------------------- + + @Override + public void clear() throws IOException { + getEnclosingWriter().clear(); + } + + @Override + public void clearBuffer() throws IOException { + getEnclosingWriter().clearBuffer(); + } + + @Override + public void close() throws IOException { + getEnclosingWriter().close(); + } + + @Override + public int getRemaining() { + return getEnclosingWriter().getRemaining(); + } + + @Override + public void newLine() throws IOException { + getEnclosingWriter().println(); + } + + @Override + public void write(char[] value, int offset, int length) throws IOException { + getEnclosingWriter().write(value, offset, length); + } + + @Override + public void print(boolean value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(char value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(char[] value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(double value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(float value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(int value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(long value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(Object value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void print(String value) throws IOException { + getEnclosingWriter().print(value); + } + + @Override + public void println() throws IOException { + getEnclosingWriter().println(); + } + + @Override + public void println(boolean value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(char value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(char[] value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(double value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(float value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(int value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(long value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(Object value) throws IOException { + getEnclosingWriter().println(value); + } + + @Override + public void println(String value) throws IOException { + getEnclosingWriter().println(value); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockCookie.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockCookie.java new file mode 100644 index 0000000000000000000000000000000000000000..408e2ecefdd1e30f78a1691e39ea593e271fb3d7 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockCookie.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.time.DateTimeException; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; + +import javax.servlet.http.Cookie; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Extension of {@code Cookie} with extra attributes, as defined in + * RFC 6265. + * + * @author Vedran Pavic + * @author Juergen Hoeller + * @author Sam Brannen + * @since 5.1 + */ +public class MockCookie extends Cookie { + + private static final long serialVersionUID = 4312531139502726325L; + + + @Nullable + private ZonedDateTime expires; + + @Nullable + private String sameSite; + + + /** + * Construct a new {@link MockCookie} with the supplied name and value. + * @param name the name + * @param value the value + * @see Cookie#Cookie(String, String) + */ + public MockCookie(String name, String value) { + super(name, value); + } + + /** + * Set the "Expires" attribute for this cookie. + * @since 5.1.11 + */ + public void setExpires(@Nullable ZonedDateTime expires) { + this.expires = expires; + } + + /** + * Get the "Expires" attribute for this cookie. + * @since 5.1.11 + * @return the "Expires" attribute for this cookie, or {@code null} if not set + */ + @Nullable + public ZonedDateTime getExpires() { + return this.expires; + } + + /** + * Set the "SameSite" attribute for this cookie. + *

This limits the scope of the cookie such that it will only be attached + * to same-site requests if the supplied value is {@code "Strict"} or cross-site + * requests if the supplied value is {@code "Lax"}. + * @see RFC6265 bis + */ + public void setSameSite(@Nullable String sameSite) { + this.sameSite = sameSite; + } + + /** + * Get the "SameSite" attribute for this cookie. + * @return the "SameSite" attribute for this cookie, or {@code null} if not set + */ + @Nullable + public String getSameSite() { + return this.sameSite; + } + + + /** + * Factory method that parses the value of the supplied "Set-Cookie" header. + * @param setCookieHeader the "Set-Cookie" value; never {@code null} or empty + * @return the created cookie + */ + public static MockCookie parse(String setCookieHeader) { + Assert.notNull(setCookieHeader, "Set-Cookie header must not be null"); + String[] cookieParts = setCookieHeader.split("\\s*=\\s*", 2); + Assert.isTrue(cookieParts.length == 2, () -> "Invalid Set-Cookie header '" + setCookieHeader + "'"); + + String name = cookieParts[0]; + String[] valueAndAttributes = cookieParts[1].split("\\s*;\\s*", 2); + String value = valueAndAttributes[0]; + String[] attributes = + (valueAndAttributes.length > 1 ? valueAndAttributes[1].split("\\s*;\\s*") : new String[0]); + + MockCookie cookie = new MockCookie(name, value); + for (String attribute : attributes) { + if (StringUtils.startsWithIgnoreCase(attribute, "Domain")) { + cookie.setDomain(extractAttributeValue(attribute, setCookieHeader)); + } + else if (StringUtils.startsWithIgnoreCase(attribute, "Max-Age")) { + cookie.setMaxAge(Integer.parseInt(extractAttributeValue(attribute, setCookieHeader))); + } + else if (StringUtils.startsWithIgnoreCase(attribute, "Expires")) { + try { + cookie.setExpires(ZonedDateTime.parse(extractAttributeValue(attribute, setCookieHeader), + DateTimeFormatter.RFC_1123_DATE_TIME)); + } + catch (DateTimeException ex) { + // ignore invalid date formats + } + } + else if (StringUtils.startsWithIgnoreCase(attribute, "Path")) { + cookie.setPath(extractAttributeValue(attribute, setCookieHeader)); + } + else if (StringUtils.startsWithIgnoreCase(attribute, "Secure")) { + cookie.setSecure(true); + } + else if (StringUtils.startsWithIgnoreCase(attribute, "HttpOnly")) { + cookie.setHttpOnly(true); + } + else if (StringUtils.startsWithIgnoreCase(attribute, "SameSite")) { + cookie.setSameSite(extractAttributeValue(attribute, setCookieHeader)); + } + } + return cookie; + } + + private static String extractAttributeValue(String attribute, String header) { + String[] nameAndValue = attribute.split("="); + Assert.isTrue(nameAndValue.length == 2, + () -> "No value in attribute '" + nameAndValue[0] + "' for Set-Cookie header '" + header + "'"); + return nameAndValue[1]; + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockExpressionEvaluator.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockExpressionEvaluator.java new file mode 100644 index 0000000000000000000000000000000000000000..0a778efb9163f1cd044d89a05daee0b5d7b64152 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockExpressionEvaluator.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import javax.servlet.jsp.JspException; +import javax.servlet.jsp.PageContext; + +import org.apache.taglibs.standard.lang.support.ExpressionEvaluatorManager; + +/** + * Mock implementation of the JSP 2.0 {@link javax.servlet.jsp.el.ExpressionEvaluator} + * interface, delegating to the Apache JSTL {@link ExpressionEvaluatorManager}. + * Only necessary for testing applications when testing custom JSP tags. + * + *

Note that the Apache JSTL implementation (jstl.jar, standard.jar) has to be + * available on the classpath to use this expression evaluator. + * + * @author Juergen Hoeller + * @since 1.1.5 + * @see org.apache.taglibs.standard.lang.support.ExpressionEvaluatorManager + */ +@SuppressWarnings("deprecation") +public class MockExpressionEvaluator extends javax.servlet.jsp.el.ExpressionEvaluator { + + private final PageContext pageContext; + + + /** + * Create a new MockExpressionEvaluator for the given PageContext. + * @param pageContext the JSP PageContext to run in + */ + public MockExpressionEvaluator(PageContext pageContext) { + this.pageContext = pageContext; + } + + + @Override + @SuppressWarnings("rawtypes") + public javax.servlet.jsp.el.Expression parseExpression(final String expression, final Class expectedType, + final javax.servlet.jsp.el.FunctionMapper functionMapper) throws javax.servlet.jsp.el.ELException { + + return new javax.servlet.jsp.el.Expression() { + @Override + public Object evaluate(javax.servlet.jsp.el.VariableResolver variableResolver) throws javax.servlet.jsp.el.ELException { + return doEvaluate(expression, expectedType, functionMapper); + } + }; + } + + @Override + @SuppressWarnings("rawtypes") + public Object evaluate(String expression, Class expectedType, javax.servlet.jsp.el.VariableResolver variableResolver, + javax.servlet.jsp.el.FunctionMapper functionMapper) throws javax.servlet.jsp.el.ELException { + + return doEvaluate(expression, expectedType, functionMapper); + } + + @SuppressWarnings("rawtypes") + protected Object doEvaluate(String expression, Class expectedType, javax.servlet.jsp.el.FunctionMapper functionMapper) + throws javax.servlet.jsp.el.ELException { + + try { + return ExpressionEvaluatorManager.evaluate("JSP EL expression", expression, expectedType, this.pageContext); + } + catch (JspException ex) { + throw new javax.servlet.jsp.el.ELException("Parsing of JSP EL expression \"" + expression + "\" failed", ex); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockFilterChain.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockFilterChain.java new file mode 100644 index 0000000000000000000000000000000000000000..b834e789cb686d92b24d42f74ffa7a5a57c1e9cb --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockFilterChain.java @@ -0,0 +1,184 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.Servlet; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + +/** + * Mock implementation of the {@link javax.servlet.FilterChain} interface. + * + *

A {@link MockFilterChain} can be configured with one or more filters and a + * Servlet to invoke. The first time the chain is called, it invokes all filters + * and the Servlet, and saves the request and response. Subsequent invocations + * raise an {@link IllegalStateException} unless {@link #reset()} is called. + * + * @author Juergen Hoeller + * @author Rob Winch + * @author Rossen Stoyanchev + * @since 2.0.3 + * @see MockFilterConfig + * @see PassThroughFilterChain + */ +public class MockFilterChain implements FilterChain { + + @Nullable + private ServletRequest request; + + @Nullable + private ServletResponse response; + + private final List filters; + + @Nullable + private Iterator iterator; + + + /** + * Register a single do-nothing {@link Filter} implementation. The first + * invocation saves the request and response. Subsequent invocations raise + * an {@link IllegalStateException} unless {@link #reset()} is called. + */ + public MockFilterChain() { + this.filters = Collections.emptyList(); + } + + /** + * Create a FilterChain with a Servlet. + * @param servlet the Servlet to invoke + * @since 3.2 + */ + public MockFilterChain(Servlet servlet) { + this.filters = initFilterList(servlet); + } + + /** + * Create a {@code FilterChain} with Filter's and a Servlet. + * @param servlet the {@link Servlet} to invoke in this {@link FilterChain} + * @param filters the {@link Filter}'s to invoke in this {@link FilterChain} + * @since 3.2 + */ + public MockFilterChain(Servlet servlet, Filter... filters) { + Assert.notNull(filters, "filters cannot be null"); + Assert.noNullElements(filters, "filters cannot contain null values"); + this.filters = initFilterList(servlet, filters); + } + + private static List initFilterList(Servlet servlet, Filter... filters) { + Filter[] allFilters = ObjectUtils.addObjectToArray(filters, new ServletFilterProxy(servlet)); + return Arrays.asList(allFilters); + } + + + /** + * Return the request that {@link #doFilter} has been called with. + */ + @Nullable + public ServletRequest getRequest() { + return this.request; + } + + /** + * Return the response that {@link #doFilter} has been called with. + */ + @Nullable + public ServletResponse getResponse() { + return this.response; + } + + /** + * Invoke registered {@link Filter Filters} and/or {@link Servlet} also saving the + * request and response. + */ + @Override + public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(response, "Response must not be null"); + Assert.state(this.request == null, "This FilterChain has already been called!"); + + if (this.iterator == null) { + this.iterator = this.filters.iterator(); + } + + if (this.iterator.hasNext()) { + Filter nextFilter = this.iterator.next(); + nextFilter.doFilter(request, response, this); + } + + this.request = request; + this.response = response; + } + + /** + * Reset the {@link MockFilterChain} allowing it to be invoked again. + */ + public void reset() { + this.request = null; + this.response = null; + this.iterator = null; + } + + + /** + * A filter that simply delegates to a Servlet. + */ + private static final class ServletFilterProxy implements Filter { + + private final Servlet delegateServlet; + + private ServletFilterProxy(Servlet servlet) { + Assert.notNull(servlet, "servlet cannot be null"); + this.delegateServlet = servlet; + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + + this.delegateServlet.service(request, response); + } + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + } + + @Override + public void destroy() { + } + + @Override + public String toString() { + return this.delegateServlet.toString(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockFilterConfig.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockFilterConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..6428466ce5306e572fe13b22bb2be8c26529d0d1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockFilterConfig.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.Map; + +import javax.servlet.FilterConfig; +import javax.servlet.ServletContext; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Mock implementation of the {@link javax.servlet.FilterConfig} interface. + * + *

Used for testing the web framework; also useful for testing + * custom {@link javax.servlet.Filter} implementations. + * + * @author Juergen Hoeller + * @since 1.0.2 + * @see MockFilterChain + * @see PassThroughFilterChain + */ +public class MockFilterConfig implements FilterConfig { + + private final ServletContext servletContext; + + private final String filterName; + + private final Map initParameters = new LinkedHashMap<>(); + + + /** + * Create a new MockFilterConfig with a default {@link MockServletContext}. + */ + public MockFilterConfig() { + this(null, ""); + } + + /** + * Create a new MockFilterConfig with a default {@link MockServletContext}. + * @param filterName the name of the filter + */ + public MockFilterConfig(String filterName) { + this(null, filterName); + } + + /** + * Create a new MockFilterConfig. + * @param servletContext the ServletContext that the servlet runs in + */ + public MockFilterConfig(@Nullable ServletContext servletContext) { + this(servletContext, ""); + } + + /** + * Create a new MockFilterConfig. + * @param servletContext the ServletContext that the servlet runs in + * @param filterName the name of the filter + */ + public MockFilterConfig(@Nullable ServletContext servletContext, String filterName) { + this.servletContext = (servletContext != null ? servletContext : new MockServletContext()); + this.filterName = filterName; + } + + + @Override + public String getFilterName() { + return this.filterName; + } + + @Override + public ServletContext getServletContext() { + return this.servletContext; + } + + public void addInitParameter(String name, String value) { + Assert.notNull(name, "Parameter name must not be null"); + this.initParameters.put(name, value); + } + + @Override + public String getInitParameter(String name) { + Assert.notNull(name, "Parameter name must not be null"); + return this.initParameters.get(name); + } + + @Override + public Enumeration getInitParameterNames() { + return Collections.enumeration(this.initParameters.keySet()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..d21ed80859a50aafe1c7d24dbd2d962c6273bf32 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java @@ -0,0 +1,1398 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.io.StringReader; +import java.io.UnsupportedEncodingException; +import java.security.Principal; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Date; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.TimeZone; +import java.util.stream.Collectors; + +import javax.servlet.AsyncContext; +import javax.servlet.DispatcherType; +import javax.servlet.RequestDispatcher; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpUpgradeHandler; +import javax.servlet.http.Part; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; + +/** + * Mock implementation of the {@link javax.servlet.http.HttpServletRequest} interface. + * + *

The default, preferred {@link Locale} for the server mocked by this request + * is {@link Locale#ENGLISH}. This value can be changed via {@link #addPreferredLocale} + * or {@link #setPreferredLocales}. + * + *

As of Spring Framework 5.0, this set of mocks is designed on a Servlet 4.0 baseline. + * + * @author Juergen Hoeller + * @author Rod Johnson + * @author Rick Evans + * @author Mark Fisher + * @author Chris Beams + * @author Sam Brannen + * @author Brian Clozel + * @since 1.0.2 + */ +public class MockHttpServletRequest implements HttpServletRequest { + + private static final String HTTP = "http"; + + private static final String HTTPS = "https"; + + private static final String CHARSET_PREFIX = "charset="; + + private static final TimeZone GMT = TimeZone.getTimeZone("GMT"); + + private static final ServletInputStream EMPTY_SERVLET_INPUT_STREAM = + new DelegatingServletInputStream(StreamUtils.emptyInput()); + + private static final BufferedReader EMPTY_BUFFERED_READER = + new BufferedReader(new StringReader("")); + + /** + * Date formats as specified in the HTTP RFC. + * @see Section 7.1.1.1 of RFC 7231 + */ + private static final String[] DATE_FORMATS = new String[] { + "EEE, dd MMM yyyy HH:mm:ss zzz", + "EEE, dd-MMM-yy HH:mm:ss zzz", + "EEE MMM dd HH:mm:ss yyyy" + }; + + + // --------------------------------------------------------------------- + // Public constants + // --------------------------------------------------------------------- + + /** + * The default protocol: 'HTTP/1.1'. + * @since 4.3.7 + */ + public static final String DEFAULT_PROTOCOL = "HTTP/1.1"; + + /** + * The default scheme: 'http'. + * @since 4.3.7 + */ + public static final String DEFAULT_SCHEME = HTTP; + + /** + * The default server address: '127.0.0.1'. + */ + public static final String DEFAULT_SERVER_ADDR = "127.0.0.1"; + + /** + * The default server name: 'localhost'. + */ + public static final String DEFAULT_SERVER_NAME = "localhost"; + + /** + * The default server port: '80'. + */ + public static final int DEFAULT_SERVER_PORT = 80; + + /** + * The default remote address: '127.0.0.1'. + */ + public static final String DEFAULT_REMOTE_ADDR = "127.0.0.1"; + + /** + * The default remote host: 'localhost'. + */ + public static final String DEFAULT_REMOTE_HOST = "localhost"; + + + // --------------------------------------------------------------------- + // Lifecycle properties + // --------------------------------------------------------------------- + + private final ServletContext servletContext; + + private boolean active = true; + + + // --------------------------------------------------------------------- + // ServletRequest properties + // --------------------------------------------------------------------- + + private final Map attributes = new LinkedHashMap<>(); + + @Nullable + private String characterEncoding; + + @Nullable + private byte[] content; + + @Nullable + private String contentType; + + @Nullable + private ServletInputStream inputStream; + + @Nullable + private BufferedReader reader; + + private final Map parameters = new LinkedHashMap<>(16); + + private String protocol = DEFAULT_PROTOCOL; + + private String scheme = DEFAULT_SCHEME; + + private String serverName = DEFAULT_SERVER_NAME; + + private int serverPort = DEFAULT_SERVER_PORT; + + private String remoteAddr = DEFAULT_REMOTE_ADDR; + + private String remoteHost = DEFAULT_REMOTE_HOST; + + /** List of locales in descending order. */ + private final LinkedList locales = new LinkedList<>(); + + private boolean secure = false; + + private int remotePort = DEFAULT_SERVER_PORT; + + private String localName = DEFAULT_SERVER_NAME; + + private String localAddr = DEFAULT_SERVER_ADDR; + + private int localPort = DEFAULT_SERVER_PORT; + + private boolean asyncStarted = false; + + private boolean asyncSupported = false; + + @Nullable + private MockAsyncContext asyncContext; + + private DispatcherType dispatcherType = DispatcherType.REQUEST; + + + // --------------------------------------------------------------------- + // HttpServletRequest properties + // --------------------------------------------------------------------- + + @Nullable + private String authType; + + @Nullable + private Cookie[] cookies; + + private final Map headers = new LinkedCaseInsensitiveMap<>(); + + @Nullable + private String method; + + @Nullable + private String pathInfo; + + private String contextPath = ""; + + @Nullable + private String queryString; + + @Nullable + private String remoteUser; + + private final Set userRoles = new HashSet<>(); + + @Nullable + private Principal userPrincipal; + + @Nullable + private String requestedSessionId; + + @Nullable + private String requestURI; + + private String servletPath = ""; + + @Nullable + private HttpSession session; + + private boolean requestedSessionIdValid = true; + + private boolean requestedSessionIdFromCookie = true; + + private boolean requestedSessionIdFromURL = false; + + private final MultiValueMap parts = new LinkedMultiValueMap<>(); + + + // --------------------------------------------------------------------- + // Constructors + // --------------------------------------------------------------------- + + /** + * Create a new {@code MockHttpServletRequest} with a default + * {@link MockServletContext}. + * @see #MockHttpServletRequest(ServletContext, String, String) + */ + public MockHttpServletRequest() { + this(null, "", ""); + } + + /** + * Create a new {@code MockHttpServletRequest} with a default + * {@link MockServletContext}. + * @param method the request method (may be {@code null}) + * @param requestURI the request URI (may be {@code null}) + * @see #setMethod + * @see #setRequestURI + * @see #MockHttpServletRequest(ServletContext, String, String) + */ + public MockHttpServletRequest(@Nullable String method, @Nullable String requestURI) { + this(null, method, requestURI); + } + + /** + * Create a new {@code MockHttpServletRequest} with the supplied {@link ServletContext}. + * @param servletContext the ServletContext that the request runs in + * (may be {@code null} to use a default {@link MockServletContext}) + * @see #MockHttpServletRequest(ServletContext, String, String) + */ + public MockHttpServletRequest(@Nullable ServletContext servletContext) { + this(servletContext, "", ""); + } + + /** + * Create a new {@code MockHttpServletRequest} with the supplied {@link ServletContext}, + * {@code method}, and {@code requestURI}. + *

The preferred locale will be set to {@link Locale#ENGLISH}. + * @param servletContext the ServletContext that the request runs in (may be + * {@code null} to use a default {@link MockServletContext}) + * @param method the request method (may be {@code null}) + * @param requestURI the request URI (may be {@code null}) + * @see #setMethod + * @see #setRequestURI + * @see #setPreferredLocales + * @see MockServletContext + */ + public MockHttpServletRequest(@Nullable ServletContext servletContext, @Nullable String method, @Nullable String requestURI) { + this.servletContext = (servletContext != null ? servletContext : new MockServletContext()); + this.method = method; + this.requestURI = requestURI; + this.locales.add(Locale.ENGLISH); + } + + + // --------------------------------------------------------------------- + // Lifecycle methods + // --------------------------------------------------------------------- + + /** + * Return the ServletContext that this request is associated with. (Not + * available in the standard HttpServletRequest interface for some reason.) + */ + @Override + public ServletContext getServletContext() { + return this.servletContext; + } + + /** + * Return whether this request is still active (that is, not completed yet). + */ + public boolean isActive() { + return this.active; + } + + /** + * Mark this request as completed, keeping its state. + */ + public void close() { + this.active = false; + } + + /** + * Invalidate this request, clearing its state. + */ + public void invalidate() { + close(); + clearAttributes(); + } + + /** + * Check whether this request is still active (that is, not completed yet), + * throwing an IllegalStateException if not active anymore. + */ + protected void checkActive() throws IllegalStateException { + Assert.state(this.active, "Request is not active anymore"); + } + + + // --------------------------------------------------------------------- + // ServletRequest interface + // --------------------------------------------------------------------- + + @Override + public Object getAttribute(String name) { + checkActive(); + return this.attributes.get(name); + } + + @Override + public Enumeration getAttributeNames() { + checkActive(); + return Collections.enumeration(new LinkedHashSet<>(this.attributes.keySet())); + } + + @Override + @Nullable + public String getCharacterEncoding() { + return this.characterEncoding; + } + + @Override + public void setCharacterEncoding(@Nullable String characterEncoding) { + this.characterEncoding = characterEncoding; + updateContentTypeHeader(); + } + + private void updateContentTypeHeader() { + if (StringUtils.hasLength(this.contentType)) { + String value = this.contentType; + if (StringUtils.hasLength(this.characterEncoding) && !this.contentType.toLowerCase().contains(CHARSET_PREFIX)) { + value += ';' + CHARSET_PREFIX + this.characterEncoding; + } + doAddHeaderValue(HttpHeaders.CONTENT_TYPE, value, true); + } + } + + /** + * Set the content of the request body as a byte array. + *

If the supplied byte array represents text such as XML or JSON, the + * {@link #setCharacterEncoding character encoding} should typically be + * set as well. + * @see #setCharacterEncoding(String) + * @see #getContentAsByteArray() + * @see #getContentAsString() + */ + public void setContent(@Nullable byte[] content) { + this.content = content; + this.inputStream = null; + this.reader = null; + } + + /** + * Get the content of the request body as a byte array. + * @return the content as a byte array (potentially {@code null}) + * @since 5.0 + * @see #setContent(byte[]) + * @see #getContentAsString() + */ + @Nullable + public byte[] getContentAsByteArray() { + return this.content; + } + + /** + * Get the content of the request body as a {@code String}, using the configured + * {@linkplain #getCharacterEncoding character encoding}. + * @return the content as a {@code String}, potentially {@code null} + * @throws IllegalStateException if the character encoding has not been set + * @throws UnsupportedEncodingException if the character encoding is not supported + * @since 5.0 + * @see #setContent(byte[]) + * @see #setCharacterEncoding(String) + * @see #getContentAsByteArray() + */ + @Nullable + public String getContentAsString() throws IllegalStateException, UnsupportedEncodingException { + Assert.state(this.characterEncoding != null, + "Cannot get content as a String for a null character encoding. " + + "Consider setting the characterEncoding in the request."); + + if (this.content == null) { + return null; + } + return new String(this.content, this.characterEncoding); + } + + @Override + public int getContentLength() { + return (this.content != null ? this.content.length : -1); + } + + @Override + public long getContentLengthLong() { + return getContentLength(); + } + + public void setContentType(@Nullable String contentType) { + this.contentType = contentType; + if (contentType != null) { + try { + MediaType mediaType = MediaType.parseMediaType(contentType); + if (mediaType.getCharset() != null) { + this.characterEncoding = mediaType.getCharset().name(); + } + } + catch (IllegalArgumentException ex) { + // Try to get charset value anyway + int charsetIndex = contentType.toLowerCase().indexOf(CHARSET_PREFIX); + if (charsetIndex != -1) { + this.characterEncoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length()); + } + } + updateContentTypeHeader(); + } + } + + @Override + @Nullable + public String getContentType() { + return this.contentType; + } + + @Override + public ServletInputStream getInputStream() { + if (this.inputStream != null) { + return this.inputStream; + } + else if (this.reader != null) { + throw new IllegalStateException( + "Cannot call getInputStream() after getReader() has already been called for the current request") ; + } + + this.inputStream = (this.content != null ? + new DelegatingServletInputStream(new ByteArrayInputStream(this.content)) : + EMPTY_SERVLET_INPUT_STREAM); + return this.inputStream; + } + + /** + * Set a single value for the specified HTTP parameter. + *

If there are already one or more values registered for the given + * parameter name, they will be replaced. + */ + public void setParameter(String name, String value) { + setParameter(name, new String[] {value}); + } + + /** + * Set an array of values for the specified HTTP parameter. + *

If there are already one or more values registered for the given + * parameter name, they will be replaced. + */ + public void setParameter(String name, String... values) { + Assert.notNull(name, "Parameter name must not be null"); + this.parameters.put(name, values); + } + + /** + * Set all provided parameters replacing any existing + * values for the provided parameter names. To add without replacing + * existing values, use {@link #addParameters(java.util.Map)}. + */ + public void setParameters(Map params) { + Assert.notNull(params, "Parameter map must not be null"); + params.forEach((key, value) -> { + if (value instanceof String) { + setParameter(key, (String) value); + } + else if (value instanceof String[]) { + setParameter(key, (String[]) value); + } + else { + throw new IllegalArgumentException( + "Parameter map value must be single value " + " or array of type [" + String.class.getName() + "]"); + } + }); + } + + /** + * Add a single value for the specified HTTP parameter. + *

If there are already one or more values registered for the given + * parameter name, the given value will be added to the end of the list. + */ + public void addParameter(String name, @Nullable String value) { + addParameter(name, new String[] {value}); + } + + /** + * Add an array of values for the specified HTTP parameter. + *

If there are already one or more values registered for the given + * parameter name, the given values will be added to the end of the list. + */ + public void addParameter(String name, String... values) { + Assert.notNull(name, "Parameter name must not be null"); + String[] oldArr = this.parameters.get(name); + if (oldArr != null) { + String[] newArr = new String[oldArr.length + values.length]; + System.arraycopy(oldArr, 0, newArr, 0, oldArr.length); + System.arraycopy(values, 0, newArr, oldArr.length, values.length); + this.parameters.put(name, newArr); + } + else { + this.parameters.put(name, values); + } + } + + /** + * Add all provided parameters without replacing any + * existing values. To replace existing values, use + * {@link #setParameters(java.util.Map)}. + */ + public void addParameters(Map params) { + Assert.notNull(params, "Parameter map must not be null"); + params.forEach((key, value) -> { + if (value instanceof String) { + addParameter(key, (String) value); + } + else if (value instanceof String[]) { + addParameter(key, (String[]) value); + } + else { + throw new IllegalArgumentException("Parameter map value must be single value " + + " or array of type [" + String.class.getName() + "]"); + } + }); + } + + /** + * Remove already registered values for the specified HTTP parameter, if any. + */ + public void removeParameter(String name) { + Assert.notNull(name, "Parameter name must not be null"); + this.parameters.remove(name); + } + + /** + * Remove all existing parameters. + */ + public void removeAllParameters() { + this.parameters.clear(); + } + + @Override + @Nullable + public String getParameter(String name) { + Assert.notNull(name, "Parameter name must not be null"); + String[] arr = this.parameters.get(name); + return (arr != null && arr.length > 0 ? arr[0] : null); + } + + @Override + public Enumeration getParameterNames() { + return Collections.enumeration(this.parameters.keySet()); + } + + @Override + public String[] getParameterValues(String name) { + Assert.notNull(name, "Parameter name must not be null"); + return this.parameters.get(name); + } + + @Override + public Map getParameterMap() { + return Collections.unmodifiableMap(this.parameters); + } + + public void setProtocol(String protocol) { + this.protocol = protocol; + } + + @Override + public String getProtocol() { + return this.protocol; + } + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + @Override + public String getScheme() { + return this.scheme; + } + + public void setServerName(String serverName) { + this.serverName = serverName; + } + + @Override + public String getServerName() { + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; + if (host != null) { + host = host.trim(); + if (host.startsWith("[")) { + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + host = host.substring(0, indexOfClosingBracket + 1); + } + else if (host.contains(":")) { + host = host.substring(0, host.indexOf(':')); + } + return host; + } + + // else + return this.serverName; + } + + public void setServerPort(int serverPort) { + this.serverPort = serverPort; + } + + @Override + public int getServerPort() { + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; + if (host != null) { + host = host.trim(); + int idx; + if (host.startsWith("[")) { + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + idx = host.indexOf(':', indexOfClosingBracket); + } + else { + idx = host.indexOf(':'); + } + if (idx != -1) { + return Integer.parseInt(host.substring(idx + 1)); + } + } + + // else + return this.serverPort; + } + + @Override + public BufferedReader getReader() throws UnsupportedEncodingException { + if (this.reader != null) { + return this.reader; + } + else if (this.inputStream != null) { + throw new IllegalStateException( + "Cannot call getReader() after getInputStream() has already been called for the current request") ; + } + + if (this.content != null) { + InputStream sourceStream = new ByteArrayInputStream(this.content); + Reader sourceReader = (this.characterEncoding != null) ? + new InputStreamReader(sourceStream, this.characterEncoding) : + new InputStreamReader(sourceStream); + this.reader = new BufferedReader(sourceReader); + } + else { + this.reader = EMPTY_BUFFERED_READER; + } + return this.reader; + } + + public void setRemoteAddr(String remoteAddr) { + this.remoteAddr = remoteAddr; + } + + @Override + public String getRemoteAddr() { + return this.remoteAddr; + } + + public void setRemoteHost(String remoteHost) { + this.remoteHost = remoteHost; + } + + @Override + public String getRemoteHost() { + return this.remoteHost; + } + + @Override + public void setAttribute(String name, @Nullable Object value) { + checkActive(); + Assert.notNull(name, "Attribute name must not be null"); + if (value != null) { + this.attributes.put(name, value); + } + else { + this.attributes.remove(name); + } + } + + @Override + public void removeAttribute(String name) { + checkActive(); + Assert.notNull(name, "Attribute name must not be null"); + this.attributes.remove(name); + } + + /** + * Clear all of this request's attributes. + */ + public void clearAttributes() { + this.attributes.clear(); + } + + /** + * Add a new preferred locale, before any existing locales. + * @see #setPreferredLocales + */ + public void addPreferredLocale(Locale locale) { + Assert.notNull(locale, "Locale must not be null"); + this.locales.addFirst(locale); + updateAcceptLanguageHeader(); + } + + /** + * Set the list of preferred locales, in descending order, effectively replacing + * any existing locales. + * @since 3.2 + * @see #addPreferredLocale + */ + public void setPreferredLocales(List locales) { + Assert.notEmpty(locales, "Locale list must not be empty"); + this.locales.clear(); + this.locales.addAll(locales); + updateAcceptLanguageHeader(); + } + + private void updateAcceptLanguageHeader() { + HttpHeaders headers = new HttpHeaders(); + headers.setAcceptLanguageAsLocales(this.locales); + doAddHeaderValue(HttpHeaders.ACCEPT_LANGUAGE, headers.getFirst(HttpHeaders.ACCEPT_LANGUAGE), true); + } + + /** + * Return the first preferred {@linkplain Locale locale} configured + * in this mock request. + *

If no locales have been explicitly configured, the default, + * preferred {@link Locale} for the server mocked by this + * request is {@link Locale#ENGLISH}. + *

In contrast to the Servlet specification, this mock implementation + * does not take into consideration any locales + * specified via the {@code Accept-Language} header. + * @see javax.servlet.ServletRequest#getLocale() + * @see #addPreferredLocale(Locale) + * @see #setPreferredLocales(List) + */ + @Override + public Locale getLocale() { + return this.locales.getFirst(); + } + + /** + * Return an {@linkplain Enumeration enumeration} of the preferred + * {@linkplain Locale locales} configured in this mock request. + *

If no locales have been explicitly configured, the default, + * preferred {@link Locale} for the server mocked by this + * request is {@link Locale#ENGLISH}. + *

In contrast to the Servlet specification, this mock implementation + * does not take into consideration any locales + * specified via the {@code Accept-Language} header. + * @see javax.servlet.ServletRequest#getLocales() + * @see #addPreferredLocale(Locale) + * @see #setPreferredLocales(List) + */ + @Override + public Enumeration getLocales() { + return Collections.enumeration(this.locales); + } + + /** + * Set the boolean {@code secure} flag indicating whether the mock request + * was made using a secure channel, such as HTTPS. + * @see #isSecure() + * @see #getScheme() + * @see #setScheme(String) + */ + public void setSecure(boolean secure) { + this.secure = secure; + } + + /** + * Return {@code true} if the {@link #setSecure secure} flag has been set + * to {@code true} or if the {@link #getScheme scheme} is {@code https}. + * @see javax.servlet.ServletRequest#isSecure() + */ + @Override + public boolean isSecure() { + return (this.secure || HTTPS.equalsIgnoreCase(this.scheme)); + } + + @Override + public RequestDispatcher getRequestDispatcher(String path) { + return new MockRequestDispatcher(path); + } + + @Override + @Deprecated + public String getRealPath(String path) { + return this.servletContext.getRealPath(path); + } + + public void setRemotePort(int remotePort) { + this.remotePort = remotePort; + } + + @Override + public int getRemotePort() { + return this.remotePort; + } + + public void setLocalName(String localName) { + this.localName = localName; + } + + @Override + public String getLocalName() { + return this.localName; + } + + public void setLocalAddr(String localAddr) { + this.localAddr = localAddr; + } + + @Override + public String getLocalAddr() { + return this.localAddr; + } + + public void setLocalPort(int localPort) { + this.localPort = localPort; + } + + @Override + public int getLocalPort() { + return this.localPort; + } + + @Override + public AsyncContext startAsync() { + return startAsync(this, null); + } + + @Override + public AsyncContext startAsync(ServletRequest request, @Nullable ServletResponse response) { + Assert.state(this.asyncSupported, "Async not supported"); + this.asyncStarted = true; + this.asyncContext = new MockAsyncContext(request, response); + return this.asyncContext; + } + + public void setAsyncStarted(boolean asyncStarted) { + this.asyncStarted = asyncStarted; + } + + @Override + public boolean isAsyncStarted() { + return this.asyncStarted; + } + + public void setAsyncSupported(boolean asyncSupported) { + this.asyncSupported = asyncSupported; + } + + @Override + public boolean isAsyncSupported() { + return this.asyncSupported; + } + + public void setAsyncContext(@Nullable MockAsyncContext asyncContext) { + this.asyncContext = asyncContext; + } + + @Override + @Nullable + public AsyncContext getAsyncContext() { + return this.asyncContext; + } + + public void setDispatcherType(DispatcherType dispatcherType) { + this.dispatcherType = dispatcherType; + } + + @Override + public DispatcherType getDispatcherType() { + return this.dispatcherType; + } + + + // --------------------------------------------------------------------- + // HttpServletRequest interface + // --------------------------------------------------------------------- + + public void setAuthType(@Nullable String authType) { + this.authType = authType; + } + + @Override + @Nullable + public String getAuthType() { + return this.authType; + } + + public void setCookies(@Nullable Cookie... cookies) { + this.cookies = (ObjectUtils.isEmpty(cookies) ? null : cookies); + if (this.cookies == null) { + removeHeader(HttpHeaders.COOKIE); + } + else { + doAddHeaderValue(HttpHeaders.COOKIE, encodeCookies(this.cookies), true); + } + } + + private static String encodeCookies(@NonNull Cookie... cookies) { + return Arrays.stream(cookies) + .map(c -> c.getName() + '=' + (c.getValue() == null ? "" : c.getValue())) + .collect(Collectors.joining("; ")); + } + + @Override + @Nullable + public Cookie[] getCookies() { + return this.cookies; + } + + /** + * Add an HTTP header entry for the given name. + *

While this method can take any {@code Object} as a parameter, + * it is recommended to use the following types: + *

    + *
  • String or any Object to be converted using {@code toString()}; see {@link #getHeader}.
  • + *
  • String, Number, or Date for date headers; see {@link #getDateHeader}.
  • + *
  • String or Number for integer headers; see {@link #getIntHeader}.
  • + *
  • {@code String[]} or {@code Collection} for multiple values; see {@link #getHeaders}.
  • + *
+ * @see #getHeaderNames + * @see #getHeaders + * @see #getHeader + * @see #getDateHeader + */ + public void addHeader(String name, Object value) { + if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name) && + !this.headers.containsKey(HttpHeaders.CONTENT_TYPE)) { + setContentType(value.toString()); + } + else if (HttpHeaders.ACCEPT_LANGUAGE.equalsIgnoreCase(name) && + !this.headers.containsKey(HttpHeaders.ACCEPT_LANGUAGE)) { + try { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ACCEPT_LANGUAGE, value.toString()); + List locales = headers.getAcceptLanguageAsLocales(); + this.locales.clear(); + this.locales.addAll(locales); + if (this.locales.isEmpty()) { + this.locales.add(Locale.ENGLISH); + } + } + catch (IllegalArgumentException ex) { + // Invalid Accept-Language format -> just store plain header + } + doAddHeaderValue(name, value, true); + } + else { + doAddHeaderValue(name, value, false); + } + } + + private void doAddHeaderValue(String name, @Nullable Object value, boolean replace) { + HeaderValueHolder header = this.headers.get(name); + Assert.notNull(value, "Header value must not be null"); + if (header == null || replace) { + header = new HeaderValueHolder(); + this.headers.put(name, header); + } + if (value instanceof Collection) { + header.addValues((Collection) value); + } + else if (value.getClass().isArray()) { + header.addValueArray(value); + } + else { + header.addValue(value); + } + } + + /** + * Remove already registered entries for the specified HTTP header, if any. + * @since 4.3.20 + */ + public void removeHeader(String name) { + Assert.notNull(name, "Header name must not be null"); + this.headers.remove(name); + } + + /** + * Return the long timestamp for the date header with the given {@code name}. + *

If the internal value representation is a String, this method will try + * to parse it as a date using the supported date formats: + *

    + *
  • "EEE, dd MMM yyyy HH:mm:ss zzz"
  • + *
  • "EEE, dd-MMM-yy HH:mm:ss zzz"
  • + *
  • "EEE MMM dd HH:mm:ss yyyy"
  • + *
+ * @param name the header name + * @see Section 7.1.1.1 of RFC 7231 + */ + @Override + public long getDateHeader(String name) { + HeaderValueHolder header = this.headers.get(name); + Object value = (header != null ? header.getValue() : null); + if (value instanceof Date) { + return ((Date) value).getTime(); + } + else if (value instanceof Number) { + return ((Number) value).longValue(); + } + else if (value instanceof String) { + return parseDateHeader(name, (String) value); + } + else if (value != null) { + throw new IllegalArgumentException( + "Value for header '" + name + "' is not a Date, Number, or String: " + value); + } + else { + return -1L; + } + } + + private long parseDateHeader(String name, String value) { + for (String dateFormat : DATE_FORMATS) { + SimpleDateFormat simpleDateFormat = new SimpleDateFormat(dateFormat, Locale.US); + simpleDateFormat.setTimeZone(GMT); + try { + return simpleDateFormat.parse(value).getTime(); + } + catch (ParseException ex) { + // ignore + } + } + throw new IllegalArgumentException("Cannot parse date value '" + value + "' for '" + name + "' header"); + } + + @Override + @Nullable + public String getHeader(String name) { + HeaderValueHolder header = this.headers.get(name); + return (header != null ? header.getStringValue() : null); + } + + @Override + public Enumeration getHeaders(String name) { + HeaderValueHolder header = this.headers.get(name); + return Collections.enumeration(header != null ? header.getStringValues() : new LinkedList<>()); + } + + @Override + public Enumeration getHeaderNames() { + return Collections.enumeration(this.headers.keySet()); + } + + @Override + public int getIntHeader(String name) { + HeaderValueHolder header = this.headers.get(name); + Object value = (header != null ? header.getValue() : null); + if (value instanceof Number) { + return ((Number) value).intValue(); + } + else if (value instanceof String) { + return Integer.parseInt((String) value); + } + else if (value != null) { + throw new NumberFormatException("Value for header '" + name + "' is not a Number: " + value); + } + else { + return -1; + } + } + + public void setMethod(@Nullable String method) { + this.method = method; + } + + @Override + @Nullable + public String getMethod() { + return this.method; + } + + public void setPathInfo(@Nullable String pathInfo) { + this.pathInfo = pathInfo; + } + + @Override + @Nullable + public String getPathInfo() { + return this.pathInfo; + } + + @Override + @Nullable + public String getPathTranslated() { + return (this.pathInfo != null ? getRealPath(this.pathInfo) : null); + } + + public void setContextPath(String contextPath) { + this.contextPath = contextPath; + } + + @Override + public String getContextPath() { + return this.contextPath; + } + + public void setQueryString(@Nullable String queryString) { + this.queryString = queryString; + } + + @Override + @Nullable + public String getQueryString() { + return this.queryString; + } + + public void setRemoteUser(@Nullable String remoteUser) { + this.remoteUser = remoteUser; + } + + @Override + @Nullable + public String getRemoteUser() { + return this.remoteUser; + } + + public void addUserRole(String role) { + this.userRoles.add(role); + } + + @Override + public boolean isUserInRole(String role) { + return (this.userRoles.contains(role) || (this.servletContext instanceof MockServletContext && + ((MockServletContext) this.servletContext).getDeclaredRoles().contains(role))); + } + + public void setUserPrincipal(@Nullable Principal userPrincipal) { + this.userPrincipal = userPrincipal; + } + + @Override + @Nullable + public Principal getUserPrincipal() { + return this.userPrincipal; + } + + public void setRequestedSessionId(@Nullable String requestedSessionId) { + this.requestedSessionId = requestedSessionId; + } + + @Override + @Nullable + public String getRequestedSessionId() { + return this.requestedSessionId; + } + + public void setRequestURI(@Nullable String requestURI) { + this.requestURI = requestURI; + } + + @Override + @Nullable + public String getRequestURI() { + return this.requestURI; + } + + @Override + public StringBuffer getRequestURL() { + String scheme = getScheme(); + String server = getServerName(); + int port = getServerPort(); + String uri = getRequestURI(); + + StringBuffer url = new StringBuffer(scheme).append("://").append(server); + if (port > 0 && ((HTTP.equalsIgnoreCase(scheme) && port != 80) || + (HTTPS.equalsIgnoreCase(scheme) && port != 443))) { + url.append(':').append(port); + } + if (StringUtils.hasText(uri)) { + url.append(uri); + } + return url; + } + + public void setServletPath(String servletPath) { + this.servletPath = servletPath; + } + + @Override + public String getServletPath() { + return this.servletPath; + } + + public void setSession(HttpSession session) { + this.session = session; + if (session instanceof MockHttpSession) { + MockHttpSession mockSession = ((MockHttpSession) session); + mockSession.access(); + } + } + + @Override + @Nullable + public HttpSession getSession(boolean create) { + checkActive(); + // Reset session if invalidated. + if (this.session instanceof MockHttpSession && ((MockHttpSession) this.session).isInvalid()) { + this.session = null; + } + // Create new session if necessary. + if (this.session == null && create) { + this.session = new MockHttpSession(this.servletContext); + } + return this.session; + } + + @Override + @Nullable + public HttpSession getSession() { + return getSession(true); + } + + /** + * The implementation of this (Servlet 3.1+) method calls + * {@link MockHttpSession#changeSessionId()} if the session is a mock session. + * Otherwise it simply returns the current session id. + * @since 4.0.3 + */ + @Override + public String changeSessionId() { + Assert.isTrue(this.session != null, "The request does not have a session"); + if (this.session instanceof MockHttpSession) { + return ((MockHttpSession) this.session).changeSessionId(); + } + return this.session.getId(); + } + + public void setRequestedSessionIdValid(boolean requestedSessionIdValid) { + this.requestedSessionIdValid = requestedSessionIdValid; + } + + @Override + public boolean isRequestedSessionIdValid() { + return this.requestedSessionIdValid; + } + + public void setRequestedSessionIdFromCookie(boolean requestedSessionIdFromCookie) { + this.requestedSessionIdFromCookie = requestedSessionIdFromCookie; + } + + @Override + public boolean isRequestedSessionIdFromCookie() { + return this.requestedSessionIdFromCookie; + } + + public void setRequestedSessionIdFromURL(boolean requestedSessionIdFromURL) { + this.requestedSessionIdFromURL = requestedSessionIdFromURL; + } + + @Override + public boolean isRequestedSessionIdFromURL() { + return this.requestedSessionIdFromURL; + } + + @Override + @Deprecated + public boolean isRequestedSessionIdFromUrl() { + return isRequestedSessionIdFromURL(); + } + + @Override + public boolean authenticate(HttpServletResponse response) throws IOException, ServletException { + throw new UnsupportedOperationException(); + } + + @Override + public void login(String username, String password) throws ServletException { + throw new UnsupportedOperationException(); + } + + @Override + public void logout() throws ServletException { + this.userPrincipal = null; + this.remoteUser = null; + this.authType = null; + } + + public void addPart(Part part) { + this.parts.add(part.getName(), part); + } + + @Override + @Nullable + public Part getPart(String name) throws IOException, ServletException { + return this.parts.getFirst(name); + } + + @Override + public Collection getParts() throws IOException, ServletException { + List result = new LinkedList<>(); + for (List list : this.parts.values()) { + result.addAll(list); + } + return result; + } + + @Override + public T upgrade(Class handlerClass) throws IOException, ServletException { + throw new UnsupportedOperationException(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletResponse.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..cd54634e42ef14bef34bd379de77e60eefa2547c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletResponse.java @@ -0,0 +1,799 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.io.UnsupportedEncodingException; +import java.io.Writer; +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; + +import javax.servlet.ServletOutputStream; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * Mock implementation of the {@link javax.servlet.http.HttpServletResponse} interface. + * + *

As of Spring Framework 5.0, this set of mocks is designed on a Servlet 4.0 baseline. + * + * @author Juergen Hoeller + * @author Rod Johnson + * @author Brian Clozel + * @author Vedran Pavic + * @author Sam Brannen + * @since 1.0.2 + */ +public class MockHttpServletResponse implements HttpServletResponse { + + private static final String CHARSET_PREFIX = "charset="; + + private static final String DATE_FORMAT = "EEE, dd MMM yyyy HH:mm:ss zzz"; + + private static final TimeZone GMT = TimeZone.getTimeZone("GMT"); + + + //--------------------------------------------------------------------- + // ServletResponse properties + //--------------------------------------------------------------------- + + private boolean outputStreamAccessAllowed = true; + + private boolean writerAccessAllowed = true; + + @Nullable + private String characterEncoding = WebUtils.DEFAULT_CHARACTER_ENCODING; + + private boolean charset = false; + + private final ByteArrayOutputStream content = new ByteArrayOutputStream(1024); + + private final ServletOutputStream outputStream = new ResponseServletOutputStream(this.content); + + @Nullable + private PrintWriter writer; + + private long contentLength = 0; + + @Nullable + private String contentType; + + private int bufferSize = 4096; + + private boolean committed; + + private Locale locale = Locale.getDefault(); + + + //--------------------------------------------------------------------- + // HttpServletResponse properties + //--------------------------------------------------------------------- + + private final List cookies = new ArrayList<>(); + + private final Map headers = new LinkedCaseInsensitiveMap<>(); + + private int status = HttpServletResponse.SC_OK; + + @Nullable + private String errorMessage; + + @Nullable + private String forwardedUrl; + + private final List includedUrls = new ArrayList<>(); + + + //--------------------------------------------------------------------- + // ServletResponse interface + //--------------------------------------------------------------------- + + /** + * Set whether {@link #getOutputStream()} access is allowed. + *

Default is {@code true}. + */ + public void setOutputStreamAccessAllowed(boolean outputStreamAccessAllowed) { + this.outputStreamAccessAllowed = outputStreamAccessAllowed; + } + + /** + * Return whether {@link #getOutputStream()} access is allowed. + */ + public boolean isOutputStreamAccessAllowed() { + return this.outputStreamAccessAllowed; + } + + /** + * Set whether {@link #getWriter()} access is allowed. + *

Default is {@code true}. + */ + public void setWriterAccessAllowed(boolean writerAccessAllowed) { + this.writerAccessAllowed = writerAccessAllowed; + } + + /** + * Return whether {@link #getOutputStream()} access is allowed. + */ + public boolean isWriterAccessAllowed() { + return this.writerAccessAllowed; + } + + /** + * Return whether the character encoding has been set. + *

If {@code false}, {@link #getCharacterEncoding()} will return a default encoding value. + */ + public boolean isCharset() { + return this.charset; + } + + @Override + public void setCharacterEncoding(String characterEncoding) { + this.characterEncoding = characterEncoding; + this.charset = true; + updateContentTypeHeader(); + } + + private void updateContentTypeHeader() { + if (this.contentType != null) { + String value = this.contentType; + if (this.charset && !this.contentType.toLowerCase().contains(CHARSET_PREFIX)) { + value = value + ';' + CHARSET_PREFIX + this.characterEncoding; + } + doAddHeaderValue(HttpHeaders.CONTENT_TYPE, value, true); + } + } + + @Override + @Nullable + public String getCharacterEncoding() { + return this.characterEncoding; + } + + @Override + public ServletOutputStream getOutputStream() { + Assert.state(this.outputStreamAccessAllowed, "OutputStream access not allowed"); + return this.outputStream; + } + + @Override + public PrintWriter getWriter() throws UnsupportedEncodingException { + Assert.state(this.writerAccessAllowed, "Writer access not allowed"); + if (this.writer == null) { + Writer targetWriter = (this.characterEncoding != null ? + new OutputStreamWriter(this.content, this.characterEncoding) : + new OutputStreamWriter(this.content)); + this.writer = new ResponsePrintWriter(targetWriter); + } + return this.writer; + } + + public byte[] getContentAsByteArray() { + return this.content.toByteArray(); + } + + public String getContentAsString() throws UnsupportedEncodingException { + return (this.characterEncoding != null ? + this.content.toString(this.characterEncoding) : this.content.toString()); + } + + @Override + public void setContentLength(int contentLength) { + this.contentLength = contentLength; + doAddHeaderValue(HttpHeaders.CONTENT_LENGTH, contentLength, true); + } + + public int getContentLength() { + return (int) this.contentLength; + } + + @Override + public void setContentLengthLong(long contentLength) { + this.contentLength = contentLength; + doAddHeaderValue(HttpHeaders.CONTENT_LENGTH, contentLength, true); + } + + public long getContentLengthLong() { + return this.contentLength; + } + + @Override + public void setContentType(@Nullable String contentType) { + this.contentType = contentType; + if (contentType != null) { + try { + MediaType mediaType = MediaType.parseMediaType(contentType); + if (mediaType.getCharset() != null) { + this.characterEncoding = mediaType.getCharset().name(); + this.charset = true; + } + } + catch (Exception ex) { + // Try to get charset value anyway + int charsetIndex = contentType.toLowerCase().indexOf(CHARSET_PREFIX); + if (charsetIndex != -1) { + this.characterEncoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length()); + this.charset = true; + } + } + updateContentTypeHeader(); + } + } + + @Override + @Nullable + public String getContentType() { + return this.contentType; + } + + @Override + public void setBufferSize(int bufferSize) { + this.bufferSize = bufferSize; + } + + @Override + public int getBufferSize() { + return this.bufferSize; + } + + @Override + public void flushBuffer() { + setCommitted(true); + } + + @Override + public void resetBuffer() { + Assert.state(!isCommitted(), "Cannot reset buffer - response is already committed"); + this.content.reset(); + } + + private void setCommittedIfBufferSizeExceeded() { + int bufSize = getBufferSize(); + if (bufSize > 0 && this.content.size() > bufSize) { + setCommitted(true); + } + } + + public void setCommitted(boolean committed) { + this.committed = committed; + } + + @Override + public boolean isCommitted() { + return this.committed; + } + + @Override + public void reset() { + resetBuffer(); + this.characterEncoding = null; + this.charset = false; + this.contentLength = 0; + this.contentType = null; + this.locale = Locale.getDefault(); + this.cookies.clear(); + this.headers.clear(); + this.status = HttpServletResponse.SC_OK; + this.errorMessage = null; + } + + @Override + public void setLocale(Locale locale) { + this.locale = locale; + doAddHeaderValue(HttpHeaders.CONTENT_LANGUAGE, locale.toLanguageTag(), true); + } + + @Override + public Locale getLocale() { + return this.locale; + } + + + //--------------------------------------------------------------------- + // HttpServletResponse interface + //--------------------------------------------------------------------- + + @Override + public void addCookie(Cookie cookie) { + Assert.notNull(cookie, "Cookie must not be null"); + this.cookies.add(cookie); + doAddHeaderValue(HttpHeaders.SET_COOKIE, getCookieHeader(cookie), false); + } + + private String getCookieHeader(Cookie cookie) { + StringBuilder buf = new StringBuilder(); + buf.append(cookie.getName()).append('=').append(cookie.getValue() == null ? "" : cookie.getValue()); + if (StringUtils.hasText(cookie.getPath())) { + buf.append("; Path=").append(cookie.getPath()); + } + if (StringUtils.hasText(cookie.getDomain())) { + buf.append("; Domain=").append(cookie.getDomain()); + } + int maxAge = cookie.getMaxAge(); + if (maxAge >= 0) { + buf.append("; Max-Age=").append(maxAge); + buf.append("; Expires="); + ZonedDateTime expires = (cookie instanceof MockCookie ? ((MockCookie) cookie).getExpires() : null); + if (expires != null) { + buf.append(expires.format(DateTimeFormatter.RFC_1123_DATE_TIME)); + } + else { + HttpHeaders headers = new HttpHeaders(); + headers.setExpires(maxAge > 0 ? System.currentTimeMillis() + 1000L * maxAge : 0); + buf.append(headers.getFirst(HttpHeaders.EXPIRES)); + } + } + + if (cookie.getSecure()) { + buf.append("; Secure"); + } + if (cookie.isHttpOnly()) { + buf.append("; HttpOnly"); + } + if (cookie instanceof MockCookie) { + MockCookie mockCookie = (MockCookie) cookie; + if (StringUtils.hasText(mockCookie.getSameSite())) { + buf.append("; SameSite=").append(mockCookie.getSameSite()); + } + } + return buf.toString(); + } + + public Cookie[] getCookies() { + return this.cookies.toArray(new Cookie[0]); + } + + @Nullable + public Cookie getCookie(String name) { + Assert.notNull(name, "Cookie name must not be null"); + for (Cookie cookie : this.cookies) { + if (name.equals(cookie.getName())) { + return cookie; + } + } + return null; + } + + @Override + public boolean containsHeader(String name) { + return (this.headers.get(name) != null); + } + + /** + * Return the names of all specified headers as a Set of Strings. + *

As of Servlet 3.0, this method is also defined in {@link HttpServletResponse}. + * @return the {@code Set} of header name {@code Strings}, or an empty {@code Set} if none + */ + @Override + public Collection getHeaderNames() { + return this.headers.keySet(); + } + + /** + * Return the primary value for the given header as a String, if any. + * Will return the first value in case of multiple values. + *

As of Servlet 3.0, this method is also defined in {@link HttpServletResponse}. + * As of Spring 3.1, it returns a stringified value for Servlet 3.0 compatibility. + * Consider using {@link #getHeaderValue(String)} for raw Object access. + * @param name the name of the header + * @return the associated header value, or {@code null} if none + */ + @Override + @Nullable + public String getHeader(String name) { + HeaderValueHolder header = this.headers.get(name); + return (header != null ? header.getStringValue() : null); + } + + /** + * Return all values for the given header as a List of Strings. + *

As of Servlet 3.0, this method is also defined in {@link HttpServletResponse}. + * As of Spring 3.1, it returns a List of stringified values for Servlet 3.0 compatibility. + * Consider using {@link #getHeaderValues(String)} for raw Object access. + * @param name the name of the header + * @return the associated header values, or an empty List if none + */ + @Override + public List getHeaders(String name) { + HeaderValueHolder header = this.headers.get(name); + if (header != null) { + return header.getStringValues(); + } + else { + return Collections.emptyList(); + } + } + + /** + * Return the primary value for the given header, if any. + *

Will return the first value in case of multiple values. + * @param name the name of the header + * @return the associated header value, or {@code null} if none + */ + @Nullable + public Object getHeaderValue(String name) { + HeaderValueHolder header = this.headers.get(name); + return (header != null ? header.getValue() : null); + } + + /** + * Return all values for the given header as a List of value objects. + * @param name the name of the header + * @return the associated header values, or an empty List if none + */ + public List getHeaderValues(String name) { + HeaderValueHolder header = this.headers.get(name); + if (header != null) { + return header.getValues(); + } + else { + return Collections.emptyList(); + } + } + + /** + * The default implementation returns the given URL String as-is. + *

Can be overridden in subclasses, appending a session id or the like. + */ + @Override + public String encodeURL(String url) { + return url; + } + + /** + * The default implementation delegates to {@link #encodeURL}, + * returning the given URL String as-is. + *

Can be overridden in subclasses, appending a session id or the like + * in a redirect-specific fashion. For general URL encoding rules, + * override the common {@link #encodeURL} method instead, applying + * to redirect URLs as well as to general URLs. + */ + @Override + public String encodeRedirectURL(String url) { + return encodeURL(url); + } + + @Override + @Deprecated + public String encodeUrl(String url) { + return encodeURL(url); + } + + @Override + @Deprecated + public String encodeRedirectUrl(String url) { + return encodeRedirectURL(url); + } + + @Override + public void sendError(int status, String errorMessage) throws IOException { + Assert.state(!isCommitted(), "Cannot set error status - response is already committed"); + this.status = status; + this.errorMessage = errorMessage; + setCommitted(true); + } + + @Override + public void sendError(int status) throws IOException { + Assert.state(!isCommitted(), "Cannot set error status - response is already committed"); + this.status = status; + setCommitted(true); + } + + @Override + public void sendRedirect(String url) throws IOException { + Assert.state(!isCommitted(), "Cannot send redirect - response is already committed"); + Assert.notNull(url, "Redirect URL must not be null"); + setHeader(HttpHeaders.LOCATION, url); + setStatus(HttpServletResponse.SC_MOVED_TEMPORARILY); + setCommitted(true); + } + + @Nullable + public String getRedirectedUrl() { + return getHeader(HttpHeaders.LOCATION); + } + + @Override + public void setDateHeader(String name, long value) { + setHeaderValue(name, formatDate(value)); + } + + @Override + public void addDateHeader(String name, long value) { + addHeaderValue(name, formatDate(value)); + } + + public long getDateHeader(String name) { + String headerValue = getHeader(name); + if (headerValue == null) { + return -1; + } + try { + return newDateFormat().parse(getHeader(name)).getTime(); + } + catch (ParseException ex) { + throw new IllegalArgumentException( + "Value for header '" + name + "' is not a valid Date: " + headerValue); + } + } + + private String formatDate(long date) { + return newDateFormat().format(new Date(date)); + } + + private DateFormat newDateFormat() { + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT, Locale.US); + dateFormat.setTimeZone(GMT); + return dateFormat; + } + + @Override + public void setHeader(String name, String value) { + setHeaderValue(name, value); + } + + @Override + public void addHeader(String name, String value) { + addHeaderValue(name, value); + } + + @Override + public void setIntHeader(String name, int value) { + setHeaderValue(name, value); + } + + @Override + public void addIntHeader(String name, int value) { + addHeaderValue(name, value); + } + + private void setHeaderValue(String name, Object value) { + boolean replaceHeader = true; + if (setSpecialHeader(name, value, replaceHeader)) { + return; + } + doAddHeaderValue(name, value, replaceHeader); + } + + private void addHeaderValue(String name, Object value) { + boolean replaceHeader = false; + if (setSpecialHeader(name, value, replaceHeader)) { + return; + } + doAddHeaderValue(name, value, replaceHeader); + } + + private boolean setSpecialHeader(String name, Object value, boolean replaceHeader) { + if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + setContentType(value.toString()); + return true; + } + else if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + setContentLength(value instanceof Number ? ((Number) value).intValue() : + Integer.parseInt(value.toString())); + return true; + } + else if (HttpHeaders.CONTENT_LANGUAGE.equalsIgnoreCase(name)) { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.CONTENT_LANGUAGE, value.toString()); + Locale language = headers.getContentLanguage(); + setLocale(language != null ? language : Locale.getDefault()); + return true; + } + else if (HttpHeaders.SET_COOKIE.equalsIgnoreCase(name)) { + MockCookie cookie = MockCookie.parse(value.toString()); + if (replaceHeader) { + setCookie(cookie); + } + else { + addCookie(cookie); + } + return true; + } + else { + return false; + } + } + + private void doAddHeaderValue(String name, Object value, boolean replace) { + HeaderValueHolder header = this.headers.get(name); + Assert.notNull(value, "Header value must not be null"); + if (header == null) { + header = new HeaderValueHolder(); + this.headers.put(name, header); + } + if (replace) { + header.setValue(value); + } + else { + header.addValue(value); + } + } + + /** + * Set the {@code Set-Cookie} header to the supplied {@link Cookie}, + * overwriting any previous cookies. + * @param cookie the {@code Cookie} to set + * @since 5.1.10 + * @see #addCookie(Cookie) + */ + private void setCookie(Cookie cookie) { + Assert.notNull(cookie, "Cookie must not be null"); + this.cookies.clear(); + this.cookies.add(cookie); + doAddHeaderValue(HttpHeaders.SET_COOKIE, getCookieHeader(cookie), true); + } + + @Override + public void setStatus(int status) { + if (!this.isCommitted()) { + this.status = status; + } + } + + @Override + @Deprecated + public void setStatus(int status, String errorMessage) { + if (!this.isCommitted()) { + this.status = status; + this.errorMessage = errorMessage; + } + } + + @Override + public int getStatus() { + return this.status; + } + + @Nullable + public String getErrorMessage() { + return this.errorMessage; + } + + + //--------------------------------------------------------------------- + // Methods for MockRequestDispatcher + //--------------------------------------------------------------------- + + public void setForwardedUrl(@Nullable String forwardedUrl) { + this.forwardedUrl = forwardedUrl; + } + + @Nullable + public String getForwardedUrl() { + return this.forwardedUrl; + } + + public void setIncludedUrl(@Nullable String includedUrl) { + this.includedUrls.clear(); + if (includedUrl != null) { + this.includedUrls.add(includedUrl); + } + } + + @Nullable + public String getIncludedUrl() { + int count = this.includedUrls.size(); + Assert.state(count <= 1, + () -> "More than 1 URL included - check getIncludedUrls instead: " + this.includedUrls); + return (count == 1 ? this.includedUrls.get(0) : null); + } + + public void addIncludedUrl(String includedUrl) { + Assert.notNull(includedUrl, "Included URL must not be null"); + this.includedUrls.add(includedUrl); + } + + public List getIncludedUrls() { + return this.includedUrls; + } + + + /** + * Inner class that adapts the ServletOutputStream to mark the + * response as committed once the buffer size is exceeded. + */ + private class ResponseServletOutputStream extends DelegatingServletOutputStream { + + public ResponseServletOutputStream(OutputStream out) { + super(out); + } + + @Override + public void write(int b) throws IOException { + super.write(b); + super.flush(); + setCommittedIfBufferSizeExceeded(); + } + + @Override + public void flush() throws IOException { + super.flush(); + setCommitted(true); + } + } + + + /** + * Inner class that adapts the PrintWriter to mark the + * response as committed once the buffer size is exceeded. + */ + private class ResponsePrintWriter extends PrintWriter { + + public ResponsePrintWriter(Writer out) { + super(out, true); + } + + @Override + public void write(char[] buf, int off, int len) { + super.write(buf, off, len); + super.flush(); + setCommittedIfBufferSizeExceeded(); + } + + @Override + public void write(String s, int off, int len) { + super.write(s, off, len); + super.flush(); + setCommittedIfBufferSizeExceeded(); + } + + @Override + public void write(int c) { + super.write(c); + super.flush(); + setCommittedIfBufferSizeExceeded(); + } + + @Override + public void flush() { + super.flush(); + setCommitted(true); + } + + @Override + public void close() { + super.flush(); + super.close(); + setCommitted(true); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpSession.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpSession.java new file mode 100644 index 0000000000000000000000000000000000000000..01ed9876d6aef8edba8e4b2982a5d9cd15090250 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpSession.java @@ -0,0 +1,306 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.Serializable; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; + +import javax.servlet.ServletContext; +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpSessionBindingEvent; +import javax.servlet.http.HttpSessionBindingListener; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Mock implementation of the {@link javax.servlet.http.HttpSession} interface. + * + *

As of Spring 5.0, this set of mocks is designed on a Servlet 4.0 baseline. + * + * @author Juergen Hoeller + * @author Rod Johnson + * @author Mark Fisher + * @author Sam Brannen + * @author Vedran Pavic + * @since 1.0.2 + */ +@SuppressWarnings("deprecation") +public class MockHttpSession implements HttpSession { + + /** + * The session cookie name. + */ + public static final String SESSION_COOKIE_NAME = "JSESSION"; + + + private static int nextId = 1; + + private String id; + + private final long creationTime = System.currentTimeMillis(); + + private int maxInactiveInterval; + + private long lastAccessedTime = System.currentTimeMillis(); + + private final ServletContext servletContext; + + private final Map attributes = new LinkedHashMap<>(); + + private boolean invalid = false; + + private boolean isNew = true; + + + /** + * Create a new MockHttpSession with a default {@link MockServletContext}. + * @see MockServletContext + */ + public MockHttpSession() { + this(null); + } + + /** + * Create a new MockHttpSession. + * @param servletContext the ServletContext that the session runs in + */ + public MockHttpSession(@Nullable ServletContext servletContext) { + this(servletContext, null); + } + + /** + * Create a new MockHttpSession. + * @param servletContext the ServletContext that the session runs in + * @param id a unique identifier for this session + */ + public MockHttpSession(@Nullable ServletContext servletContext, @Nullable String id) { + this.servletContext = (servletContext != null ? servletContext : new MockServletContext()); + this.id = (id != null ? id : Integer.toString(nextId++)); + } + + + @Override + public long getCreationTime() { + assertIsValid(); + return this.creationTime; + } + + @Override + public String getId() { + return this.id; + } + + /** + * As of Servlet 3.1, the id of a session can be changed. + * @return the new session id + * @since 4.0.3 + */ + public String changeSessionId() { + this.id = Integer.toString(nextId++); + return this.id; + } + + public void access() { + this.lastAccessedTime = System.currentTimeMillis(); + this.isNew = false; + } + + @Override + public long getLastAccessedTime() { + assertIsValid(); + return this.lastAccessedTime; + } + + @Override + public ServletContext getServletContext() { + return this.servletContext; + } + + @Override + public void setMaxInactiveInterval(int interval) { + this.maxInactiveInterval = interval; + } + + @Override + public int getMaxInactiveInterval() { + return this.maxInactiveInterval; + } + + @Override + public javax.servlet.http.HttpSessionContext getSessionContext() { + throw new UnsupportedOperationException("getSessionContext"); + } + + @Override + public Object getAttribute(String name) { + assertIsValid(); + Assert.notNull(name, "Attribute name must not be null"); + return this.attributes.get(name); + } + + @Override + public Object getValue(String name) { + return getAttribute(name); + } + + @Override + public Enumeration getAttributeNames() { + assertIsValid(); + return Collections.enumeration(new LinkedHashSet<>(this.attributes.keySet())); + } + + @Override + public String[] getValueNames() { + assertIsValid(); + return StringUtils.toStringArray(this.attributes.keySet()); + } + + @Override + public void setAttribute(String name, @Nullable Object value) { + assertIsValid(); + Assert.notNull(name, "Attribute name must not be null"); + if (value != null) { + Object oldValue = this.attributes.put(name, value); + if (value != oldValue) { + if (oldValue instanceof HttpSessionBindingListener) { + ((HttpSessionBindingListener) oldValue).valueUnbound(new HttpSessionBindingEvent(this, name, oldValue)); + } + if (value instanceof HttpSessionBindingListener) { + ((HttpSessionBindingListener) value).valueBound(new HttpSessionBindingEvent(this, name, value)); + } + } + } + else { + removeAttribute(name); + } + } + + @Override + public void putValue(String name, Object value) { + setAttribute(name, value); + } + + @Override + public void removeAttribute(String name) { + assertIsValid(); + Assert.notNull(name, "Attribute name must not be null"); + Object value = this.attributes.remove(name); + if (value instanceof HttpSessionBindingListener) { + ((HttpSessionBindingListener) value).valueUnbound(new HttpSessionBindingEvent(this, name, value)); + } + } + + @Override + public void removeValue(String name) { + removeAttribute(name); + } + + /** + * Clear all of this session's attributes. + */ + public void clearAttributes() { + for (Iterator> it = this.attributes.entrySet().iterator(); it.hasNext();) { + Map.Entry entry = it.next(); + String name = entry.getKey(); + Object value = entry.getValue(); + it.remove(); + if (value instanceof HttpSessionBindingListener) { + ((HttpSessionBindingListener) value).valueUnbound(new HttpSessionBindingEvent(this, name, value)); + } + } + } + + /** + * Invalidates this session then unbinds any objects bound to it. + * @throws IllegalStateException if this method is called on an already invalidated session + */ + @Override + public void invalidate() { + assertIsValid(); + this.invalid = true; + clearAttributes(); + } + + public boolean isInvalid() { + return this.invalid; + } + + /** + * Convenience method for asserting that this session has not been + * {@linkplain #invalidate() invalidated}. + * @throws IllegalStateException if this session has been invalidated + */ + private void assertIsValid() { + Assert.state(!isInvalid(), "The session has already been invalidated"); + } + + public void setNew(boolean value) { + this.isNew = value; + } + + @Override + public boolean isNew() { + assertIsValid(); + return this.isNew; + } + + /** + * Serialize the attributes of this session into an object that can be + * turned into a byte array with standard Java serialization. + * @return a representation of this session's serialized state + */ + public Serializable serializeState() { + HashMap state = new HashMap<>(); + for (Iterator> it = this.attributes.entrySet().iterator(); it.hasNext();) { + Map.Entry entry = it.next(); + String name = entry.getKey(); + Object value = entry.getValue(); + it.remove(); + if (value instanceof Serializable) { + state.put(name, (Serializable) value); + } + else { + // Not serializable... Servlet containers usually automatically + // unbind the attribute in this case. + if (value instanceof HttpSessionBindingListener) { + ((HttpSessionBindingListener) value).valueUnbound(new HttpSessionBindingEvent(this, name, value)); + } + } + } + return state; + } + + /** + * Deserialize the attributes of this session from a state object created by + * {@link #serializeState()}. + * @param state a representation of this session's serialized state + */ + @SuppressWarnings("unchecked") + public void deserializeState(Serializable state) { + Assert.isTrue(state instanceof Map, "Serialized state needs to be of type [java.util.Map]"); + this.attributes.putAll((Map) state); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockJspWriter.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockJspWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..dbc05ecd618cf7b31ee4e884af1702d3069c6f3b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockJspWriter.java @@ -0,0 +1,219 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.Writer; + +import javax.servlet.http.HttpServletResponse; +import javax.servlet.jsp.JspWriter; + +import org.springframework.lang.Nullable; + +/** + * Mock implementation of the {@link javax.servlet.jsp.JspWriter} class. + * Only necessary for testing applications when testing custom JSP tags. + * + * @author Juergen Hoeller + * @since 2.5 + */ +public class MockJspWriter extends JspWriter { + + private final HttpServletResponse response; + + @Nullable + private PrintWriter targetWriter; + + + /** + * Create a MockJspWriter for the given response, + * using the response's default Writer. + * @param response the servlet response to wrap + */ + public MockJspWriter(HttpServletResponse response) { + this(response, null); + } + + /** + * Create a MockJspWriter for the given plain Writer. + * @param targetWriter the target Writer to wrap + */ + public MockJspWriter(Writer targetWriter) { + this(null, targetWriter); + } + + /** + * Create a MockJspWriter for the given response. + * @param response the servlet response to wrap + * @param targetWriter the target Writer to wrap + */ + public MockJspWriter(@Nullable HttpServletResponse response, @Nullable Writer targetWriter) { + super(DEFAULT_BUFFER, true); + this.response = (response != null ? response : new MockHttpServletResponse()); + if (targetWriter instanceof PrintWriter) { + this.targetWriter = (PrintWriter) targetWriter; + } + else if (targetWriter != null) { + this.targetWriter = new PrintWriter(targetWriter); + } + } + + /** + * Lazily initialize the target Writer. + */ + protected PrintWriter getTargetWriter() throws IOException { + if (this.targetWriter == null) { + this.targetWriter = this.response.getWriter(); + } + return this.targetWriter; + } + + + @Override + public void clear() throws IOException { + if (this.response.isCommitted()) { + throw new IOException("Response already committed"); + } + this.response.resetBuffer(); + } + + @Override + public void clearBuffer() throws IOException { + } + + @Override + public void flush() throws IOException { + this.response.flushBuffer(); + } + + @Override + public void close() throws IOException { + flush(); + } + + @Override + public int getRemaining() { + return Integer.MAX_VALUE; + } + + @Override + public void newLine() throws IOException { + getTargetWriter().println(); + } + + @Override + public void write(char[] value, int offset, int length) throws IOException { + getTargetWriter().write(value, offset, length); + } + + @Override + public void print(boolean value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(char value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(char[] value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(double value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(float value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(int value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(long value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(Object value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void print(String value) throws IOException { + getTargetWriter().print(value); + } + + @Override + public void println() throws IOException { + getTargetWriter().println(); + } + + @Override + public void println(boolean value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(char value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(char[] value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(double value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(float value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(int value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(long value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(Object value) throws IOException { + getTargetWriter().println(value); + } + + @Override + public void println(String value) throws IOException { + getTargetWriter().println(value); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartFile.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartFile.java new file mode 100644 index 0000000000000000000000000000000000000000..8813df7e5cbaff73745c3fa5122d7da95c554f2f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartFile.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.FileCopyUtils; +import org.springframework.web.multipart.MultipartFile; + +/** + * Mock implementation of the {@link org.springframework.web.multipart.MultipartFile} + * interface. + * + *

Useful in conjunction with a {@link MockMultipartHttpServletRequest} + * for testing application controllers that access multipart uploads. + * + * @author Juergen Hoeller + * @author Eric Crampton + * @since 2.0 + * @see MockMultipartHttpServletRequest + */ +public class MockMultipartFile implements MultipartFile { + + private final String name; + + private String originalFilename; + + @Nullable + private String contentType; + + private final byte[] content; + + + /** + * Create a new MockMultipartFile with the given content. + * @param name the name of the file + * @param content the content of the file + */ + public MockMultipartFile(String name, @Nullable byte[] content) { + this(name, "", null, content); + } + + /** + * Create a new MockMultipartFile with the given content. + * @param name the name of the file + * @param contentStream the content of the file as stream + * @throws IOException if reading from the stream failed + */ + public MockMultipartFile(String name, InputStream contentStream) throws IOException { + this(name, "", null, FileCopyUtils.copyToByteArray(contentStream)); + } + + /** + * Create a new MockMultipartFile with the given content. + * @param name the name of the file + * @param originalFilename the original filename (as on the client's machine) + * @param contentType the content type (if known) + * @param content the content of the file + */ + public MockMultipartFile( + String name, @Nullable String originalFilename, @Nullable String contentType, @Nullable byte[] content) { + + Assert.hasLength(name, "Name must not be null"); + this.name = name; + this.originalFilename = (originalFilename != null ? originalFilename : ""); + this.contentType = contentType; + this.content = (content != null ? content : new byte[0]); + } + + /** + * Create a new MockMultipartFile with the given content. + * @param name the name of the file + * @param originalFilename the original filename (as on the client's machine) + * @param contentType the content type (if known) + * @param contentStream the content of the file as stream + * @throws IOException if reading from the stream failed + */ + public MockMultipartFile( + String name, @Nullable String originalFilename, @Nullable String contentType, InputStream contentStream) + throws IOException { + + this(name, originalFilename, contentType, FileCopyUtils.copyToByteArray(contentStream)); + } + + + @Override + public String getName() { + return this.name; + } + + @Override + public String getOriginalFilename() { + return this.originalFilename; + } + + @Override + @Nullable + public String getContentType() { + return this.contentType; + } + + @Override + public boolean isEmpty() { + return (this.content.length == 0); + } + + @Override + public long getSize() { + return this.content.length; + } + + @Override + public byte[] getBytes() throws IOException { + return this.content; + } + + @Override + public InputStream getInputStream() throws IOException { + return new ByteArrayInputStream(this.content); + } + + @Override + public void transferTo(File dest) throws IOException, IllegalStateException { + FileCopyUtils.copy(this.content, dest); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartHttpServletRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..d96a29b4168f3e655218733b4c8b31fe105f6ed7 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartHttpServletRequest.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.util.Collections; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.Part; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartHttpServletRequest; + +/** + * Mock implementation of the + * {@link org.springframework.web.multipart.MultipartHttpServletRequest} interface. + * + *

As of Spring 5.0, this set of mocks is designed on a Servlet 4.0 baseline. + * + *

Useful for testing application controllers that access multipart uploads. + * {@link MockMultipartFile} can be used to populate these mock requests with files. + * + * @author Juergen Hoeller + * @author Eric Crampton + * @author Arjen Poutsma + * @since 2.0 + * @see MockMultipartFile + */ +public class MockMultipartHttpServletRequest extends MockHttpServletRequest implements MultipartHttpServletRequest { + + private final MultiValueMap multipartFiles = new LinkedMultiValueMap<>(); + + + /** + * Create a new {@code MockMultipartHttpServletRequest} with a default + * {@link MockServletContext}. + * @see #MockMultipartHttpServletRequest(ServletContext) + */ + public MockMultipartHttpServletRequest() { + this(null); + } + + /** + * Create a new {@code MockMultipartHttpServletRequest} with the supplied {@link ServletContext}. + * @param servletContext the ServletContext that the request runs in + * (may be {@code null} to use a default {@link MockServletContext}) + */ + public MockMultipartHttpServletRequest(@Nullable ServletContext servletContext) { + super(servletContext); + setMethod("POST"); + setContentType("multipart/form-data"); + } + + + /** + * Add a file to this request. The parameter name from the multipart + * form is taken from the {@link MultipartFile#getName()}. + * @param file multipart file to be added + */ + public void addFile(MultipartFile file) { + Assert.notNull(file, "MultipartFile must not be null"); + this.multipartFiles.add(file.getName(), file); + } + + @Override + public Iterator getFileNames() { + return this.multipartFiles.keySet().iterator(); + } + + @Override + public MultipartFile getFile(String name) { + return this.multipartFiles.getFirst(name); + } + + @Override + public List getFiles(String name) { + List multipartFiles = this.multipartFiles.get(name); + if (multipartFiles != null) { + return multipartFiles; + } + else { + return Collections.emptyList(); + } + } + + @Override + public Map getFileMap() { + return this.multipartFiles.toSingleValueMap(); + } + + @Override + public MultiValueMap getMultiFileMap() { + return new LinkedMultiValueMap<>(this.multipartFiles); + } + + @Override + public String getMultipartContentType(String paramOrFileName) { + MultipartFile file = getFile(paramOrFileName); + if (file != null) { + return file.getContentType(); + } + try { + Part part = getPart(paramOrFileName); + if (part != null) { + return part.getContentType(); + } + } + catch (ServletException | IOException ex) { + // Should never happen (we're not actually parsing) + throw new IllegalStateException(ex); + } + return null; + } + + @Override + public HttpMethod getRequestMethod() { + return HttpMethod.resolve(getMethod()); + } + + @Override + public HttpHeaders getRequestHeaders() { + HttpHeaders headers = new HttpHeaders(); + Enumeration headerNames = getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + headers.put(headerName, Collections.list(getHeaders(headerName))); + } + return headers; + } + + @Override + public HttpHeaders getMultipartHeaders(String paramOrFileName) { + String contentType = getMultipartContentType(paramOrFileName); + if (contentType != null) { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.CONTENT_TYPE, contentType); + return headers; + } + else { + return null; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockPageContext.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockPageContext.java new file mode 100644 index 0000000000000000000000000000000000000000..3959db201f4d1bfa44e74b4c0627435849fe348b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockPageContext.java @@ -0,0 +1,387 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; + +import javax.el.ELContext; +import javax.servlet.Servlet; +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import javax.servlet.jsp.JspWriter; +import javax.servlet.jsp.PageContext; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Mock implementation of the {@link javax.servlet.jsp.PageContext} interface. + * Only necessary for testing applications when testing custom JSP tags. + * + *

Note: Expects initialization via the constructor rather than via the + * {@code PageContext.initialize} method. Does not support writing to a + * JspWriter, request dispatching, or {@code handlePageException} calls. + * + * @author Juergen Hoeller + * @since 1.0.2 + */ +public class MockPageContext extends PageContext { + + private final ServletContext servletContext; + + private final HttpServletRequest request; + + private final HttpServletResponse response; + + private final ServletConfig servletConfig; + + private final Map attributes = new LinkedHashMap<>(); + + @Nullable + private JspWriter out; + + + /** + * Create new MockPageContext with a default {@link MockServletContext}, + * {@link MockHttpServletRequest}, {@link MockHttpServletResponse}, + * {@link MockServletConfig}. + */ + public MockPageContext() { + this(null, null, null, null); + } + + /** + * Create new MockPageContext with a default {@link MockHttpServletRequest}, + * {@link MockHttpServletResponse}, {@link MockServletConfig}. + * @param servletContext the ServletContext that the JSP page runs in + * (only necessary when actually accessing the ServletContext) + */ + public MockPageContext(@Nullable ServletContext servletContext) { + this(servletContext, null, null, null); + } + + /** + * Create new MockPageContext with a MockHttpServletResponse, + * MockServletConfig. + * @param servletContext the ServletContext that the JSP page runs in + * @param request the current HttpServletRequest + * (only necessary when actually accessing the request) + */ + public MockPageContext(@Nullable ServletContext servletContext, @Nullable HttpServletRequest request) { + this(servletContext, request, null, null); + } + + /** + * Create new MockPageContext with a MockServletConfig. + * @param servletContext the ServletContext that the JSP page runs in + * @param request the current HttpServletRequest + * @param response the current HttpServletResponse + * (only necessary when actually writing to the response) + */ + public MockPageContext(@Nullable ServletContext servletContext, @Nullable HttpServletRequest request, + @Nullable HttpServletResponse response) { + + this(servletContext, request, response, null); + } + + /** + * Create new MockServletConfig. + * @param servletContext the ServletContext that the JSP page runs in + * @param request the current HttpServletRequest + * @param response the current HttpServletResponse + * @param servletConfig the ServletConfig (hardly ever accessed from within a tag) + */ + public MockPageContext(@Nullable ServletContext servletContext, @Nullable HttpServletRequest request, + @Nullable HttpServletResponse response, @Nullable ServletConfig servletConfig) { + + this.servletContext = (servletContext != null ? servletContext : new MockServletContext()); + this.request = (request != null ? request : new MockHttpServletRequest(servletContext)); + this.response = (response != null ? response : new MockHttpServletResponse()); + this.servletConfig = (servletConfig != null ? servletConfig : new MockServletConfig(servletContext)); + } + + + @Override + public void initialize( + Servlet servlet, ServletRequest request, ServletResponse response, + String errorPageURL, boolean needsSession, int bufferSize, boolean autoFlush) { + + throw new UnsupportedOperationException("Use appropriate constructor"); + } + + @Override + public void release() { + } + + @Override + public void setAttribute(String name, @Nullable Object value) { + Assert.notNull(name, "Attribute name must not be null"); + if (value != null) { + this.attributes.put(name, value); + } + else { + this.attributes.remove(name); + } + } + + @Override + public void setAttribute(String name, @Nullable Object value, int scope) { + Assert.notNull(name, "Attribute name must not be null"); + switch (scope) { + case PAGE_SCOPE: + setAttribute(name, value); + break; + case REQUEST_SCOPE: + this.request.setAttribute(name, value); + break; + case SESSION_SCOPE: + this.request.getSession().setAttribute(name, value); + break; + case APPLICATION_SCOPE: + this.servletContext.setAttribute(name, value); + break; + default: + throw new IllegalArgumentException("Invalid scope: " + scope); + } + } + + @Override + @Nullable + public Object getAttribute(String name) { + Assert.notNull(name, "Attribute name must not be null"); + return this.attributes.get(name); + } + + @Override + @Nullable + public Object getAttribute(String name, int scope) { + Assert.notNull(name, "Attribute name must not be null"); + switch (scope) { + case PAGE_SCOPE: + return getAttribute(name); + case REQUEST_SCOPE: + return this.request.getAttribute(name); + case SESSION_SCOPE: + HttpSession session = this.request.getSession(false); + return (session != null ? session.getAttribute(name) : null); + case APPLICATION_SCOPE: + return this.servletContext.getAttribute(name); + default: + throw new IllegalArgumentException("Invalid scope: " + scope); + } + } + + @Override + @Nullable + public Object findAttribute(String name) { + Object value = getAttribute(name); + if (value == null) { + value = getAttribute(name, REQUEST_SCOPE); + if (value == null) { + value = getAttribute(name, SESSION_SCOPE); + if (value == null) { + value = getAttribute(name, APPLICATION_SCOPE); + } + } + } + return value; + } + + @Override + public void removeAttribute(String name) { + Assert.notNull(name, "Attribute name must not be null"); + this.removeAttribute(name, PageContext.PAGE_SCOPE); + this.removeAttribute(name, PageContext.REQUEST_SCOPE); + this.removeAttribute(name, PageContext.SESSION_SCOPE); + this.removeAttribute(name, PageContext.APPLICATION_SCOPE); + } + + @Override + public void removeAttribute(String name, int scope) { + Assert.notNull(name, "Attribute name must not be null"); + switch (scope) { + case PAGE_SCOPE: + this.attributes.remove(name); + break; + case REQUEST_SCOPE: + this.request.removeAttribute(name); + break; + case SESSION_SCOPE: + this.request.getSession().removeAttribute(name); + break; + case APPLICATION_SCOPE: + this.servletContext.removeAttribute(name); + break; + default: + throw new IllegalArgumentException("Invalid scope: " + scope); + } + } + + @Override + public int getAttributesScope(String name) { + if (getAttribute(name) != null) { + return PAGE_SCOPE; + } + else if (getAttribute(name, REQUEST_SCOPE) != null) { + return REQUEST_SCOPE; + } + else if (getAttribute(name, SESSION_SCOPE) != null) { + return SESSION_SCOPE; + } + else if (getAttribute(name, APPLICATION_SCOPE) != null) { + return APPLICATION_SCOPE; + } + else { + return 0; + } + } + + public Enumeration getAttributeNames() { + return Collections.enumeration(new LinkedHashSet<>(this.attributes.keySet())); + } + + @Override + public Enumeration getAttributeNamesInScope(int scope) { + switch (scope) { + case PAGE_SCOPE: + return getAttributeNames(); + case REQUEST_SCOPE: + return this.request.getAttributeNames(); + case SESSION_SCOPE: + HttpSession session = this.request.getSession(false); + return (session != null ? session.getAttributeNames() : Collections.emptyEnumeration()); + case APPLICATION_SCOPE: + return this.servletContext.getAttributeNames(); + default: + throw new IllegalArgumentException("Invalid scope: " + scope); + } + } + + @Override + public JspWriter getOut() { + if (this.out == null) { + this.out = new MockJspWriter(this.response); + } + return this.out; + } + + @Override + @Deprecated + public javax.servlet.jsp.el.ExpressionEvaluator getExpressionEvaluator() { + return new MockExpressionEvaluator(this); + } + + @Override + @Nullable + public ELContext getELContext() { + return null; + } + + @Override + @Deprecated + @Nullable + public javax.servlet.jsp.el.VariableResolver getVariableResolver() { + return null; + } + + @Override + public HttpSession getSession() { + return this.request.getSession(); + } + + @Override + public Object getPage() { + return this; + } + + @Override + public ServletRequest getRequest() { + return this.request; + } + + @Override + public ServletResponse getResponse() { + return this.response; + } + + @Override + @Nullable + public Exception getException() { + return null; + } + + @Override + public ServletConfig getServletConfig() { + return this.servletConfig; + } + + @Override + public ServletContext getServletContext() { + return this.servletContext; + } + + @Override + public void forward(String path) throws ServletException, IOException { + this.request.getRequestDispatcher(path).forward(this.request, this.response); + } + + @Override + public void include(String path) throws ServletException, IOException { + this.request.getRequestDispatcher(path).include(this.request, this.response); + } + + @Override + public void include(String path, boolean flush) throws ServletException, IOException { + this.request.getRequestDispatcher(path).include(this.request, this.response); + if (flush) { + this.response.flushBuffer(); + } + } + + public byte[] getContentAsByteArray() { + Assert.state(this.response instanceof MockHttpServletResponse, "MockHttpServletResponse required"); + return ((MockHttpServletResponse) this.response).getContentAsByteArray(); + } + + public String getContentAsString() throws UnsupportedEncodingException { + Assert.state(this.response instanceof MockHttpServletResponse, "MockHttpServletResponse required"); + return ((MockHttpServletResponse) this.response).getContentAsString(); + } + + @Override + public void handlePageException(Exception ex) throws ServletException, IOException { + throw new ServletException("Page exception", ex); + } + + @Override + public void handlePageException(Throwable ex) throws ServletException, IOException { + throw new ServletException("Page exception", ex); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockPart.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockPart.java new file mode 100644 index 0000000000000000000000000000000000000000..62398ffd801ab64272672fe1234cf95832e15bb2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockPart.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Collection; +import java.util.Collections; + +import javax.servlet.http.Part; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Mock implementation of {@code javax.servlet.http.Part}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @since 3.1 + * @see MockHttpServletRequest#addPart + * @see MockMultipartFile + */ +public class MockPart implements Part { + + private final String name; + + @Nullable + private final String filename; + + private final byte[] content; + + private final HttpHeaders headers = new HttpHeaders(); + + + /** + * Constructor for a part with byte[] content only. + * @see #getHeaders() + */ + public MockPart(String name, @Nullable byte[] content) { + this(name, null, content); + } + + /** + * Constructor for a part with a filename and byte[] content. + * @see #getHeaders() + */ + public MockPart(String name, @Nullable String filename, @Nullable byte[] content) { + Assert.hasLength(name, "'name' must not be empty"); + this.name = name; + this.filename = filename; + this.content = (content != null ? content : new byte[0]); + this.headers.setContentDispositionFormData(name, filename); + } + + + @Override + public String getName() { + return this.name; + } + + @Override + @Nullable + public String getSubmittedFileName() { + return this.filename; + } + + @Override + @Nullable + public String getContentType() { + MediaType contentType = this.headers.getContentType(); + return (contentType != null ? contentType.toString() : null); + } + + @Override + public long getSize() { + return this.content.length; + } + + @Override + public InputStream getInputStream() throws IOException { + return new ByteArrayInputStream(this.content); + } + + @Override + public void write(String fileName) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void delete() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + @Nullable + public String getHeader(String name) { + return this.headers.getFirst(name); + } + + @Override + public Collection getHeaders(String name) { + Collection headerValues = this.headers.get(name); + return (headerValues != null ? headerValues : Collections.emptyList()); + } + + @Override + public Collection getHeaderNames() { + return this.headers.keySet(); + } + + /** + * Return the {@link HttpHeaders} backing header related accessor methods, + * allowing for populating selected header entries. + */ + public final HttpHeaders getHeaders() { + return this.headers; + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockRequestDispatcher.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockRequestDispatcher.java new file mode 100644 index 0000000000000000000000000000000000000000..df9dd91dcb0b153bb22df062f6a058a33e90f42e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockRequestDispatcher.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import javax.servlet.RequestDispatcher; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.util.Assert; + +/** + * Mock implementation of the {@link javax.servlet.RequestDispatcher} interface. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Sam Brannen + * @since 1.0.2 + * @see MockHttpServletRequest#getRequestDispatcher(String) + */ +public class MockRequestDispatcher implements RequestDispatcher { + + private final Log logger = LogFactory.getLog(getClass()); + + private final String resource; + + + /** + * Create a new MockRequestDispatcher for the given resource. + * @param resource the server resource to dispatch to, located at a + * particular path or given by a particular name + */ + public MockRequestDispatcher(String resource) { + Assert.notNull(resource, "Resource must not be null"); + this.resource = resource; + } + + + @Override + public void forward(ServletRequest request, ServletResponse response) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(response, "Response must not be null"); + Assert.state(!response.isCommitted(), "Cannot perform forward - response is already committed"); + getMockHttpServletResponse(response).setForwardedUrl(this.resource); + if (logger.isDebugEnabled()) { + logger.debug("MockRequestDispatcher: forwarding to [" + this.resource + "]"); + } + } + + @Override + public void include(ServletRequest request, ServletResponse response) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(response, "Response must not be null"); + getMockHttpServletResponse(response).addIncludedUrl(this.resource); + if (logger.isDebugEnabled()) { + logger.debug("MockRequestDispatcher: including [" + this.resource + "]"); + } + } + + /** + * Obtain the underlying {@link MockHttpServletResponse}, unwrapping + * {@link HttpServletResponseWrapper} decorators if necessary. + */ + protected MockHttpServletResponse getMockHttpServletResponse(ServletResponse response) { + if (response instanceof MockHttpServletResponse) { + return (MockHttpServletResponse) response; + } + if (response instanceof HttpServletResponseWrapper) { + return getMockHttpServletResponse(((HttpServletResponseWrapper) response).getResponse()); + } + throw new IllegalArgumentException("MockRequestDispatcher requires MockHttpServletResponse"); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockServletConfig.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockServletConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..6bcec47ae480b18307bc366ec1e038da85b05d2f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockServletConfig.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.Map; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Mock implementation of the {@link javax.servlet.ServletConfig} interface. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @since 1.0.2 + */ +public class MockServletConfig implements ServletConfig { + + private final ServletContext servletContext; + + private final String servletName; + + private final Map initParameters = new LinkedHashMap<>(); + + + /** + * Create a new MockServletConfig with a default {@link MockServletContext}. + */ + public MockServletConfig() { + this(null, ""); + } + + /** + * Create a new MockServletConfig with a default {@link MockServletContext}. + * @param servletName the name of the servlet + */ + public MockServletConfig(String servletName) { + this(null, servletName); + } + + /** + * Create a new MockServletConfig. + * @param servletContext the ServletContext that the servlet runs in + */ + public MockServletConfig(@Nullable ServletContext servletContext) { + this(servletContext, ""); + } + + /** + * Create a new MockServletConfig. + * @param servletContext the ServletContext that the servlet runs in + * @param servletName the name of the servlet + */ + public MockServletConfig(@Nullable ServletContext servletContext, String servletName) { + this.servletContext = (servletContext != null ? servletContext : new MockServletContext()); + this.servletName = servletName; + } + + + @Override + public String getServletName() { + return this.servletName; + } + + @Override + public ServletContext getServletContext() { + return this.servletContext; + } + + public void addInitParameter(String name, String value) { + Assert.notNull(name, "Parameter name must not be null"); + this.initParameters.put(name, value); + } + + @Override + public String getInitParameter(String name) { + Assert.notNull(name, "Parameter name must not be null"); + return this.initParameters.get(name); + } + + @Override + public Enumeration getInitParameterNames() { + return Collections.enumeration(this.initParameters.keySet()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockServletContext.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockServletContext.java new file mode 100644 index 0000000000000000000000000000000000000000..32fa355b8662cc3755e66539cf0cfad14483ce41 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockServletContext.java @@ -0,0 +1,747 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.file.InvalidPathException; +import java.util.Collections; +import java.util.Enumeration; +import java.util.EventListener; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +import javax.servlet.Filter; +import javax.servlet.FilterRegistration; +import javax.servlet.RequestDispatcher; +import javax.servlet.Servlet; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRegistration; +import javax.servlet.SessionCookieConfig; +import javax.servlet.SessionTrackingMode; +import javax.servlet.descriptor.JspConfigDescriptor; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.core.io.Resource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.MimeType; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.WebUtils; + +/** + * Mock implementation of the {@link javax.servlet.ServletContext} interface. + * + *

As of Spring 5.0, this set of mocks is designed on a Servlet 4.0 baseline. + * + *

Compatible with Servlet 3.1 but can be configured to expose a specific version + * through {@link #setMajorVersion}/{@link #setMinorVersion}; default is 3.1. + * Note that Servlet 3.1 support is limited: servlet, filter and listener + * registration methods are not supported; neither is JSP configuration. + * We generally do not recommend to unit test your ServletContainerInitializers and + * WebApplicationInitializers which is where those registration methods would be used. + * + *

For setting up a full {@code WebApplicationContext} in a test environment, you can + * use {@code AnnotationConfigWebApplicationContext}, {@code XmlWebApplicationContext}, + * or {@code GenericWebApplicationContext}, passing in a corresponding + * {@code MockServletContext} instance. Consider configuring your + * {@code MockServletContext} with a {@code FileSystemResourceLoader} in order to + * interpret resource paths as relative filesystem locations. + * + * @author Rod Johnson + * @author Juergen Hoeller + * @author Sam Brannen + * @since 1.0.2 + * @see #MockServletContext(org.springframework.core.io.ResourceLoader) + * @see org.springframework.web.context.support.AnnotationConfigWebApplicationContext + * @see org.springframework.web.context.support.XmlWebApplicationContext + * @see org.springframework.web.context.support.GenericWebApplicationContext + */ +public class MockServletContext implements ServletContext { + + /** Default Servlet name used by Tomcat, Jetty, JBoss, and GlassFish: {@value}. */ + private static final String COMMON_DEFAULT_SERVLET_NAME = "default"; + + private static final String TEMP_DIR_SYSTEM_PROPERTY = "java.io.tmpdir"; + + private static final Set DEFAULT_SESSION_TRACKING_MODES = new LinkedHashSet<>(4); + + static { + DEFAULT_SESSION_TRACKING_MODES.add(SessionTrackingMode.COOKIE); + DEFAULT_SESSION_TRACKING_MODES.add(SessionTrackingMode.URL); + DEFAULT_SESSION_TRACKING_MODES.add(SessionTrackingMode.SSL); + } + + + private final Log logger = LogFactory.getLog(getClass()); + + private final ResourceLoader resourceLoader; + + private final String resourceBasePath; + + private String contextPath = ""; + + private final Map contexts = new HashMap<>(); + + private int majorVersion = 3; + + private int minorVersion = 1; + + private int effectiveMajorVersion = 3; + + private int effectiveMinorVersion = 1; + + private final Map namedRequestDispatchers = new HashMap<>(); + + private String defaultServletName = COMMON_DEFAULT_SERVLET_NAME; + + private final Map initParameters = new LinkedHashMap<>(); + + private final Map attributes = new LinkedHashMap<>(); + + private String servletContextName = "MockServletContext"; + + private final Set declaredRoles = new LinkedHashSet<>(); + + @Nullable + private Set sessionTrackingModes; + + private final SessionCookieConfig sessionCookieConfig = new MockSessionCookieConfig(); + + private int sessionTimeout; + + @Nullable + private String requestCharacterEncoding; + + @Nullable + private String responseCharacterEncoding; + + private final Map mimeTypes = new LinkedHashMap<>(); + + + /** + * Create a new {@code MockServletContext}, using no base path and a + * {@link DefaultResourceLoader} (i.e. the classpath root as WAR root). + * @see org.springframework.core.io.DefaultResourceLoader + */ + public MockServletContext() { + this("", null); + } + + /** + * Create a new {@code MockServletContext}, using a {@link DefaultResourceLoader}. + * @param resourceBasePath the root directory of the WAR (should not end with a slash) + * @see org.springframework.core.io.DefaultResourceLoader + */ + public MockServletContext(String resourceBasePath) { + this(resourceBasePath, null); + } + + /** + * Create a new {@code MockServletContext}, using the specified {@link ResourceLoader} + * and no base path. + * @param resourceLoader the ResourceLoader to use (or null for the default) + */ + public MockServletContext(@Nullable ResourceLoader resourceLoader) { + this("", resourceLoader); + } + + /** + * Create a new {@code MockServletContext} using the supplied resource base + * path and resource loader. + *

Registers a {@link MockRequestDispatcher} for the Servlet named + * {@literal 'default'}. + * @param resourceBasePath the root directory of the WAR (should not end with a slash) + * @param resourceLoader the ResourceLoader to use (or null for the default) + * @see #registerNamedDispatcher + */ + public MockServletContext(String resourceBasePath, @Nullable ResourceLoader resourceLoader) { + this.resourceLoader = (resourceLoader != null ? resourceLoader : new DefaultResourceLoader()); + this.resourceBasePath = resourceBasePath; + + // Use JVM temp dir as ServletContext temp dir. + String tempDir = System.getProperty(TEMP_DIR_SYSTEM_PROPERTY); + if (tempDir != null) { + this.attributes.put(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File(tempDir)); + } + + registerNamedDispatcher(this.defaultServletName, new MockRequestDispatcher(this.defaultServletName)); + } + + /** + * Build a full resource location for the given path, prepending the resource + * base path of this {@code MockServletContext}. + * @param path the path as specified + * @return the full resource path + */ + protected String getResourceLocation(String path) { + if (!path.startsWith("/")) { + path = "/" + path; + } + return this.resourceBasePath + path; + } + + public void setContextPath(String contextPath) { + this.contextPath = contextPath; + } + + @Override + public String getContextPath() { + return this.contextPath; + } + + public void registerContext(String contextPath, ServletContext context) { + this.contexts.put(contextPath, context); + } + + @Override + public ServletContext getContext(String contextPath) { + if (this.contextPath.equals(contextPath)) { + return this; + } + return this.contexts.get(contextPath); + } + + public void setMajorVersion(int majorVersion) { + this.majorVersion = majorVersion; + } + + @Override + public int getMajorVersion() { + return this.majorVersion; + } + + public void setMinorVersion(int minorVersion) { + this.minorVersion = minorVersion; + } + + @Override + public int getMinorVersion() { + return this.minorVersion; + } + + public void setEffectiveMajorVersion(int effectiveMajorVersion) { + this.effectiveMajorVersion = effectiveMajorVersion; + } + + @Override + public int getEffectiveMajorVersion() { + return this.effectiveMajorVersion; + } + + public void setEffectiveMinorVersion(int effectiveMinorVersion) { + this.effectiveMinorVersion = effectiveMinorVersion; + } + + @Override + public int getEffectiveMinorVersion() { + return this.effectiveMinorVersion; + } + + @Override + @Nullable + public String getMimeType(String filePath) { + String extension = StringUtils.getFilenameExtension(filePath); + if (this.mimeTypes.containsKey(extension)) { + return this.mimeTypes.get(extension).toString(); + } + else { + return MediaTypeFactory.getMediaType(filePath). + map(MimeType::toString) + .orElse(null); + } + } + + /** + * Adds a mime type mapping for use by {@link #getMimeType(String)}. + * @param fileExtension a file extension, such as {@code txt}, {@code gif} + * @param mimeType the mime type + */ + public void addMimeType(String fileExtension, MediaType mimeType) { + Assert.notNull(fileExtension, "'fileExtension' must not be null"); + this.mimeTypes.put(fileExtension, mimeType); + } + + @Override + @Nullable + public Set getResourcePaths(String path) { + String actualPath = (path.endsWith("/") ? path : path + "/"); + String resourceLocation = getResourceLocation(actualPath); + Resource resource = null; + try { + resource = this.resourceLoader.getResource(resourceLocation); + File file = resource.getFile(); + String[] fileList = file.list(); + if (ObjectUtils.isEmpty(fileList)) { + return null; + } + Set resourcePaths = new LinkedHashSet<>(fileList.length); + for (String fileEntry : fileList) { + String resultPath = actualPath + fileEntry; + if (resource.createRelative(fileEntry).getFile().isDirectory()) { + resultPath += "/"; + } + resourcePaths.add(resultPath); + } + return resourcePaths; + } + catch (InvalidPathException | IOException ex ) { + if (logger.isWarnEnabled()) { + logger.warn("Could not get resource paths for " + + (resource != null ? resource : resourceLocation), ex); + } + return null; + } + } + + @Override + @Nullable + public URL getResource(String path) throws MalformedURLException { + String resourceLocation = getResourceLocation(path); + Resource resource = null; + try { + resource = this.resourceLoader.getResource(resourceLocation); + if (!resource.exists()) { + return null; + } + return resource.getURL(); + } + catch (MalformedURLException ex) { + throw ex; + } + catch (InvalidPathException | IOException ex) { + if (logger.isWarnEnabled()) { + logger.warn("Could not get URL for resource " + + (resource != null ? resource : resourceLocation), ex); + } + return null; + } + } + + @Override + @Nullable + public InputStream getResourceAsStream(String path) { + String resourceLocation = getResourceLocation(path); + Resource resource = null; + try { + resource = this.resourceLoader.getResource(resourceLocation); + if (!resource.exists()) { + return null; + } + return resource.getInputStream(); + } + catch (InvalidPathException | IOException ex) { + if (logger.isWarnEnabled()) { + logger.warn("Could not open InputStream for resource " + + (resource != null ? resource : resourceLocation), ex); + } + return null; + } + } + + @Override + public RequestDispatcher getRequestDispatcher(String path) { + Assert.isTrue(path.startsWith("/"), + () -> "RequestDispatcher path [" + path + "] at ServletContext level must start with '/'"); + return new MockRequestDispatcher(path); + } + + @Override + public RequestDispatcher getNamedDispatcher(String path) { + return this.namedRequestDispatchers.get(path); + } + + /** + * Register a {@link RequestDispatcher} (typically a {@link MockRequestDispatcher}) + * that acts as a wrapper for the named Servlet. + * @param name the name of the wrapped Servlet + * @param requestDispatcher the dispatcher that wraps the named Servlet + * @see #getNamedDispatcher + * @see #unregisterNamedDispatcher + */ + public void registerNamedDispatcher(String name, RequestDispatcher requestDispatcher) { + Assert.notNull(name, "RequestDispatcher name must not be null"); + Assert.notNull(requestDispatcher, "RequestDispatcher must not be null"); + this.namedRequestDispatchers.put(name, requestDispatcher); + } + + /** + * Unregister the {@link RequestDispatcher} with the given name. + * @param name the name of the dispatcher to unregister + * @see #getNamedDispatcher + * @see #registerNamedDispatcher + */ + public void unregisterNamedDispatcher(String name) { + Assert.notNull(name, "RequestDispatcher name must not be null"); + this.namedRequestDispatchers.remove(name); + } + + /** + * Get the name of the default {@code Servlet}. + *

Defaults to {@literal 'default'}. + * @see #setDefaultServletName + */ + public String getDefaultServletName() { + return this.defaultServletName; + } + + /** + * Set the name of the default {@code Servlet}. + *

Also {@link #unregisterNamedDispatcher unregisters} the current default + * {@link RequestDispatcher} and {@link #registerNamedDispatcher replaces} + * it with a {@link MockRequestDispatcher} for the provided + * {@code defaultServletName}. + * @param defaultServletName the name of the default {@code Servlet}; + * never {@code null} or empty + * @see #getDefaultServletName + */ + public void setDefaultServletName(String defaultServletName) { + Assert.hasText(defaultServletName, "defaultServletName must not be null or empty"); + unregisterNamedDispatcher(this.defaultServletName); + this.defaultServletName = defaultServletName; + registerNamedDispatcher(this.defaultServletName, new MockRequestDispatcher(this.defaultServletName)); + } + + @Deprecated + @Override + @Nullable + public Servlet getServlet(String name) { + return null; + } + + @Override + @Deprecated + public Enumeration getServlets() { + return Collections.enumeration(Collections.emptySet()); + } + + @Override + @Deprecated + public Enumeration getServletNames() { + return Collections.enumeration(Collections.emptySet()); + } + + @Override + public void log(String message) { + logger.info(message); + } + + @Override + @Deprecated + public void log(Exception ex, String message) { + logger.info(message, ex); + } + + @Override + public void log(String message, Throwable ex) { + logger.info(message, ex); + } + + @Override + @Nullable + public String getRealPath(String path) { + String resourceLocation = getResourceLocation(path); + Resource resource = null; + try { + resource = this.resourceLoader.getResource(resourceLocation); + return resource.getFile().getAbsolutePath(); + } + catch (InvalidPathException | IOException ex) { + if (logger.isWarnEnabled()) { + logger.warn("Could not determine real path of resource " + + (resource != null ? resource : resourceLocation), ex); + } + return null; + } + } + + @Override + public String getServerInfo() { + return "MockServletContext"; + } + + @Override + public String getInitParameter(String name) { + Assert.notNull(name, "Parameter name must not be null"); + return this.initParameters.get(name); + } + + @Override + public Enumeration getInitParameterNames() { + return Collections.enumeration(this.initParameters.keySet()); + } + + @Override + public boolean setInitParameter(String name, String value) { + Assert.notNull(name, "Parameter name must not be null"); + if (this.initParameters.containsKey(name)) { + return false; + } + this.initParameters.put(name, value); + return true; + } + + public void addInitParameter(String name, String value) { + Assert.notNull(name, "Parameter name must not be null"); + this.initParameters.put(name, value); + } + + @Override + @Nullable + public Object getAttribute(String name) { + Assert.notNull(name, "Attribute name must not be null"); + return this.attributes.get(name); + } + + @Override + public Enumeration getAttributeNames() { + return Collections.enumeration(new LinkedHashSet<>(this.attributes.keySet())); + } + + @Override + public void setAttribute(String name, @Nullable Object value) { + Assert.notNull(name, "Attribute name must not be null"); + if (value != null) { + this.attributes.put(name, value); + } + else { + this.attributes.remove(name); + } + } + + @Override + public void removeAttribute(String name) { + Assert.notNull(name, "Attribute name must not be null"); + this.attributes.remove(name); + } + + public void setServletContextName(String servletContextName) { + this.servletContextName = servletContextName; + } + + @Override + public String getServletContextName() { + return this.servletContextName; + } + + @Override + @Nullable + public ClassLoader getClassLoader() { + return ClassUtils.getDefaultClassLoader(); + } + + @Override + public void declareRoles(String... roleNames) { + Assert.notNull(roleNames, "Role names array must not be null"); + for (String roleName : roleNames) { + Assert.hasLength(roleName, "Role name must not be empty"); + this.declaredRoles.add(roleName); + } + } + + public Set getDeclaredRoles() { + return Collections.unmodifiableSet(this.declaredRoles); + } + + @Override + public void setSessionTrackingModes(Set sessionTrackingModes) + throws IllegalStateException, IllegalArgumentException { + this.sessionTrackingModes = sessionTrackingModes; + } + + @Override + public Set getDefaultSessionTrackingModes() { + return DEFAULT_SESSION_TRACKING_MODES; + } + + @Override + public Set getEffectiveSessionTrackingModes() { + return (this.sessionTrackingModes != null ? + Collections.unmodifiableSet(this.sessionTrackingModes) : DEFAULT_SESSION_TRACKING_MODES); + } + + @Override + public SessionCookieConfig getSessionCookieConfig() { + return this.sessionCookieConfig; + } + + @Override // on Servlet 4.0 + public void setSessionTimeout(int sessionTimeout) { + this.sessionTimeout = sessionTimeout; + } + + @Override // on Servlet 4.0 + public int getSessionTimeout() { + return this.sessionTimeout; + } + + @Override // on Servlet 4.0 + public void setRequestCharacterEncoding(@Nullable String requestCharacterEncoding) { + this.requestCharacterEncoding = requestCharacterEncoding; + } + + @Override // on Servlet 4.0 + @Nullable + public String getRequestCharacterEncoding() { + return this.requestCharacterEncoding; + } + + @Override // on Servlet 4.0 + public void setResponseCharacterEncoding(@Nullable String responseCharacterEncoding) { + this.responseCharacterEncoding = responseCharacterEncoding; + } + + @Override // on Servlet 4.0 + @Nullable + public String getResponseCharacterEncoding() { + return this.responseCharacterEncoding; + } + + + //--------------------------------------------------------------------- + // Unsupported Servlet 3.0 registration methods + //--------------------------------------------------------------------- + + @Override + public JspConfigDescriptor getJspConfigDescriptor() { + throw new UnsupportedOperationException(); + } + + @Override // on Servlet 4.0 + public ServletRegistration.Dynamic addJspFile(String servletName, String jspFile) { + throw new UnsupportedOperationException(); + } + + @Override + public ServletRegistration.Dynamic addServlet(String servletName, String className) { + throw new UnsupportedOperationException(); + } + + @Override + public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) { + throw new UnsupportedOperationException(); + } + + @Override + public ServletRegistration.Dynamic addServlet(String servletName, Class servletClass) { + throw new UnsupportedOperationException(); + } + + @Override + public T createServlet(Class c) throws ServletException { + throw new UnsupportedOperationException(); + } + + /** + * This method always returns {@code null}. + * @see javax.servlet.ServletContext#getServletRegistration(java.lang.String) + */ + @Override + @Nullable + public ServletRegistration getServletRegistration(String servletName) { + return null; + } + + /** + * This method always returns an {@linkplain Collections#emptyMap empty map}. + * @see javax.servlet.ServletContext#getServletRegistrations() + */ + @Override + public Map getServletRegistrations() { + return Collections.emptyMap(); + } + + @Override + public FilterRegistration.Dynamic addFilter(String filterName, String className) { + throw new UnsupportedOperationException(); + } + + @Override + public FilterRegistration.Dynamic addFilter(String filterName, Filter filter) { + throw new UnsupportedOperationException(); + } + + @Override + public FilterRegistration.Dynamic addFilter(String filterName, Class filterClass) { + throw new UnsupportedOperationException(); + } + + @Override + public T createFilter(Class c) throws ServletException { + throw new UnsupportedOperationException(); + } + + /** + * This method always returns {@code null}. + * @see javax.servlet.ServletContext#getFilterRegistration(java.lang.String) + */ + @Override + @Nullable + public FilterRegistration getFilterRegistration(String filterName) { + return null; + } + + /** + * This method always returns an {@linkplain Collections#emptyMap empty map}. + * @see javax.servlet.ServletContext#getFilterRegistrations() + */ + @Override + public Map getFilterRegistrations() { + return Collections.emptyMap(); + } + + @Override + public void addListener(Class listenerClass) { + throw new UnsupportedOperationException(); + } + + @Override + public void addListener(String className) { + throw new UnsupportedOperationException(); + } + + @Override + public void addListener(T t) { + throw new UnsupportedOperationException(); + } + + @Override + public T createListener(Class c) throws ServletException { + throw new UnsupportedOperationException(); + } + + @Override + public String getVirtualServerName() { + throw new UnsupportedOperationException(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockSessionCookieConfig.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockSessionCookieConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..54b9b5bf28b3f4f71cd37af2774e7183a003a396 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockSessionCookieConfig.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import javax.servlet.SessionCookieConfig; + +import org.springframework.lang.Nullable; + +/** + * Mock implementation of the {@link javax.servlet.SessionCookieConfig} interface. + * + * @author Juergen Hoeller + * @since 4.0 + * @see javax.servlet.ServletContext#getSessionCookieConfig() + */ +public class MockSessionCookieConfig implements SessionCookieConfig { + + @Nullable + private String name; + + @Nullable + private String domain; + + @Nullable + private String path; + + @Nullable + private String comment; + + private boolean httpOnly; + + private boolean secure; + + private int maxAge = -1; + + + @Override + public void setName(@Nullable String name) { + this.name = name; + } + + @Override + @Nullable + public String getName() { + return this.name; + } + + @Override + public void setDomain(@Nullable String domain) { + this.domain = domain; + } + + @Override + @Nullable + public String getDomain() { + return this.domain; + } + + @Override + public void setPath(@Nullable String path) { + this.path = path; + } + + @Override + @Nullable + public String getPath() { + return this.path; + } + + @Override + public void setComment(@Nullable String comment) { + this.comment = comment; + } + + @Override + @Nullable + public String getComment() { + return this.comment; + } + + @Override + public void setHttpOnly(boolean httpOnly) { + this.httpOnly = httpOnly; + } + + @Override + public boolean isHttpOnly() { + return this.httpOnly; + } + + @Override + public void setSecure(boolean secure) { + this.secure = secure; + } + + @Override + public boolean isSecure() { + return this.secure; + } + + @Override + public void setMaxAge(int maxAge) { + this.maxAge = maxAge; + } + + @Override + public int getMaxAge() { + return this.maxAge; + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/PassThroughFilterChain.java b/spring-web/src/test/java/org/springframework/mock/web/test/PassThroughFilterChain.java new file mode 100644 index 0000000000000000000000000000000000000000..c440b7857908a254b9b3059aabd0bc7a0787b724 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/PassThroughFilterChain.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test; + +import java.io.IOException; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.Servlet; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of the {@link javax.servlet.FilterConfig} interface which + * simply passes the call through to a given Filter/FilterChain combination + * (indicating the next Filter in the chain along with the FilterChain that it is + * supposed to work on) or to a given Servlet (indicating the end of the chain). + * + * @author Juergen Hoeller + * @since 2.0.3 + * @see javax.servlet.Filter + * @see javax.servlet.Servlet + * @see MockFilterChain + */ +public class PassThroughFilterChain implements FilterChain { + + @Nullable + private Filter filter; + + @Nullable + private FilterChain nextFilterChain; + + @Nullable + private Servlet servlet; + + + /** + * Create a new PassThroughFilterChain that delegates to the given Filter, + * calling it with the given FilterChain. + * @param filter the Filter to delegate to + * @param nextFilterChain the FilterChain to use for that next Filter + */ + public PassThroughFilterChain(Filter filter, FilterChain nextFilterChain) { + Assert.notNull(filter, "Filter must not be null"); + Assert.notNull(nextFilterChain, "'FilterChain must not be null"); + this.filter = filter; + this.nextFilterChain = nextFilterChain; + } + + /** + * Create a new PassThroughFilterChain that delegates to the given Servlet. + * @param servlet the Servlet to delegate to + */ + public PassThroughFilterChain(Servlet servlet) { + Assert.notNull(servlet, "Servlet must not be null"); + this.servlet = servlet; + } + + + /** + * Pass the call on to the Filter/Servlet. + */ + @Override + public void doFilter(ServletRequest request, ServletResponse response) throws ServletException, IOException { + if (this.filter != null) { + this.filter.doFilter(request, response, this.nextFilterChain); + } + else { + Assert.state(this.servlet != null, "Neither a Filter not a Servlet set"); + this.servlet.service(request, response); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/server/MockServerWebExchange.java b/spring-web/src/test/java/org/springframework/mock/web/test/server/MockServerWebExchange.java new file mode 100644 index 0000000000000000000000000000000000000000..75b51a892bcb6aeed05ce817e010eeb01ff489f6 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/server/MockServerWebExchange.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.mock.web.test.server; + +import reactor.core.publisher.Mono; + +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.lang.Nullable; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.session.DefaultWebSessionManager; +import org.springframework.web.server.session.WebSessionManager; + +/** + * Extension of {@link DefaultServerWebExchange} for use in tests, along with + * {@link MockServerHttpRequest} and {@link MockServerHttpResponse}. + * + *

See static factory methods to create an instance. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public final class MockServerWebExchange extends DefaultServerWebExchange { + + + private MockServerWebExchange(MockServerHttpRequest request, WebSessionManager sessionManager) { + super(request, new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + } + + + @Override + public MockServerHttpResponse getResponse() { + return (MockServerHttpResponse) super.getResponse(); + } + + + /** + * Create a {@link MockServerWebExchange} from the given mock request. + * @param request the request to use. + * @return the exchange + */ + public static MockServerWebExchange from(MockServerHttpRequest request) { + return builder(request).build(); + } + + /** + * Variant of {@link #from(MockServerHttpRequest)} with a mock request builder. + * @param requestBuilder the builder for the mock request. + * @return the exchange + */ + public static MockServerWebExchange from(MockServerHttpRequest.BaseBuilder requestBuilder) { + return builder(requestBuilder).build(); + } + + /** + * Create a {@link Builder} starting with the given mock request. + * @param request the request to use. + * @return the exchange builder + * @since 5.1 + */ + public static MockServerWebExchange.Builder builder(MockServerHttpRequest request) { + return new MockServerWebExchange.Builder(request); + } + + /** + * Variant of {@link #builder(MockServerHttpRequest)} with a mock request builder. + * @param requestBuilder the builder for the mock request. + * @return the exchange builder + * @since 5.1 + */ + public static MockServerWebExchange.Builder builder(MockServerHttpRequest.BaseBuilder requestBuilder) { + return new MockServerWebExchange.Builder(requestBuilder.build()); + } + + + /** + * Builder for a {@link MockServerWebExchange}. + * @since 5.1 + */ + public static class Builder { + + private final MockServerHttpRequest request; + + @Nullable + private WebSessionManager sessionManager; + + + public Builder(MockServerHttpRequest request) { + this.request = request; + } + + + /** + * Set the session to use for the exchange. + *

This is mutually exclusive with {@link #sessionManager(WebSessionManager)}. + * @param session the session to use + */ + public Builder session(WebSession session) { + this.sessionManager = exchange -> Mono.just(session); + return this; + } + + /** + * Provide a {@code WebSessionManager} instance to use with the exchange. + *

This is mutually exclusive with {@link #session(WebSession)}. + * @param sessionManager the session manager to use + */ + public Builder sessionManager(WebSessionManager sessionManager) { + this.sessionManager = sessionManager; + return this; + } + + /** + * Build the {@code MockServerWebExchange} instance. + */ + public MockServerWebExchange build() { + return new MockServerWebExchange(this.request, + this.sessionManager != null ? this.sessionManager : new DefaultWebSessionManager()); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/server/MockWebSession.java b/spring-web/src/test/java/org/springframework/mock/web/test/server/MockWebSession.java new file mode 100644 index 0000000000000000000000000000000000000000..ce043a2b1e1f1501fccf435fd885cbf2b63e0816 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/server/MockWebSession.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.mock.web.test.server; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.session.InMemoryWebSessionStore; + +/** + * Implementation of {@code WebSession} that delegates to a session instance + * obtained via {@link InMemoryWebSessionStore}. + * + *

This is intended for use with the + * {@link MockServerWebExchange.Builder#session(WebSession) session(WebSession)} + * method of the {@code MockServerWebExchange} builder, eliminating the need + * to use {@code WebSessionManager} or {@code WebSessionStore} altogether. + * + * @author Rossen Stoyanchev + * @since 5.1 + */ +public class MockWebSession implements WebSession { + + private final WebSession delegate; + + + public MockWebSession() { + this(null); + } + + public MockWebSession(@Nullable Clock clock) { + InMemoryWebSessionStore sessionStore = new InMemoryWebSessionStore(); + if (clock != null) { + sessionStore.setClock(clock); + } + WebSession session = sessionStore.createWebSession().block(); + Assert.state(session != null, "WebSession must not be null"); + this.delegate = session; + } + + + @Override + public String getId() { + return this.delegate.getId(); + } + + @Override + public Map getAttributes() { + return this.delegate.getAttributes(); + } + + @Override + public void start() { + this.delegate.start(); + } + + @Override + public boolean isStarted() { + return this.delegate.isStarted(); + } + + @Override + public Mono changeSessionId() { + return this.delegate.changeSessionId(); + } + + @Override + public Mono invalidate() { + return this.delegate.invalidate(); + } + + @Override + public Mono save() { + return this.delegate.save(); + } + + @Override + public boolean isExpired() { + return this.delegate.isExpired(); + } + + @Override + public Instant getCreationTime() { + return this.delegate.getCreationTime(); + } + + @Override + public Instant getLastAccessTime() { + return this.delegate.getLastAccessTime(); + } + + @Override + public void setMaxIdleTime(Duration maxIdleTime) { + this.delegate.setMaxIdleTime(maxIdleTime); + } + + @Override + public Duration getMaxIdleTime() { + return this.delegate.getMaxIdleTime(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/server/package-info.java b/spring-web/src/test/java/org/springframework/mock/web/test/server/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..1567f1fca26404a4b559764c6b4f5db9dc55fd94 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/test/server/package-info.java @@ -0,0 +1,9 @@ + +// For @NonNull annotations on implementation classes + +@NonNullApi +@NonNullFields +package org.springframework.mock.web.test.server; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-web/src/test/java/org/springframework/protobuf/Msg.java b/spring-web/src/test/java/org/springframework/protobuf/Msg.java new file mode 100644 index 0000000000000000000000000000000000000000..878d8392c451fc22813bf04c5ffda7fdb515f7db --- /dev/null +++ b/spring-web/src/test/java/org/springframework/protobuf/Msg.java @@ -0,0 +1,654 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: sample.proto + +package org.springframework.protobuf; + +/** + * Protobuf type {@code Msg} + */ +public final class Msg extends + com.google.protobuf.GeneratedMessage + implements MsgOrBuilder { + // Use Msg.newBuilder() to construct. + private Msg(com.google.protobuf.GeneratedMessage.Builder builder) { + super(builder); + this.unknownFields = builder.getUnknownFields(); + } + private Msg(boolean noInit) { this.unknownFields = com.google.protobuf.UnknownFieldSet.getDefaultInstance(); } + + private static final Msg defaultInstance; + public static Msg getDefaultInstance() { + return defaultInstance; + } + + public Msg getDefaultInstanceForType() { + return defaultInstance; + } + + private final com.google.protobuf.UnknownFieldSet unknownFields; + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private Msg( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + initFields(); + @SuppressWarnings("unused") + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!parseUnknownField(input, unknownFields, + extensionRegistry, tag)) { + done = true; + } + break; + } + case 10: { + bitField0_ |= 0x00000001; + foo_ = input.readBytes(); + break; + } + case 18: { + org.springframework.protobuf.SecondMsg.Builder subBuilder = null; + if (((bitField0_ & 0x00000002) == 0x00000002)) { + subBuilder = blah_.toBuilder(); + } + blah_ = input.readMessage(org.springframework.protobuf.SecondMsg.PARSER, extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(blah_); + blah_ = subBuilder.buildPartial(); + } + bitField0_ |= 0x00000002; + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e.getMessage()).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.springframework.protobuf.OuterSample.internal_static_Msg_descriptor; + } + + protected com.google.protobuf.GeneratedMessage.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.springframework.protobuf.OuterSample.internal_static_Msg_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.springframework.protobuf.Msg.class, org.springframework.protobuf.Msg.Builder.class); + } + + public static com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + public Msg parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new Msg(input, extensionRegistry); + } + }; + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + private int bitField0_; + // optional string foo = 1; + public static final int FOO_FIELD_NUMBER = 1; + private java.lang.Object foo_; + /** + * optional string foo = 1; + */ + public boolean hasFoo() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + * optional string foo = 1; + */ + public java.lang.String getFoo() { + java.lang.Object ref = foo_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + foo_ = s; + } + return s; + } + } + /** + * optional string foo = 1; + */ + public com.google.protobuf.ByteString + getFooBytes() { + java.lang.Object ref = foo_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + foo_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + // optional .SecondMsg blah = 2; + public static final int BLAH_FIELD_NUMBER = 2; + private org.springframework.protobuf.SecondMsg blah_; + /** + * optional .SecondMsg blah = 2; + */ + public boolean hasBlah() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + * optional .SecondMsg blah = 2; + */ + public org.springframework.protobuf.SecondMsg getBlah() { + return blah_; + } + /** + * optional .SecondMsg blah = 2; + */ + public org.springframework.protobuf.SecondMsgOrBuilder getBlahOrBuilder() { + return blah_; + } + + private void initFields() { + foo_ = ""; + blah_ = org.springframework.protobuf.SecondMsg.getDefaultInstance(); + } + private byte memoizedIsInitialized = -1; + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized != -1) return isInitialized == 1; + + memoizedIsInitialized = 1; + return true; + } + + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getSerializedSize(); + if (((bitField0_ & 0x00000001) == 0x00000001)) { + output.writeBytes(1, getFooBytes()); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + output.writeMessage(2, blah_); + } + getUnknownFields().writeTo(output); + } + + private int memoizedSerializedSize = -1; + public int getSerializedSize() { + int size = memoizedSerializedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, getFooBytes()); + } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, blah_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSerializedSize = size; + return size; + } + + private static final long serialVersionUID = 0L; + @java.lang.Override + protected java.lang.Object writeReplace() + throws java.io.ObjectStreamException { + return super.writeReplace(); + } + + public static org.springframework.protobuf.Msg parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.springframework.protobuf.Msg parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.springframework.protobuf.Msg parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.springframework.protobuf.Msg parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.springframework.protobuf.Msg parseFrom(java.io.InputStream input) + throws java.io.IOException { + return PARSER.parseFrom(input); + } + public static org.springframework.protobuf.Msg parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return PARSER.parseFrom(input, extensionRegistry); + } + public static org.springframework.protobuf.Msg parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return PARSER.parseDelimitedFrom(input); + } + public static org.springframework.protobuf.Msg parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return PARSER.parseDelimitedFrom(input, extensionRegistry); + } + public static org.springframework.protobuf.Msg parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return PARSER.parseFrom(input); + } + public static org.springframework.protobuf.Msg parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return PARSER.parseFrom(input, extensionRegistry); + } + + public static Builder newBuilder() { return Builder.create(); } + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder(org.springframework.protobuf.Msg prototype) { + return newBuilder().mergeFrom(prototype); + } + public Builder toBuilder() { return newBuilder(this); } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessage.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code Msg} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessage.Builder + implements org.springframework.protobuf.MsgOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.springframework.protobuf.OuterSample.internal_static_Msg_descriptor; + } + + protected com.google.protobuf.GeneratedMessage.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.springframework.protobuf.OuterSample.internal_static_Msg_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.springframework.protobuf.Msg.class, org.springframework.protobuf.Msg.Builder.class); + } + + // Construct using org.springframework.protobuf.Msg.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessage.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessage.alwaysUseFieldBuilders) { + getBlahFieldBuilder(); + } + } + private static Builder create() { + return new Builder(); + } + + public Builder clear() { + super.clear(); + foo_ = ""; + bitField0_ = (bitField0_ & ~0x00000001); + if (blahBuilder_ == null) { + blah_ = org.springframework.protobuf.SecondMsg.getDefaultInstance(); + } else { + blahBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000002); + return this; + } + + public Builder clone() { + return create().mergeFrom(buildPartial()); + } + + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.springframework.protobuf.OuterSample.internal_static_Msg_descriptor; + } + + public org.springframework.protobuf.Msg getDefaultInstanceForType() { + return org.springframework.protobuf.Msg.getDefaultInstance(); + } + + public org.springframework.protobuf.Msg build() { + org.springframework.protobuf.Msg result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + public org.springframework.protobuf.Msg buildPartial() { + org.springframework.protobuf.Msg result = new org.springframework.protobuf.Msg(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + result.foo_ = foo_; + if (((from_bitField0_ & 0x00000002) == 0x00000002)) { + to_bitField0_ |= 0x00000002; + } + if (blahBuilder_ == null) { + result.blah_ = blah_; + } else { + result.blah_ = blahBuilder_.build(); + } + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.springframework.protobuf.Msg) { + return mergeFrom((org.springframework.protobuf.Msg)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.springframework.protobuf.Msg other) { + if (other == org.springframework.protobuf.Msg.getDefaultInstance()) return this; + if (other.hasFoo()) { + bitField0_ |= 0x00000001; + foo_ = other.foo_; + onChanged(); + } + if (other.hasBlah()) { + mergeBlah(other.getBlah()); + } + this.mergeUnknownFields(other.getUnknownFields()); + return this; + } + + public final boolean isInitialized() { + return true; + } + + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + org.springframework.protobuf.Msg parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (org.springframework.protobuf.Msg) e.getUnfinishedMessage(); + throw e; + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + // optional string foo = 1; + private java.lang.Object foo_ = ""; + /** + * optional string foo = 1; + */ + public boolean hasFoo() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + * optional string foo = 1; + */ + public java.lang.String getFoo() { + java.lang.Object ref = foo_; + if (!(ref instanceof java.lang.String)) { + java.lang.String s = ((com.google.protobuf.ByteString) ref) + .toStringUtf8(); + foo_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * optional string foo = 1; + */ + public com.google.protobuf.ByteString + getFooBytes() { + java.lang.Object ref = foo_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + foo_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * optional string foo = 1; + */ + public Builder setFoo( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + foo_ = value; + onChanged(); + return this; + } + /** + * optional string foo = 1; + */ + public Builder clearFoo() { + bitField0_ = (bitField0_ & ~0x00000001); + foo_ = getDefaultInstance().getFoo(); + onChanged(); + return this; + } + /** + * optional string foo = 1; + */ + public Builder setFooBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000001; + foo_ = value; + onChanged(); + return this; + } + + // optional .SecondMsg blah = 2; + private org.springframework.protobuf.SecondMsg blah_ = org.springframework.protobuf.SecondMsg.getDefaultInstance(); + private com.google.protobuf.SingleFieldBuilder< + org.springframework.protobuf.SecondMsg, org.springframework.protobuf.SecondMsg.Builder, + org.springframework.protobuf.SecondMsgOrBuilder> blahBuilder_; + /** + * optional .SecondMsg blah = 2; + */ + public boolean hasBlah() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + * optional .SecondMsg blah = 2; + */ + public org.springframework.protobuf.SecondMsg getBlah() { + if (blahBuilder_ == null) { + return blah_; + } else { + return blahBuilder_.getMessage(); + } + } + /** + * optional .SecondMsg blah = 2; + */ + public Builder setBlah(org.springframework.protobuf.SecondMsg value) { + if (blahBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + blah_ = value; + onChanged(); + } else { + blahBuilder_.setMessage(value); + } + bitField0_ |= 0x00000002; + return this; + } + /** + * optional .SecondMsg blah = 2; + */ + public Builder setBlah( + org.springframework.protobuf.SecondMsg.Builder builderForValue) { + if (blahBuilder_ == null) { + blah_ = builderForValue.build(); + onChanged(); + } else { + blahBuilder_.setMessage(builderForValue.build()); + } + bitField0_ |= 0x00000002; + return this; + } + /** + * optional .SecondMsg blah = 2; + */ + public Builder mergeBlah(org.springframework.protobuf.SecondMsg value) { + if (blahBuilder_ == null) { + if (((bitField0_ & 0x00000002) == 0x00000002) && + blah_ != org.springframework.protobuf.SecondMsg.getDefaultInstance()) { + blah_ = + org.springframework.protobuf.SecondMsg.newBuilder(blah_).mergeFrom(value).buildPartial(); + } else { + blah_ = value; + } + onChanged(); + } else { + blahBuilder_.mergeFrom(value); + } + bitField0_ |= 0x00000002; + return this; + } + /** + * optional .SecondMsg blah = 2; + */ + public Builder clearBlah() { + if (blahBuilder_ == null) { + blah_ = org.springframework.protobuf.SecondMsg.getDefaultInstance(); + onChanged(); + } else { + blahBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000002); + return this; + } + /** + * optional .SecondMsg blah = 2; + */ + public org.springframework.protobuf.SecondMsg.Builder getBlahBuilder() { + bitField0_ |= 0x00000002; + onChanged(); + return getBlahFieldBuilder().getBuilder(); + } + /** + * optional .SecondMsg blah = 2; + */ + public org.springframework.protobuf.SecondMsgOrBuilder getBlahOrBuilder() { + if (blahBuilder_ != null) { + return blahBuilder_.getMessageOrBuilder(); + } else { + return blah_; + } + } + /** + * optional .SecondMsg blah = 2; + */ + private com.google.protobuf.SingleFieldBuilder< + org.springframework.protobuf.SecondMsg, org.springframework.protobuf.SecondMsg.Builder, + org.springframework.protobuf.SecondMsgOrBuilder> + getBlahFieldBuilder() { + if (blahBuilder_ == null) { + blahBuilder_ = new com.google.protobuf.SingleFieldBuilder<>( + blah_, + getParentForChildren(), + isClean()); + blah_ = null; + } + return blahBuilder_; + } + + // @@protoc_insertion_point(builder_scope:Msg) + } + + static { + defaultInstance = new Msg(true); + defaultInstance.initFields(); + } + + // @@protoc_insertion_point(class_scope:Msg) +} + diff --git a/spring-web/src/test/java/org/springframework/protobuf/MsgOrBuilder.java b/spring-web/src/test/java/org/springframework/protobuf/MsgOrBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..7a25a3ad5565b65c06701568881a42573eed5492 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/protobuf/MsgOrBuilder.java @@ -0,0 +1,37 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: sample.proto + +package org.springframework.protobuf; + +public interface MsgOrBuilder + extends com.google.protobuf.MessageOrBuilder { + + // optional string foo = 1; + /** + * optional string foo = 1; + */ + boolean hasFoo(); + /** + * optional string foo = 1; + */ + java.lang.String getFoo(); + /** + * optional string foo = 1; + */ + com.google.protobuf.ByteString + getFooBytes(); + + // optional .SecondMsg blah = 2; + /** + * optional .SecondMsg blah = 2; + */ + boolean hasBlah(); + /** + * optional .SecondMsg blah = 2; + */ + org.springframework.protobuf.SecondMsg getBlah(); + /** + * optional .SecondMsg blah = 2; + */ + org.springframework.protobuf.SecondMsgOrBuilder getBlahOrBuilder(); +} diff --git a/spring-web/src/test/java/org/springframework/protobuf/OuterSample.java b/spring-web/src/test/java/org/springframework/protobuf/OuterSample.java new file mode 100644 index 0000000000000000000000000000000000000000..b0c36ed5c72ebb80e5836a53af7b6c5838741413 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/protobuf/OuterSample.java @@ -0,0 +1,62 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: sample.proto + +package org.springframework.protobuf; + +public class OuterSample { + private OuterSample() {} + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistry registry) { + } + static com.google.protobuf.Descriptors.Descriptor + internal_static_Msg_descriptor; + static + com.google.protobuf.GeneratedMessage.FieldAccessorTable + internal_static_Msg_fieldAccessorTable; + static com.google.protobuf.Descriptors.Descriptor + internal_static_SecondMsg_descriptor; + static + com.google.protobuf.GeneratedMessage.FieldAccessorTable + internal_static_SecondMsg_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n\014sample.proto\",\n\003Msg\022\013\n\003foo\030\001 \001(\t\022\030\n\004bl" + + "ah\030\002 \001(\0132\n.SecondMsg\"\031\n\tSecondMsg\022\014\n\004bla" + + "h\030\001 \001(\005B-\n\034org.springframework.protobufB" + + "\013OuterSampleP\001" + }; + com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = + new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() { + public com.google.protobuf.ExtensionRegistry assignDescriptors( + com.google.protobuf.Descriptors.FileDescriptor root) { + descriptor = root; + internal_static_Msg_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_Msg_fieldAccessorTable = new + com.google.protobuf.GeneratedMessage.FieldAccessorTable( + internal_static_Msg_descriptor, + new java.lang.String[] { "Foo", "Blah", }); + internal_static_SecondMsg_descriptor = + getDescriptor().getMessageTypes().get(1); + internal_static_SecondMsg_fieldAccessorTable = new + com.google.protobuf.GeneratedMessage.FieldAccessorTable( + internal_static_SecondMsg_descriptor, + new java.lang.String[] { "Blah", }); + return null; + } + }; + com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + }, assigner); + } + + // @@protoc_insertion_point(outer_class_scope) +} diff --git a/spring-web/src/test/java/org/springframework/protobuf/SecondMsg.java b/spring-web/src/test/java/org/springframework/protobuf/SecondMsg.java new file mode 100644 index 0000000000000000000000000000000000000000..efd418e3462698976453c7b8ae0a47f50b3219c0 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/protobuf/SecondMsg.java @@ -0,0 +1,389 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: sample.proto + +package org.springframework.protobuf; + +/** + * Protobuf type {@code SecondMsg} + */ +public final class SecondMsg extends + com.google.protobuf.GeneratedMessage + implements SecondMsgOrBuilder { + // Use SecondMsg.newBuilder() to construct. + private SecondMsg(com.google.protobuf.GeneratedMessage.Builder builder) { + super(builder); + this.unknownFields = builder.getUnknownFields(); + } + private SecondMsg(boolean noInit) { this.unknownFields = com.google.protobuf.UnknownFieldSet.getDefaultInstance(); } + + private static final SecondMsg defaultInstance; + public static SecondMsg getDefaultInstance() { + return defaultInstance; + } + + public SecondMsg getDefaultInstanceForType() { + return defaultInstance; + } + + private final com.google.protobuf.UnknownFieldSet unknownFields; + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private SecondMsg( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + initFields(); + @SuppressWarnings("unused") + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!parseUnknownField(input, unknownFields, + extensionRegistry, tag)) { + done = true; + } + break; + } + case 8: { + bitField0_ |= 0x00000001; + blah_ = input.readInt32(); + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e.getMessage()).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.springframework.protobuf.OuterSample.internal_static_SecondMsg_descriptor; + } + + protected com.google.protobuf.GeneratedMessage.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.springframework.protobuf.OuterSample.internal_static_SecondMsg_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.springframework.protobuf.SecondMsg.class, org.springframework.protobuf.SecondMsg.Builder.class); + } + + public static com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + public SecondMsg parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new SecondMsg(input, extensionRegistry); + } + }; + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + private int bitField0_; + // optional int32 blah = 1; + public static final int BLAH_FIELD_NUMBER = 1; + private int blah_; + /** + * optional int32 blah = 1; + */ + public boolean hasBlah() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + * optional int32 blah = 1; + */ + public int getBlah() { + return blah_; + } + + private void initFields() { + blah_ = 0; + } + private byte memoizedIsInitialized = -1; + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized != -1) return isInitialized == 1; + + memoizedIsInitialized = 1; + return true; + } + + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getSerializedSize(); + if (((bitField0_ & 0x00000001) == 0x00000001)) { + output.writeInt32(1, blah_); + } + getUnknownFields().writeTo(output); + } + + private int memoizedSerializedSize = -1; + public int getSerializedSize() { + int size = memoizedSerializedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) == 0x00000001)) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(1, blah_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSerializedSize = size; + return size; + } + + private static final long serialVersionUID = 0L; + @java.lang.Override + protected java.lang.Object writeReplace() + throws java.io.ObjectStreamException { + return super.writeReplace(); + } + + public static org.springframework.protobuf.SecondMsg parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.springframework.protobuf.SecondMsg parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.springframework.protobuf.SecondMsg parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.springframework.protobuf.SecondMsg parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.springframework.protobuf.SecondMsg parseFrom(java.io.InputStream input) + throws java.io.IOException { + return PARSER.parseFrom(input); + } + public static org.springframework.protobuf.SecondMsg parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return PARSER.parseFrom(input, extensionRegistry); + } + public static org.springframework.protobuf.SecondMsg parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return PARSER.parseDelimitedFrom(input); + } + public static org.springframework.protobuf.SecondMsg parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return PARSER.parseDelimitedFrom(input, extensionRegistry); + } + public static org.springframework.protobuf.SecondMsg parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return PARSER.parseFrom(input); + } + public static org.springframework.protobuf.SecondMsg parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return PARSER.parseFrom(input, extensionRegistry); + } + + public static Builder newBuilder() { return Builder.create(); } + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder(org.springframework.protobuf.SecondMsg prototype) { + return newBuilder().mergeFrom(prototype); + } + public Builder toBuilder() { return newBuilder(this); } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessage.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code SecondMsg} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessage.Builder + implements org.springframework.protobuf.SecondMsgOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.springframework.protobuf.OuterSample.internal_static_SecondMsg_descriptor; + } + + protected com.google.protobuf.GeneratedMessage.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.springframework.protobuf.OuterSample.internal_static_SecondMsg_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.springframework.protobuf.SecondMsg.class, org.springframework.protobuf.SecondMsg.Builder.class); + } + + // Construct using org.springframework.protobuf.SecondMsg.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessage.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessage.alwaysUseFieldBuilders) { + } + } + private static Builder create() { + return new Builder(); + } + + public Builder clear() { + super.clear(); + blah_ = 0; + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + + public Builder clone() { + return create().mergeFrom(buildPartial()); + } + + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.springframework.protobuf.OuterSample.internal_static_SecondMsg_descriptor; + } + + public org.springframework.protobuf.SecondMsg getDefaultInstanceForType() { + return org.springframework.protobuf.SecondMsg.getDefaultInstance(); + } + + public org.springframework.protobuf.SecondMsg build() { + org.springframework.protobuf.SecondMsg result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + public org.springframework.protobuf.SecondMsg buildPartial() { + org.springframework.protobuf.SecondMsg result = new org.springframework.protobuf.SecondMsg(this); + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) == 0x00000001)) { + to_bitField0_ |= 0x00000001; + } + result.blah_ = blah_; + result.bitField0_ = to_bitField0_; + onBuilt(); + return result; + } + + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.springframework.protobuf.SecondMsg) { + return mergeFrom((org.springframework.protobuf.SecondMsg)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.springframework.protobuf.SecondMsg other) { + if (other == org.springframework.protobuf.SecondMsg.getDefaultInstance()) return this; + if (other.hasBlah()) { + setBlah(other.getBlah()); + } + this.mergeUnknownFields(other.getUnknownFields()); + return this; + } + + public final boolean isInitialized() { + return true; + } + + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + org.springframework.protobuf.SecondMsg parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (org.springframework.protobuf.SecondMsg) e.getUnfinishedMessage(); + throw e; + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + private int bitField0_; + + // optional int32 blah = 1; + private int blah_ ; + /** + * optional int32 blah = 1; + */ + public boolean hasBlah() { + return ((bitField0_ & 0x00000001) == 0x00000001); + } + /** + * optional int32 blah = 1; + */ + public int getBlah() { + return blah_; + } + /** + * optional int32 blah = 1; + */ + public Builder setBlah(int value) { + bitField0_ |= 0x00000001; + blah_ = value; + onChanged(); + return this; + } + /** + * optional int32 blah = 1; + */ + public Builder clearBlah() { + bitField0_ = (bitField0_ & ~0x00000001); + blah_ = 0; + onChanged(); + return this; + } + + // @@protoc_insertion_point(builder_scope:SecondMsg) + } + + static { + defaultInstance = new SecondMsg(true); + defaultInstance.initFields(); + } + + // @@protoc_insertion_point(class_scope:SecondMsg) +} + diff --git a/spring-web/src/test/java/org/springframework/protobuf/SecondMsgOrBuilder.java b/spring-web/src/test/java/org/springframework/protobuf/SecondMsgOrBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..baafb872ddf16463d1939f58dd566b28f1e340d5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/protobuf/SecondMsgOrBuilder.java @@ -0,0 +1,18 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: sample.proto + +package org.springframework.protobuf; + +public interface SecondMsgOrBuilder + extends com.google.protobuf.MessageOrBuilder { + + // optional int32 blah = 1; + /** + * optional int32 blah = 1; + */ + boolean hasBlah(); + /** + * optional int32 blah = 1; + */ + int getBlah(); +} diff --git a/spring-web/src/test/java/org/springframework/remoting/caucho/CauchoRemotingTests.java b/spring-web/src/test/java/org/springframework/remoting/caucho/CauchoRemotingTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fc41bc95a5d1f446e46a467c08ca897d214d07a2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/caucho/CauchoRemotingTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.caucho; + +import java.io.IOException; +import java.net.InetSocketAddress; + +import com.caucho.hessian.client.HessianProxyFactory; +import com.sun.net.httpserver.HttpServer; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.remoting.RemoteAccessException; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.util.SocketUtils; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @author Sam Brannen + * @since 16.05.2003 + */ +public class CauchoRemotingTests { + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void hessianProxyFactoryBeanWithClassInsteadOfInterface() throws Exception { + HessianProxyFactoryBean factory = new HessianProxyFactoryBean(); + exception.expect(IllegalArgumentException.class); + factory.setServiceInterface(TestBean.class); + } + + @Test + public void hessianProxyFactoryBeanWithAccessError() throws Exception { + HessianProxyFactoryBean factory = new HessianProxyFactoryBean(); + factory.setServiceInterface(ITestBean.class); + factory.setServiceUrl("http://localhosta/testbean"); + factory.afterPropertiesSet(); + + assertTrue("Correct singleton value", factory.isSingleton()); + assertTrue(factory.getObject() instanceof ITestBean); + ITestBean bean = (ITestBean) factory.getObject(); + + exception.expect(RemoteAccessException.class); + bean.setName("test"); + } + + @Test + public void hessianProxyFactoryBeanWithAuthenticationAndAccessError() throws Exception { + HessianProxyFactoryBean factory = new HessianProxyFactoryBean(); + factory.setServiceInterface(ITestBean.class); + factory.setServiceUrl("http://localhosta/testbean"); + factory.setUsername("test"); + factory.setPassword("bean"); + factory.setOverloadEnabled(true); + factory.afterPropertiesSet(); + + assertTrue("Correct singleton value", factory.isSingleton()); + assertTrue(factory.getObject() instanceof ITestBean); + ITestBean bean = (ITestBean) factory.getObject(); + + exception.expect(RemoteAccessException.class); + bean.setName("test"); + } + + @Test + public void hessianProxyFactoryBeanWithCustomProxyFactory() throws Exception { + TestHessianProxyFactory proxyFactory = new TestHessianProxyFactory(); + HessianProxyFactoryBean factory = new HessianProxyFactoryBean(); + factory.setServiceInterface(ITestBean.class); + factory.setServiceUrl("http://localhosta/testbean"); + factory.setProxyFactory(proxyFactory); + factory.setUsername("test"); + factory.setPassword("bean"); + factory.setOverloadEnabled(true); + factory.afterPropertiesSet(); + assertTrue("Correct singleton value", factory.isSingleton()); + assertTrue(factory.getObject() instanceof ITestBean); + ITestBean bean = (ITestBean) factory.getObject(); + + assertEquals("test", proxyFactory.user); + assertEquals("bean", proxyFactory.password); + assertTrue(proxyFactory.overloadEnabled); + + exception.expect(RemoteAccessException.class); + bean.setName("test"); + } + + @Test + public void simpleHessianServiceExporter() throws IOException { + final int port = SocketUtils.findAvailableTcpPort(); + + TestBean tb = new TestBean("tb"); + SimpleHessianServiceExporter exporter = new SimpleHessianServiceExporter(); + exporter.setService(tb); + exporter.setServiceInterface(ITestBean.class); + exporter.setDebug(true); + exporter.prepare(); + + HttpServer server = HttpServer.create(new InetSocketAddress(port), -1); + server.createContext("/hessian", exporter); + server.start(); + try { + HessianClientInterceptor client = new HessianClientInterceptor(); + client.setServiceUrl("http://localhost:" + port + "/hessian"); + client.setServiceInterface(ITestBean.class); + //client.setHessian2(true); + client.prepare(); + ITestBean proxy = ProxyFactory.getProxy(ITestBean.class, client); + assertEquals("tb", proxy.getName()); + proxy.setName("test"); + assertEquals("test", proxy.getName()); + } + finally { + server.stop(Integer.MAX_VALUE); + } + } + + + private static class TestHessianProxyFactory extends HessianProxyFactory { + + private String user; + private String password; + private boolean overloadEnabled; + + @Override + public void setUser(String user) { + this.user = user; + } + + @Override + public void setPassword(String password) { + this.password = password; + } + + @Override + public void setOverloadEnabled(boolean overloadEnabled) { + this.overloadEnabled = overloadEnabled; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpComponentsHttpInvokerRequestExecutorTests.java b/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpComponentsHttpInvokerRequestExecutorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c0fb15f4d369d018bb2069a775fdb0f6e7f4f2c5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpComponentsHttpInvokerRequestExecutorTests.java @@ -0,0 +1,167 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.IOException; + +import org.apache.http.client.HttpClient; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.Configurable; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * @author Stephane Nicoll + */ +public class HttpComponentsHttpInvokerRequestExecutorTests { + + @Test + public void customizeConnectionTimeout() throws IOException { + HttpComponentsHttpInvokerRequestExecutor executor = new HttpComponentsHttpInvokerRequestExecutor(); + executor.setConnectTimeout(5000); + + HttpInvokerClientConfiguration config = mockHttpInvokerClientConfiguration("http://fake-service"); + HttpPost httpPost = executor.createHttpPost(config); + assertEquals(5000, httpPost.getConfig().getConnectTimeout()); + } + + @Test + public void customizeConnectionRequestTimeout() throws IOException { + HttpComponentsHttpInvokerRequestExecutor executor = new HttpComponentsHttpInvokerRequestExecutor(); + executor.setConnectionRequestTimeout(7000); + + HttpInvokerClientConfiguration config = mockHttpInvokerClientConfiguration("http://fake-service"); + HttpPost httpPost = executor.createHttpPost(config); + assertEquals(7000, httpPost.getConfig().getConnectionRequestTimeout()); + } + + @Test + public void customizeReadTimeout() throws IOException { + HttpComponentsHttpInvokerRequestExecutor executor = new HttpComponentsHttpInvokerRequestExecutor(); + executor.setReadTimeout(10000); + + HttpInvokerClientConfiguration config = mockHttpInvokerClientConfiguration("http://fake-service"); + HttpPost httpPost = executor.createHttpPost(config); + assertEquals(10000, httpPost.getConfig().getSocketTimeout()); + } + + @Test + public void defaultSettingsOfHttpClientMergedOnExecutorCustomization() throws IOException { + RequestConfig defaultConfig = RequestConfig.custom().setConnectTimeout(1234).build(); + CloseableHttpClient client = mock(CloseableHttpClient.class, + withSettings().extraInterfaces(Configurable.class)); + Configurable configurable = (Configurable) client; + when(configurable.getConfig()).thenReturn(defaultConfig); + + HttpComponentsHttpInvokerRequestExecutor executor = + new HttpComponentsHttpInvokerRequestExecutor(client); + HttpInvokerClientConfiguration config = mockHttpInvokerClientConfiguration("http://fake-service"); + HttpPost httpPost = executor.createHttpPost(config); + assertSame("Default client configuration is expected", defaultConfig, httpPost.getConfig()); + + executor.setConnectionRequestTimeout(4567); + HttpPost httpPost2 = executor.createHttpPost(config); + assertNotNull(httpPost2.getConfig()); + assertEquals(4567, httpPost2.getConfig().getConnectionRequestTimeout()); + // Default connection timeout merged + assertEquals(1234, httpPost2.getConfig().getConnectTimeout()); + } + + @Test + public void localSettingsOverrideClientDefaultSettings() throws Exception { + RequestConfig defaultConfig = RequestConfig.custom() + .setConnectTimeout(1234).setConnectionRequestTimeout(6789).build(); + CloseableHttpClient client = mock(CloseableHttpClient.class, + withSettings().extraInterfaces(Configurable.class)); + Configurable configurable = (Configurable) client; + when(configurable.getConfig()).thenReturn(defaultConfig); + + HttpComponentsHttpInvokerRequestExecutor executor = + new HttpComponentsHttpInvokerRequestExecutor(client); + executor.setConnectTimeout(5000); + + HttpInvokerClientConfiguration config = mockHttpInvokerClientConfiguration("http://fake-service"); + HttpPost httpPost = executor.createHttpPost(config); + RequestConfig requestConfig = httpPost.getConfig(); + assertEquals(5000, requestConfig.getConnectTimeout()); + assertEquals(6789, requestConfig.getConnectionRequestTimeout()); + assertEquals(-1, requestConfig.getSocketTimeout()); + } + + @Test + public void mergeBasedOnCurrentHttpClient() throws Exception { + RequestConfig defaultConfig = RequestConfig.custom() + .setSocketTimeout(1234).build(); + final CloseableHttpClient client = mock(CloseableHttpClient.class, + withSettings().extraInterfaces(Configurable.class)); + Configurable configurable = (Configurable) client; + when(configurable.getConfig()).thenReturn(defaultConfig); + + HttpComponentsHttpInvokerRequestExecutor executor = + new HttpComponentsHttpInvokerRequestExecutor() { + @Override + public HttpClient getHttpClient() { + return client; + } + }; + executor.setReadTimeout(5000); + HttpInvokerClientConfiguration config = mockHttpInvokerClientConfiguration("http://fake-service"); + HttpPost httpPost = executor.createHttpPost(config); + RequestConfig requestConfig = httpPost.getConfig(); + assertEquals(-1, requestConfig.getConnectTimeout()); + assertEquals(-1, requestConfig.getConnectionRequestTimeout()); + assertEquals(5000, requestConfig.getSocketTimeout()); + + // Update the Http client so that it returns an updated config + RequestConfig updatedDefaultConfig = RequestConfig.custom() + .setConnectTimeout(1234).build(); + when(configurable.getConfig()).thenReturn(updatedDefaultConfig); + executor.setReadTimeout(7000); + HttpPost httpPost2 = executor.createHttpPost(config); + RequestConfig requestConfig2 = httpPost2.getConfig(); + assertEquals(1234, requestConfig2.getConnectTimeout()); + assertEquals(-1, requestConfig2.getConnectionRequestTimeout()); + assertEquals(7000, requestConfig2.getSocketTimeout()); + } + + @Test + public void ignoreFactorySettings() throws IOException { + CloseableHttpClient httpClient = HttpClientBuilder.create().build(); + HttpComponentsHttpInvokerRequestExecutor executor = new HttpComponentsHttpInvokerRequestExecutor(httpClient) { + @Override + protected RequestConfig createRequestConfig(HttpInvokerClientConfiguration config) { + return null; + } + }; + + HttpInvokerClientConfiguration config = mockHttpInvokerClientConfiguration("http://fake-service"); + HttpPost httpPost = executor.createHttpPost(config); + assertNull("custom request config should not be set", httpPost.getConfig()); + } + + private HttpInvokerClientConfiguration mockHttpInvokerClientConfiguration(String serviceUrl) { + HttpInvokerClientConfiguration config = mock(HttpInvokerClientConfiguration.class); + when(config.getServiceUrl()).thenReturn(serviceUrl); + return config; + } + +} diff --git a/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpInvokerFactoryBeanIntegrationTests.java b/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpInvokerFactoryBeanIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..525a7a3744463b0e22241b7f6d2905a3028aeaaf --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpInvokerFactoryBeanIntegrationTests.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import org.junit.Test; + +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Lazy; +import org.springframework.core.env.Environment; +import org.springframework.remoting.support.RemoteInvocationResult; +import org.springframework.scheduling.annotation.Async; +import org.springframework.scheduling.annotation.AsyncAnnotationBeanPostProcessor; +import org.springframework.stereotype.Component; + +import static org.junit.Assert.*; + +/** + * @author Stephane Nicoll + */ +public class HttpInvokerFactoryBeanIntegrationTests { + + @Test + @SuppressWarnings("resource") + public void testLoadedConfigClass() { + ApplicationContext context = new AnnotationConfigApplicationContext(InvokerAutowiringConfig.class); + MyBean myBean = context.getBean("myBean", MyBean.class); + assertSame(context.getBean("myService"), myBean.myService); + myBean.myService.handle(); + myBean.myService.handleAsync(); + } + + @Test + @SuppressWarnings("resource") + public void testNonLoadedConfigClass() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.registerBeanDefinition("config", new RootBeanDefinition(InvokerAutowiringConfig.class.getName())); + context.refresh(); + MyBean myBean = context.getBean("myBean", MyBean.class); + assertSame(context.getBean("myService"), myBean.myService); + myBean.myService.handle(); + myBean.myService.handleAsync(); + } + + @Test + @SuppressWarnings("resource") + public void withConfigurationClassWithPlainFactoryBean() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(ConfigWithPlainFactoryBean.class); + context.refresh(); + MyBean myBean = context.getBean("myBean", MyBean.class); + assertSame(context.getBean("myService"), myBean.myService); + myBean.myService.handle(); + myBean.myService.handleAsync(); + } + + + public interface MyService { + + public void handle(); + + @Async + public void handleAsync(); + } + + + @Component("myBean") + public static class MyBean { + + @Autowired + public MyService myService; + } + + + @Configuration + @ComponentScan + @Lazy + public static class InvokerAutowiringConfig { + + @Bean + public AsyncAnnotationBeanPostProcessor aabpp() { + return new AsyncAnnotationBeanPostProcessor(); + } + + @Bean + public HttpInvokerProxyFactoryBean myService() { + HttpInvokerProxyFactoryBean factory = new HttpInvokerProxyFactoryBean(); + factory.setServiceUrl("/svc/dummy"); + factory.setServiceInterface(MyService.class); + factory.setHttpInvokerRequestExecutor((config, invocation) -> new RemoteInvocationResult()); + return factory; + } + + @Bean + public FactoryBean myOtherService() { + throw new IllegalStateException("Don't ever call me"); + } + } + + + @Configuration + static class ConfigWithPlainFactoryBean { + + @Autowired + Environment env; + + @Bean + public MyBean myBean() { + return new MyBean(); + } + + @Bean + public HttpInvokerProxyFactoryBean myService() { + String name = env.getProperty("testbean.name"); + HttpInvokerProxyFactoryBean factory = new HttpInvokerProxyFactoryBean(); + factory.setServiceUrl("/svc/" + name); + factory.setServiceInterface(MyService.class); + factory.setHttpInvokerRequestExecutor((config, invocation) -> new RemoteInvocationResult()); + return factory; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpInvokerTests.java b/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpInvokerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2781664be58c7e63015789681ab5ecc33b743554 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/httpinvoker/HttpInvokerTests.java @@ -0,0 +1,514 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.httpinvoker; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; +import java.util.Arrays; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.aopalliance.intercept.MethodInvocation; + +import org.junit.Test; + +import org.springframework.beans.factory.BeanClassLoaderAware; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.remoting.RemoteAccessException; +import org.springframework.remoting.support.DefaultRemoteInvocationExecutor; +import org.springframework.remoting.support.RemoteInvocation; +import org.springframework.remoting.support.RemoteInvocationFactory; +import org.springframework.remoting.support.RemoteInvocationResult; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @since 09.08.2004 + */ +public class HttpInvokerTests { + + @Test + public void httpInvokerProxyFactoryBeanAndServiceExporter() throws Throwable { + doTestHttpInvokerProxyFactoryBeanAndServiceExporter(false); + } + + @Test + public void httpInvokerProxyFactoryBeanAndServiceExporterWithExplicitClassLoader() throws Throwable { + doTestHttpInvokerProxyFactoryBeanAndServiceExporter(true); + } + + private void doTestHttpInvokerProxyFactoryBeanAndServiceExporter(boolean explicitClassLoader) throws Throwable { + TestBean target = new TestBean("myname", 99); + + final HttpInvokerServiceExporter exporter = new HttpInvokerServiceExporter(); + exporter.setServiceInterface(ITestBean.class); + exporter.setService(target); + exporter.afterPropertiesSet(); + + HttpInvokerProxyFactoryBean pfb = new HttpInvokerProxyFactoryBean(); + pfb.setServiceInterface(ITestBean.class); + pfb.setServiceUrl("http://myurl"); + + pfb.setHttpInvokerRequestExecutor(new AbstractHttpInvokerRequestExecutor() { + @Override + protected RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) throws Exception { + assertEquals("http://myurl", config.getServiceUrl()); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + request.setContent(baos.toByteArray()); + exporter.handleRequest(request, response); + return readRemoteInvocationResult( + new ByteArrayInputStream(response.getContentAsByteArray()), config.getCodebaseUrl()); + } + }); + if (explicitClassLoader) { + ((BeanClassLoaderAware) pfb.getHttpInvokerRequestExecutor()).setBeanClassLoader(getClass().getClassLoader()); + } + + pfb.afterPropertiesSet(); + ITestBean proxy = (ITestBean) pfb.getObject(); + assertEquals("myname", proxy.getName()); + assertEquals(99, proxy.getAge()); + proxy.setAge(50); + assertEquals(50, proxy.getAge()); + proxy.setStringArray(new String[] {"str1", "str2"}); + assertTrue(Arrays.equals(new String[] {"str1", "str2"}, proxy.getStringArray())); + proxy.setSomeIntegerArray(new Integer[] {1, 2, 3}); + assertTrue(Arrays.equals(new Integer[] {1, 2, 3}, proxy.getSomeIntegerArray())); + proxy.setNestedIntegerArray(new Integer[][] {{1, 2, 3}, {4, 5, 6}}); + Integer[][] integerArray = proxy.getNestedIntegerArray(); + assertTrue(Arrays.equals(new Integer[] {1, 2, 3}, integerArray[0])); + assertTrue(Arrays.equals(new Integer[] {4, 5, 6}, integerArray[1])); + proxy.setSomeIntArray(new int[] {1, 2, 3}); + assertTrue(Arrays.equals(new int[] {1, 2, 3}, proxy.getSomeIntArray())); + proxy.setNestedIntArray(new int[][] {{1, 2, 3}, {4, 5, 6}}); + int[][] intArray = proxy.getNestedIntArray(); + assertTrue(Arrays.equals(new int[] {1, 2, 3}, intArray[0])); + assertTrue(Arrays.equals(new int[] {4, 5, 6}, intArray[1])); + + try { + proxy.exceptional(new IllegalStateException()); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + // expected + } + try { + proxy.exceptional(new IllegalAccessException()); + fail("Should have thrown IllegalAccessException"); + } + catch (IllegalAccessException ex) { + // expected + } + } + + @Test + public void httpInvokerProxyFactoryBeanAndServiceExporterWithIOException() throws Exception { + TestBean target = new TestBean("myname", 99); + + final HttpInvokerServiceExporter exporter = new HttpInvokerServiceExporter(); + exporter.setServiceInterface(ITestBean.class); + exporter.setService(target); + exporter.afterPropertiesSet(); + + HttpInvokerProxyFactoryBean pfb = new HttpInvokerProxyFactoryBean(); + pfb.setServiceInterface(ITestBean.class); + pfb.setServiceUrl("http://myurl"); + + pfb.setHttpInvokerRequestExecutor(new HttpInvokerRequestExecutor() { + @Override + public RemoteInvocationResult executeRequest( + HttpInvokerClientConfiguration config, RemoteInvocation invocation) throws IOException { + throw new IOException("argh"); + } + }); + + pfb.afterPropertiesSet(); + ITestBean proxy = (ITestBean) pfb.getObject(); + try { + proxy.setAge(50); + fail("Should have thrown RemoteAccessException"); + } + catch (RemoteAccessException ex) { + // expected + assertTrue(ex.getCause() instanceof IOException); + } + } + + @Test + public void httpInvokerProxyFactoryBeanAndServiceExporterWithGzipCompression() throws Throwable { + TestBean target = new TestBean("myname", 99); + + final HttpInvokerServiceExporter exporter = new HttpInvokerServiceExporter() { + @Override + protected InputStream decorateInputStream(HttpServletRequest request, InputStream is) throws IOException { + if ("gzip".equals(request.getHeader("Compression"))) { + return new GZIPInputStream(is); + } + else { + return is; + } + } + @Override + protected OutputStream decorateOutputStream( + HttpServletRequest request, HttpServletResponse response, OutputStream os) throws IOException { + if ("gzip".equals(request.getHeader("Compression"))) { + return new GZIPOutputStream(os); + } + else { + return os; + } + } + }; + exporter.setServiceInterface(ITestBean.class); + exporter.setService(target); + exporter.afterPropertiesSet(); + + HttpInvokerProxyFactoryBean pfb = new HttpInvokerProxyFactoryBean(); + pfb.setServiceInterface(ITestBean.class); + pfb.setServiceUrl("http://myurl"); + + pfb.setHttpInvokerRequestExecutor(new AbstractHttpInvokerRequestExecutor() { + @Override + protected RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) + throws IOException, ClassNotFoundException { + assertEquals("http://myurl", config.getServiceUrl()); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Compression", "gzip"); + MockHttpServletResponse response = new MockHttpServletResponse(); + request.setContent(baos.toByteArray()); + try { + exporter.handleRequest(request, response); + } + catch (ServletException ex) { + throw new IOException(ex.toString()); + } + return readRemoteInvocationResult( + new ByteArrayInputStream(response.getContentAsByteArray()), config.getCodebaseUrl()); + } + @Override + protected OutputStream decorateOutputStream(OutputStream os) throws IOException { + return new GZIPOutputStream(os); + } + @Override + protected InputStream decorateInputStream(InputStream is) throws IOException { + return new GZIPInputStream(is); + } + }); + + pfb.afterPropertiesSet(); + ITestBean proxy = (ITestBean) pfb.getObject(); + assertEquals("myname", proxy.getName()); + assertEquals(99, proxy.getAge()); + proxy.setAge(50); + assertEquals(50, proxy.getAge()); + + try { + proxy.exceptional(new IllegalStateException()); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + // expected + } + try { + proxy.exceptional(new IllegalAccessException()); + fail("Should have thrown IllegalAccessException"); + } + catch (IllegalAccessException ex) { + // expected + } + } + + @Test + public void httpInvokerProxyFactoryBeanAndServiceExporterWithWrappedInvocations() throws Throwable { + TestBean target = new TestBean("myname", 99); + + final HttpInvokerServiceExporter exporter = new HttpInvokerServiceExporter() { + @Override + protected RemoteInvocation doReadRemoteInvocation(ObjectInputStream ois) + throws IOException, ClassNotFoundException { + Object obj = ois.readObject(); + if (!(obj instanceof TestRemoteInvocationWrapper)) { + throw new IOException("Deserialized object needs to be assignable to type [" + + TestRemoteInvocationWrapper.class.getName() + "]: " + obj); + } + return ((TestRemoteInvocationWrapper) obj).remoteInvocation; + } + @Override + protected void doWriteRemoteInvocationResult(RemoteInvocationResult result, ObjectOutputStream oos) + throws IOException { + oos.writeObject(new TestRemoteInvocationResultWrapper(result)); + } + }; + exporter.setServiceInterface(ITestBean.class); + exporter.setService(target); + exporter.afterPropertiesSet(); + + HttpInvokerProxyFactoryBean pfb = new HttpInvokerProxyFactoryBean(); + pfb.setServiceInterface(ITestBean.class); + pfb.setServiceUrl("http://myurl"); + + pfb.setHttpInvokerRequestExecutor(new AbstractHttpInvokerRequestExecutor() { + @Override + protected RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) throws Exception { + assertEquals("http://myurl", config.getServiceUrl()); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + request.setContent(baos.toByteArray()); + exporter.handleRequest(request, response); + return readRemoteInvocationResult( + new ByteArrayInputStream(response.getContentAsByteArray()), config.getCodebaseUrl()); + } + @Override + protected void doWriteRemoteInvocation(RemoteInvocation invocation, ObjectOutputStream oos) throws IOException { + oos.writeObject(new TestRemoteInvocationWrapper(invocation)); + } + @Override + protected RemoteInvocationResult doReadRemoteInvocationResult(ObjectInputStream ois) + throws IOException, ClassNotFoundException { + Object obj = ois.readObject(); + if (!(obj instanceof TestRemoteInvocationResultWrapper)) { + throw new IOException("Deserialized object needs to be assignable to type [" + + TestRemoteInvocationResultWrapper.class.getName() + "]: " + obj); + } + return ((TestRemoteInvocationResultWrapper) obj).remoteInvocationResult; + } + }); + + pfb.afterPropertiesSet(); + ITestBean proxy = (ITestBean) pfb.getObject(); + assertEquals("myname", proxy.getName()); + assertEquals(99, proxy.getAge()); + proxy.setAge(50); + assertEquals(50, proxy.getAge()); + + try { + proxy.exceptional(new IllegalStateException()); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + // expected + } + try { + proxy.exceptional(new IllegalAccessException()); + fail("Should have thrown IllegalAccessException"); + } + catch (IllegalAccessException ex) { + // expected + } + } + + @Test + public void httpInvokerProxyFactoryBeanAndServiceExporterWithInvocationAttributes() throws Exception { + TestBean target = new TestBean("myname", 99); + + final HttpInvokerServiceExporter exporter = new HttpInvokerServiceExporter(); + exporter.setServiceInterface(ITestBean.class); + exporter.setService(target); + exporter.setRemoteInvocationExecutor(new DefaultRemoteInvocationExecutor() { + @Override + public Object invoke(RemoteInvocation invocation, Object targetObject) + throws NoSuchMethodException, IllegalAccessException, InvocationTargetException { + assertNotNull(invocation.getAttributes()); + assertEquals(1, invocation.getAttributes().size()); + assertEquals("myValue", invocation.getAttributes().get("myKey")); + assertEquals("myValue", invocation.getAttribute("myKey")); + return super.invoke(invocation, targetObject); + } + }); + exporter.afterPropertiesSet(); + + HttpInvokerProxyFactoryBean pfb = new HttpInvokerProxyFactoryBean(); + pfb.setServiceInterface(ITestBean.class); + pfb.setServiceUrl("http://myurl"); + pfb.setRemoteInvocationFactory(new RemoteInvocationFactory() { + @Override + public RemoteInvocation createRemoteInvocation(MethodInvocation methodInvocation) { + RemoteInvocation invocation = new RemoteInvocation(methodInvocation); + invocation.addAttribute("myKey", "myValue"); + try { + invocation.addAttribute("myKey", "myValue"); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + // expected: already defined + } + assertNotNull(invocation.getAttributes()); + assertEquals(1, invocation.getAttributes().size()); + assertEquals("myValue", invocation.getAttributes().get("myKey")); + assertEquals("myValue", invocation.getAttribute("myKey")); + return invocation; + } + }); + + pfb.setHttpInvokerRequestExecutor(new AbstractHttpInvokerRequestExecutor() { + @Override + protected RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) throws Exception { + assertEquals("http://myurl", config.getServiceUrl()); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + request.setContent(baos.toByteArray()); + exporter.handleRequest(request, response); + return readRemoteInvocationResult( + new ByteArrayInputStream(response.getContentAsByteArray()), config.getCodebaseUrl()); + } + }); + + pfb.afterPropertiesSet(); + ITestBean proxy = (ITestBean) pfb.getObject(); + assertEquals("myname", proxy.getName()); + assertEquals(99, proxy.getAge()); + } + + @Test + public void httpInvokerProxyFactoryBeanAndServiceExporterWithCustomInvocationObject() throws Exception { + TestBean target = new TestBean("myname", 99); + + final HttpInvokerServiceExporter exporter = new HttpInvokerServiceExporter(); + exporter.setServiceInterface(ITestBean.class); + exporter.setService(target); + exporter.setRemoteInvocationExecutor(new DefaultRemoteInvocationExecutor() { + @Override + public Object invoke(RemoteInvocation invocation, Object targetObject) + throws NoSuchMethodException, IllegalAccessException, InvocationTargetException { + assertTrue(invocation instanceof TestRemoteInvocation); + assertNull(invocation.getAttributes()); + assertNull(invocation.getAttribute("myKey")); + return super.invoke(invocation, targetObject); + } + }); + exporter.afterPropertiesSet(); + + HttpInvokerProxyFactoryBean pfb = new HttpInvokerProxyFactoryBean(); + pfb.setServiceInterface(ITestBean.class); + pfb.setServiceUrl("http://myurl"); + pfb.setRemoteInvocationFactory(new RemoteInvocationFactory() { + @Override + public RemoteInvocation createRemoteInvocation(MethodInvocation methodInvocation) { + RemoteInvocation invocation = new TestRemoteInvocation(methodInvocation); + assertNull(invocation.getAttributes()); + assertNull(invocation.getAttribute("myKey")); + return invocation; + } + }); + + pfb.setHttpInvokerRequestExecutor(new AbstractHttpInvokerRequestExecutor() { + @Override + protected RemoteInvocationResult doExecuteRequest( + HttpInvokerClientConfiguration config, ByteArrayOutputStream baos) throws Exception { + assertEquals("http://myurl", config.getServiceUrl()); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + request.setContent(baos.toByteArray()); + exporter.handleRequest(request, response); + return readRemoteInvocationResult( + new ByteArrayInputStream(response.getContentAsByteArray()), config.getCodebaseUrl()); + } + }); + + pfb.afterPropertiesSet(); + ITestBean proxy = (ITestBean) pfb.getObject(); + assertEquals("myname", proxy.getName()); + assertEquals(99, proxy.getAge()); + } + + @Test + public void httpInvokerWithSpecialLocalMethods() throws Exception { + String serviceUrl = "http://myurl"; + HttpInvokerProxyFactoryBean pfb = new HttpInvokerProxyFactoryBean(); + pfb.setServiceInterface(ITestBean.class); + pfb.setServiceUrl(serviceUrl); + + pfb.setHttpInvokerRequestExecutor(new HttpInvokerRequestExecutor() { + @Override + public RemoteInvocationResult executeRequest( + HttpInvokerClientConfiguration config, RemoteInvocation invocation) throws IOException { + throw new IOException("argh"); + } + }); + + pfb.afterPropertiesSet(); + ITestBean proxy = (ITestBean) pfb.getObject(); + + // shouldn't go through to remote service + assertTrue(proxy.toString().contains("HTTP invoker")); + assertTrue(proxy.toString().contains(serviceUrl)); + assertEquals(proxy.hashCode(), proxy.hashCode()); + assertTrue(proxy.equals(proxy)); + + // should go through + try { + proxy.setAge(50); + fail("Should have thrown RemoteAccessException"); + } + catch (RemoteAccessException ex) { + // expected + assertTrue(ex.getCause() instanceof IOException); + } + } + + + @SuppressWarnings("serial") + private static class TestRemoteInvocation extends RemoteInvocation { + + public TestRemoteInvocation(MethodInvocation methodInvocation) { + super(methodInvocation); + } + } + + + @SuppressWarnings("serial") + private static class TestRemoteInvocationWrapper implements Serializable { + + private final RemoteInvocation remoteInvocation; + + public TestRemoteInvocationWrapper(RemoteInvocation remoteInvocation) { + this.remoteInvocation = remoteInvocation; + } + } + + + @SuppressWarnings("serial") + private static class TestRemoteInvocationResultWrapper implements Serializable { + + private final RemoteInvocationResult remoteInvocationResult; + + public TestRemoteInvocationResultWrapper(RemoteInvocationResult remoteInvocationResult) { + this.remoteInvocationResult = remoteInvocationResult; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/remoting/jaxws/JaxWsSupportTests.java b/spring-web/src/test/java/org/springframework/remoting/jaxws/JaxWsSupportTests.java new file mode 100644 index 0000000000000000000000000000000000000000..99e6c22a304235386bbc6204d4cded1f6a4c7a76 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/jaxws/JaxWsSupportTests.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import java.net.MalformedURLException; +import java.net.URL; + +import javax.xml.namespace.QName; +import javax.xml.ws.BindingProvider; +import javax.xml.ws.Service; +import javax.xml.ws.WebServiceClient; +import javax.xml.ws.WebServiceException; +import javax.xml.ws.WebServiceFeature; +import javax.xml.ws.WebServiceRef; +import javax.xml.ws.soap.AddressingFeature; + +import org.junit.Test; + +import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.support.GenericBeanDefinition; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.annotation.AnnotationConfigUtils; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.remoting.RemoteAccessException; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @since 2.5 + */ +public class JaxWsSupportTests { + + @Test + public void testJaxWsPortAccess() throws Exception { + doTestJaxWsPortAccess((WebServiceFeature[]) null); + } + + @Test + public void testJaxWsPortAccessWithFeature() throws Exception { + doTestJaxWsPortAccess(new AddressingFeature()); + } + + private void doTestJaxWsPortAccess(WebServiceFeature... features) throws Exception { + GenericApplicationContext ac = new GenericApplicationContext(); + + GenericBeanDefinition serviceDef = new GenericBeanDefinition(); + serviceDef.setBeanClass(OrderServiceImpl.class); + ac.registerBeanDefinition("service", serviceDef); + + GenericBeanDefinition exporterDef = new GenericBeanDefinition(); + exporterDef.setBeanClass(SimpleJaxWsServiceExporter.class); + exporterDef.getPropertyValues().add("baseAddress", "http://localhost:9999/"); + ac.registerBeanDefinition("exporter", exporterDef); + + GenericBeanDefinition clientDef = new GenericBeanDefinition(); + clientDef.setBeanClass(JaxWsPortProxyFactoryBean.class); + clientDef.getPropertyValues().add("wsdlDocumentUrl", "http://localhost:9999/OrderService?wsdl"); + clientDef.getPropertyValues().add("namespaceUri", "http://jaxws.remoting.springframework.org/"); + clientDef.getPropertyValues().add("username", "juergen"); + clientDef.getPropertyValues().add("password", "hoeller"); + clientDef.getPropertyValues().add("serviceName", "OrderService"); + clientDef.getPropertyValues().add("serviceInterface", OrderService.class); + clientDef.getPropertyValues().add("lookupServiceOnStartup", Boolean.FALSE); + if (features != null) { + clientDef.getPropertyValues().add("portFeatures", features); + } + ac.registerBeanDefinition("client", clientDef); + + GenericBeanDefinition serviceFactoryDef = new GenericBeanDefinition(); + serviceFactoryDef.setBeanClass(LocalJaxWsServiceFactoryBean.class); + serviceFactoryDef.getPropertyValues().add("wsdlDocumentUrl", "http://localhost:9999/OrderService?wsdl"); + serviceFactoryDef.getPropertyValues().add("namespaceUri", "http://jaxws.remoting.springframework.org/"); + serviceFactoryDef.getPropertyValues().add("serviceName", "OrderService"); + ac.registerBeanDefinition("orderService", serviceFactoryDef); + + ac.registerBeanDefinition("accessor", new RootBeanDefinition(ServiceAccessor.class)); + AnnotationConfigUtils.registerAnnotationConfigProcessors(ac); + + try { + ac.refresh(); + + OrderService orderService = ac.getBean("client", OrderService.class); + assertTrue(orderService instanceof BindingProvider); + ((BindingProvider) orderService).getRequestContext(); + + String order = orderService.getOrder(1000); + assertEquals("order 1000", order); + try { + orderService.getOrder(0); + fail("Should have thrown OrderNotFoundException"); + } + catch (OrderNotFoundException ex) { + // expected + } + catch (RemoteAccessException ex) { + // ignore - probably setup issue with JAX-WS provider vs JAXB + } + + ServiceAccessor serviceAccessor = ac.getBean("accessor", ServiceAccessor.class); + order = serviceAccessor.orderService.getOrder(1000); + assertEquals("order 1000", order); + try { + serviceAccessor.orderService.getOrder(0); + fail("Should have thrown OrderNotFoundException"); + } + catch (OrderNotFoundException ex) { + // expected + } + catch (WebServiceException ex) { + // ignore - probably setup issue with JAX-WS provider vs JAXB + } + } + catch (BeanCreationException ex) { + if ("exporter".equals(ex.getBeanName()) && ex.getRootCause() instanceof ClassNotFoundException) { + // ignore - probably running on JDK without the JAX-WS impl present + } + else { + throw ex; + } + } + finally { + ac.close(); + } + } + + + public static class ServiceAccessor { + + @WebServiceRef + public OrderService orderService; + + public OrderService myService; + + @WebServiceRef(value = OrderServiceService.class, wsdlLocation = "http://localhost:9999/OrderService?wsdl") + public void setMyService(OrderService myService) { + this.myService = myService; + } + } + + + @WebServiceClient(targetNamespace = "http://jaxws.remoting.springframework.org/", name="OrderService") + public static class OrderServiceService extends Service { + + public OrderServiceService() throws MalformedURLException { + super(new URL("http://localhost:9999/OrderService?wsdl"), + new QName("http://jaxws.remoting.springframework.org/", "OrderService")); + } + + public OrderServiceService(URL wsdlDocumentLocation, QName serviceName) { + super(wsdlDocumentLocation, serviceName); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderNotFoundException.java b/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderNotFoundException.java new file mode 100644 index 0000000000000000000000000000000000000000..add5efc0977de8706ef3d4cb98d773e0d7137c20 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderNotFoundException.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import javax.xml.ws.WebFault; + +/** + * @author Juergen Hoeller + */ +@WebFault +@SuppressWarnings("serial") +public class OrderNotFoundException extends Exception { + + private String faultInfo; + + public OrderNotFoundException(String message) { + super(message); + } + + public OrderNotFoundException(String message, String faultInfo) { + super(message); + this.faultInfo = faultInfo; + } + + public String getFaultInfo() { + return this.faultInfo; + } + +} diff --git a/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderService.java b/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderService.java new file mode 100644 index 0000000000000000000000000000000000000000..2b29d13cb3aab3c117143e21fb50bc0c5a1bbadb --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderService.java @@ -0,0 +1,31 @@ +/* + * Copyright 2002-2007 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import javax.jws.WebService; +import javax.jws.soap.SOAPBinding; + +/** + * @author Juergen Hoeller + */ +@WebService +@SOAPBinding(style = SOAPBinding.Style.RPC) +public interface OrderService { + + String getOrder(int id) throws OrderNotFoundException; + +} diff --git a/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderServiceImpl.java b/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..3f9c9781b38603d6b7dce2d0349d9c76bbe12ee4 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/remoting/jaxws/OrderServiceImpl.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.remoting.jaxws; + +import javax.annotation.Resource; +import javax.jws.WebService; +import javax.xml.ws.WebServiceContext; + +import org.springframework.util.Assert; + +/** + * @author Juergen Hoeller + */ +@WebService(serviceName="OrderService", portName="OrderService", + endpointInterface = "org.springframework.remoting.jaxws.OrderService") +public class OrderServiceImpl implements OrderService { + + @Resource + private WebServiceContext webServiceContext; + + @Override + public String getOrder(int id) throws OrderNotFoundException { + Assert.notNull(this.webServiceContext, "WebServiceContext has not been injected"); + if (id == 0) { + throw new OrderNotFoundException("Order 0 not found"); + } + return "order " + id; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/accept/ContentNegotiationManagerFactoryBeanTests.java b/spring-web/src/test/java/org/springframework/web/accept/ContentNegotiationManagerFactoryBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b32b72b2da02acf05a18aecac764e46b44ddff00 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/accept/ContentNegotiationManagerFactoryBeanTests.java @@ -0,0 +1,242 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.util.StringUtils; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * Test fixture for {@link ContentNegotiationManagerFactoryBean} tests. + * + * @author Rossen Stoyanchev + */ +public class ContentNegotiationManagerFactoryBeanTests { + + private ContentNegotiationManagerFactoryBean factoryBean; + + private NativeWebRequest webRequest; + + private MockHttpServletRequest servletRequest; + + + @Before + public void setup() { + TestServletContext servletContext = new TestServletContext(); + servletContext.getMimeTypes().put("foo", "application/foo"); + + this.servletRequest = new MockHttpServletRequest(servletContext); + this.webRequest = new ServletWebRequest(this.servletRequest); + + this.factoryBean = new ContentNegotiationManagerFactoryBean(); + this.factoryBean.setServletContext(this.servletRequest.getServletContext()); + } + + + @Test + public void defaultSettings() throws Exception { + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + this.servletRequest.setRequestURI("/flower.gif"); + + assertEquals("Should be able to resolve file extensions by default", + Collections.singletonList(MediaType.IMAGE_GIF), manager.resolveMediaTypes(this.webRequest)); + + this.servletRequest.setRequestURI("/flower.foobarbaz"); + + assertEquals("Should ignore unknown extensions by default", + ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, manager.resolveMediaTypes(this.webRequest)); + + this.servletRequest.setRequestURI("/flower"); + this.servletRequest.setParameter("format", "gif"); + + assertEquals("Should not resolve request parameters by default", + ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, manager.resolveMediaTypes(this.webRequest)); + + this.servletRequest.setRequestURI("/flower"); + this.servletRequest.addHeader("Accept", MediaType.IMAGE_GIF_VALUE); + + assertEquals("Should resolve Accept header by default", + Collections.singletonList(MediaType.IMAGE_GIF), manager.resolveMediaTypes(this.webRequest)); + } + + @Test + public void explicitStrategies() throws Exception { + Map mediaTypes = Collections.singletonMap("bar", new MediaType("application", "bar")); + ParameterContentNegotiationStrategy strategy1 = new ParameterContentNegotiationStrategy(mediaTypes); + HeaderContentNegotiationStrategy strategy2 = new HeaderContentNegotiationStrategy(); + List strategies = Arrays.asList(strategy1, strategy2); + this.factoryBean.setStrategies(strategies); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + assertEquals(strategies, manager.getStrategies()); + + this.servletRequest.setRequestURI("/flower"); + this.servletRequest.addParameter("format", "bar"); + assertEquals(Collections.singletonList(new MediaType("application", "bar")), + manager.resolveMediaTypes(this.webRequest)); + + } + + @Test + public void favorPath() throws Exception { + this.factoryBean.setFavorPathExtension(true); + this.factoryBean.addMediaTypes(Collections.singletonMap("bar", new MediaType("application", "bar"))); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + this.servletRequest.setRequestURI("/flower.foo"); + assertEquals(Collections.singletonList(new MediaType("application", "foo")), + manager.resolveMediaTypes(this.webRequest)); + + this.servletRequest.setRequestURI("/flower.bar"); + assertEquals(Collections.singletonList(new MediaType("application", "bar")), + manager.resolveMediaTypes(this.webRequest)); + + this.servletRequest.setRequestURI("/flower.gif"); + assertEquals(Collections.singletonList(MediaType.IMAGE_GIF), manager.resolveMediaTypes(this.webRequest)); + } + + @Test(expected = HttpMediaTypeNotAcceptableException.class) // SPR-10170 + public void favorPathWithIgnoreUnknownPathExtensionTurnedOff() throws Exception { + this.factoryBean.setFavorPathExtension(true); + this.factoryBean.setIgnoreUnknownPathExtensions(false); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + this.servletRequest.setRequestURI("/flower.foobarbaz"); + this.servletRequest.addParameter("format", "json"); + + manager.resolveMediaTypes(this.webRequest); + } + + @Test + public void favorParameter() throws Exception { + this.factoryBean.setFavorParameter(true); + + Map mediaTypes = new HashMap<>(); + mediaTypes.put("json", MediaType.APPLICATION_JSON); + this.factoryBean.addMediaTypes(mediaTypes); + + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + this.servletRequest.setRequestURI("/flower"); + this.servletRequest.addParameter("format", "json"); + + assertEquals(Collections.singletonList(MediaType.APPLICATION_JSON), + manager.resolveMediaTypes(this.webRequest)); + } + + @Test(expected = HttpMediaTypeNotAcceptableException.class) // SPR-10170 + public void favorParameterWithUnknownMediaType() throws HttpMediaTypeNotAcceptableException { + this.factoryBean.setFavorParameter(true); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + this.servletRequest.setRequestURI("/flower"); + this.servletRequest.setParameter("format", "invalid"); + + manager.resolveMediaTypes(this.webRequest); + } + + @Test + public void ignoreAcceptHeader() throws Exception { + this.factoryBean.setIgnoreAcceptHeader(true); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + this.servletRequest.setRequestURI("/flower"); + this.servletRequest.addHeader("Accept", MediaType.IMAGE_GIF_VALUE); + + assertEquals(ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, manager.resolveMediaTypes(this.webRequest)); + } + + @Test + public void setDefaultContentType() throws Exception { + this.factoryBean.setDefaultContentType(MediaType.APPLICATION_JSON); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + assertEquals(MediaType.APPLICATION_JSON, manager.resolveMediaTypes(this.webRequest).get(0)); + + // SPR-10513 + this.servletRequest.addHeader("Accept", MediaType.ALL_VALUE); + assertEquals(MediaType.APPLICATION_JSON, manager.resolveMediaTypes(this.webRequest).get(0)); + } + + @Test // SPR-15367 + public void setDefaultContentTypes() throws Exception { + List mediaTypes = Arrays.asList(MediaType.APPLICATION_JSON, MediaType.ALL); + this.factoryBean.setDefaultContentTypes(mediaTypes); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + assertEquals(mediaTypes, manager.resolveMediaTypes(this.webRequest)); + + this.servletRequest.addHeader("Accept", MediaType.ALL_VALUE); + assertEquals(mediaTypes, manager.resolveMediaTypes(this.webRequest)); + } + + @Test // SPR-12286 + public void setDefaultContentTypeWithStrategy() throws Exception { + this.factoryBean.setDefaultContentTypeStrategy(new FixedContentNegotiationStrategy(MediaType.APPLICATION_JSON)); + this.factoryBean.afterPropertiesSet(); + ContentNegotiationManager manager = this.factoryBean.getObject(); + + assertEquals(Collections.singletonList(MediaType.APPLICATION_JSON), + manager.resolveMediaTypes(this.webRequest)); + + this.servletRequest.addHeader("Accept", MediaType.ALL_VALUE); + assertEquals(Collections.singletonList(MediaType.APPLICATION_JSON), + manager.resolveMediaTypes(this.webRequest)); + } + + + private static class TestServletContext extends MockServletContext { + + private final Map mimeTypes = new HashMap<>(); + + public Map getMimeTypes() { + return this.mimeTypes; + } + + @Override + public String getMimeType(String filePath) { + String extension = StringUtils.getFilenameExtension(filePath); + return getMimeTypes().get(extension); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/accept/HeaderContentNegotiationStrategyTests.java b/spring-web/src/test/java/org/springframework/web/accept/HeaderContentNegotiationStrategyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c307937b7929fa801c3409165383b038a1aee1b5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/accept/HeaderContentNegotiationStrategyTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.List; + +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * Test fixture for HeaderContentNegotiationStrategy tests. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class HeaderContentNegotiationStrategyTests { + + private final HeaderContentNegotiationStrategy strategy = new HeaderContentNegotiationStrategy(); + + private final MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + + private final NativeWebRequest webRequest = new ServletWebRequest(this.servletRequest); + + + @Test + public void resolveMediaTypes() throws Exception { + this.servletRequest.addHeader("Accept", "text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"); + List mediaTypes = this.strategy.resolveMediaTypes(this.webRequest); + + assertEquals(4, mediaTypes.size()); + assertEquals("text/html", mediaTypes.get(0).toString()); + assertEquals("text/x-c", mediaTypes.get(1).toString()); + assertEquals("text/x-dvi;q=0.8", mediaTypes.get(2).toString()); + assertEquals("text/plain;q=0.5", mediaTypes.get(3).toString()); + } + + @Test // SPR-14506 + public void resolveMediaTypesFromMultipleHeaderValues() throws Exception { + this.servletRequest.addHeader("Accept", "text/plain; q=0.5, text/html"); + this.servletRequest.addHeader("Accept", "text/x-dvi; q=0.8, text/x-c"); + List mediaTypes = this.strategy.resolveMediaTypes(this.webRequest); + + assertEquals(4, mediaTypes.size()); + assertEquals("text/html", mediaTypes.get(0).toString()); + assertEquals("text/x-c", mediaTypes.get(1).toString()); + assertEquals("text/x-dvi;q=0.8", mediaTypes.get(2).toString()); + assertEquals("text/plain;q=0.5", mediaTypes.get(3).toString()); + } + + @Test(expected = HttpMediaTypeNotAcceptableException.class) + public void resolveMediaTypesParseError() throws Exception { + this.servletRequest.addHeader("Accept", "textplain; q=0.5"); + this.strategy.resolveMediaTypes(this.webRequest); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/accept/MappingContentNegotiationStrategyTests.java b/spring-web/src/test/java/org/springframework/web/accept/MappingContentNegotiationStrategyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ab50baba2bdcb1cd5b0041180f57781877298015 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/accept/MappingContentNegotiationStrategyTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.accept; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.web.context.request.NativeWebRequest; + +import static org.junit.Assert.*; + +/** + * A test fixture with a test sub-class of AbstractMappingContentNegotiationStrategy. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class MappingContentNegotiationStrategyTests { + + @Test + public void resolveMediaTypes() throws Exception { + Map mapping = Collections.singletonMap("json", MediaType.APPLICATION_JSON); + TestMappingContentNegotiationStrategy strategy = new TestMappingContentNegotiationStrategy("json", mapping); + + List mediaTypes = strategy.resolveMediaTypes(null); + + assertEquals(1, mediaTypes.size()); + assertEquals("application/json", mediaTypes.get(0).toString()); + } + + @Test + public void resolveMediaTypesNoMatch() throws Exception { + Map mapping = null; + TestMappingContentNegotiationStrategy strategy = new TestMappingContentNegotiationStrategy("blah", mapping); + + List mediaTypes = strategy.resolveMediaTypes(null); + + assertEquals(ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, mediaTypes); + } + + @Test + public void resolveMediaTypesNoKey() throws Exception { + Map mapping = Collections.singletonMap("json", MediaType.APPLICATION_JSON); + TestMappingContentNegotiationStrategy strategy = new TestMappingContentNegotiationStrategy(null, mapping); + + List mediaTypes = strategy.resolveMediaTypes(null); + + assertEquals(ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, mediaTypes); + } + + @Test + public void resolveMediaTypesHandleNoMatch() throws Exception { + Map mapping = null; + TestMappingContentNegotiationStrategy strategy = new TestMappingContentNegotiationStrategy("xml", mapping); + + List mediaTypes = strategy.resolveMediaTypes(null); + + assertEquals(1, mediaTypes.size()); + assertEquals("application/xml", mediaTypes.get(0).toString()); + } + + + private static class TestMappingContentNegotiationStrategy extends AbstractMappingContentNegotiationStrategy { + + private final String extension; + + public TestMappingContentNegotiationStrategy(String extension, Map mapping) { + super(mapping); + this.extension = extension; + } + + @Override + protected String getMediaTypeKey(NativeWebRequest request) { + return this.extension; + } + + @Override + protected MediaType handleNoMatch(NativeWebRequest request, String mappingKey) { + return "xml".equals(mappingKey) ? MediaType.APPLICATION_XML : null; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/accept/MappingMediaTypeFileExtensionResolverTests.java b/spring-web/src/test/java/org/springframework/web/accept/MappingMediaTypeFileExtensionResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..32e4a4f8f07f9b50127bbd49a9a769dd9633ffaa --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/accept/MappingMediaTypeFileExtensionResolverTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.accept; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.http.MediaType; + +import static org.junit.Assert.*; + +/** + * Test fixture for {@link MappingMediaTypeFileExtensionResolver}. + * + * @author Rossen Stoyanchev + * @author Melissa Hartsock + */ +public class MappingMediaTypeFileExtensionResolverTests { + + @Test + public void resolveExtensions() { + Map mapping = Collections.singletonMap("json", MediaType.APPLICATION_JSON); + MappingMediaTypeFileExtensionResolver resolver = new MappingMediaTypeFileExtensionResolver(mapping); + List extensions = resolver.resolveFileExtensions(MediaType.APPLICATION_JSON); + + assertEquals(1, extensions.size()); + assertEquals("json", extensions.get(0)); + } + + @Test + public void resolveExtensionsNoMatch() { + Map mapping = Collections.singletonMap("json", MediaType.APPLICATION_JSON); + MappingMediaTypeFileExtensionResolver resolver = new MappingMediaTypeFileExtensionResolver(mapping); + List extensions = resolver.resolveFileExtensions(MediaType.TEXT_HTML); + + assertTrue(extensions.isEmpty()); + } + + /** + * Unit test for SPR-13747 - ensures that reverse lookup of media type from media + * type key is case-insensitive. + */ + @Test + public void lookupMediaTypeCaseInsensitive() { + Map mapping = Collections.singletonMap("json", MediaType.APPLICATION_JSON); + MappingMediaTypeFileExtensionResolver resolver = new MappingMediaTypeFileExtensionResolver(mapping); + MediaType mediaType = resolver.lookupMediaType("JSON"); + + assertEquals(MediaType.APPLICATION_JSON, mediaType); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategyTests.java b/spring-web/src/test/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5cf712dea5027d2799fa033bcf8ab4bbe60261cf --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategyTests.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.accept; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.MediaType; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * A test fixture for PathExtensionContentNegotiationStrategy. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class PathExtensionContentNegotiationStrategyTests { + + private NativeWebRequest webRequest; + + private MockHttpServletRequest servletRequest; + + + @Before + public void setup() { + this.servletRequest = new MockHttpServletRequest(); + this.webRequest = new ServletWebRequest(servletRequest); + } + + + @Test + public void resolveMediaTypesFromMapping() throws Exception { + + this.servletRequest.setRequestURI("test.html"); + + PathExtensionContentNegotiationStrategy strategy = new PathExtensionContentNegotiationStrategy(); + List mediaTypes = strategy.resolveMediaTypes(this.webRequest); + + assertEquals(Arrays.asList(new MediaType("text", "html")), mediaTypes); + + Map mapping = Collections.singletonMap("HTML", MediaType.APPLICATION_XHTML_XML); + strategy = new PathExtensionContentNegotiationStrategy(mapping); + mediaTypes = strategy.resolveMediaTypes(this.webRequest); + + assertEquals(Arrays.asList(new MediaType("application", "xhtml+xml")), mediaTypes); + } + + @Test + public void resolveMediaTypesFromMediaTypeFactory() throws Exception { + + this.servletRequest.setRequestURI("test.xls"); + + PathExtensionContentNegotiationStrategy strategy = new PathExtensionContentNegotiationStrategy(); + List mediaTypes = strategy.resolveMediaTypes(this.webRequest); + + assertEquals(Arrays.asList(new MediaType("application", "vnd.ms-excel")), mediaTypes); + } + + // SPR-8678 + + @Test + public void getMediaTypeFilenameWithContextPath() throws Exception { + + PathExtensionContentNegotiationStrategy strategy = new PathExtensionContentNegotiationStrategy(); + + this.servletRequest.setContextPath("/project-1.0.0.M3"); + this.servletRequest.setRequestURI("/project-1.0.0.M3/"); + assertEquals("Context path should be excluded", ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, + strategy.resolveMediaTypes(webRequest)); + + this.servletRequest.setRequestURI("/project-1.0.0.M3"); + assertEquals("Context path should be excluded", ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, + strategy.resolveMediaTypes(webRequest)); + } + + // SPR-9390 + + @Test + public void getMediaTypeFilenameWithEncodedURI() throws Exception { + + this.servletRequest.setRequestURI("/quo%20vadis%3f.html"); + + PathExtensionContentNegotiationStrategy strategy = new PathExtensionContentNegotiationStrategy(); + List result = strategy.resolveMediaTypes(webRequest); + + assertEquals("Invalid content type", Collections.singletonList(new MediaType("text", "html")), result); + } + + // SPR-10170 + + @Test + public void resolveMediaTypesIgnoreUnknownExtension() throws Exception { + + this.servletRequest.setRequestURI("test.foobar"); + + PathExtensionContentNegotiationStrategy strategy = new PathExtensionContentNegotiationStrategy(); + List mediaTypes = strategy.resolveMediaTypes(this.webRequest); + + assertEquals(ContentNegotiationStrategy.MEDIA_TYPE_ALL_LIST, mediaTypes); + } + + @Test(expected = HttpMediaTypeNotAcceptableException.class) + public void resolveMediaTypesDoNotIgnoreUnknownExtension() throws Exception { + + this.servletRequest.setRequestURI("test.foobar"); + + PathExtensionContentNegotiationStrategy strategy = new PathExtensionContentNegotiationStrategy(); + strategy.setIgnoreUnknownExtensions(false); + strategy.resolveMediaTypes(this.webRequest); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/bind/EscapedErrorsTests.java b/spring-web/src/test/java/org/springframework/web/bind/EscapedErrorsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..770b1cf9e608b4ecaa9ee1ee3f6091dc6f2d55c7 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/bind/EscapedErrorsTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.junit.Test; + +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.validation.BindException; +import org.springframework.validation.Errors; +import org.springframework.validation.FieldError; +import org.springframework.validation.ObjectError; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @since 02.05.2003 + */ +public class EscapedErrorsTests { + + @Test + public void testEscapedErrors() { + TestBean tb = new TestBean(); + tb.setName("empty &"); + + Errors errors = new EscapedErrors(new BindException(tb, "tb")); + errors.rejectValue("name", "NAME_EMPTY &", null, "message: &"); + errors.rejectValue("age", "AGE_NOT_SET ", null, "message: "); + errors.rejectValue("age", "AGE_NOT_32 ", null, "message: "); + errors.reject("GENERAL_ERROR \" '", null, "message: \" '"); + + assertTrue("Correct errors flag", errors.hasErrors()); + assertTrue("Correct number of errors", errors.getErrorCount() == 4); + assertTrue("Correct object name", "tb".equals(errors.getObjectName())); + + assertTrue("Correct global errors flag", errors.hasGlobalErrors()); + assertTrue("Correct number of global errors", errors.getGlobalErrorCount() == 1); + ObjectError globalError = errors.getGlobalError(); + String defaultMessage = globalError.getDefaultMessage(); + assertTrue("Global error message escaped", "message: " '".equals(defaultMessage)); + assertTrue("Global error code not escaped", "GENERAL_ERROR \" '".equals(globalError.getCode())); + ObjectError globalErrorInList = errors.getGlobalErrors().get(0); + assertTrue("Same global error in list", defaultMessage.equals(globalErrorInList.getDefaultMessage())); + ObjectError globalErrorInAllList = errors.getAllErrors().get(3); + assertTrue("Same global error in list", defaultMessage.equals(globalErrorInAllList.getDefaultMessage())); + + assertTrue("Correct field errors flag", errors.hasFieldErrors()); + assertTrue("Correct number of field errors", errors.getFieldErrorCount() == 3); + assertTrue("Correct number of field errors in list", errors.getFieldErrors().size() == 3); + FieldError fieldError = errors.getFieldError(); + assertTrue("Field error code not escaped", "NAME_EMPTY &".equals(fieldError.getCode())); + assertTrue("Field value escaped", "empty &".equals(errors.getFieldValue("name"))); + FieldError fieldErrorInList = errors.getFieldErrors().get(0); + assertTrue("Same field error in list", fieldError.getDefaultMessage().equals(fieldErrorInList.getDefaultMessage())); + + assertTrue("Correct name errors flag", errors.hasFieldErrors("name")); + assertTrue("Correct number of name errors", errors.getFieldErrorCount("name") == 1); + assertTrue("Correct number of name errors in list", errors.getFieldErrors("name").size() == 1); + FieldError nameError = errors.getFieldError("name"); + assertTrue("Name error message escaped", "message: &".equals(nameError.getDefaultMessage())); + assertTrue("Name error code not escaped", "NAME_EMPTY &".equals(nameError.getCode())); + assertTrue("Name value escaped", "empty &".equals(errors.getFieldValue("name"))); + FieldError nameErrorInList = errors.getFieldErrors("name").get(0); + assertTrue("Same name error in list", nameError.getDefaultMessage().equals(nameErrorInList.getDefaultMessage())); + + assertTrue("Correct age errors flag", errors.hasFieldErrors("age")); + assertTrue("Correct number of age errors", errors.getFieldErrorCount("age") == 2); + assertTrue("Correct number of age errors in list", errors.getFieldErrors("age").size() == 2); + FieldError ageError = errors.getFieldError("age"); + assertTrue("Age error message escaped", "message: <tag>".equals(ageError.getDefaultMessage())); + assertTrue("Age error code not escaped", "AGE_NOT_SET ".equals(ageError.getCode())); + assertTrue("Age value not escaped", (new Integer(0)).equals(errors.getFieldValue("age"))); + FieldError ageErrorInList = errors.getFieldErrors("age").get(0); + assertTrue("Same name error in list", ageError.getDefaultMessage().equals(ageErrorInList.getDefaultMessage())); + FieldError ageError2 = errors.getFieldErrors("age").get(1); + assertTrue("Age error 2 message escaped", "message: <tag>".equals(ageError2.getDefaultMessage())); + assertTrue("Age error 2 code not escaped", "AGE_NOT_32 ".equals(ageError2.getCode())); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/bind/ServletRequestDataBinderTests.java b/spring-web/src/test/java/org/springframework/web/bind/ServletRequestDataBinderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ab341c1f2488e60846e6986d96b5f17b936cf4a7 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/bind/ServletRequestDataBinderTests.java @@ -0,0 +1,259 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import java.beans.PropertyEditorSupport; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.beans.PropertyValue; +import org.springframework.beans.PropertyValues; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; + +import static org.junit.Assert.*; + +/** + * @author Rod Johnson + * @author Juergen Hoeller + * @author Chris Beams + * @author Scott Andrews + */ +public class ServletRequestDataBinderTests { + + @Test + public void testBindingWithNestedObjectCreation() throws Exception { + TestBean tb = new TestBean(); + + ServletRequestDataBinder binder = new ServletRequestDataBinder(tb, "person"); + binder.registerCustomEditor(ITestBean.class, new PropertyEditorSupport() { + @Override + public void setAsText(String text) throws IllegalArgumentException { + setValue(new TestBean()); + } + }); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("spouse", "someValue"); + request.addParameter("spouse.name", "test"); + binder.bind(request); + + assertNotNull(tb.getSpouse()); + assertEquals("test", tb.getSpouse().getName()); + } + + @Test + public void testFieldPrefixCausesFieldReset() throws Exception { + TestBean target = new TestBean(); + ServletRequestDataBinder binder = new ServletRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("_postProcessed", "visible"); + request.addParameter("postProcessed", "on"); + binder.bind(request); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(request); + assertFalse(target.isPostProcessed()); + } + + @Test + public void testFieldPrefixCausesFieldResetWithIgnoreUnknownFields() throws Exception { + TestBean target = new TestBean(); + ServletRequestDataBinder binder = new ServletRequestDataBinder(target); + binder.setIgnoreUnknownFields(false); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("_postProcessed", "visible"); + request.addParameter("postProcessed", "on"); + binder.bind(request); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(request); + assertFalse(target.isPostProcessed()); + } + + @Test + public void testFieldDefault() throws Exception { + TestBean target = new TestBean(); + ServletRequestDataBinder binder = new ServletRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("!postProcessed", "off"); + request.addParameter("postProcessed", "on"); + binder.bind(request); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(request); + assertFalse(target.isPostProcessed()); + } + + @Test + public void testFieldDefaultPreemptsFieldMarker() throws Exception { + TestBean target = new TestBean(); + ServletRequestDataBinder binder = new ServletRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("!postProcessed", "on"); + request.addParameter("_postProcessed", "visible"); + request.addParameter("postProcessed", "on"); + binder.bind(request); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(request); + assertTrue(target.isPostProcessed()); + + request.removeParameter("!postProcessed"); + binder.bind(request); + assertFalse(target.isPostProcessed()); + } + + @Test + public void testFieldDefaultNonBoolean() throws Exception { + TestBean target = new TestBean(); + ServletRequestDataBinder binder = new ServletRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("!name", "anonymous"); + request.addParameter("name", "Scott"); + binder.bind(request); + assertEquals("Scott", target.getName()); + + request.removeParameter("name"); + binder.bind(request); + assertEquals("anonymous", target.getName()); + } + + @Test + public void testWithCommaSeparatedStringArray() throws Exception { + TestBean target = new TestBean(); + ServletRequestDataBinder binder = new ServletRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("stringArray", "bar"); + request.addParameter("stringArray", "abc"); + request.addParameter("stringArray", "123,def"); + binder.bind(request); + assertEquals("Expected all three items to be bound", 3, target.getStringArray().length); + + request.removeParameter("stringArray"); + request.addParameter("stringArray", "123,def"); + binder.bind(request); + assertEquals("Expected only 1 item to be bound", 1, target.getStringArray().length); + } + + @Test + public void testBindingWithNestedObjectCreationAndWrongOrder() throws Exception { + TestBean tb = new TestBean(); + + ServletRequestDataBinder binder = new ServletRequestDataBinder(tb, "person"); + binder.registerCustomEditor(ITestBean.class, new PropertyEditorSupport() { + @Override + public void setAsText(String text) throws IllegalArgumentException { + setValue(new TestBean()); + } + }); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("spouse.name", "test"); + request.addParameter("spouse", "someValue"); + binder.bind(request); + + assertNotNull(tb.getSpouse()); + assertEquals("test", tb.getSpouse().getName()); + } + + @Test + public void testNoPrefix() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("forname", "Tony"); + request.addParameter("surname", "Blair"); + request.addParameter("age", "" + 50); + + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + doTestTony(pvs); + } + + @Test + public void testPrefix() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("test_forname", "Tony"); + request.addParameter("test_surname", "Blair"); + request.addParameter("test_age", "" + 50); + + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + assertTrue("Didn't find normal when given prefix", !pvs.contains("forname")); + assertTrue("Did treat prefix as normal when not given prefix", pvs.contains("test_forname")); + + pvs = new ServletRequestParameterPropertyValues(request, "test"); + doTestTony(pvs); + } + + @Test + public void testNoParameters() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + assertTrue("Found no parameters", pvs.getPropertyValues().length == 0); + } + + @Test + public void testMultipleValuesForParameter() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + String[] original = new String[] {"Tony", "Rod"}; + request.addParameter("forname", original); + + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + assertTrue("Found 1 parameter", pvs.getPropertyValues().length == 1); + assertTrue("Found array value", pvs.getPropertyValue("forname").getValue() instanceof String[]); + String[] values = (String[]) pvs.getPropertyValue("forname").getValue(); + assertEquals("Correct values", Arrays.asList(values), Arrays.asList(original)); + } + + /** + * Must contain: forname=Tony surname=Blair age=50 + */ + protected void doTestTony(PropertyValues pvs) throws Exception { + assertTrue("Contains 3", pvs.getPropertyValues().length == 3); + assertTrue("Contains forname", pvs.contains("forname")); + assertTrue("Contains surname", pvs.contains("surname")); + assertTrue("Contains age", pvs.contains("age")); + assertTrue("Doesn't contain tory", !pvs.contains("tory")); + + PropertyValue[] ps = pvs.getPropertyValues(); + Map m = new HashMap<>(); + m.put("forname", "Tony"); + m.put("surname", "Blair"); + m.put("age", "50"); + for (int i = 0; i < ps.length; i++) { + Object val = m.get(ps[i].getName()); + assertTrue("Can't have unexpected value", val != null); + assertTrue("Val i string", val instanceof String); + assertTrue("val matches expected", val.equals(ps[i].getValue())); + m.remove(ps[i].getName()); + } + assertTrue("Map size is 0", m.size() == 0); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/bind/ServletRequestUtilsTests.java b/spring-web/src/test/java/org/springframework/web/bind/ServletRequestUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1c31f9988beccd98e400c1a03142e970e37fc8fd --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/bind/ServletRequestUtilsTests.java @@ -0,0 +1,474 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.tests.Assume; +import org.springframework.tests.TestGroup; +import org.springframework.util.StopWatch; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @author Chris Beams + * @since 06.08.2003 + */ +public class ServletRequestUtilsTests { + + @Test + public void testIntParameter() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param1", "5"); + request.addParameter("param2", "e"); + request.addParameter("paramEmpty", ""); + + assertEquals(ServletRequestUtils.getIntParameter(request, "param1"), new Integer(5)); + assertEquals(ServletRequestUtils.getIntParameter(request, "param1", 6), 5); + assertEquals(ServletRequestUtils.getRequiredIntParameter(request, "param1"), 5); + + assertEquals(ServletRequestUtils.getIntParameter(request, "param2", 6), 6); + try { + ServletRequestUtils.getRequiredIntParameter(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + assertEquals(ServletRequestUtils.getIntParameter(request, "param3"), null); + assertEquals(ServletRequestUtils.getIntParameter(request, "param3", 6), 6); + try { + ServletRequestUtils.getRequiredIntParameter(request, "param3"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + try { + ServletRequestUtils.getRequiredIntParameter(request, "paramEmpty"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testIntParameters() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param", new String[] {"1", "2", "3"}); + + request.addParameter("param2", "1"); + request.addParameter("param2", "2"); + request.addParameter("param2", "bogus"); + + int[] array = new int[] {1, 2, 3}; + int[] values = ServletRequestUtils.getRequiredIntParameters(request, "param"); + assertEquals(3, values.length); + for (int i = 0; i < array.length; i++) { + assertEquals(array[i], values[i]); + } + + try { + ServletRequestUtils.getRequiredIntParameters(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testLongParameter() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param1", "5"); + request.addParameter("param2", "e"); + request.addParameter("paramEmpty", ""); + + assertEquals(ServletRequestUtils.getLongParameter(request, "param1"), new Long(5L)); + assertEquals(ServletRequestUtils.getLongParameter(request, "param1", 6L), 5L); + assertEquals(ServletRequestUtils.getRequiredIntParameter(request, "param1"), 5L); + + assertEquals(ServletRequestUtils.getLongParameter(request, "param2", 6L), 6L); + try { + ServletRequestUtils.getRequiredLongParameter(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + assertEquals(ServletRequestUtils.getLongParameter(request, "param3"), null); + assertEquals(ServletRequestUtils.getLongParameter(request, "param3", 6L), 6L); + try { + ServletRequestUtils.getRequiredLongParameter(request, "param3"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + try { + ServletRequestUtils.getRequiredLongParameter(request, "paramEmpty"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testLongParameters() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("param", new String[] {"1", "2", "3"}); + + request.setParameter("param2", "0"); + request.setParameter("param2", "1"); + request.addParameter("param2", "2"); + request.addParameter("param2", "bogus"); + + long[] array = new long[] {1L, 2L, 3L}; + long[] values = ServletRequestUtils.getRequiredLongParameters(request, "param"); + assertEquals(3, values.length); + for (int i = 0; i < array.length; i++) { + assertEquals(array[i], values[i]); + } + + try { + ServletRequestUtils.getRequiredLongParameters(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + request.setParameter("param2", new String[] {"1", "2"}); + values = ServletRequestUtils.getRequiredLongParameters(request, "param2"); + assertEquals(2, values.length); + assertEquals(1, values[0]); + assertEquals(2, values[1]); + + request.removeParameter("param2"); + try { + ServletRequestUtils.getRequiredLongParameters(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testFloatParameter() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param1", "5.5"); + request.addParameter("param2", "e"); + request.addParameter("paramEmpty", ""); + + assertTrue(ServletRequestUtils.getFloatParameter(request, "param1").equals(new Float(5.5f))); + assertTrue(ServletRequestUtils.getFloatParameter(request, "param1", 6.5f) == 5.5f); + assertTrue(ServletRequestUtils.getRequiredFloatParameter(request, "param1") == 5.5f); + + assertTrue(ServletRequestUtils.getFloatParameter(request, "param2", 6.5f) == 6.5f); + try { + ServletRequestUtils.getRequiredFloatParameter(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + assertTrue(ServletRequestUtils.getFloatParameter(request, "param3") == null); + assertTrue(ServletRequestUtils.getFloatParameter(request, "param3", 6.5f) == 6.5f); + try { + ServletRequestUtils.getRequiredFloatParameter(request, "param3"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + try { + ServletRequestUtils.getRequiredFloatParameter(request, "paramEmpty"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testFloatParameters() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param", new String[] {"1.5", "2.5", "3"}); + + request.addParameter("param2", "1.5"); + request.addParameter("param2", "2"); + request.addParameter("param2", "bogus"); + + float[] array = new float[] {1.5F, 2.5F, 3}; + float[] values = ServletRequestUtils.getRequiredFloatParameters(request, "param"); + assertEquals(3, values.length); + for (int i = 0; i < array.length; i++) { + assertEquals(array[i], values[i], 0); + } + + try { + ServletRequestUtils.getRequiredFloatParameters(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testDoubleParameter() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param1", "5.5"); + request.addParameter("param2", "e"); + request.addParameter("paramEmpty", ""); + + assertTrue(ServletRequestUtils.getDoubleParameter(request, "param1").equals(new Double(5.5))); + assertTrue(ServletRequestUtils.getDoubleParameter(request, "param1", 6.5) == 5.5); + assertTrue(ServletRequestUtils.getRequiredDoubleParameter(request, "param1") == 5.5); + + assertTrue(ServletRequestUtils.getDoubleParameter(request, "param2", 6.5) == 6.5); + try { + ServletRequestUtils.getRequiredDoubleParameter(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + assertTrue(ServletRequestUtils.getDoubleParameter(request, "param3") == null); + assertTrue(ServletRequestUtils.getDoubleParameter(request, "param3", 6.5) == 6.5); + try { + ServletRequestUtils.getRequiredDoubleParameter(request, "param3"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + try { + ServletRequestUtils.getRequiredDoubleParameter(request, "paramEmpty"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testDoubleParameters() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param", new String[] {"1.5", "2.5", "3"}); + + request.addParameter("param2", "1.5"); + request.addParameter("param2", "2"); + request.addParameter("param2", "bogus"); + + double[] array = new double[] {1.5, 2.5, 3}; + double[] values = ServletRequestUtils.getRequiredDoubleParameters(request, "param"); + assertEquals(3, values.length); + for (int i = 0; i < array.length; i++) { + assertEquals(array[i], values[i], 0); + } + + try { + ServletRequestUtils.getRequiredDoubleParameters(request, "param2"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + } + + @Test + public void testBooleanParameter() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param1", "true"); + request.addParameter("param2", "e"); + request.addParameter("param4", "yes"); + request.addParameter("param5", "1"); + request.addParameter("paramEmpty", ""); + + assertTrue(ServletRequestUtils.getBooleanParameter(request, "param1").equals(Boolean.TRUE)); + assertTrue(ServletRequestUtils.getBooleanParameter(request, "param1", false)); + assertTrue(ServletRequestUtils.getRequiredBooleanParameter(request, "param1")); + + assertFalse(ServletRequestUtils.getBooleanParameter(request, "param2", true)); + assertFalse(ServletRequestUtils.getRequiredBooleanParameter(request, "param2")); + + assertTrue(ServletRequestUtils.getBooleanParameter(request, "param3") == null); + assertTrue(ServletRequestUtils.getBooleanParameter(request, "param3", true)); + try { + ServletRequestUtils.getRequiredBooleanParameter(request, "param3"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + assertTrue(ServletRequestUtils.getBooleanParameter(request, "param4", false)); + assertTrue(ServletRequestUtils.getRequiredBooleanParameter(request, "param4")); + + assertTrue(ServletRequestUtils.getBooleanParameter(request, "param5", false)); + assertTrue(ServletRequestUtils.getRequiredBooleanParameter(request, "param5")); + assertFalse(ServletRequestUtils.getRequiredBooleanParameter(request, "paramEmpty")); + } + + @Test + public void testBooleanParameters() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param", new String[] {"true", "yes", "off", "1", "bogus"}); + + request.addParameter("param2", "false"); + request.addParameter("param2", "true"); + request.addParameter("param2", ""); + + boolean[] array = new boolean[] {true, true, false, true, false}; + boolean[] values = ServletRequestUtils.getRequiredBooleanParameters(request, "param"); + assertEquals(array.length, values.length); + for (int i = 0; i < array.length; i++) { + assertEquals(array[i], values[i]); + } + + array = new boolean[] {false, true, false}; + values = ServletRequestUtils.getRequiredBooleanParameters(request, "param2"); + assertEquals(array.length, values.length); + for (int i = 0; i < array.length; i++) { + assertEquals(array[i], values[i]); + } + } + + @Test + public void testStringParameter() throws ServletRequestBindingException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("param1", "str"); + request.addParameter("paramEmpty", ""); + + assertEquals("str", ServletRequestUtils.getStringParameter(request, "param1")); + assertEquals("str", ServletRequestUtils.getStringParameter(request, "param1", "string")); + assertEquals("str", ServletRequestUtils.getRequiredStringParameter(request, "param1")); + + assertEquals(null, ServletRequestUtils.getStringParameter(request, "param3")); + assertEquals("string", ServletRequestUtils.getStringParameter(request, "param3", "string")); + assertNull(ServletRequestUtils.getStringParameter(request, "param3", null)); + try { + ServletRequestUtils.getRequiredStringParameter(request, "param3"); + fail("Should have thrown ServletRequestBindingException"); + } + catch (ServletRequestBindingException ex) { + // expected + } + + assertEquals("", ServletRequestUtils.getStringParameter(request, "paramEmpty")); + assertEquals("", ServletRequestUtils.getRequiredStringParameter(request, "paramEmpty")); + } + + @Test + public void testGetIntParameterWithDefaultValueHandlingIsFastEnough() { + Assume.group(TestGroup.PERFORMANCE); + MockHttpServletRequest request = new MockHttpServletRequest(); + StopWatch sw = new StopWatch(); + sw.start(); + for (int i = 0; i < 1000000; i++) { + ServletRequestUtils.getIntParameter(request, "nonExistingParam", 0); + } + sw.stop(); + System.out.println(sw.getTotalTimeMillis()); + assertTrue("getStringParameter took too long: " + sw.getTotalTimeMillis(), sw.getTotalTimeMillis() < 250); + } + + @Test + public void testGetLongParameterWithDefaultValueHandlingIsFastEnough() { + Assume.group(TestGroup.PERFORMANCE); + MockHttpServletRequest request = new MockHttpServletRequest(); + StopWatch sw = new StopWatch(); + sw.start(); + for (int i = 0; i < 1000000; i++) { + ServletRequestUtils.getLongParameter(request, "nonExistingParam", 0); + } + sw.stop(); + System.out.println(sw.getTotalTimeMillis()); + assertTrue("getStringParameter took too long: " + sw.getTotalTimeMillis(), sw.getTotalTimeMillis() < 250); + } + + @Test + public void testGetFloatParameterWithDefaultValueHandlingIsFastEnough() { + Assume.group(TestGroup.PERFORMANCE); + MockHttpServletRequest request = new MockHttpServletRequest(); + StopWatch sw = new StopWatch(); + sw.start(); + for (int i = 0; i < 1000000; i++) { + ServletRequestUtils.getFloatParameter(request, "nonExistingParam", 0f); + } + sw.stop(); + System.out.println(sw.getTotalTimeMillis()); + assertTrue("getStringParameter took too long: " + sw.getTotalTimeMillis(), sw.getTotalTimeMillis() < 250); + } + + @Test + public void testGetDoubleParameterWithDefaultValueHandlingIsFastEnough() { + Assume.group(TestGroup.PERFORMANCE); + MockHttpServletRequest request = new MockHttpServletRequest(); + StopWatch sw = new StopWatch(); + sw.start(); + for (int i = 0; i < 1000000; i++) { + ServletRequestUtils.getDoubleParameter(request, "nonExistingParam", 0d); + } + sw.stop(); + System.out.println(sw.getTotalTimeMillis()); + assertTrue("getStringParameter took too long: " + sw.getTotalTimeMillis(), sw.getTotalTimeMillis() < 250); + } + + @Test + public void testGetBooleanParameterWithDefaultValueHandlingIsFastEnough() { + Assume.group(TestGroup.PERFORMANCE); + MockHttpServletRequest request = new MockHttpServletRequest(); + StopWatch sw = new StopWatch(); + sw.start(); + for (int i = 0; i < 1000000; i++) { + ServletRequestUtils.getBooleanParameter(request, "nonExistingParam", false); + } + sw.stop(); + System.out.println(sw.getTotalTimeMillis()); + assertTrue("getStringParameter took too long: " + sw.getTotalTimeMillis(), sw.getTotalTimeMillis() < 250); + } + + @Test + public void testGetStringParameterWithDefaultValueHandlingIsFastEnough() { + Assume.group(TestGroup.PERFORMANCE); + MockHttpServletRequest request = new MockHttpServletRequest(); + StopWatch sw = new StopWatch(); + sw.start(); + for (int i = 0; i < 1000000; i++) { + ServletRequestUtils.getStringParameter(request, "nonExistingParam", "defaultValue"); + } + sw.stop(); + System.out.println(sw.getTotalTimeMillis()); + assertTrue("getStringParameter took too long: " + sw.getTotalTimeMillis(), sw.getTotalTimeMillis() < 250); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebExchangeDataBinderTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebExchangeDataBinderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..dd386e20585f4fb0c30d405b93061760e5fcb009 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebExchangeDataBinderTests.java @@ -0,0 +1,310 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import java.beans.PropertyEditorSupport; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.mock.http.client.reactive.test.MockClientHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.server.ServerWebExchange; + +import static junit.framework.TestCase.assertFalse; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.springframework.core.ResolvableType.forClass; +import static org.springframework.core.ResolvableType.forClassWithGenerics; + +/** + * Unit tests for {@link WebExchangeDataBinder}. + * + * @author Rossen Stoyanchev + */ +public class WebExchangeDataBinderTests { + + private TestBean testBean; + + private WebExchangeDataBinder binder; + + + @Before + public void setup() throws Exception { + this.testBean = new TestBean(); + this.binder = new WebExchangeDataBinder(this.testBean, "person"); + this.binder.registerCustomEditor(ITestBean.class, new TestBeanPropertyEditor()); + } + + + @Test + public void testBindingWithNestedObjectCreation() throws Exception { + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("spouse", "someValue"); + formData.add("spouse.name", "test"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + + assertNotNull(this.testBean.getSpouse()); + assertEquals("test", testBean.getSpouse().getName()); + } + + @Test + public void testFieldPrefixCausesFieldReset() throws Exception { + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("_postProcessed", "visible"); + formData.add("postProcessed", "on"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertTrue(this.testBean.isPostProcessed()); + + formData.remove("postProcessed"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertFalse(this.testBean.isPostProcessed()); + } + + @Test + public void testFieldPrefixCausesFieldResetWithIgnoreUnknownFields() throws Exception { + this.binder.setIgnoreUnknownFields(false); + + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("_postProcessed", "visible"); + formData.add("postProcessed", "on"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertTrue(this.testBean.isPostProcessed()); + + formData.remove("postProcessed"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertFalse(this.testBean.isPostProcessed()); + } + + @Test + public void testFieldDefault() throws Exception { + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("!postProcessed", "off"); + formData.add("postProcessed", "on"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertTrue(this.testBean.isPostProcessed()); + + formData.remove("postProcessed"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertFalse(this.testBean.isPostProcessed()); + } + + @Test + public void testFieldDefaultPreemptsFieldMarker() throws Exception { + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("!postProcessed", "on"); + formData.add("_postProcessed", "visible"); + formData.add("postProcessed", "on"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertTrue(this.testBean.isPostProcessed()); + + formData.remove("postProcessed"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertTrue(this.testBean.isPostProcessed()); + + formData.remove("!postProcessed"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertFalse(this.testBean.isPostProcessed()); + } + + @Test + public void testFieldDefaultNonBoolean() throws Exception { + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("!name", "anonymous"); + formData.add("name", "Scott"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertEquals("Scott", this.testBean.getName()); + + formData.remove("name"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertEquals("anonymous", this.testBean.getName()); + } + + @Test + public void testWithCommaSeparatedStringArray() throws Exception { + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("stringArray", "bar"); + formData.add("stringArray", "abc"); + formData.add("stringArray", "123,def"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertEquals("Expected all three items to be bound", 3, this.testBean.getStringArray().length); + + formData.remove("stringArray"); + formData.add("stringArray", "123,def"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + assertEquals("Expected only 1 item to be bound", 1, this.testBean.getStringArray().length); + } + + @Test + public void testBindingWithNestedObjectCreationAndWrongOrder() throws Exception { + MultiValueMap formData = new LinkedMultiValueMap<>(); + formData.add("spouse.name", "test"); + formData.add("spouse", "someValue"); + this.binder.bind(exchange(formData)).block(Duration.ofMillis(5000)); + + assertNotNull(this.testBean.getSpouse()); + assertEquals("test", this.testBean.getSpouse().getName()); + } + + @Test + public void testBindingWithQueryParams() throws Exception { + String url = "/path?spouse=someValue&spouse.name=test"; + ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post(url)); + this.binder.bind(exchange).block(Duration.ofSeconds(5)); + + assertNotNull(this.testBean.getSpouse()); + assertEquals("test", this.testBean.getSpouse().getName()); + } + + @Test + public void testMultipart() throws Exception { + + MultipartBean bean = new MultipartBean(); + WebExchangeDataBinder binder = new WebExchangeDataBinder(bean); + + MultiValueMap data = new LinkedMultiValueMap<>(); + data.add("name", "bar"); + data.add("someList", "123"); + data.add("someList", "abc"); + data.add("someArray", "dec"); + data.add("someArray", "456"); + data.add("part", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + data.add("somePartList", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + data.add("somePartList", new ClassPathResource("org/springframework/http/server/reactive/spring.png")); + binder.bind(exchangeMultipart(data)).block(Duration.ofMillis(5000)); + + assertEquals("bar", bean.getName()); + assertEquals(Arrays.asList("123", "abc"), bean.getSomeList()); + assertArrayEquals(new String[] {"dec", "456"}, bean.getSomeArray()); + assertEquals("foo.txt", bean.getPart().filename()); + assertEquals(2, bean.getSomePartList().size()); + assertEquals("foo.txt", bean.getSomePartList().get(0).filename()); + assertEquals("spring.png", bean.getSomePartList().get(1).filename()); + } + + + + private ServerWebExchange exchange(MultiValueMap formData) { + + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.POST, "/"); + + new FormHttpMessageWriter().write(Mono.just(formData), + forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.APPLICATION_FORM_URLENCODED, request, Collections.emptyMap()).block(); + + return MockServerWebExchange.from( + MockServerHttpRequest + .post("/") + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .body(request.getBody())); + } + + private ServerWebExchange exchangeMultipart(MultiValueMap multipartData) { + + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.POST, "/"); + + new MultipartHttpMessageWriter().write(Mono.just(multipartData), forClass(MultiValueMap.class), + MediaType.MULTIPART_FORM_DATA, request, Collections.emptyMap()).block(); + + return MockServerWebExchange.from(MockServerHttpRequest + .post("/") + .contentType(request.getHeaders().getContentType()) + .body(request.getBody())); + } + + + private static class TestBeanPropertyEditor extends PropertyEditorSupport { + + @Override + public void setAsText(String text) { + setValue(new TestBean()); + } + } + + private static class MultipartBean { + + private String name; + + private List someList; + + private String[] someArray; + + private FilePart part; + + private List somePartList; + + + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + + public List getSomeList() { + return this.someList; + } + + public void setSomeList(List someList) { + this.someList = someList; + } + + public String[] getSomeArray() { + return this.someArray; + } + + public void setSomeArray(String[] someArray) { + this.someArray = someArray; + } + + public FilePart getPart() { + return this.part; + } + + public void setPart(FilePart part) { + this.part = part; + } + + public List getSomePartList() { + return this.somePartList; + } + + public void setSomePartList(List somePartList) { + this.somePartList = somePartList; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8ec0492b2b031f6089d9a028b72ca2414d4485ac --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java @@ -0,0 +1,202 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import java.util.List; + +import javax.servlet.MultipartConfigElement; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.Part; + +import org.eclipse.jetty.server.Connector; +import org.eclipse.jetty.server.NetworkConnector; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * @author Brian Clozel + * @author Sam Brannen + */ +public class WebRequestDataBinderIntegrationTests { + + private static Server jettyServer; + + private static final PartsServlet partsServlet = new PartsServlet(); + + private static final PartListServlet partListServlet = new PartListServlet(); + + private final RestTemplate template = new RestTemplate(new HttpComponentsClientHttpRequestFactory()); + + protected static String baseUrl; + + protected static MediaType contentType; + + + @BeforeClass + public static void startJettyServer() throws Exception { + // Let server pick its own random, available port. + jettyServer = new Server(0); + + ServletContextHandler handler = new ServletContextHandler(); + + MultipartConfigElement multipartConfig = new MultipartConfigElement(""); + + ServletHolder holder = new ServletHolder(partsServlet); + holder.getRegistration().setMultipartConfig(multipartConfig); + handler.addServlet(holder, "/parts"); + + holder = new ServletHolder(partListServlet); + holder.getRegistration().setMultipartConfig(multipartConfig); + handler.addServlet(holder, "/partlist"); + + jettyServer.setHandler(handler); + jettyServer.start(); + + Connector[] connectors = jettyServer.getConnectors(); + NetworkConnector connector = (NetworkConnector) connectors[0]; + baseUrl = "http://localhost:" + connector.getLocalPort(); + } + + @AfterClass + public static void stopJettyServer() throws Exception { + if (jettyServer != null) { + jettyServer.stop(); + } + } + + + @Test + public void partsBinding() { + PartsBean bean = new PartsBean(); + partsServlet.setBean(bean); + + MultiValueMap parts = new LinkedMultiValueMap<>(); + Resource firstPart = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("firstPart", firstPart); + parts.add("secondPart", "secondValue"); + + template.postForLocation(baseUrl + "/parts", parts); + + assertNotNull(bean.getFirstPart()); + assertNotNull(bean.getSecondPart()); + } + + @Test + public void partListBinding() { + PartListBean bean = new PartListBean(); + partListServlet.setBean(bean); + + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("partList", "first value"); + parts.add("partList", "second value"); + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("partList", logo); + + template.postForLocation(baseUrl + "/partlist", parts); + + assertNotNull(bean.getPartList()); + assertEquals(parts.get("partList").size(), bean.getPartList().size()); + } + + + @SuppressWarnings("serial") + private abstract static class AbstractStandardMultipartServlet extends HttpServlet { + + private T bean; + + @Override + public void service(HttpServletRequest request, HttpServletResponse response) { + WebRequestDataBinder binder = new WebRequestDataBinder(bean); + ServletWebRequest webRequest = new ServletWebRequest(request, response); + binder.bind(webRequest); + response.setStatus(HttpServletResponse.SC_OK); + } + + public void setBean(T bean) { + this.bean = bean; + } + } + + + private static class PartsBean { + + public Part firstPart; + + public Part secondPart; + + public Part getFirstPart() { + return firstPart; + } + + @SuppressWarnings("unused") + public void setFirstPart(Part firstPart) { + this.firstPart = firstPart; + } + + public Part getSecondPart() { + return secondPart; + } + + @SuppressWarnings("unused") + public void setSecondPart(Part secondPart) { + this.secondPart = secondPart; + } + } + + + @SuppressWarnings("serial") + private static class PartsServlet extends AbstractStandardMultipartServlet { + } + + + private static class PartListBean { + + public List partList; + + public List getPartList() { + return partList; + } + + @SuppressWarnings("unused") + public void setPartList(List partList) { + this.partList = partList; + } + } + + + @SuppressWarnings("serial") + private static class PartListServlet extends AbstractStandardMultipartServlet { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a387726e7084ed1b61d2f2e8e6aaa5fc0e6e0209 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderTests.java @@ -0,0 +1,385 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + +import java.beans.PropertyEditorSupport; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.Test; + +import org.springframework.beans.PropertyValue; +import org.springframework.beans.PropertyValues; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockMultipartFile; +import org.springframework.mock.web.test.MockMultipartHttpServletRequest; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.web.bind.ServletRequestParameterPropertyValues; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.multipart.support.StringMultipartFileEditor; + +/** + * @author Juergen Hoeller + */ +public class WebRequestDataBinderTests { + + @Test + public void testBindingWithNestedObjectCreation() throws Exception { + TestBean tb = new TestBean(); + + WebRequestDataBinder binder = new WebRequestDataBinder(tb, "person"); + binder.registerCustomEditor(ITestBean.class, new PropertyEditorSupport() { + @Override + public void setAsText(String text) throws IllegalArgumentException { + setValue(new TestBean()); + } + }); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("spouse", "someValue"); + request.addParameter("spouse.name", "test"); + binder.bind(new ServletWebRequest(request)); + + assertNotNull(tb.getSpouse()); + assertEquals("test", tb.getSpouse().getName()); + } + + @Test + public void testBindingWithNestedObjectCreationThroughAutoGrow() throws Exception { + TestBean tb = new TestBeanWithConcreteSpouse(); + + WebRequestDataBinder binder = new WebRequestDataBinder(tb, "person"); + binder.setIgnoreUnknownFields(false); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("concreteSpouse.name", "test"); + binder.bind(new ServletWebRequest(request)); + + assertNotNull(tb.getSpouse()); + assertEquals("test", tb.getSpouse().getName()); + } + + @Test + public void testFieldPrefixCausesFieldReset() throws Exception { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("_postProcessed", "visible"); + request.addParameter("postProcessed", "on"); + binder.bind(new ServletWebRequest(request)); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(new ServletWebRequest(request)); + assertFalse(target.isPostProcessed()); + } + + @Test + public void testFieldPrefixCausesFieldResetWithIgnoreUnknownFields() throws Exception { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + binder.setIgnoreUnknownFields(false); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("_postProcessed", "visible"); + request.addParameter("postProcessed", "on"); + binder.bind(new ServletWebRequest(request)); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(new ServletWebRequest(request)); + assertFalse(target.isPostProcessed()); + } + + @Test + public void testFieldDefault() throws Exception { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("!postProcessed", "off"); + request.addParameter("postProcessed", "on"); + binder.bind(new ServletWebRequest(request)); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(new ServletWebRequest(request)); + assertFalse(target.isPostProcessed()); + } + + // SPR-13502 + @Test + public void testCollectionFieldsDefault() throws Exception { + TestBean target = new TestBean(); + target.setSomeSet(null); + target.setSomeList(null); + target.setSomeMap(null); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("_someSet", "visible"); + request.addParameter("_someList", "visible"); + request.addParameter("_someMap", "visible"); + + binder.bind(new ServletWebRequest(request)); + assertThat(target.getSomeSet(), notNullValue()); + assertThat(target.getSomeSet(), isA(Set.class)); + + assertThat(target.getSomeList(), notNullValue()); + assertThat(target.getSomeList(), isA(List.class)); + + assertThat(target.getSomeMap(), notNullValue()); + assertThat(target.getSomeMap(), isA(Map.class)); + } + + @Test + public void testFieldDefaultPreemptsFieldMarker() throws Exception { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("!postProcessed", "on"); + request.addParameter("_postProcessed", "visible"); + request.addParameter("postProcessed", "on"); + binder.bind(new ServletWebRequest(request)); + assertTrue(target.isPostProcessed()); + + request.removeParameter("postProcessed"); + binder.bind(new ServletWebRequest(request)); + assertTrue(target.isPostProcessed()); + + request.removeParameter("!postProcessed"); + binder.bind(new ServletWebRequest(request)); + assertFalse(target.isPostProcessed()); + } + + @Test + public void testFieldDefaultWithNestedProperty() throws Exception { + TestBean target = new TestBean(); + target.setSpouse(new TestBean()); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("!spouse.postProcessed", "on"); + request.addParameter("_spouse.postProcessed", "visible"); + request.addParameter("spouse.postProcessed", "on"); + binder.bind(new ServletWebRequest(request)); + assertTrue(((TestBean) target.getSpouse()).isPostProcessed()); + + request.removeParameter("spouse.postProcessed"); + binder.bind(new ServletWebRequest(request)); + assertTrue(((TestBean) target.getSpouse()).isPostProcessed()); + + request.removeParameter("!spouse.postProcessed"); + binder.bind(new ServletWebRequest(request)); + assertFalse(((TestBean) target.getSpouse()).isPostProcessed()); + } + + @Test + public void testFieldDefaultNonBoolean() throws Exception { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("!name", "anonymous"); + request.addParameter("name", "Scott"); + binder.bind(new ServletWebRequest(request)); + assertEquals("Scott", target.getName()); + + request.removeParameter("name"); + binder.bind(new ServletWebRequest(request)); + assertEquals("anonymous", target.getName()); + } + + @Test + public void testWithCommaSeparatedStringArray() throws Exception { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("stringArray", "bar"); + request.addParameter("stringArray", "abc"); + request.addParameter("stringArray", "123,def"); + binder.bind(new ServletWebRequest(request)); + assertEquals("Expected all three items to be bound", 3, target.getStringArray().length); + + request.removeParameter("stringArray"); + request.addParameter("stringArray", "123,def"); + binder.bind(new ServletWebRequest(request)); + assertEquals("Expected only 1 item to be bound", 1, target.getStringArray().length); + } + + @Test + public void testEnumBinding() { + EnumHolder target = new EnumHolder(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("myEnum", "FOO"); + binder.bind(new ServletWebRequest(request)); + assertEquals(MyEnum.FOO, target.getMyEnum()); + } + + @Test + public void testMultipartFileAsString() { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + binder.registerCustomEditor(String.class, new StringMultipartFileEditor()); + + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + request.addFile(new MockMultipartFile("name", "Juergen".getBytes())); + binder.bind(new ServletWebRequest(request)); + assertEquals("Juergen", target.getName()); + } + + @Test + public void testMultipartFileAsStringArray() { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + binder.registerCustomEditor(String.class, new StringMultipartFileEditor()); + + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + request.addFile(new MockMultipartFile("stringArray", "Juergen".getBytes())); + binder.bind(new ServletWebRequest(request)); + assertEquals(1, target.getStringArray().length); + assertEquals("Juergen", target.getStringArray()[0]); + } + + @Test + public void testMultipartFilesAsStringArray() { + TestBean target = new TestBean(); + WebRequestDataBinder binder = new WebRequestDataBinder(target); + binder.registerCustomEditor(String.class, new StringMultipartFileEditor()); + + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + request.addFile(new MockMultipartFile("stringArray", "Juergen".getBytes())); + request.addFile(new MockMultipartFile("stringArray", "Eva".getBytes())); + binder.bind(new ServletWebRequest(request)); + assertEquals(2, target.getStringArray().length); + assertEquals("Juergen", target.getStringArray()[0]); + assertEquals("Eva", target.getStringArray()[1]); + } + + @Test + public void testNoPrefix() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("forname", "Tony"); + request.addParameter("surname", "Blair"); + request.addParameter("age", "" + 50); + + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + doTestTony(pvs); + } + + @Test + public void testPrefix() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addParameter("test_forname", "Tony"); + request.addParameter("test_surname", "Blair"); + request.addParameter("test_age", "" + 50); + + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + assertTrue("Didn't find normal when given prefix", !pvs.contains("forname")); + assertTrue("Did treat prefix as normal when not given prefix", pvs.contains("test_forname")); + + pvs = new ServletRequestParameterPropertyValues(request, "test"); + doTestTony(pvs); + } + + /** + * Must contain: forname=Tony surname=Blair age=50 + */ + protected void doTestTony(PropertyValues pvs) throws Exception { + assertTrue("Contains 3", pvs.getPropertyValues().length == 3); + assertTrue("Contains forname", pvs.contains("forname")); + assertTrue("Contains surname", pvs.contains("surname")); + assertTrue("Contains age", pvs.contains("age")); + assertTrue("Doesn't contain tory", !pvs.contains("tory")); + + PropertyValue[] pvArray = pvs.getPropertyValues(); + Map m = new HashMap<>(); + m.put("forname", "Tony"); + m.put("surname", "Blair"); + m.put("age", "50"); + for (PropertyValue pv : pvArray) { + Object val = m.get(pv.getName()); + assertTrue("Can't have unexpected value", val != null); + assertTrue("Val i string", val instanceof String); + assertTrue("val matches expected", val.equals(pv.getValue())); + m.remove(pv.getName()); + } + assertTrue("Map size is 0", m.size() == 0); + } + + @Test + public void testNoParameters() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + assertTrue("Found no parameters", pvs.getPropertyValues().length == 0); + } + + @Test + public void testMultipleValuesForParameter() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + String[] original = new String[] {"Tony", "Rod"}; + request.addParameter("forname", original); + + ServletRequestParameterPropertyValues pvs = new ServletRequestParameterPropertyValues(request); + assertTrue("Found 1 parameter", pvs.getPropertyValues().length == 1); + assertTrue("Found array value", pvs.getPropertyValue("forname").getValue() instanceof String[]); + String[] values = (String[]) pvs.getPropertyValue("forname").getValue(); + assertEquals("Correct values", Arrays.asList(values), Arrays.asList(original)); + } + + + public static class EnumHolder { + + private MyEnum myEnum; + + public MyEnum getMyEnum() { + return myEnum; + } + + public void setMyEnum(MyEnum myEnum) { + this.myEnum = myEnum; + } + } + + public enum MyEnum { + FOO, BAR + } + + static class TestBeanWithConcreteSpouse extends TestBean { + public void setConcreteSpouse(TestBean spouse) { + setSpouse(spouse); + } + + public TestBean getConcreteSpouse() { + return (TestBean) getSpouse(); + } + } + + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/AbstractMockWebServerTestCase.java b/spring-web/src/test/java/org/springframework/web/client/AbstractMockWebServerTestCase.java new file mode 100644 index 0000000000000000000000000000000000000000..ab104179d6d17a63028ffd5abdae103ba146a068 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/AbstractMockWebServerTestCase.java @@ -0,0 +1,276 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.EOFException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import okio.Buffer; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import org.springframework.http.MediaType; + +import static org.junit.Assert.*; + +/** + * @author Brian Clozel + */ +public class AbstractMockWebServerTestCase { + + protected static final MediaType textContentType = + new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8")); + + protected static final String helloWorld = "H\u00e9llo W\u00f6rld"; + + private MockWebServer server; + + protected int port; + + protected String baseUrl; + + + @Before + public void setUp() throws Exception { + this.server = new MockWebServer(); + this.server.setDispatcher(new TestDispatcher()); + this.server.start(); + this.port = this.server.getPort(); + this.baseUrl = "http://localhost:" + this.port; + } + + @After + public void tearDown() throws Exception { + this.server.shutdown(); + } + + + private MockResponse getRequest(RecordedRequest request, byte[] body, String contentType) { + if (request.getMethod().equals("OPTIONS")) { + return new MockResponse().setResponseCode(200).setHeader("Allow", "GET, OPTIONS, HEAD, TRACE"); + } + Buffer buf = new Buffer(); + buf.write(body); + MockResponse response = new MockResponse() + .setHeader("Content-Length", body.length) + .setBody(buf) + .setResponseCode(200); + if (contentType != null) { + response = response.setHeader("Content-Type", contentType); + } + return response; + } + + private MockResponse postRequest(RecordedRequest request, String expectedRequestContent, + String location, String contentType, byte[] responseBody) { + + assertEquals(1, request.getHeaders().values("Content-Length").size()); + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + String requestContentType = request.getHeader("Content-Type"); + assertNotNull("No content-type", requestContentType); + Charset charset = StandardCharsets.ISO_8859_1; + if (requestContentType.contains("charset=")) { + String charsetName = requestContentType.split("charset=")[1]; + charset = Charset.forName(charsetName); + } + assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset)); + Buffer buf = new Buffer(); + buf.write(responseBody); + return new MockResponse() + .setHeader("Location", baseUrl + location) + .setHeader("Content-Type", contentType) + .setHeader("Content-Length", responseBody.length) + .setBody(buf) + .setResponseCode(201); + } + + private MockResponse jsonPostRequest(RecordedRequest request, String location, String contentType) { + if (request.getBodySize() > 0) { + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + assertNotNull("No content-type", request.getHeader("Content-Type")); + } + return new MockResponse() + .setHeader("Location", baseUrl + location) + .setHeader("Content-Type", contentType) + .setHeader("Content-Length", request.getBody().size()) + .setBody(request.getBody()) + .setResponseCode(201); + } + + private MockResponse multipartRequest(RecordedRequest request) { + MediaType mediaType = MediaType.parseMediaType(request.getHeader("Content-Type")); + assertTrue(mediaType.isCompatibleWith(MediaType.MULTIPART_FORM_DATA)); + String boundary = mediaType.getParameter("boundary"); + Buffer body = request.getBody(); + try { + assertPart(body, "form-data", boundary, "name 1", "text/plain", "value 1"); + assertPart(body, "form-data", boundary, "name 2", "text/plain", "value 2+1"); + assertPart(body, "form-data", boundary, "name 2", "text/plain", "value 2+2"); + assertFilePart(body, "form-data", boundary, "logo", "logo.jpg", "image/jpeg"); + } + catch (EOFException ex) { + throw new IllegalStateException(ex); + } + return new MockResponse().setResponseCode(200); + } + + private void assertPart(Buffer buffer, String disposition, String boundary, String name, + String contentType, String value) throws EOFException { + + assertTrue(buffer.readUtf8Line().contains("--" + boundary)); + String line = buffer.readUtf8Line(); + assertTrue(line.contains("Content-Disposition: "+ disposition)); + assertTrue(line.contains("name=\""+ name + "\"")); + assertTrue(buffer.readUtf8Line().startsWith("Content-Type: "+contentType)); + assertTrue(buffer.readUtf8Line().equals("Content-Length: " + value.length())); + assertTrue(buffer.readUtf8Line().equals("")); + assertTrue(buffer.readUtf8Line().equals(value)); + } + + private void assertFilePart(Buffer buffer, String disposition, String boundary, String name, + String filename, String contentType) throws EOFException { + + assertTrue(buffer.readUtf8Line().contains("--" + boundary)); + String line = buffer.readUtf8Line(); + assertTrue(line.contains("Content-Disposition: "+ disposition)); + assertTrue(line.contains("name=\""+ name + "\"")); + assertTrue(line.contains("filename=\""+ filename + "\"")); + assertTrue(buffer.readUtf8Line().startsWith("Content-Type: "+contentType)); + assertTrue(buffer.readUtf8Line().startsWith("Content-Length: ")); + assertTrue(buffer.readUtf8Line().equals("")); + assertNotNull(buffer.readUtf8Line()); + } + + private MockResponse formRequest(RecordedRequest request) { + assertEquals("application/x-www-form-urlencoded;charset=UTF-8", request.getHeader("Content-Type")); + String body = request.getBody().readUtf8(); + assertThat(body, Matchers.containsString("name+1=value+1")); + assertThat(body, Matchers.containsString("name+2=value+2%2B1")); + assertThat(body, Matchers.containsString("name+2=value+2%2B2")); + return new MockResponse().setResponseCode(200); + } + + private MockResponse patchRequest(RecordedRequest request, String expectedRequestContent, + String contentType, byte[] responseBody) { + + assertEquals("PATCH", request.getMethod()); + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + String requestContentType = request.getHeader("Content-Type"); + assertNotNull("No content-type", requestContentType); + Charset charset = StandardCharsets.ISO_8859_1; + if (requestContentType.contains("charset=")) { + String charsetName = requestContentType.split("charset=")[1]; + charset = Charset.forName(charsetName); + } + assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset)); + Buffer buf = new Buffer(); + buf.write(responseBody); + return new MockResponse().setResponseCode(201) + .setHeader("Content-Length", responseBody.length) + .setHeader("Content-Type", contentType) + .setBody(buf); + } + + private MockResponse putRequest(RecordedRequest request, String expectedRequestContent) { + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + String requestContentType = request.getHeader("Content-Type"); + assertNotNull("No content-type", requestContentType); + Charset charset = StandardCharsets.ISO_8859_1; + if (requestContentType.contains("charset=")) { + String charsetName = requestContentType.split("charset=")[1]; + charset = Charset.forName(charsetName); + } + assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset)); + return new MockResponse().setResponseCode(202); + } + + + protected class TestDispatcher extends Dispatcher { + + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + try { + byte[] helloWorldBytes = helloWorld.getBytes(StandardCharsets.UTF_8); + + if (request.getPath().equals("/get")) { + return getRequest(request, helloWorldBytes, textContentType.toString()); + } + else if (request.getPath().equals("/get/nothing")) { + return getRequest(request, new byte[0], textContentType.toString()); + } + else if (request.getPath().equals("/get/nocontenttype")) { + return getRequest(request, helloWorldBytes, null); + } + else if (request.getPath().equals("/post")) { + return postRequest(request, helloWorld, "/post/1", textContentType.toString(), helloWorldBytes); + } + else if (request.getPath().equals("/jsonpost")) { + return jsonPostRequest(request, "/jsonpost/1", "application/json; charset=utf-8"); + } + else if (request.getPath().equals("/status/nocontent")) { + return new MockResponse().setResponseCode(204); + } + else if (request.getPath().equals("/status/notmodified")) { + return new MockResponse().setResponseCode(304); + } + else if (request.getPath().equals("/status/notfound")) { + return new MockResponse().setResponseCode(404); + } + else if (request.getPath().equals("/status/badrequest")) { + return new MockResponse().setResponseCode(400); + } + else if (request.getPath().equals("/status/server")) { + return new MockResponse().setResponseCode(500); + } + else if (request.getPath().contains("/uri/")) { + return new MockResponse().setBody(request.getPath()).setHeader("Content-Type", "text/plain"); + } + else if (request.getPath().equals("/multipart")) { + return multipartRequest(request); + } + else if (request.getPath().equals("/form")) { + return formRequest(request); + } + else if (request.getPath().equals("/delete")) { + return new MockResponse().setResponseCode(200); + } + else if (request.getPath().equals("/patch")) { + return patchRequest(request, helloWorld, textContentType.toString(), helloWorldBytes); + } + else if (request.getPath().equals("/put")) { + return putRequest(request, helloWorld); + } + return new MockResponse().setResponseCode(404); + } + catch (Throwable ex) { + return new MockResponse().setResponseCode(500).setBody(ex.toString()); + } + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/AsyncRestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/AsyncRestTemplateIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..87c311d230d147fd008956be799719175d99436d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/AsyncRestTemplateIntegrationTests.java @@ -0,0 +1,652 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.EnumSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import org.junit.Assert; +import org.junit.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.AsyncClientHttpRequestExecution; +import org.springframework.http.client.AsyncClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.HttpComponentsAsyncClientHttpRequestFactory; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Sebastien Deleuze + */ +@SuppressWarnings("deprecation") +public class AsyncRestTemplateIntegrationTests extends AbstractMockWebServerTestCase { + + private final AsyncRestTemplate template = new AsyncRestTemplate( + new HttpComponentsAsyncClientHttpRequestFactory()); + + + @Test + public void getEntity() throws Exception { + Future> future = template.getForEntity(baseUrl + "/{method}", String.class, "get"); + ResponseEntity entity = future.get(); + assertEquals("Invalid content", helloWorld, entity.getBody()); + assertFalse("No headers", entity.getHeaders().isEmpty()); + assertEquals("Invalid content-type", textContentType, entity.getHeaders().getContentType()); + assertEquals("Invalid status code", HttpStatus.OK, entity.getStatusCode()); + } + + @Test + public void getEntityFromCompletable() throws Exception { + ListenableFuture> future = template.getForEntity(baseUrl + "/{method}", String.class, "get"); + ResponseEntity entity = future.completable().get(); + assertEquals("Invalid content", helloWorld, entity.getBody()); + assertFalse("No headers", entity.getHeaders().isEmpty()); + assertEquals("Invalid content-type", textContentType, entity.getHeaders().getContentType()); + assertEquals("Invalid status code", HttpStatus.OK, entity.getStatusCode()); + } + + @Test + public void multipleFutureGets() throws Exception { + Future> future = template.getForEntity(baseUrl + "/{method}", String.class, "get"); + future.get(); + future.get(); + } + + @Test + public void getEntityCallback() throws Exception { + ListenableFuture> futureEntity = + template.getForEntity(baseUrl + "/{method}", String.class, "get"); + futureEntity.addCallback(new ListenableFutureCallback>() { + @Override + public void onSuccess(ResponseEntity entity) { + assertEquals("Invalid content", helloWorld, entity.getBody()); + assertFalse("No headers", entity.getHeaders().isEmpty()); + assertEquals("Invalid content-type", textContentType, entity.getHeaders().getContentType()); + assertEquals("Invalid status code", HttpStatus.OK, entity.getStatusCode()); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(futureEntity); + } + + @Test + public void getEntityCallbackWithLambdas() throws Exception { + ListenableFuture> futureEntity = + template.getForEntity(baseUrl + "/{method}", String.class, "get"); + futureEntity.addCallback((entity) -> { + assertEquals("Invalid content", helloWorld, entity.getBody()); + assertFalse("No headers", entity.getHeaders().isEmpty()); + assertEquals("Invalid content-type", textContentType, entity.getHeaders().getContentType()); + assertEquals("Invalid status code", HttpStatus.OK, entity.getStatusCode()); + }, ex -> fail(ex.getMessage())); + waitTillDone(futureEntity); + } + + @Test + public void getNoResponse() throws Exception { + Future> futureEntity = template.getForEntity(baseUrl + "/get/nothing", String.class); + ResponseEntity entity = futureEntity.get(); + assertNull("Invalid content", entity.getBody()); + } + + @Test + public void getNoContentTypeHeader() throws Exception { + Future> futureEntity = template.getForEntity(baseUrl + "/get/nocontenttype", byte[].class); + ResponseEntity responseEntity = futureEntity.get(); + assertArrayEquals("Invalid content", helloWorld.getBytes("UTF-8"), responseEntity.getBody()); + } + + @Test + public void getNoContent() throws Exception { + Future> responseFuture = template.getForEntity(baseUrl + "/status/nocontent", String.class); + ResponseEntity entity = responseFuture.get(); + assertEquals("Invalid response code", HttpStatus.NO_CONTENT, entity.getStatusCode()); + assertNull("Invalid content", entity.getBody()); + } + + @Test + public void getNotModified() throws Exception { + Future> responseFuture = template.getForEntity(baseUrl + "/status/notmodified", String.class); + ResponseEntity entity = responseFuture.get(); + assertEquals("Invalid response code", HttpStatus.NOT_MODIFIED, entity.getStatusCode()); + assertNull("Invalid content", entity.getBody()); + } + + @Test + public void headForHeaders() throws Exception { + Future headersFuture = template.headForHeaders(baseUrl + "/get"); + HttpHeaders headers = headersFuture.get(); + assertTrue("No Content-Type header", headers.containsKey("Content-Type")); + } + + @Test + public void headForHeadersCallback() throws Exception { + ListenableFuture headersFuture = template.headForHeaders(baseUrl + "/get"); + headersFuture.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(HttpHeaders result) { + assertTrue("No Content-Type header", result.containsKey("Content-Type")); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(headersFuture); + } + + @Test + public void headForHeadersCallbackWithLambdas() throws Exception { + ListenableFuture headersFuture = template.headForHeaders(baseUrl + "/get"); + headersFuture.addCallback(result -> assertTrue("No Content-Type header", + result.containsKey("Content-Type")), ex -> fail(ex.getMessage())); + waitTillDone(headersFuture); + } + + @Test + public void postForLocation() throws Exception { + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(new MediaType("text", "plain", StandardCharsets.ISO_8859_1)); + HttpEntity entity = new HttpEntity<>(helloWorld, entityHeaders); + Future locationFuture = template.postForLocation(baseUrl + "/{method}", entity, "post"); + URI location = locationFuture.get(); + assertEquals("Invalid location", new URI(baseUrl + "/post/1"), location); + } + + @Test + public void postForLocationCallback() throws Exception { + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(new MediaType("text", "plain", StandardCharsets.ISO_8859_1)); + HttpEntity entity = new HttpEntity<>(helloWorld, entityHeaders); + final URI expected = new URI(baseUrl + "/post/1"); + ListenableFuture locationFuture = template.postForLocation(baseUrl + "/{method}", entity, "post"); + locationFuture.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(URI result) { + assertEquals("Invalid location", expected, result); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(locationFuture); + } + + @Test + public void postForLocationCallbackWithLambdas() throws Exception { + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(new MediaType("text", "plain", StandardCharsets.ISO_8859_1)); + HttpEntity entity = new HttpEntity<>(helloWorld, entityHeaders); + final URI expected = new URI(baseUrl + "/post/1"); + ListenableFuture locationFuture = template.postForLocation(baseUrl + "/{method}", entity, "post"); + locationFuture.addCallback(result -> assertEquals("Invalid location", expected, result), + ex -> fail(ex.getMessage())); + waitTillDone(locationFuture); + } + + @Test + public void postForEntity() throws Exception { + HttpEntity requestEntity = new HttpEntity<>(helloWorld); + Future> responseEntityFuture = + template.postForEntity(baseUrl + "/{method}", requestEntity, String.class, "post"); + ResponseEntity responseEntity = responseEntityFuture.get(); + assertEquals("Invalid content", helloWorld, responseEntity.getBody()); + } + + @Test + public void postForEntityCallback() throws Exception { + HttpEntity requestEntity = new HttpEntity<>(helloWorld); + ListenableFuture> responseEntityFuture = + template.postForEntity(baseUrl + "/{method}", requestEntity, String.class, "post"); + responseEntityFuture.addCallback(new ListenableFutureCallback>() { + @Override + public void onSuccess(ResponseEntity result) { + assertEquals("Invalid content", helloWorld, result.getBody()); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(responseEntityFuture); + } + + @Test + public void postForEntityCallbackWithLambdas() throws Exception { + HttpEntity requestEntity = new HttpEntity<>(helloWorld); + ListenableFuture> responseEntityFuture = + template.postForEntity(baseUrl + "/{method}", requestEntity, String.class, "post"); + responseEntityFuture.addCallback( + result -> assertEquals("Invalid content", helloWorld, result.getBody()), + ex -> fail(ex.getMessage())); + waitTillDone(responseEntityFuture); + } + + @Test + public void put() throws Exception { + HttpEntity requestEntity = new HttpEntity<>(helloWorld); + Future responseEntityFuture = template.put(baseUrl + "/{method}", requestEntity, "put"); + responseEntityFuture.get(); + } + + @Test + public void putCallback() throws Exception { + HttpEntity requestEntity = new HttpEntity<>(helloWorld); + ListenableFuture responseEntityFuture = template.put(baseUrl + "/{method}", requestEntity, "put"); + responseEntityFuture.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(Object result) { + assertNull(result); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(responseEntityFuture); + } + + @Test + public void delete() throws Exception { + Future deletedFuture = template.delete(new URI(baseUrl + "/delete")); + deletedFuture.get(); + } + + @Test + public void deleteCallback() throws Exception { + ListenableFuture deletedFuture = template.delete(new URI(baseUrl + "/delete")); + deletedFuture.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(Object result) { + assertNull(result); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(deletedFuture); + } + + @Test + public void deleteCallbackWithLambdas() throws Exception { + ListenableFuture deletedFuture = template.delete(new URI(baseUrl + "/delete")); + deletedFuture.addCallback(Assert::assertNull, ex -> fail(ex.getMessage())); + waitTillDone(deletedFuture); + } + + @Test + public void identicalExceptionThroughGetAndCallback() throws Exception { + final HttpClientErrorException[] callbackException = new HttpClientErrorException[1]; + + final CountDownLatch latch = new CountDownLatch(1); + ListenableFuture future = template.execute(baseUrl + "/status/notfound", HttpMethod.GET, null, null); + future.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(Object result) { + fail("onSuccess not expected"); + } + @Override + public void onFailure(Throwable ex) { + assertTrue(ex instanceof HttpClientErrorException); + callbackException[0] = (HttpClientErrorException) ex; + latch.countDown(); + } + }); + + try { + future.get(); + fail("Exception expected"); + } + catch (ExecutionException ex) { + Throwable cause = ex.getCause(); + assertTrue(cause instanceof HttpClientErrorException); + latch.await(5, TimeUnit.SECONDS); + assertSame(callbackException[0], cause); + } + } + + @Test + public void notFoundGet() throws Exception { + try { + Future future = template.execute(baseUrl + "/status/notfound", HttpMethod.GET, null, null); + future.get(); + fail("HttpClientErrorException expected"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof HttpClientErrorException); + HttpClientErrorException cause = (HttpClientErrorException)ex.getCause(); + + assertEquals(HttpStatus.NOT_FOUND, cause.getStatusCode()); + assertNotNull(cause.getStatusText()); + assertNotNull(cause.getResponseBodyAsString()); + } + } + + @Test + public void notFoundCallback() throws Exception { + ListenableFuture future = template.execute(baseUrl + "/status/notfound", HttpMethod.GET, null, null); + future.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(Object result) { + fail("onSuccess not expected"); + } + @Override + public void onFailure(Throwable t) { + assertTrue(t instanceof HttpClientErrorException); + HttpClientErrorException ex = (HttpClientErrorException) t; + assertEquals(HttpStatus.NOT_FOUND, ex.getStatusCode()); + assertNotNull(ex.getStatusText()); + assertNotNull(ex.getResponseBodyAsString()); + } + }); + waitTillDone(future); + } + + @Test + public void notFoundCallbackWithLambdas() throws Exception { + ListenableFuture future = template.execute(baseUrl + "/status/notfound", HttpMethod.GET, null, null); + future.addCallback(result -> fail("onSuccess not expected"), ex -> { + assertTrue(ex instanceof HttpClientErrorException); + HttpClientErrorException hcex = (HttpClientErrorException) ex; + assertEquals(HttpStatus.NOT_FOUND, hcex.getStatusCode()); + assertNotNull(hcex.getStatusText()); + assertNotNull(hcex.getResponseBodyAsString()); + }); + waitTillDone(future); + } + + @Test + public void serverError() throws Exception { + try { + Future future = template.execute(baseUrl + "/status/server", HttpMethod.GET, null, null); + future.get(); + fail("HttpServerErrorException expected"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof HttpServerErrorException); + HttpServerErrorException cause = (HttpServerErrorException)ex.getCause(); + + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, cause.getStatusCode()); + assertNotNull(cause.getStatusText()); + assertNotNull(cause.getResponseBodyAsString()); + } + } + + @Test + public void serverErrorCallback() throws Exception { + ListenableFuture future = template.execute(baseUrl + "/status/server", HttpMethod.GET, null, null); + future.addCallback(new ListenableFutureCallback() { + @Override + public void onSuccess(Void result) { + fail("onSuccess not expected"); + } + @Override + public void onFailure(Throwable ex) { + assertTrue(ex instanceof HttpServerErrorException); + HttpServerErrorException hsex = (HttpServerErrorException) ex; + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, hsex.getStatusCode()); + assertNotNull(hsex.getStatusText()); + assertNotNull(hsex.getResponseBodyAsString()); + } + }); + waitTillDone(future); + } + + @Test + public void serverErrorCallbackWithLambdas() throws Exception { + ListenableFuture future = template.execute(baseUrl + "/status/server", HttpMethod.GET, null, null); + future.addCallback(result -> fail("onSuccess not expected"), ex -> { + assertTrue(ex instanceof HttpServerErrorException); + HttpServerErrorException hsex = (HttpServerErrorException) ex; + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, hsex.getStatusCode()); + assertNotNull(hsex.getStatusText()); + assertNotNull(hsex.getResponseBodyAsString()); + }); + waitTillDone(future); + } + + @Test + public void optionsForAllow() throws Exception { + Future> allowedFuture = template.optionsForAllow(new URI(baseUrl + "/get")); + Set allowed = allowedFuture.get(); + assertEquals("Invalid response", + EnumSet.of(HttpMethod.GET, HttpMethod.OPTIONS, HttpMethod.HEAD, HttpMethod.TRACE), allowed); + } + + @Test + public void optionsForAllowCallback() throws Exception { + ListenableFuture> allowedFuture = template.optionsForAllow(new URI(baseUrl + "/get")); + allowedFuture.addCallback(new ListenableFutureCallback>() { + @Override + public void onSuccess(Set result) { + assertEquals("Invalid response", EnumSet.of(HttpMethod.GET, HttpMethod.OPTIONS, + HttpMethod.HEAD, HttpMethod.TRACE), result); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(allowedFuture); + } + + @Test + public void optionsForAllowCallbackWithLambdas() throws Exception{ + ListenableFuture> allowedFuture = template.optionsForAllow(new URI(baseUrl + "/get")); + allowedFuture.addCallback(result -> assertEquals("Invalid response", + EnumSet.of(HttpMethod.GET, HttpMethod.OPTIONS, HttpMethod.HEAD,HttpMethod.TRACE), result), + ex -> fail(ex.getMessage())); + waitTillDone(allowedFuture); + } + + @Test + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void exchangeGet() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + HttpEntity requestEntity = new HttpEntity(requestHeaders); + Future> responseFuture = + template.exchange(baseUrl + "/{method}", HttpMethod.GET, requestEntity, String.class, "get"); + ResponseEntity response = responseFuture.get(); + assertEquals("Invalid content", helloWorld, response.getBody()); + } + + @Test + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void exchangeGetCallback() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + HttpEntity requestEntity = new HttpEntity(requestHeaders); + ListenableFuture> responseFuture = + template.exchange(baseUrl + "/{method}", HttpMethod.GET, requestEntity, String.class, "get"); + responseFuture.addCallback(new ListenableFutureCallback>() { + @Override + public void onSuccess(ResponseEntity result) { + assertEquals("Invalid content", helloWorld, result.getBody()); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(responseFuture); + } + + @Test + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void exchangeGetCallbackWithLambdas() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + HttpEntity requestEntity = new HttpEntity(requestHeaders); + ListenableFuture> responseFuture = + template.exchange(baseUrl + "/{method}", HttpMethod.GET, requestEntity, String.class, "get"); + responseFuture.addCallback(result -> assertEquals("Invalid content", helloWorld, + result.getBody()), ex -> fail(ex.getMessage())); + waitTillDone(responseFuture); + } + + @Test + public void exchangePost() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + requestHeaders.setContentType(MediaType.TEXT_PLAIN); + HttpEntity requestEntity = new HttpEntity<>(helloWorld, requestHeaders); + Future> resultFuture = + template.exchange(baseUrl + "/{method}", HttpMethod.POST, requestEntity, Void.class, "post"); + ResponseEntity result = resultFuture.get(); + assertEquals("Invalid location", new URI(baseUrl + "/post/1"), + result.getHeaders().getLocation()); + assertFalse(result.hasBody()); + } + + @Test + public void exchangePostCallback() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + requestHeaders.setContentType(MediaType.TEXT_PLAIN); + HttpEntity requestEntity = new HttpEntity<>(helloWorld, requestHeaders); + ListenableFuture> resultFuture = + template.exchange(baseUrl + "/{method}", HttpMethod.POST, requestEntity, Void.class, "post"); + final URI expected =new URI(baseUrl + "/post/1"); + resultFuture.addCallback(new ListenableFutureCallback>() { + @Override + public void onSuccess(ResponseEntity result) { + assertEquals("Invalid location", expected, result.getHeaders().getLocation()); + assertFalse(result.hasBody()); + } + @Override + public void onFailure(Throwable ex) { + fail(ex.getMessage()); + } + }); + waitTillDone(resultFuture); + } + + @Test + public void exchangePostCallbackWithLambdas() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + requestHeaders.setContentType(MediaType.TEXT_PLAIN); + HttpEntity requestEntity = new HttpEntity<>(helloWorld, requestHeaders); + ListenableFuture> resultFuture = + template.exchange(baseUrl + "/{method}", HttpMethod.POST, requestEntity, Void.class, "post"); + final URI expected =new URI(baseUrl + "/post/1"); + resultFuture.addCallback(result -> { + assertEquals("Invalid location", expected, result.getHeaders().getLocation()); + assertFalse(result.hasBody()); + }, ex -> fail(ex.getMessage())); + waitTillDone(resultFuture); + } + + @Test + public void multipart() throws Exception { + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("name 1", "value 1"); + parts.add("name 2", "value 2+1"); + parts.add("name 2", "value 2+2"); + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("logo", logo); + + HttpEntity> requestBody = new HttpEntity<>(parts); + Future future = template.postForLocation(baseUrl + "/multipart", requestBody); + future.get(); + } + + @Test + public void getAndInterceptResponse() throws Exception { + RequestInterceptor interceptor = new RequestInterceptor(); + template.setInterceptors(Collections.singletonList(interceptor)); + ListenableFuture> future = template.getForEntity(baseUrl + "/get", String.class); + + interceptor.latch.await(5, TimeUnit.SECONDS); + assertNotNull(interceptor.response); + assertEquals(HttpStatus.OK, interceptor.response.getStatusCode()); + assertNull(interceptor.exception); + assertEquals(helloWorld, future.get().getBody()); + } + + @Test + public void getAndInterceptError() throws Exception { + RequestInterceptor interceptor = new RequestInterceptor(); + template.setInterceptors(Collections.singletonList(interceptor)); + template.getForEntity(baseUrl + "/status/notfound", String.class); + + interceptor.latch.await(5, TimeUnit.SECONDS); + assertNotNull(interceptor.response); + assertEquals(HttpStatus.NOT_FOUND, interceptor.response.getStatusCode()); + assertNull(interceptor.exception); + } + + private void waitTillDone(ListenableFuture future) { + while (!future.isDone()) { + } + } + + + private static class RequestInterceptor implements AsyncClientHttpRequestInterceptor { + + private final CountDownLatch latch = new CountDownLatch(1); + + private volatile ClientHttpResponse response; + + private volatile Throwable exception; + + @Override + public ListenableFuture intercept(HttpRequest request, byte[] body, + AsyncClientHttpRequestExecution execution) throws IOException { + + ListenableFuture future = execution.executeAsync(request, body); + future.addCallback( + resp -> { + response = resp; + this.latch.countDown(); + }, + ex -> { + exception = ex; + this.latch.countDown(); + }); + return future; + } + } +} diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerHttpStatusTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerHttpStatusTests.java new file mode 100644 index 0000000000000000000000000000000000000000..19f650d9f2e050de38380c066cb37f1a561f11f6 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerHttpStatusTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; +import static org.springframework.http.HttpStatus.*; + +/** + * Unit tests for {@link DefaultResponseErrorHandler} handling of specific + * HTTP status codes. + */ +@RunWith(Parameterized.class) +public class DefaultResponseErrorHandlerHttpStatusTests { + + @Parameters(name = "error: [{0}], exception: [{1}]") + public static Object[][] errorCodes() { + return new Object[][]{ + // 4xx + {BAD_REQUEST, HttpClientErrorException.BadRequest.class}, + {UNAUTHORIZED, HttpClientErrorException.Unauthorized.class}, + {FORBIDDEN, HttpClientErrorException.Forbidden.class}, + {NOT_FOUND, HttpClientErrorException.NotFound.class}, + {METHOD_NOT_ALLOWED, HttpClientErrorException.MethodNotAllowed.class}, + {NOT_ACCEPTABLE, HttpClientErrorException.NotAcceptable.class}, + {CONFLICT, HttpClientErrorException.Conflict.class}, + {TOO_MANY_REQUESTS, HttpClientErrorException.TooManyRequests.class}, + {UNPROCESSABLE_ENTITY, HttpClientErrorException.UnprocessableEntity.class}, + {I_AM_A_TEAPOT, HttpClientErrorException.class}, + // 5xx + {INTERNAL_SERVER_ERROR, HttpServerErrorException.InternalServerError.class}, + {NOT_IMPLEMENTED, HttpServerErrorException.NotImplemented.class}, + {BAD_GATEWAY, HttpServerErrorException.BadGateway.class}, + {SERVICE_UNAVAILABLE, HttpServerErrorException.ServiceUnavailable.class}, + {GATEWAY_TIMEOUT, HttpServerErrorException.GatewayTimeout.class}, + {HTTP_VERSION_NOT_SUPPORTED, HttpServerErrorException.class} + }; + } + + @Parameterized.Parameter + public HttpStatus httpStatus; + + @Parameterized.Parameter(1) + public Class expectedExceptionClass; + + private final DefaultResponseErrorHandler handler = new DefaultResponseErrorHandler(); + + private final ClientHttpResponse response = mock(ClientHttpResponse.class); + + + @Test + public void hasErrorTrue() throws Exception { + given(this.response.getRawStatusCode()).willReturn(this.httpStatus.value()); + assertTrue(this.handler.hasError(this.response)); + } + + @Test + public void handleErrorException() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(this.response.getRawStatusCode()).willReturn(this.httpStatus.value()); + given(this.response.getHeaders()).willReturn(headers); + + try { + this.handler.handleError(this.response); + fail("expected " + this.expectedExceptionClass.getSimpleName()); + } + catch (HttpStatusCodeException ex) { + assertEquals("Expected " + this.expectedExceptionClass.getSimpleName(), + this.expectedExceptionClass, ex.getClass()); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3c91d7a58522e88ecadb369a02d560d897216350 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.StreamUtils; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Unit tests for {@link DefaultResponseErrorHandler}. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Denys Ivano + */ +public class DefaultResponseErrorHandlerTests { + + private final DefaultResponseErrorHandler handler = new DefaultResponseErrorHandler(); + + private final ClientHttpResponse response = mock(ClientHttpResponse.class); + + + @Test + public void hasErrorTrue() throws Exception { + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); + assertTrue(handler.hasError(response)); + } + + @Test + public void hasErrorFalse() throws Exception { + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + assertFalse(handler.hasError(response)); + } + + @Test + public void handleError() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); + given(response.getStatusText()).willReturn("Not Found"); + given(response.getHeaders()).willReturn(headers); + given(response.getBody()).willReturn(new ByteArrayInputStream("Hello World".getBytes(StandardCharsets.UTF_8))); + + try { + handler.handleError(response); + fail("expected HttpClientErrorException"); + } + catch (HttpClientErrorException ex) { + assertSame(headers, ex.getResponseHeaders()); + } + } + + @Test(expected = HttpClientErrorException.class) + public void handleErrorIOException() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); + given(response.getStatusText()).willReturn("Not Found"); + given(response.getHeaders()).willReturn(headers); + given(response.getBody()).willThrow(new IOException()); + + handler.handleError(response); + } + + @Test(expected = HttpClientErrorException.class) + public void handleErrorNullResponse() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); + given(response.getStatusText()).willReturn("Not Found"); + given(response.getHeaders()).willReturn(headers); + + handler.handleError(response); + } + + @Test // SPR-16108 + public void hasErrorForUnknownStatusCode() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(999); + given(response.getStatusText()).willReturn("Custom status code"); + given(response.getHeaders()).willReturn(headers); + + assertFalse(handler.hasError(response)); + } + + @Test(expected = UnknownHttpStatusCodeException.class) // SPR-9406 + public void handleErrorUnknownStatusCode() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(999); + given(response.getStatusText()).willReturn("Custom status code"); + given(response.getHeaders()).willReturn(headers); + + handler.handleError(response); + } + + @Test // SPR-17461 + public void hasErrorForCustomClientError() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(499); + given(response.getStatusText()).willReturn("Custom status code"); + given(response.getHeaders()).willReturn(headers); + + assertTrue(handler.hasError(response)); + } + + @Test(expected = UnknownHttpStatusCodeException.class) + public void handleErrorForCustomClientError() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(499); + given(response.getStatusText()).willReturn("Custom status code"); + given(response.getHeaders()).willReturn(headers); + + handler.handleError(response); + } + + @Test // SPR-17461 + public void hasErrorForCustomServerError() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(599); + given(response.getStatusText()).willReturn("Custom status code"); + given(response.getHeaders()).willReturn(headers); + + assertTrue(handler.hasError(response)); + } + + @Test(expected = UnknownHttpStatusCodeException.class) + public void handleErrorForCustomServerError() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + given(response.getRawStatusCode()).willReturn(599); + given(response.getStatusText()).willReturn("Custom status code"); + given(response.getHeaders()).willReturn(headers); + + handler.handleError(response); + } + + @Test // SPR-16604 + public void bodyAvailableAfterHasErrorForUnknownStatusCode() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + TestByteArrayInputStream body = new TestByteArrayInputStream("Hello World".getBytes(StandardCharsets.UTF_8)); + + given(response.getRawStatusCode()).willReturn(999); + given(response.getStatusText()).willReturn("Custom status code"); + given(response.getHeaders()).willReturn(headers); + given(response.getBody()).willReturn(body); + + assertFalse(handler.hasError(response)); + assertFalse(body.isClosed()); + assertEquals("Hello World", StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8)); + } + + + private static class TestByteArrayInputStream extends ByteArrayInputStream { + + private boolean closed; + + public TestByteArrayInputStream(byte[] buf) { + super(buf); + this.closed = false; + } + + public boolean isClosed() { + return closed; + } + + @Override + public boolean markSupported() { + return false; + } + + @Override + public synchronized void mark(int readlimit) { + throw new UnsupportedOperationException("mark/reset not supported"); + } + + @Override + public synchronized void reset() { + throw new UnsupportedOperationException("mark/reset not supported"); + } + + @Override + public void close() throws IOException { + super.close(); + this.closed = true; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java b/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..23d1118babb572adb501a293e215d3fbb3e76325 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Arjen Poutsma + */ +public class ExtractingResponseErrorHandlerTests { + + private ExtractingResponseErrorHandler errorHandler; + + private final ClientHttpResponse response = mock(ClientHttpResponse.class); + + + @Before + public void setup() throws Exception { + HttpMessageConverter converter = new MappingJackson2HttpMessageConverter(); + this.errorHandler = new ExtractingResponseErrorHandler( + Collections.singletonList(converter)); + + this.errorHandler.setStatusMapping( + Collections.singletonMap(HttpStatus.I_AM_A_TEAPOT, MyRestClientException.class)); + this.errorHandler.setSeriesMapping(Collections + .singletonMap(HttpStatus.Series.SERVER_ERROR, MyRestClientException.class)); + } + + + @Test + public void hasError() throws Exception { + given(this.response.getRawStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT.value()); + assertTrue(this.errorHandler.hasError(this.response)); + + given(this.response.getRawStatusCode()).willReturn(HttpStatus.INTERNAL_SERVER_ERROR.value()); + assertTrue(this.errorHandler.hasError(this.response)); + + given(this.response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + assertFalse(this.errorHandler.hasError(this.response)); + } + + @Test + public void hasErrorOverride() throws Exception { + this.errorHandler.setSeriesMapping(Collections + .singletonMap(HttpStatus.Series.CLIENT_ERROR, null)); + + given(this.response.getRawStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT.value()); + assertTrue(this.errorHandler.hasError(this.response)); + + given(this.response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); + assertFalse(this.errorHandler.hasError(this.response)); + + given(this.response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + assertFalse(this.errorHandler.hasError(this.response)); + } + + @Test + public void handleErrorStatusMatch() throws Exception { + given(this.response.getRawStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT.value()); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.APPLICATION_JSON); + given(this.response.getHeaders()).willReturn(responseHeaders); + + byte[] body = "{\"foo\":\"bar\"}".getBytes(StandardCharsets.UTF_8); + responseHeaders.setContentLength(body.length); + given(this.response.getBody()).willReturn(new ByteArrayInputStream(body)); + + try { + this.errorHandler.handleError(this.response); + fail("MyRestClientException expected"); + } + catch (MyRestClientException ex) { + assertEquals("bar", ex.getFoo()); + } + } + + @Test + public void handleErrorSeriesMatch() throws Exception { + given(this.response.getRawStatusCode()).willReturn(HttpStatus.INTERNAL_SERVER_ERROR.value()); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.APPLICATION_JSON); + given(this.response.getHeaders()).willReturn(responseHeaders); + + byte[] body = "{\"foo\":\"bar\"}".getBytes(StandardCharsets.UTF_8); + responseHeaders.setContentLength(body.length); + given(this.response.getBody()).willReturn(new ByteArrayInputStream(body)); + + try { + this.errorHandler.handleError(this.response); + fail("MyRestClientException expected"); + } + catch (MyRestClientException ex) { + assertEquals("bar", ex.getFoo()); + } + } + + @Test + public void handleNoMatch() throws Exception { + given(this.response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.APPLICATION_JSON); + given(this.response.getHeaders()).willReturn(responseHeaders); + + byte[] body = "{\"foo\":\"bar\"}".getBytes(StandardCharsets.UTF_8); + responseHeaders.setContentLength(body.length); + given(this.response.getBody()).willReturn(new ByteArrayInputStream(body)); + + try { + this.errorHandler.handleError(this.response); + fail("HttpClientErrorException expected"); + } + catch (HttpClientErrorException ex) { + assertEquals(HttpStatus.NOT_FOUND, ex.getStatusCode()); + assertArrayEquals(body, ex.getResponseBodyAsByteArray()); + } + } + + @Test + public void handleNoMatchOverride() throws Exception { + this.errorHandler.setSeriesMapping(Collections + .singletonMap(HttpStatus.Series.CLIENT_ERROR, null)); + + given(this.response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.APPLICATION_JSON); + given(this.response.getHeaders()).willReturn(responseHeaders); + + byte[] body = "{\"foo\":\"bar\"}".getBytes(StandardCharsets.UTF_8); + responseHeaders.setContentLength(body.length); + given(this.response.getBody()).willReturn(new ByteArrayInputStream(body)); + + this.errorHandler.handleError(this.response); + } + + + @SuppressWarnings("serial") + private static class MyRestClientException extends RestClientException { + + private String foo; + + public MyRestClientException(String msg) { + super(msg); + } + + public MyRestClientException(String msg, Throwable ex) { + super(msg, ex); + } + + public String getFoo() { + return this.foo; + } + + public void setFoo(String foo) { + this.foo = foo; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2b472cbfcb1f4febd41e391e00c1e26498133fcd --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java @@ -0,0 +1,235 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Test fixture for {@link HttpMessageConverter}. + * + * @author Arjen Poutsma + * @author Brian Clozel + */ +public class HttpMessageConverterExtractorTests { + + private HttpMessageConverterExtractor extractor; + + private final ClientHttpResponse response = mock(ClientHttpResponse.class); + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + + @Test + public void noContent() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.NO_CONTENT.value()); + + Object result = extractor.extractData(response); + assertNull(result); + } + + @Test + public void notModified() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_MODIFIED.value()); + + Object result = extractor.extractData(response); + assertNull(result); + } + + @Test + public void informational() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.CONTINUE.value()); + + Object result = extractor.extractData(response); + assertNull(result); + } + + @Test + public void zeroContentLength() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentLength(0); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + + Object result = extractor.extractData(response); + assertNull(result); + } + + @Test + @SuppressWarnings("unchecked") + public void emptyMessageBody() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream("".getBytes())); + + Object result = extractor.extractData(response); + assertNull(result); + } + + @Test // gh-22265 + @SuppressWarnings("unchecked") + public void nullMessageBody() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(null); + + Object result = extractor.extractData(response); + assertNull(result); + } + + @Test + @SuppressWarnings("unchecked") + public void normal() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + MediaType contentType = MediaType.TEXT_PLAIN; + responseHeaders.setContentType(contentType); + String expected = "Foo"; + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); + given(converter.canRead(String.class, contentType)).willReturn(true); + given(converter.read(eq(String.class), any(HttpInputMessage.class))).willReturn(expected); + + Object result = extractor.extractData(response); + assertEquals(expected, result); + } + + @Test + @SuppressWarnings("unchecked") + public void cannotRead() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + MediaType contentType = MediaType.TEXT_PLAIN; + responseHeaders.setContentType(contentType); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); + given(converter.canRead(String.class, contentType)).willReturn(false); + exception.expect(RestClientException.class); + + extractor.extractData(response); + } + + @Test + @SuppressWarnings("unchecked") + public void generics() throws IOException { + GenericHttpMessageConverter converter = mock(GenericHttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + MediaType contentType = MediaType.TEXT_PLAIN; + responseHeaders.setContentType(contentType); + String expected = "Foo"; + ParameterizedTypeReference> reference = new ParameterizedTypeReference>() {}; + Type type = reference.getType(); + extractor = new HttpMessageConverterExtractor>(type, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); + given(converter.canRead(type, null, contentType)).willReturn(true); + given(converter.read(eq(type), eq(null), any(HttpInputMessage.class))).willReturn(expected); + + Object result = extractor.extractData(response); + assertEquals(expected, result); + } + + @Test // SPR-13592 + @SuppressWarnings("unchecked") + public void converterThrowsIOException() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + MediaType contentType = MediaType.TEXT_PLAIN; + responseHeaders.setContentType(contentType); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); + given(converter.canRead(String.class, contentType)).willReturn(true); + given(converter.read(eq(String.class), any(HttpInputMessage.class))).willThrow(IOException.class); + exception.expect(RestClientException.class); + exception.expectMessage("Error while extracting response for type " + + "[class java.lang.String] and content type [text/plain]"); + exception.expectCause(Matchers.instanceOf(IOException.class)); + + extractor.extractData(response); + } + + @Test // SPR-13592 + @SuppressWarnings("unchecked") + public void converterThrowsHttpMessageNotReadableException() throws IOException { + HttpMessageConverter converter = mock(HttpMessageConverter.class); + HttpHeaders responseHeaders = new HttpHeaders(); + MediaType contentType = MediaType.TEXT_PLAIN; + responseHeaders.setContentType(contentType); + extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); + given(converter.canRead(String.class, contentType)).willThrow(HttpMessageNotReadableException.class); + exception.expect(RestClientException.class); + exception.expectMessage("Error while extracting response for type " + + "[class java.lang.String] and content type [text/plain]"); + exception.expectCause(Matchers.instanceOf(HttpMessageNotReadableException.class)); + + extractor.extractData(response); + } + + private List> createConverterList(HttpMessageConverter converter) { + List> converters = new ArrayList<>(1); + converters.add(converter); + return converters; + } + + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/HttpStatusCodeExceptionTests.java b/spring-web/src/test/java/org/springframework/web/client/HttpStatusCodeExceptionTests.java new file mode 100644 index 0000000000000000000000000000000000000000..845d73a0fd95eed1a8a401dbc8041be4ed19402d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/HttpStatusCodeExceptionTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.charset.StandardCharsets; + +import org.junit.Test; + +import org.springframework.http.HttpStatus; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link HttpStatusCodeException} and subclasses. + * + * @author Chris Beams + */ +public class HttpStatusCodeExceptionTests { + + /** + * Corners bug SPR-9273, which reported the fact that following the changes made in + * SPR-7591, {@link HttpStatusCodeException} and subtypes became no longer + * serializable due to the addition of a non-serializable {@code Charset} field. + */ + @Test + public void testSerializability() throws IOException, ClassNotFoundException { + HttpStatusCodeException ex1 = new HttpClientErrorException( + HttpStatus.BAD_REQUEST, null, null, StandardCharsets.US_ASCII); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + new ObjectOutputStream(out).writeObject(ex1); + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + HttpStatusCodeException ex2 = + (HttpStatusCodeException) new ObjectInputStream(in).readObject(); + assertThat(ex2.getResponseBodyAsString(), equalTo(ex1.getResponseBodyAsString())); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..7b5e82017bff79259279a22ce5966c7f696ba714 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -0,0 +1,421 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.JsonTypeName; +import com.fasterxml.jackson.annotation.JsonView; +import org.hamcrest.Matchers; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.Netty4ClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.http.converter.json.MappingJacksonValue; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; +import static org.springframework.http.HttpMethod.POST; + +/** + * @author Arjen Poutsma + * @author Brian Clozel + */ +@RunWith(Parameterized.class) +public class RestTemplateIntegrationTests extends AbstractMockWebServerTestCase { + + private RestTemplate template; + + @Parameter + public ClientHttpRequestFactory clientHttpRequestFactory; + + @SuppressWarnings("deprecation") + @Parameters + public static Iterable data() { + return Arrays.asList( + new SimpleClientHttpRequestFactory(), + new HttpComponentsClientHttpRequestFactory(), + new Netty4ClientHttpRequestFactory(), + new OkHttp3ClientHttpRequestFactory() + ); + } + + + @Before + public void setupClient() { + this.template = new RestTemplate(this.clientHttpRequestFactory); + } + + + @Test + public void getString() { + String s = template.getForObject(baseUrl + "/{method}", String.class, "get"); + assertEquals("Invalid content", helloWorld, s); + } + + @Test + public void getEntity() { + ResponseEntity entity = template.getForEntity(baseUrl + "/{method}", String.class, "get"); + assertEquals("Invalid content", helloWorld, entity.getBody()); + assertFalse("No headers", entity.getHeaders().isEmpty()); + assertEquals("Invalid content-type", textContentType, entity.getHeaders().getContentType()); + assertEquals("Invalid status code", HttpStatus.OK, entity.getStatusCode()); + } + + @Test + public void getNoResponse() { + String s = template.getForObject(baseUrl + "/get/nothing", String.class); + assertNull("Invalid content", s); + } + + @Test + public void getNoContentTypeHeader() throws UnsupportedEncodingException { + byte[] bytes = template.getForObject(baseUrl + "/get/nocontenttype", byte[].class); + assertArrayEquals("Invalid content", helloWorld.getBytes("UTF-8"), bytes); + } + + @Test + public void getNoContent() { + String s = template.getForObject(baseUrl + "/status/nocontent", String.class); + assertNull("Invalid content", s); + + ResponseEntity entity = template.getForEntity(baseUrl + "/status/nocontent", String.class); + assertEquals("Invalid response code", HttpStatus.NO_CONTENT, entity.getStatusCode()); + assertNull("Invalid content", entity.getBody()); + } + + @Test + public void getNotModified() { + String s = template.getForObject(baseUrl + "/status/notmodified", String.class); + assertNull("Invalid content", s); + + ResponseEntity entity = template.getForEntity(baseUrl + "/status/notmodified", String.class); + assertEquals("Invalid response code", HttpStatus.NOT_MODIFIED, entity.getStatusCode()); + assertNull("Invalid content", entity.getBody()); + } + + @Test + public void postForLocation() throws URISyntaxException { + URI location = template.postForLocation(baseUrl + "/{method}", helloWorld, "post"); + assertEquals("Invalid location", new URI(baseUrl + "/post/1"), location); + } + + @Test + public void postForLocationEntity() throws URISyntaxException { + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(new MediaType("text", "plain", StandardCharsets.ISO_8859_1)); + HttpEntity entity = new HttpEntity<>(helloWorld, entityHeaders); + URI location = template.postForLocation(baseUrl + "/{method}", entity, "post"); + assertEquals("Invalid location", new URI(baseUrl + "/post/1"), location); + } + + @Test + public void postForObject() throws URISyntaxException { + String s = template.postForObject(baseUrl + "/{method}", helloWorld, String.class, "post"); + assertEquals("Invalid content", helloWorld, s); + } + + @Test + public void patchForObject() throws URISyntaxException { + // JDK client does not support the PATCH method + Assume.assumeThat(this.clientHttpRequestFactory, + Matchers.not(Matchers.instanceOf(SimpleClientHttpRequestFactory.class))); + String s = template.patchForObject(baseUrl + "/{method}", helloWorld, String.class, "patch"); + assertEquals("Invalid content", helloWorld, s); + } + + @Test + public void notFound() { + try { + template.execute(baseUrl + "/status/notfound", HttpMethod.GET, null, null); + fail("HttpClientErrorException expected"); + } + catch (HttpClientErrorException ex) { + assertEquals(HttpStatus.NOT_FOUND, ex.getStatusCode()); + assertNotNull(ex.getStatusText()); + assertNotNull(ex.getResponseBodyAsString()); + } + } + + @Test + public void badRequest() { + try { + template.execute(baseUrl + "/status/badrequest", HttpMethod.GET, null, null); + fail("HttpClientErrorException.BadRequest expected"); + } + catch (HttpClientErrorException.BadRequest ex) { + assertEquals(HttpStatus.BAD_REQUEST, ex.getStatusCode()); + assertEquals("400 Client Error", ex.getMessage()); + } + } + + @Test + public void serverError() { + try { + template.execute(baseUrl + "/status/server", HttpMethod.GET, null, null); + fail("HttpServerErrorException expected"); + } + catch (HttpServerErrorException ex) { + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, ex.getStatusCode()); + assertNotNull(ex.getStatusText()); + assertNotNull(ex.getResponseBodyAsString()); + } + } + + @Test + public void optionsForAllow() throws URISyntaxException { + Set allowed = template.optionsForAllow(new URI(baseUrl + "/get")); + assertEquals("Invalid response", + EnumSet.of(HttpMethod.GET, HttpMethod.OPTIONS, HttpMethod.HEAD, HttpMethod.TRACE), allowed); + } + + @Test + public void uri() throws InterruptedException, URISyntaxException { + String result = template.getForObject(baseUrl + "/uri/{query}", String.class, "Z\u00fcrich"); + assertEquals("Invalid request URI", "/uri/Z%C3%BCrich", result); + + result = template.getForObject(baseUrl + "/uri/query={query}", String.class, "foo@bar"); + assertEquals("Invalid request URI", "/uri/query=foo@bar", result); + + result = template.getForObject(baseUrl + "/uri/query={query}", String.class, "T\u014dky\u014d"); + assertEquals("Invalid request URI", "/uri/query=T%C5%8Dky%C5%8D", result); + } + + @Test + public void multipart() throws UnsupportedEncodingException { + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("name 1", "value 1"); + parts.add("name 2", "value 2+1"); + parts.add("name 2", "value 2+2"); + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("logo", logo); + + template.postForLocation(baseUrl + "/multipart", parts); + } + + @Test + public void form() throws UnsupportedEncodingException { + MultiValueMap form = new LinkedMultiValueMap<>(); + form.add("name 1", "value 1"); + form.add("name 2", "value 2+1"); + form.add("name 2", "value 2+2"); + + template.postForLocation(baseUrl + "/form", form); + } + + @Test + public void exchangeGet() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + HttpEntity requestEntity = new HttpEntity<>(requestHeaders); + ResponseEntity response = + template.exchange(baseUrl + "/{method}", HttpMethod.GET, requestEntity, String.class, "get"); + assertEquals("Invalid content", helloWorld, response.getBody()); + } + + @Test + public void exchangePost() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + requestHeaders.setContentType(MediaType.TEXT_PLAIN); + HttpEntity entity = new HttpEntity<>(helloWorld, requestHeaders); + HttpEntity result = template.exchange(baseUrl + "/{method}", POST, entity, Void.class, "post"); + assertEquals("Invalid location", new URI(baseUrl + "/post/1"), result.getHeaders().getLocation()); + assertFalse(result.hasBody()); + } + + @Test + public void jsonPostForObject() throws URISyntaxException { + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(new MediaType("application", "json", StandardCharsets.UTF_8)); + MySampleBean bean = new MySampleBean(); + bean.setWith1("with"); + bean.setWith2("with"); + bean.setWithout("without"); + HttpEntity entity = new HttpEntity<>(bean, entityHeaders); + String s = template.postForObject(baseUrl + "/jsonpost", entity, String.class); + assertTrue(s.contains("\"with1\":\"with\"")); + assertTrue(s.contains("\"with2\":\"with\"")); + assertTrue(s.contains("\"without\":\"without\"")); + } + + @Test + public void jsonPostForObjectWithJacksonView() throws URISyntaxException { + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(new MediaType("application", "json", StandardCharsets.UTF_8)); + MySampleBean bean = new MySampleBean("with", "with", "without"); + MappingJacksonValue jacksonValue = new MappingJacksonValue(bean); + jacksonValue.setSerializationView(MyJacksonView1.class); + HttpEntity entity = new HttpEntity<>(jacksonValue, entityHeaders); + String s = template.postForObject(baseUrl + "/jsonpost", entity, String.class); + assertTrue(s.contains("\"with1\":\"with\"")); + assertFalse(s.contains("\"with2\":\"with\"")); + assertFalse(s.contains("\"without\":\"without\"")); + } + + @Test // SPR-12123 + public void serverPort() { + String s = template.getForObject("http://localhost:{port}/get", String.class, port); + assertEquals("Invalid content", helloWorld, s); + } + + @Test // SPR-13154 + public void jsonPostForObjectWithJacksonTypeInfoList() throws URISyntaxException { + List list = new ArrayList<>(); + list.add(new Foo("foo")); + list.add(new Bar("bar")); + ParameterizedTypeReference typeReference = new ParameterizedTypeReference>() {}; + RequestEntity> entity = RequestEntity + .post(new URI(baseUrl + "/jsonpost")) + .contentType(new MediaType("application", "json", StandardCharsets.UTF_8)) + .body(list, typeReference.getType()); + String content = template.exchange(entity, String.class).getBody(); + assertTrue(content.contains("\"type\":\"foo\"")); + assertTrue(content.contains("\"type\":\"bar\"")); + } + + @Test // SPR-15015 + public void postWithoutBody() throws Exception { + assertNull(template.postForObject(baseUrl + "/jsonpost", null, String.class)); + } + + + public interface MyJacksonView1 {} + + public interface MyJacksonView2 {} + + + public static class MySampleBean { + + @JsonView(MyJacksonView1.class) + private String with1; + + @JsonView(MyJacksonView2.class) + private String with2; + + private String without; + + private MySampleBean() { + } + + private MySampleBean(String with1, String with2, String without) { + this.with1 = with1; + this.with2 = with2; + this.without = without; + } + + public String getWith1() { + return with1; + } + + public void setWith1(String with1) { + this.with1 = with1; + } + + public String getWith2() { + return with2; + } + + public void setWith2(String with2) { + this.with2 = with2; + } + + public String getWithout() { + return without; + } + + public void setWithout(String without) { + this.without = without; + } + } + + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") + public static class ParentClass { + + private String parentProperty; + + public ParentClass() { + } + + public ParentClass(String parentProperty) { + this.parentProperty = parentProperty; + } + + public String getParentProperty() { + return parentProperty; + } + + public void setParentProperty(String parentProperty) { + this.parentProperty = parentProperty; + } + } + + + @JsonTypeName("foo") + public static class Foo extends ParentClass { + + public Foo() { + } + + public Foo(String parentProperty) { + super(parentProperty); + } + } + + + @JsonTypeName("bar") + public static class Bar extends ParentClass { + + public Bar() { + } + + public Bar(String parentProperty) { + super(parentProperty); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java new file mode 100644 index 0000000000000000000000000000000000000000..27134f9f3f3b29290e7c750aad7a0ffa7627fd4e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -0,0 +1,730 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.util.StreamUtils; +import org.springframework.web.util.DefaultUriBuilderFactory; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.willThrow; +import static org.springframework.http.HttpMethod.DELETE; +import static org.springframework.http.HttpMethod.GET; +import static org.springframework.http.HttpMethod.HEAD; +import static org.springframework.http.HttpMethod.OPTIONS; +import static org.springframework.http.HttpMethod.PATCH; +import static org.springframework.http.HttpMethod.POST; +import static org.springframework.http.HttpMethod.PUT; +import static org.springframework.http.MediaType.parseMediaType; + +/** + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Brian Clozel + */ +@SuppressWarnings("unchecked") +public class RestTemplateTests { + + private RestTemplate template; + + private ClientHttpRequestFactory requestFactory; + + private ClientHttpRequest request; + + private ClientHttpResponse response; + + private ResponseErrorHandler errorHandler; + + @SuppressWarnings("rawtypes") + private HttpMessageConverter converter; + + + @Before + public void setup() { + requestFactory = mock(ClientHttpRequestFactory.class); + request = mock(ClientHttpRequest.class); + response = mock(ClientHttpResponse.class); + errorHandler = mock(ResponseErrorHandler.class); + converter = mock(HttpMessageConverter.class); + template = new RestTemplate(Collections.singletonList(converter)); + template.setRequestFactory(requestFactory); + template.setErrorHandler(errorHandler); + } + + + @Test + public void varArgsTemplateVariables() throws Exception { + mockSentRequest(GET, "https://example.com/hotels/42/bookings/21"); + mockResponseStatus(HttpStatus.OK); + + template.execute("https://example.com/hotels/{hotel}/bookings/{booking}", GET, + null, null, "42", "21"); + + verify(response).close(); + } + + @Test + public void varArgsNullTemplateVariable() throws Exception { + mockSentRequest(GET, "https://example.com/-foo"); + mockResponseStatus(HttpStatus.OK); + + template.execute("https://example.com/{first}-{last}", GET, null, null, null, "foo"); + + verify(response).close(); + } + + @Test + public void mapTemplateVariables() throws Exception { + mockSentRequest(GET, "https://example.com/hotels/42/bookings/42"); + mockResponseStatus(HttpStatus.OK); + + Map vars = Collections.singletonMap("hotel", "42"); + template.execute("https://example.com/hotels/{hotel}/bookings/{hotel}", GET, null, null, vars); + + verify(response).close(); + } + + @Test + public void mapNullTemplateVariable() throws Exception { + mockSentRequest(GET, "https://example.com/-foo"); + mockResponseStatus(HttpStatus.OK); + + Map vars = new HashMap<>(2); + vars.put("first", null); + vars.put("last", "foo"); + template.execute("https://example.com/{first}-{last}", GET, null, null, vars); + + verify(response).close(); + } + + @Test // SPR-15201 + public void uriTemplateWithTrailingSlash() throws Exception { + String url = "https://example.com/spring/"; + mockSentRequest(GET, url); + mockResponseStatus(HttpStatus.OK); + + template.execute(url, GET, null, null); + + verify(response).close(); + } + + @Test + public void errorHandling() throws Exception { + String url = "https://example.com"; + mockSentRequest(GET, url); + mockResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR); + willThrow(new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR)) + .given(errorHandler).handleError(new URI(url), GET, response); + + try { + template.execute(url, GET, null, null); + fail("HttpServerErrorException expected"); + } + catch (HttpServerErrorException ex) { + // expected + } + + verify(response).close(); + } + + @Test + public void getForObject() throws Exception { + String expected = "Hello World"; + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(GET, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + mockTextResponseBody("Hello World"); + + String result = template.getForObject("https://example.com", String.class); + assertEquals("Invalid GET result", expected, result); + assertEquals("Invalid Accept header", MediaType.TEXT_PLAIN_VALUE, + requestHeaders.getFirst("Accept")); + + verify(response).close(); + } + + @Test + public void getUnsupportedMediaType() throws Exception { + mockSentRequest(GET, "https://example.com/resource"); + mockResponseStatus(HttpStatus.OK); + + given(converter.canRead(String.class, null)).willReturn(true); + MediaType supportedMediaType = new MediaType("foo", "bar"); + given(converter.getSupportedMediaTypes()).willReturn(Collections.singletonList(supportedMediaType)); + + MediaType barBaz = new MediaType("bar", "baz"); + mockResponseBody("Foo", new MediaType("bar", "baz")); + given(converter.canRead(String.class, barBaz)).willReturn(false); + + try { + template.getForObject("https://example.com/{p}", String.class, "resource"); + fail("UnsupportedMediaTypeException expected"); + } + catch (RestClientException ex) { + // expected + } + + verify(response).close(); + } + + @Test + public void requestAvoidsDuplicateAcceptHeaderValues() throws Exception { + HttpMessageConverter firstConverter = mock(HttpMessageConverter.class); + given(firstConverter.canRead(any(), any())).willReturn(true); + given(firstConverter.getSupportedMediaTypes()) + .willReturn(Collections.singletonList(MediaType.TEXT_PLAIN)); + HttpMessageConverter secondConverter = mock(HttpMessageConverter.class); + given(secondConverter.canRead(any(), any())).willReturn(true); + given(secondConverter.getSupportedMediaTypes()) + .willReturn(Collections.singletonList(MediaType.TEXT_PLAIN)); + + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(GET, "https://example.com/", requestHeaders); + mockResponseStatus(HttpStatus.OK); + mockTextResponseBody("Hello World"); + + template.setMessageConverters(Arrays.asList(firstConverter, secondConverter)); + template.getForObject("https://example.com/", String.class); + + assertEquals("Sent duplicate Accept header values", 1, + requestHeaders.getAccept().size()); + } + + @Test + public void getForEntity() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(GET, "https://example.com", requestHeaders); + mockTextPlainHttpMessageConverter(); + mockResponseStatus(HttpStatus.OK); + String expected = "Hello World"; + mockTextResponseBody(expected); + + ResponseEntity result = template.getForEntity("https://example.com", String.class); + assertEquals("Invalid GET result", expected, result.getBody()); + assertEquals("Invalid Accept header", MediaType.TEXT_PLAIN_VALUE, requestHeaders.getFirst("Accept")); + assertEquals("Invalid Content-Type header", MediaType.TEXT_PLAIN, result.getHeaders().getContentType()); + assertEquals("Invalid status code", HttpStatus.OK, result.getStatusCode()); + + verify(response).close(); + } + + @Test + public void getForObjectWithCustomUriTemplateHandler() throws Exception { + DefaultUriBuilderFactory uriTemplateHandler = new DefaultUriBuilderFactory(); + template.setUriTemplateHandler(uriTemplateHandler); + mockSentRequest(GET, "https://example.com/hotels/1/pic/pics%2Flogo.png/size/150x150"); + mockResponseStatus(HttpStatus.OK); + given(response.getHeaders()).willReturn(new HttpHeaders()); + given(response.getBody()).willReturn(StreamUtils.emptyInput()); + + Map uriVariables = new HashMap<>(2); + uriVariables.put("hotel", "1"); + uriVariables.put("publicpath", "pics/logo.png"); + uriVariables.put("scale", "150x150"); + + String url = "https://example.com/hotels/{hotel}/pic/{publicpath}/size/{scale}"; + template.getForObject(url, String.class, uriVariables); + + verify(response).close(); + } + + @Test + public void headForHeaders() throws Exception { + mockSentRequest(HEAD, "https://example.com"); + mockResponseStatus(HttpStatus.OK); + HttpHeaders responseHeaders = new HttpHeaders(); + given(response.getHeaders()).willReturn(responseHeaders); + + HttpHeaders result = template.headForHeaders("https://example.com"); + + assertSame("Invalid headers returned", responseHeaders, result); + + verify(response).close(); + } + + @Test + public void postForLocation() throws Exception { + mockSentRequest(POST, "https://example.com"); + mockTextPlainHttpMessageConverter(); + mockResponseStatus(HttpStatus.OK); + String helloWorld = "Hello World"; + HttpHeaders responseHeaders = new HttpHeaders(); + URI expected = new URI("https://example.com/hotels"); + responseHeaders.setLocation(expected); + given(response.getHeaders()).willReturn(responseHeaders); + + URI result = template.postForLocation("https://example.com", helloWorld); + assertEquals("Invalid POST result", expected, result); + + verify(response).close(); + } + + @Test + public void postForLocationEntityContentType() throws Exception { + mockSentRequest(POST, "https://example.com"); + mockTextPlainHttpMessageConverter(); + mockResponseStatus(HttpStatus.OK); + + String helloWorld = "Hello World"; + HttpHeaders responseHeaders = new HttpHeaders(); + URI expected = new URI("https://example.com/hotels"); + responseHeaders.setLocation(expected); + given(response.getHeaders()).willReturn(responseHeaders); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(MediaType.TEXT_PLAIN); + HttpEntity entity = new HttpEntity<>(helloWorld, entityHeaders); + + URI result = template.postForLocation("https://example.com", entity); + assertEquals("Invalid POST result", expected, result); + + verify(response).close(); + } + + @Test + public void postForLocationEntityCustomHeader() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockTextPlainHttpMessageConverter(); + mockResponseStatus(HttpStatus.OK); + HttpHeaders responseHeaders = new HttpHeaders(); + URI expected = new URI("https://example.com/hotels"); + responseHeaders.setLocation(expected); + given(response.getHeaders()).willReturn(responseHeaders); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.set("MyHeader", "MyValue"); + HttpEntity entity = new HttpEntity<>("Hello World", entityHeaders); + + URI result = template.postForLocation("https://example.com", entity); + assertEquals("Invalid POST result", expected, result); + assertEquals("No custom header set", "MyValue", requestHeaders.getFirst("MyHeader")); + + verify(response).close(); + } + + @Test + public void postForLocationNoLocation() throws Exception { + mockSentRequest(POST, "https://example.com"); + mockTextPlainHttpMessageConverter(); + mockResponseStatus(HttpStatus.OK); + + URI result = template.postForLocation("https://example.com", "Hello World"); + assertNull("Invalid POST result", result); + + verify(response).close(); + } + + @Test + public void postForLocationNull() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + + template.postForLocation("https://example.com", null); + assertEquals("Invalid content length", 0, requestHeaders.getContentLength()); + + verify(response).close(); + } + + @Test + public void postForObject() throws Exception { + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + String expected = "42"; + mockResponseBody(expected, MediaType.TEXT_PLAIN); + + String result = template.postForObject("https://example.com", "Hello World", String.class); + assertEquals("Invalid POST result", expected, result); + assertEquals("Invalid Accept header", MediaType.TEXT_PLAIN_VALUE, requestHeaders.getFirst("Accept")); + + verify(response).close(); + } + + @Test + public void postForEntity() throws Exception { + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + String expected = "42"; + mockResponseBody(expected, MediaType.TEXT_PLAIN); + + ResponseEntity result = template.postForEntity("https://example.com", "Hello World", String.class); + assertEquals("Invalid POST result", expected, result.getBody()); + assertEquals("Invalid Content-Type", MediaType.TEXT_PLAIN, result.getHeaders().getContentType()); + assertEquals("Invalid Accept header", MediaType.TEXT_PLAIN_VALUE, requestHeaders.getFirst("Accept")); + assertEquals("Invalid status code", HttpStatus.OK, result.getStatusCode()); + + verify(response).close(); + } + + @Test + public void postForObjectNull() throws Exception { + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.TEXT_PLAIN); + responseHeaders.setContentLength(10); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(StreamUtils.emptyInput()); + given(converter.read(String.class, response)).willReturn(null); + + String result = template.postForObject("https://example.com", null, String.class); + assertNull("Invalid POST result", result); + assertEquals("Invalid content length", 0, requestHeaders.getContentLength()); + + verify(response).close(); + } + + @Test + public void postForEntityNull() throws Exception { + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.TEXT_PLAIN); + responseHeaders.setContentLength(10); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(StreamUtils.emptyInput()); + given(converter.read(String.class, response)).willReturn(null); + + ResponseEntity result = template.postForEntity("https://example.com", null, String.class); + assertFalse("Invalid POST result", result.hasBody()); + assertEquals("Invalid Content-Type", MediaType.TEXT_PLAIN, result.getHeaders().getContentType()); + assertEquals("Invalid content length", 0, requestHeaders.getContentLength()); + assertEquals("Invalid status code", HttpStatus.OK, result.getStatusCode()); + + verify(response).close(); + } + + @Test + public void put() throws Exception { + mockTextPlainHttpMessageConverter(); + mockSentRequest(PUT, "https://example.com"); + mockResponseStatus(HttpStatus.OK); + + template.put("https://example.com", "Hello World"); + + verify(response).close(); + } + + @Test + public void putNull() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(PUT, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + + template.put("https://example.com", null); + assertEquals("Invalid content length", 0, requestHeaders.getContentLength()); + + verify(response).close(); + } + + @Test + public void patchForObject() throws Exception { + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(PATCH, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + String expected = "42"; + mockResponseBody("42", MediaType.TEXT_PLAIN); + + String result = template.patchForObject("https://example.com", "Hello World", String.class); + assertEquals("Invalid POST result", expected, result); + assertEquals("Invalid Accept header", MediaType.TEXT_PLAIN_VALUE, requestHeaders.getFirst("Accept")); + + verify(response).close(); + } + + @Test + public void patchForObjectNull() throws Exception { + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(PATCH, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.TEXT_PLAIN); + responseHeaders.setContentLength(10); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(StreamUtils.emptyInput()); + + String result = template.patchForObject("https://example.com", null, String.class); + assertNull("Invalid POST result", result); + assertEquals("Invalid content length", 0, requestHeaders.getContentLength()); + + verify(response).close(); + } + + + @Test + public void delete() throws Exception { + mockSentRequest(DELETE, "https://example.com"); + mockResponseStatus(HttpStatus.OK); + + template.delete("https://example.com"); + + verify(response).close(); + } + + @Test + public void optionsForAllow() throws Exception { + mockSentRequest(OPTIONS, "https://example.com"); + mockResponseStatus(HttpStatus.OK); + HttpHeaders responseHeaders = new HttpHeaders(); + EnumSet expected = EnumSet.of(GET, POST); + responseHeaders.setAllow(expected); + given(response.getHeaders()).willReturn(responseHeaders); + + Set result = template.optionsForAllow("https://example.com"); + assertEquals("Invalid OPTIONS result", expected, result); + + verify(response).close(); + } + + @Test // SPR-9325, SPR-13860 + public void ioException() throws Exception { + String url = "https://example.com/resource?access_token=123"; + mockSentRequest(GET, url); + mockHttpMessageConverter(new MediaType("foo", "bar"), String.class); + given(request.execute()).willThrow(new IOException("Socket failure")); + + try { + template.getForObject(url, String.class); + fail("RestClientException expected"); + } + catch (ResourceAccessException ex) { + assertEquals("I/O error on GET request for \"https://example.com/resource\": " + + "Socket failure; nested exception is java.io.IOException: Socket failure", + ex.getMessage()); + } + } + + @Test // SPR-15900 + public void ioExceptionWithEmptyQueryString() throws Exception { + + // https://example.com/resource? + URI uri = new URI("https", "example.com", "/resource", "", null); + + given(converter.canRead(String.class, null)).willReturn(true); + given(converter.getSupportedMediaTypes()).willReturn(Collections.singletonList(parseMediaType("foo/bar"))); + given(requestFactory.createRequest(uri, GET)).willReturn(request); + given(request.getHeaders()).willReturn(new HttpHeaders()); + given(request.execute()).willThrow(new IOException("Socket failure")); + + try { + template.getForObject(uri, String.class); + fail("RestClientException expected"); + } + catch (ResourceAccessException ex) { + assertEquals("I/O error on GET request for \"https://example.com/resource\": " + + "Socket failure; nested exception is java.io.IOException: Socket failure", + ex.getMessage()); + } + } + + @Test + public void exchange() throws Exception { + mockTextPlainHttpMessageConverter(); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + String expected = "42"; + mockResponseBody(expected, MediaType.TEXT_PLAIN); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.set("MyHeader", "MyValue"); + HttpEntity entity = new HttpEntity<>("Hello World", entityHeaders); + ResponseEntity result = template.exchange("https://example.com", POST, entity, String.class); + assertEquals("Invalid POST result", expected, result.getBody()); + assertEquals("Invalid Content-Type", MediaType.TEXT_PLAIN, result.getHeaders().getContentType()); + assertEquals("Invalid Accept header", MediaType.TEXT_PLAIN_VALUE, requestHeaders.getFirst("Accept")); + assertEquals("Invalid custom header", "MyValue", requestHeaders.getFirst("MyHeader")); + assertEquals("Invalid status code", HttpStatus.OK, result.getStatusCode()); + + verify(response).close(); + } + + @Test + @SuppressWarnings("rawtypes") + public void exchangeParameterizedType() throws Exception { + GenericHttpMessageConverter converter = mock(GenericHttpMessageConverter.class); + template.setMessageConverters(Collections.>singletonList(converter)); + ParameterizedTypeReference> intList = new ParameterizedTypeReference>() {}; + given(converter.canRead(intList.getType(), null, null)).willReturn(true); + given(converter.getSupportedMediaTypes()).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN)); + given(converter.canWrite(String.class, String.class, null)).willReturn(true); + + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + List expected = Collections.singletonList(42); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.TEXT_PLAIN); + responseHeaders.setContentLength(10); + mockResponseStatus(HttpStatus.OK); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream(Integer.toString(42).getBytes())); + given(converter.canRead(intList.getType(), null, MediaType.TEXT_PLAIN)).willReturn(true); + given(converter.read(eq(intList.getType()), eq(null), any(HttpInputMessage.class))).willReturn(expected); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.set("MyHeader", "MyValue"); + HttpEntity requestEntity = new HttpEntity<>("Hello World", entityHeaders); + ResponseEntity> result = template.exchange("https://example.com", POST, requestEntity, intList); + assertEquals("Invalid POST result", expected, result.getBody()); + assertEquals("Invalid Content-Type", MediaType.TEXT_PLAIN, result.getHeaders().getContentType()); + assertEquals("Invalid Accept header", MediaType.TEXT_PLAIN_VALUE, requestHeaders.getFirst("Accept")); + assertEquals("Invalid custom header", "MyValue", requestHeaders.getFirst("MyHeader")); + assertEquals("Invalid status code", HttpStatus.OK, result.getStatusCode()); + + verify(response).close(); + } + + @Test // SPR-15066 + public void requestInterceptorCanAddExistingHeaderValueWithoutBody() throws Exception { + ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { + request.getHeaders().add("MyHeader", "MyInterceptorValue"); + return execution.execute(request, body); + }; + template.setInterceptors(Collections.singletonList(interceptor)); + + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.add("MyHeader", "MyEntityValue"); + HttpEntity entity = new HttpEntity<>(null, entityHeaders); + template.exchange("https://example.com", POST, entity, Void.class); + assertThat(requestHeaders.get("MyHeader"), contains("MyEntityValue", "MyInterceptorValue")); + + verify(response).close(); + } + + @Test // SPR-15066 + public void requestInterceptorCanAddExistingHeaderValueWithBody() throws Exception { + ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { + request.getHeaders().add("MyHeader", "MyInterceptorValue"); + return execution.execute(request, body); + }; + template.setInterceptors(Collections.singletonList(interceptor)); + + MediaType contentType = MediaType.TEXT_PLAIN; + given(converter.canWrite(String.class, contentType)).willReturn(true); + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + mockResponseStatus(HttpStatus.OK); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(contentType); + entityHeaders.add("MyHeader", "MyEntityValue"); + HttpEntity entity = new HttpEntity<>("Hello World", entityHeaders); + template.exchange("https://example.com", POST, entity, Void.class); + assertThat(requestHeaders.get("MyHeader"), contains("MyEntityValue", "MyInterceptorValue")); + + verify(response).close(); + } + + private void mockSentRequest(HttpMethod method, String uri) throws Exception { + mockSentRequest(method, uri, new HttpHeaders()); + } + + private void mockSentRequest(HttpMethod method, String uri, HttpHeaders requestHeaders) throws Exception { + given(requestFactory.createRequest(new URI(uri), method)).willReturn(request); + given(request.getHeaders()).willReturn(requestHeaders); + } + + private void mockResponseStatus(HttpStatus responseStatus) throws Exception { + given(request.execute()).willReturn(response); + given(errorHandler.hasError(response)).willReturn(responseStatus.isError()); + given(response.getStatusCode()).willReturn(responseStatus); + given(response.getRawStatusCode()).willReturn(responseStatus.value()); + given(response.getStatusText()).willReturn(responseStatus.getReasonPhrase()); + } + + private void mockTextPlainHttpMessageConverter() { + mockHttpMessageConverter(MediaType.TEXT_PLAIN, String.class); + } + + private void mockHttpMessageConverter(MediaType mediaType, Class type) { + given(converter.canRead(type, null)).willReturn(true); + given(converter.canRead(type, mediaType)).willReturn(true); + given(converter.getSupportedMediaTypes()) + .willReturn(Collections.singletonList(mediaType)); + given(converter.canRead(type, mediaType)).willReturn(true); + given(converter.canWrite(type, null)).willReturn(true); + given(converter.canWrite(type, mediaType)).willReturn(true); + } + + private void mockTextResponseBody(String expectedBody) throws Exception { + mockResponseBody(expectedBody, MediaType.TEXT_PLAIN); + } + + private void mockResponseBody(String expectedBody, MediaType mediaType) throws Exception { + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(mediaType); + responseHeaders.setContentLength(expectedBody.length()); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream(expectedBody.getBytes())); + given(converter.read(eq(String.class), any(HttpInputMessage.class))).willReturn(expectedBody); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/ContextLoaderInitializerTests.java b/spring-web/src/test/java/org/springframework/web/context/ContextLoaderInitializerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..12f576905dc8a8fbbc45ba62a9daaa05cd7947bf --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/ContextLoaderInitializerTests.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context; + +import java.util.EventListener; + +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletException; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.web.context.support.StaticWebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; + +import static org.junit.Assert.*; + +/** + * Test case for {@link AbstractContextLoaderInitializer}. + * + * @author Arjen Poutsma + */ +public class ContextLoaderInitializerTests { + + private static final String BEAN_NAME = "myBean"; + + private AbstractContextLoaderInitializer initializer; + + private MockServletContext servletContext; + + private EventListener eventListener; + + @Before + public void setUp() throws Exception { + servletContext = new MyMockServletContext(); + initializer = new MyContextLoaderInitializer(); + eventListener = null; + } + + @Test + public void register() throws ServletException { + initializer.onStartup(servletContext); + + assertTrue(eventListener instanceof ContextLoaderListener); + ContextLoaderListener cll = (ContextLoaderListener) eventListener; + cll.contextInitialized(new ServletContextEvent(servletContext)); + + WebApplicationContext applicationContext = WebApplicationContextUtils + .getRequiredWebApplicationContext(servletContext); + + assertTrue(applicationContext.containsBean(BEAN_NAME)); + assertTrue(applicationContext.getBean(BEAN_NAME) instanceof MyBean); + } + + private class MyMockServletContext extends MockServletContext { + + @Override + public void addListener(T listener) { + eventListener = listener; + } + + } + + private static class MyContextLoaderInitializer + extends AbstractContextLoaderInitializer { + + @Override + protected WebApplicationContext createRootApplicationContext() { + StaticWebApplicationContext rootContext = new StaticWebApplicationContext(); + rootContext.registerSingleton(BEAN_NAME, MyBean.class); + return rootContext; + } + } + + private static class MyBean { + + } +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/RequestAndSessionScopedBeanTests.java b/spring-web/src/test/java/org/springframework/web/context/request/RequestAndSessionScopedBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6b1938e2903e5292eb8ec37130ec59151a43deaa --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/RequestAndSessionScopedBeanTests.java @@ -0,0 +1,108 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import javax.servlet.http.HttpServletRequest; + +import org.junit.Test; + +import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.StaticWebApplicationContext; + +import static org.junit.Assert.*; + +/** + * @author Rod Johnson + * @author Juergen Hoeller + */ +public class RequestAndSessionScopedBeanTests { + + @Test + @SuppressWarnings("resource") + public void testPutBeanInRequest() throws Exception { + String targetBeanName = "target"; + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + RootBeanDefinition bd = new RootBeanDefinition(TestBean.class); + bd.setScope(WebApplicationContext.SCOPE_REQUEST); + bd.getPropertyValues().add("name", "abc"); + wac.registerBeanDefinition(targetBeanName, bd); + wac.refresh(); + + HttpServletRequest request = new MockHttpServletRequest(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request)); + TestBean target = (TestBean) wac.getBean(targetBeanName); + assertEquals("abc", target.getName()); + assertSame(target, request.getAttribute(targetBeanName)); + + TestBean target2 = (TestBean) wac.getBean(targetBeanName); + assertEquals("abc", target2.getName()); + assertSame(target2, target); + assertSame(target2, request.getAttribute(targetBeanName)); + + request = new MockHttpServletRequest(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request)); + TestBean target3 = (TestBean) wac.getBean(targetBeanName); + assertEquals("abc", target3.getName()); + assertSame(target3, request.getAttribute(targetBeanName)); + assertNotSame(target3, target); + + RequestContextHolder.setRequestAttributes(null); + try { + wac.getBean(targetBeanName); + fail("Should have thrown BeanCreationException"); + } + catch (BeanCreationException ex) { + // expected + } + } + + @Test + @SuppressWarnings("resource") + public void testPutBeanInSession() throws Exception { + String targetBeanName = "target"; + HttpServletRequest request = new MockHttpServletRequest(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request)); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + RootBeanDefinition bd = new RootBeanDefinition(TestBean.class); + bd.setScope(WebApplicationContext.SCOPE_SESSION); + bd.getPropertyValues().add("name", "abc"); + wac.registerBeanDefinition(targetBeanName, bd); + wac.refresh(); + + TestBean target = (TestBean) wac.getBean(targetBeanName); + assertEquals("abc", target.getName()); + assertSame(target, request.getSession().getAttribute(targetBeanName)); + + RequestContextHolder.setRequestAttributes(null); + try { + wac.getBean(targetBeanName); + fail("Should have thrown BeanCreationException"); + } + catch (BeanCreationException ex) { + // expected + } + + + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/RequestContextListenerTests.java b/spring-web/src/test/java/org/springframework/web/context/request/RequestContextListenerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..df9cc847c5a68b6de49609d30bc80e1298bb80ef --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/RequestContextListenerTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import javax.servlet.ServletRequestEvent; + +import org.junit.Test; + +import org.springframework.core.task.MockRunnable; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockServletContext; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + */ +public class RequestContextListenerTests { + + @Test + public void requestContextListenerWithSameThread() { + RequestContextListener listener = new RequestContextListener(); + MockServletContext context = new MockServletContext(); + MockHttpServletRequest request = new MockHttpServletRequest(context); + request.setAttribute("test", "value"); + + assertNull(RequestContextHolder.getRequestAttributes()); + listener.requestInitialized(new ServletRequestEvent(context, request)); + assertNotNull(RequestContextHolder.getRequestAttributes()); + assertEquals("value", + RequestContextHolder.getRequestAttributes().getAttribute("test", RequestAttributes.SCOPE_REQUEST)); + MockRunnable runnable = new MockRunnable(); + RequestContextHolder.getRequestAttributes().registerDestructionCallback( + "test", runnable, RequestAttributes.SCOPE_REQUEST); + + listener.requestDestroyed(new ServletRequestEvent(context, request)); + assertNull(RequestContextHolder.getRequestAttributes()); + assertTrue(runnable.wasExecuted()); + } + + @Test + public void requestContextListenerWithSameThreadAndAttributesGone() { + RequestContextListener listener = new RequestContextListener(); + MockServletContext context = new MockServletContext(); + MockHttpServletRequest request = new MockHttpServletRequest(context); + request.setAttribute("test", "value"); + + assertNull(RequestContextHolder.getRequestAttributes()); + listener.requestInitialized(new ServletRequestEvent(context, request)); + assertNotNull(RequestContextHolder.getRequestAttributes()); + assertEquals("value", + RequestContextHolder.getRequestAttributes().getAttribute("test", RequestAttributes.SCOPE_REQUEST)); + MockRunnable runnable = new MockRunnable(); + RequestContextHolder.getRequestAttributes().registerDestructionCallback( + "test", runnable, RequestAttributes.SCOPE_REQUEST); + + request.clearAttributes(); + listener.requestDestroyed(new ServletRequestEvent(context, request)); + assertNull(RequestContextHolder.getRequestAttributes()); + assertTrue(runnable.wasExecuted()); + } + + @Test + public void requestContextListenerWithDifferentThread() { + final RequestContextListener listener = new RequestContextListener(); + final MockServletContext context = new MockServletContext(); + final MockHttpServletRequest request = new MockHttpServletRequest(context); + request.setAttribute("test", "value"); + + assertNull(RequestContextHolder.getRequestAttributes()); + listener.requestInitialized(new ServletRequestEvent(context, request)); + assertNotNull(RequestContextHolder.getRequestAttributes()); + assertEquals("value", + RequestContextHolder.getRequestAttributes().getAttribute("test", RequestAttributes.SCOPE_REQUEST)); + MockRunnable runnable = new MockRunnable(); + RequestContextHolder.getRequestAttributes().registerDestructionCallback( + "test", runnable, RequestAttributes.SCOPE_REQUEST); + + // Execute requestDestroyed callback in different thread. + Thread thread = new Thread() { + @Override + public void run() { + listener.requestDestroyed(new ServletRequestEvent(context, request)); + } + }; + thread.start(); + try { + thread.join(); + } + catch (InterruptedException ex) { + } + // Still bound to original thread, but at least completed. + assertNotNull(RequestContextHolder.getRequestAttributes()); + assertTrue(runnable.wasExecuted()); + + // Check that a repeated execution in the same thread works and performs cleanup. + listener.requestInitialized(new ServletRequestEvent(context, request)); + listener.requestDestroyed(new ServletRequestEvent(context, request)); + assertNull(RequestContextHolder.getRequestAttributes()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/RequestScopeTests.java b/spring-web/src/test/java/org/springframework/web/context/request/RequestScopeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..22345be1ba7dc9deed476de10ba8f1ad5b53d21b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/RequestScopeTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.BeanCurrentlyInCreationException; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.context.expression.StandardBeanExpressionResolver; +import org.springframework.core.io.ClassPathResource; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.tests.sample.beans.DerivedTestBean; +import org.springframework.tests.sample.beans.TestBean; + +import static org.junit.Assert.*; + +/** + * @author Rob Harrop + * @author Juergen Hoeller + * @author Mark Fisher + * @author Sam Brannen + * @see SessionScopeTests + */ +public class RequestScopeTests { + + private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + + @Before + public void setup() throws Exception { + this.beanFactory.registerScope("request", new RequestScope()); + this.beanFactory.setBeanExpressionResolver(new StandardBeanExpressionResolver()); + XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.beanFactory); + reader.loadBeanDefinitions(new ClassPathResource("requestScopeTests.xml", getClass())); + this.beanFactory.preInstantiateSingletons(); + } + + @After + public void resetRequestAttributes() { + RequestContextHolder.setRequestAttributes(null); + } + + + @Test + public void getFromScope() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContextPath("/path"); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + String name = "requestScopedObject"; + assertNull(request.getAttribute(name)); + TestBean bean = (TestBean) this.beanFactory.getBean(name); + assertEquals("/path", bean.getName()); + assertSame(bean, request.getAttribute(name)); + assertSame(bean, this.beanFactory.getBean(name)); + } + + @Test + public void destructionAtRequestCompletion() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + String name = "requestScopedDisposableObject"; + assertNull(request.getAttribute(name)); + DerivedTestBean bean = (DerivedTestBean) this.beanFactory.getBean(name); + assertSame(bean, request.getAttribute(name)); + assertSame(bean, this.beanFactory.getBean(name)); + + requestAttributes.requestCompleted(); + assertTrue(bean.wasDestroyed()); + } + + @Test + public void getFromFactoryBeanInScope() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + String name = "requestScopedFactoryBean"; + assertNull(request.getAttribute(name)); + TestBean bean = (TestBean) this.beanFactory.getBean(name); + assertTrue(request.getAttribute(name) instanceof FactoryBean); + assertSame(bean, this.beanFactory.getBean(name)); + } + + @Test + public void circleLeadsToException() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + try { + String name = "requestScopedObjectCircle1"; + assertNull(request.getAttribute(name)); + + this.beanFactory.getBean(name); + fail("Should have thrown BeanCreationException"); + } + catch (BeanCreationException ex) { + assertTrue(ex.contains(BeanCurrentlyInCreationException.class)); + } + } + + @Test + public void innerBeanInheritsContainingBeanScopeByDefault() { + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + String outerBeanName = "requestScopedOuterBean"; + assertNull(request.getAttribute(outerBeanName)); + TestBean outer1 = (TestBean) this.beanFactory.getBean(outerBeanName); + assertNotNull(request.getAttribute(outerBeanName)); + TestBean inner1 = (TestBean) outer1.getSpouse(); + assertSame(outer1, this.beanFactory.getBean(outerBeanName)); + requestAttributes.requestCompleted(); + assertTrue(outer1.wasDestroyed()); + assertTrue(inner1.wasDestroyed()); + request = new MockHttpServletRequest(); + requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + TestBean outer2 = (TestBean) this.beanFactory.getBean(outerBeanName); + assertNotSame(outer1, outer2); + assertNotSame(inner1, outer2.getSpouse()); + } + + @Test + public void requestScopedInnerBeanDestroyedWhileContainedBySingleton() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + String outerBeanName = "singletonOuterBean"; + TestBean outer1 = (TestBean) this.beanFactory.getBean(outerBeanName); + assertNull(request.getAttribute(outerBeanName)); + TestBean inner1 = (TestBean) outer1.getSpouse(); + TestBean outer2 = (TestBean) this.beanFactory.getBean(outerBeanName); + assertSame(outer1, outer2); + assertSame(inner1, outer2.getSpouse()); + requestAttributes.requestCompleted(); + assertTrue(inner1.wasDestroyed()); + assertFalse(outer1.wasDestroyed()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/RequestScopedProxyTests.java b/spring-web/src/test/java/org/springframework/web/context/request/RequestScopedProxyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..38f26ac710434f6fa9832fe2fc2d13b694f3c5c0 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/RequestScopedProxyTests.java @@ -0,0 +1,202 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanDefinitionHolder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.core.io.ClassPathResource; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.tests.sample.beans.DerivedTestBean; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.tests.sample.beans.factory.DummyFactory; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + */ +public class RequestScopedProxyTests { + + private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + + @Before + public void setup() { + this.beanFactory.registerScope("request", new RequestScope()); + XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.beanFactory); + reader.loadBeanDefinitions(new ClassPathResource("requestScopedProxyTests.xml", getClass())); + this.beanFactory.preInstantiateSingletons(); + } + + + @Test + public void testGetFromScope() throws Exception { + String name = "requestScopedObject"; + TestBean bean = (TestBean) this.beanFactory.getBean(name); + assertTrue(AopUtils.isCglibProxy(bean)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + try { + assertNull(request.getAttribute("scopedTarget." + name)); + assertEquals("scoped", bean.getName()); + assertNotNull(request.getAttribute("scopedTarget." + name)); + TestBean target = (TestBean) request.getAttribute("scopedTarget." + name); + assertEquals(TestBean.class, target.getClass()); + assertEquals("scoped", target.getName()); + assertSame(bean, this.beanFactory.getBean(name)); + assertEquals(bean.toString(), target.toString()); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + + @Test + public void testGetFromScopeThroughDynamicProxy() throws Exception { + String name = "requestScopedProxy"; + ITestBean bean = (ITestBean) this.beanFactory.getBean(name); + // assertTrue(AopUtils.isJdkDynamicProxy(bean)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + try { + assertNull(request.getAttribute("scopedTarget." + name)); + assertEquals("scoped", bean.getName()); + assertNotNull(request.getAttribute("scopedTarget." + name)); + TestBean target = (TestBean) request.getAttribute("scopedTarget." + name); + assertEquals(TestBean.class, target.getClass()); + assertEquals("scoped", target.getName()); + assertSame(bean, this.beanFactory.getBean(name)); + assertEquals(bean.toString(), target.toString()); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + + @Test + public void testDestructionAtRequestCompletion() throws Exception { + String name = "requestScopedDisposableObject"; + DerivedTestBean bean = (DerivedTestBean) this.beanFactory.getBean(name); + assertTrue(AopUtils.isCglibProxy(bean)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + try { + assertNull(request.getAttribute("scopedTarget." + name)); + assertEquals("scoped", bean.getName()); + assertNotNull(request.getAttribute("scopedTarget." + name)); + assertEquals(DerivedTestBean.class, request.getAttribute("scopedTarget." + name).getClass()); + assertEquals("scoped", ((TestBean) request.getAttribute("scopedTarget." + name)).getName()); + assertSame(bean, this.beanFactory.getBean(name)); + + requestAttributes.requestCompleted(); + assertTrue(((TestBean) request.getAttribute("scopedTarget." + name)).wasDestroyed()); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + + @Test + public void testGetFromFactoryBeanInScope() throws Exception { + String name = "requestScopedFactoryBean"; + TestBean bean = (TestBean) this.beanFactory.getBean(name); + assertTrue(AopUtils.isCglibProxy(bean)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + try { + assertNull(request.getAttribute("scopedTarget." + name)); + assertEquals(DummyFactory.SINGLETON_NAME, bean.getName()); + assertNotNull(request.getAttribute("scopedTarget." + name)); + assertEquals(DummyFactory.class, request.getAttribute("scopedTarget." + name).getClass()); + assertSame(bean, this.beanFactory.getBean(name)); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + + @Test + public void testGetInnerBeanFromScope() throws Exception { + TestBean bean = (TestBean) this.beanFactory.getBean("outerBean"); + assertFalse(AopUtils.isAopProxy(bean)); + assertTrue(AopUtils.isCglibProxy(bean.getSpouse())); + + String name = "scopedInnerBean"; + + MockHttpServletRequest request = new MockHttpServletRequest(); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + try { + assertNull(request.getAttribute("scopedTarget." + name)); + assertEquals("scoped", bean.getSpouse().getName()); + assertNotNull(request.getAttribute("scopedTarget." + name)); + assertEquals(TestBean.class, request.getAttribute("scopedTarget." + name).getClass()); + assertEquals("scoped", ((TestBean) request.getAttribute("scopedTarget." + name)).getName()); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + + @Test + public void testGetAnonymousInnerBeanFromScope() throws Exception { + TestBean bean = (TestBean) this.beanFactory.getBean("outerBean"); + assertFalse(AopUtils.isAopProxy(bean)); + assertTrue(AopUtils.isCglibProxy(bean.getSpouse())); + + BeanDefinition beanDef = this.beanFactory.getBeanDefinition("outerBean"); + BeanDefinitionHolder innerBeanDef = + (BeanDefinitionHolder) beanDef.getPropertyValues().getPropertyValue("spouse").getValue(); + String name = innerBeanDef.getBeanName(); + + MockHttpServletRequest request = new MockHttpServletRequest(); + RequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + + try { + assertNull(request.getAttribute("scopedTarget." + name)); + assertEquals("scoped", bean.getSpouse().getName()); + assertNotNull(request.getAttribute("scopedTarget." + name)); + assertEquals(TestBean.class, request.getAttribute("scopedTarget." + name).getClass()); + assertEquals("scoped", ((TestBean) request.getAttribute("scopedTarget." + name)).getName()); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/ServletRequestAttributesTests.java b/spring-web/src/test/java/org/springframework/web/context/request/ServletRequestAttributesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d212403f8936c3325da2d656795f63f651ebf504 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/ServletRequestAttributesTests.java @@ -0,0 +1,190 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.io.Serializable; +import java.math.BigInteger; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpSession; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpSession; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Rick Evans + * @author Juergen Hoeller + */ +public class ServletRequestAttributesTests { + + private static final String KEY = "ThatThingThatThing"; + + @SuppressWarnings("serial") + private static final Serializable VALUE = new Serializable() { + }; + + + @Test(expected = IllegalArgumentException.class) + public void ctorRejectsNullArg() throws Exception { + new ServletRequestAttributes(null); + } + + @Test + public void setRequestScopedAttribute() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + attrs.setAttribute(KEY, VALUE, RequestAttributes.SCOPE_REQUEST); + Object value = request.getAttribute(KEY); + assertSame(VALUE, value); + } + + @Test + public void setRequestScopedAttributeAfterCompletion() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + request.close(); + try { + attrs.setAttribute(KEY, VALUE, RequestAttributes.SCOPE_REQUEST); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + // expected + } + } + + @Test + public void setSessionScopedAttribute() throws Exception { + MockHttpSession session = new MockHttpSession(); + session.setAttribute(KEY, VALUE); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + attrs.setAttribute(KEY, VALUE, RequestAttributes.SCOPE_SESSION); + assertSame(VALUE, session.getAttribute(KEY)); + } + + @Test + public void setSessionScopedAttributeAfterCompletion() throws Exception { + MockHttpSession session = new MockHttpSession(); + session.setAttribute(KEY, VALUE); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + assertSame(VALUE, attrs.getAttribute(KEY, RequestAttributes.SCOPE_SESSION)); + attrs.requestCompleted(); + request.close(); + attrs.setAttribute(KEY, VALUE, RequestAttributes.SCOPE_SESSION); + assertSame(VALUE, session.getAttribute(KEY)); + } + + @Test + public void getSessionScopedAttributeDoesNotForceCreationOfSession() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + Object value = attrs.getAttribute(KEY, RequestAttributes.SCOPE_SESSION); + assertNull(value); + verify(request).getSession(false); + } + + @Test + public void removeSessionScopedAttribute() throws Exception { + MockHttpSession session = new MockHttpSession(); + session.setAttribute(KEY, VALUE); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + attrs.removeAttribute(KEY, RequestAttributes.SCOPE_SESSION); + Object value = session.getAttribute(KEY); + assertNull(value); + } + + @Test + public void removeSessionScopedAttributeDoesNotForceCreationOfSession() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + attrs.removeAttribute(KEY, RequestAttributes.SCOPE_SESSION); + verify(request).getSession(false); + } + + @Test + public void updateAccessedAttributes() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpSession session = mock(HttpSession.class); + given(request.getSession(anyBoolean())).willReturn(session); + given(session.getAttribute(KEY)).willReturn(VALUE); + + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + assertSame(VALUE, attrs.getAttribute(KEY, RequestAttributes.SCOPE_SESSION)); + attrs.requestCompleted(); + + verify(session, times(2)).getAttribute(KEY); + verify(session).setAttribute(KEY, VALUE); + verifyNoMoreInteractions(session); + } + + @Test + public void skipImmutableString() { + doSkipImmutableValue("someString"); + } + + @Test + public void skipImmutableCharacter() { + doSkipImmutableValue(new Character('x')); + } + + @Test + public void skipImmutableBoolean() { + doSkipImmutableValue(Boolean.TRUE); + } + + @Test + public void skipImmutableInteger() { + doSkipImmutableValue(new Integer(1)); + } + + @Test + public void skipImmutableFloat() { + doSkipImmutableValue(new Float(1.1)); + } + + @Test + public void skipImmutableBigInteger() { + doSkipImmutableValue(new BigInteger("1")); + } + + private void doSkipImmutableValue(Object immutableValue) { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpSession session = mock(HttpSession.class); + given(request.getSession(anyBoolean())).willReturn(session); + given(session.getAttribute(KEY)).willReturn(immutableValue); + + ServletRequestAttributes attrs = new ServletRequestAttributes(request); + attrs.getAttribute(KEY, RequestAttributes.SCOPE_SESSION); + attrs.requestCompleted(); + + verify(session, times(2)).getAttribute(KEY); + verifyNoMoreInteractions(session); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/ServletWebRequestHttpMethodsTests.java b/spring-web/src/test/java/org/springframework/web/context/request/ServletWebRequestHttpMethodsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..342d5de076e854df3065f4dc434390ea2476d683 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/ServletWebRequestHttpMethodsTests.java @@ -0,0 +1,319 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.time.ZonedDateTime; +import java.util.Arrays; +import java.util.Date; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static java.time.format.DateTimeFormatter.*; +import static org.junit.Assert.*; + +/** + * Parameterized tests for {@link ServletWebRequest}. + * + * @author Juergen Hoeller + * @author Brian Clozel + * @author Markus Malkusch + */ +@RunWith(Parameterized.class) +public class ServletWebRequestHttpMethodsTests { + + private static final String CURRENT_TIME = "Wed, 9 Apr 2014 09:57:42 GMT"; + + private MockHttpServletRequest servletRequest; + + private MockHttpServletResponse servletResponse; + + private ServletWebRequest request; + + private Date currentDate; + + @Parameter + public String method; + + @Parameters(name = "{0}") + static public Iterable safeMethods() { + return Arrays.asList(new Object[][] { + {"GET"}, {"HEAD"} + }); + } + + + @Before + public void setup() { + currentDate = new Date(); + servletRequest = new MockHttpServletRequest(method, "https://example.org"); + servletResponse = new MockHttpServletResponse(); + request = new ServletWebRequest(servletRequest, servletResponse); + } + + + @Test + public void checkNotModifiedNon2xxStatus() { + long epochTime = currentDate.getTime(); + servletRequest.addHeader("If-Modified-Since", epochTime); + servletResponse.setStatus(304); + + assertFalse(request.checkNotModified(epochTime)); + assertEquals(304, servletResponse.getStatus()); + assertNull(servletResponse.getHeader("Last-Modified")); + } + + @Test // SPR-13516 + public void checkNotModifiedInvalidStatus() { + long epochTime = currentDate.getTime(); + servletRequest.addHeader("If-Modified-Since", epochTime); + servletResponse.setStatus(0); + + assertFalse(request.checkNotModified(epochTime)); + } + + @Test // SPR-14559 + public void checkNotModifiedInvalidIfNoneMatchHeader() { + String etag = "\"etagvalue\""; + servletRequest.addHeader("If-None-Match", "missingquotes"); + assertFalse(request.checkNotModified(etag)); + assertEquals(200, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedHeaderAlreadySet() { + long epochTime = currentDate.getTime(); + servletRequest.addHeader("If-Modified-Since", epochTime); + servletResponse.addHeader("Last-Modified", CURRENT_TIME); + + assertTrue(request.checkNotModified(epochTime)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(1, servletResponse.getHeaders("Last-Modified").size()); + assertEquals(CURRENT_TIME, servletResponse.getHeader("Last-Modified")); + } + + @Test + public void checkNotModifiedTimestamp() { + long epochTime = currentDate.getTime(); + servletRequest.addHeader("If-Modified-Since", epochTime); + + assertTrue(request.checkNotModified(epochTime)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(currentDate.getTime() / 1000, servletResponse.getDateHeader("Last-Modified") / 1000); + } + + @Test + public void checkModifiedTimestamp() { + long oneMinuteAgo = currentDate.getTime() - (1000 * 60); + servletRequest.addHeader("If-Modified-Since", oneMinuteAgo); + + assertFalse(request.checkNotModified(currentDate.getTime())); + assertEquals(200, servletResponse.getStatus()); + assertEquals(currentDate.getTime() / 1000, servletResponse.getDateHeader("Last-Modified") / 1000); + } + + @Test + public void checkNotModifiedETag() { + String etag = "\"Foo\""; + servletRequest.addHeader("If-None-Match", etag); + + assertTrue(request.checkNotModified(etag)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedETagWithSeparatorChars() { + String etag = "\"Foo, Bar\""; + servletRequest.addHeader("If-None-Match", etag); + + assertTrue(request.checkNotModified(etag)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + } + + + @Test + public void checkModifiedETag() { + String currentETag = "\"Foo\""; + String oldETag = "Bar"; + servletRequest.addHeader("If-None-Match", oldETag); + + assertFalse(request.checkNotModified(currentETag)); + assertEquals(200, servletResponse.getStatus()); + assertEquals(currentETag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedUnpaddedETag() { + String etag = "Foo"; + String paddedETag = String.format("\"%s\"", etag); + servletRequest.addHeader("If-None-Match", paddedETag); + + assertTrue(request.checkNotModified(etag)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(paddedETag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkModifiedUnpaddedETag() { + String currentETag = "Foo"; + String oldETag = "Bar"; + servletRequest.addHeader("If-None-Match", oldETag); + + assertFalse(request.checkNotModified(currentETag)); + assertEquals(200, servletResponse.getStatus()); + assertEquals(String.format("\"%s\"", currentETag), servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedWildcardIsIgnored() { + String etag = "\"Foo\""; + servletRequest.addHeader("If-None-Match", "*"); + + assertFalse(request.checkNotModified(etag)); + assertEquals(200, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedETagAndTimestamp() { + String etag = "\"Foo\""; + servletRequest.addHeader("If-None-Match", etag); + servletRequest.addHeader("If-Modified-Since", currentDate.getTime()); + + assertTrue(request.checkNotModified(etag, currentDate.getTime())); + assertEquals(304, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + assertEquals(currentDate.getTime() / 1000, servletResponse.getDateHeader("Last-Modified") / 1000); + } + + @Test // SPR-14224 + public void checkNotModifiedETagAndModifiedTimestamp() { + String etag = "\"Foo\""; + servletRequest.addHeader("If-None-Match", etag); + long currentEpoch = currentDate.getTime(); + long oneMinuteAgo = currentEpoch - (1000 * 60); + servletRequest.addHeader("If-Modified-Since", oneMinuteAgo); + + assertTrue(request.checkNotModified(etag, currentEpoch)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + assertEquals(currentDate.getTime() / 1000, servletResponse.getDateHeader("Last-Modified") / 1000); + } + + @Test + public void checkModifiedETagAndNotModifiedTimestamp() { + String currentETag = "\"Foo\""; + String oldETag = "\"Bar\""; + servletRequest.addHeader("If-None-Match", oldETag); + long epochTime = currentDate.getTime(); + servletRequest.addHeader("If-Modified-Since", epochTime); + + assertFalse(request.checkNotModified(currentETag, epochTime)); + assertEquals(200, servletResponse.getStatus()); + assertEquals(currentETag, servletResponse.getHeader("ETag")); + assertEquals(currentDate.getTime() / 1000, servletResponse.getDateHeader("Last-Modified") / 1000); + } + + @Test + public void checkNotModifiedETagWeakStrong() { + String etag = "\"Foo\""; + String weakETag = String.format("W/%s", etag); + servletRequest.addHeader("If-None-Match", etag); + + assertTrue(request.checkNotModified(weakETag)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(weakETag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedETagStrongWeak() { + String etag = "\"Foo\""; + servletRequest.addHeader("If-None-Match", String.format("W/%s", etag)); + + assertTrue(request.checkNotModified(etag)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedMultipleETags() { + String etag = "\"Bar\""; + String multipleETags = String.format("\"Foo\", %s", etag); + servletRequest.addHeader("If-None-Match", multipleETags); + + assertTrue(request.checkNotModified(etag)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(etag, servletResponse.getHeader("ETag")); + } + + @Test + public void checkNotModifiedTimestampWithLengthPart() { + long epochTime = ZonedDateTime.parse(CURRENT_TIME, RFC_1123_DATE_TIME).toInstant().toEpochMilli(); + servletRequest.setMethod("GET"); + servletRequest.addHeader("If-Modified-Since", "Wed, 09 Apr 2014 09:57:42 GMT; length=13774"); + + assertTrue(request.checkNotModified(epochTime)); + assertEquals(304, servletResponse.getStatus()); + assertEquals(epochTime / 1000, servletResponse.getDateHeader("Last-Modified") / 1000); + } + + @Test + public void checkModifiedTimestampWithLengthPart() { + long epochTime = ZonedDateTime.parse(CURRENT_TIME, RFC_1123_DATE_TIME).toInstant().toEpochMilli(); + servletRequest.setMethod("GET"); + servletRequest.addHeader("If-Modified-Since", "Wed, 08 Apr 2014 09:57:42 GMT; length=13774"); + + assertFalse(request.checkNotModified(epochTime)); + assertEquals(200, servletResponse.getStatus()); + assertEquals(epochTime / 1000, servletResponse.getDateHeader("Last-Modified") / 1000); + } + + @Test + public void checkNotModifiedTimestampConditionalPut() { + long currentEpoch = currentDate.getTime(); + long oneMinuteAgo = currentEpoch - (1000 * 60); + servletRequest.setMethod("PUT"); + servletRequest.addHeader("If-UnModified-Since", currentEpoch); + + assertFalse(request.checkNotModified(oneMinuteAgo)); + assertEquals(200, servletResponse.getStatus()); + assertEquals(null, servletResponse.getHeader("Last-Modified")); + } + + @Test + public void checkNotModifiedTimestampConditionalPutConflict() { + long currentEpoch = currentDate.getTime(); + long oneMinuteAgo = currentEpoch - (1000 * 60); + servletRequest.setMethod("PUT"); + servletRequest.addHeader("If-UnModified-Since", oneMinuteAgo); + + assertTrue(request.checkNotModified(currentEpoch)); + assertEquals(412, servletResponse.getStatus()); + assertEquals(null, servletResponse.getHeader("Last-Modified")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/ServletWebRequestTests.java b/spring-web/src/test/java/org/springframework/web/context/request/ServletWebRequestTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f707e8afd9bcccbe2afe9116e6182c22341427c6 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/ServletWebRequestTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.util.Locale; +import java.util.Map; + +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.multipart.MultipartRequest; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + */ +public class ServletWebRequestTests { + + private MockHttpServletRequest servletRequest; + + private MockHttpServletResponse servletResponse; + + private ServletWebRequest request; + + + @Before + public void setup() { + servletRequest = new MockHttpServletRequest(); + servletResponse = new MockHttpServletResponse(); + request = new ServletWebRequest(servletRequest, servletResponse); + } + + + @Test + public void parameters() { + servletRequest.addParameter("param1", "value1"); + servletRequest.addParameter("param2", "value2"); + servletRequest.addParameter("param2", "value2a"); + + assertEquals("value1", request.getParameter("param1")); + assertEquals(1, request.getParameterValues("param1").length); + assertEquals("value1", request.getParameterValues("param1")[0]); + assertEquals("value2", request.getParameter("param2")); + assertEquals(2, request.getParameterValues("param2").length); + assertEquals("value2", request.getParameterValues("param2")[0]); + assertEquals("value2a", request.getParameterValues("param2")[1]); + + Map paramMap = request.getParameterMap(); + assertEquals(2, paramMap.size()); + assertEquals(1, paramMap.get("param1").length); + assertEquals("value1", paramMap.get("param1")[0]); + assertEquals(2, paramMap.get("param2").length); + assertEquals("value2", paramMap.get("param2")[0]); + assertEquals("value2a", paramMap.get("param2")[1]); + } + + @Test + public void locale() { + servletRequest.addPreferredLocale(Locale.UK); + + assertEquals(Locale.UK, request.getLocale()); + } + + @Test + public void nativeRequest() { + assertSame(servletRequest, request.getNativeRequest()); + assertSame(servletRequest, request.getNativeRequest(ServletRequest.class)); + assertSame(servletRequest, request.getNativeRequest(HttpServletRequest.class)); + assertSame(servletRequest, request.getNativeRequest(MockHttpServletRequest.class)); + assertNull(request.getNativeRequest(MultipartRequest.class)); + assertSame(servletResponse, request.getNativeResponse()); + assertSame(servletResponse, request.getNativeResponse(ServletResponse.class)); + assertSame(servletResponse, request.getNativeResponse(HttpServletResponse.class)); + assertSame(servletResponse, request.getNativeResponse(MockHttpServletResponse.class)); + assertNull(request.getNativeResponse(MultipartRequest.class)); + } + + @Test + public void decoratedNativeRequest() { + HttpServletRequest decoratedRequest = new HttpServletRequestWrapper(servletRequest); + HttpServletResponse decoratedResponse = new HttpServletResponseWrapper(servletResponse); + ServletWebRequest request = new ServletWebRequest(decoratedRequest, decoratedResponse); + assertSame(decoratedRequest, request.getNativeRequest()); + assertSame(decoratedRequest, request.getNativeRequest(ServletRequest.class)); + assertSame(decoratedRequest, request.getNativeRequest(HttpServletRequest.class)); + assertSame(servletRequest, request.getNativeRequest(MockHttpServletRequest.class)); + assertNull(request.getNativeRequest(MultipartRequest.class)); + assertSame(decoratedResponse, request.getNativeResponse()); + assertSame(decoratedResponse, request.getNativeResponse(ServletResponse.class)); + assertSame(decoratedResponse, request.getNativeResponse(HttpServletResponse.class)); + assertSame(servletResponse, request.getNativeResponse(MockHttpServletResponse.class)); + assertNull(request.getNativeResponse(MultipartRequest.class)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/SessionScopeTests.java b/spring-web/src/test/java/org/springframework/web/context/request/SessionScopeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ae05b9fa0c9d361c2ae6537d0441020d78aabdfd --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/SessionScopeTests.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import java.io.Serializable; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.beans.factory.config.DestructionAwareBeanPostProcessor; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; +import org.springframework.core.io.ClassPathResource; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpSession; +import org.springframework.tests.sample.beans.DerivedTestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.util.SerializationTestUtils; + +import static org.junit.Assert.*; + +/** + * @author Rob Harrop + * @author Juergen Hoeller + * @author Sam Brannen + * @see RequestScopeTests + */ +public class SessionScopeTests { + + private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + + @Before + public void setup() throws Exception { + this.beanFactory.registerScope("session", new SessionScope()); + XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.beanFactory); + reader.loadBeanDefinitions(new ClassPathResource("sessionScopeTests.xml", getClass())); + } + + @After + public void resetRequestAttributes() { + RequestContextHolder.setRequestAttributes(null); + } + + + @Test + public void getFromScope() throws Exception { + AtomicInteger count = new AtomicInteger(); + MockHttpSession session = new MockHttpSession() { + @Override + public void setAttribute(String name, Object value) { + super.setAttribute(name, value); + count.incrementAndGet(); + } + }; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + + RequestContextHolder.setRequestAttributes(requestAttributes); + String name = "sessionScopedObject"; + assertNull(session.getAttribute(name)); + TestBean bean = (TestBean) this.beanFactory.getBean(name); + assertEquals(1, count.intValue()); + assertEquals(session.getAttribute(name), bean); + assertSame(bean, this.beanFactory.getBean(name)); + assertEquals(1, count.intValue()); + + // should re-propagate updated attribute + requestAttributes.requestCompleted(); + assertEquals(session.getAttribute(name), bean); + assertEquals(2, count.intValue()); + } + + @Test + public void getFromScopeWithSingleAccess() throws Exception { + AtomicInteger count = new AtomicInteger(); + MockHttpSession session = new MockHttpSession() { + @Override + public void setAttribute(String name, Object value) { + super.setAttribute(name, value); + count.incrementAndGet(); + } + }; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + + RequestContextHolder.setRequestAttributes(requestAttributes); + String name = "sessionScopedObject"; + assertNull(session.getAttribute(name)); + TestBean bean = (TestBean) this.beanFactory.getBean(name); + assertEquals(1, count.intValue()); + + // should re-propagate updated attribute + requestAttributes.requestCompleted(); + assertEquals(session.getAttribute(name), bean); + assertEquals(2, count.intValue()); + } + + @Test + public void destructionAtSessionTermination() throws Exception { + MockHttpSession session = new MockHttpSession(); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + + RequestContextHolder.setRequestAttributes(requestAttributes); + String name = "sessionScopedDisposableObject"; + assertNull(session.getAttribute(name)); + DerivedTestBean bean = (DerivedTestBean) this.beanFactory.getBean(name); + assertEquals(session.getAttribute(name), bean); + assertSame(bean, this.beanFactory.getBean(name)); + + requestAttributes.requestCompleted(); + session.invalidate(); + assertTrue(bean.wasDestroyed()); + } + + @Test + public void destructionWithSessionSerialization() throws Exception { + doTestDestructionWithSessionSerialization(false); + } + + @Test + public void destructionWithSessionSerializationAndBeanPostProcessor() throws Exception { + this.beanFactory.addBeanPostProcessor(new CustomDestructionAwareBeanPostProcessor()); + doTestDestructionWithSessionSerialization(false); + } + + @Test + public void destructionWithSessionSerializationAndSerializableBeanPostProcessor() throws Exception { + this.beanFactory.addBeanPostProcessor(new CustomSerializableDestructionAwareBeanPostProcessor()); + doTestDestructionWithSessionSerialization(true); + } + + private void doTestDestructionWithSessionSerialization(boolean beanNameReset) throws Exception { + Serializable serializedState = null; + + MockHttpSession session = new MockHttpSession(); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + + RequestContextHolder.setRequestAttributes(requestAttributes); + String name = "sessionScopedDisposableObject"; + assertNull(session.getAttribute(name)); + DerivedTestBean bean = (DerivedTestBean) this.beanFactory.getBean(name); + assertEquals(session.getAttribute(name), bean); + assertSame(bean, this.beanFactory.getBean(name)); + + requestAttributes.requestCompleted(); + serializedState = session.serializeState(); + assertFalse(bean.wasDestroyed()); + + serializedState = (Serializable) SerializationTestUtils.serializeAndDeserialize(serializedState); + + session = new MockHttpSession(); + session.deserializeState(serializedState); + request = new MockHttpServletRequest(); + request.setSession(session); + requestAttributes = new ServletRequestAttributes(request); + + RequestContextHolder.setRequestAttributes(requestAttributes); + name = "sessionScopedDisposableObject"; + assertNotNull(session.getAttribute(name)); + bean = (DerivedTestBean) this.beanFactory.getBean(name); + assertEquals(session.getAttribute(name), bean); + assertSame(bean, this.beanFactory.getBean(name)); + + requestAttributes.requestCompleted(); + session.invalidate(); + assertTrue(bean.wasDestroyed()); + + if (beanNameReset) { + assertNull(bean.getBeanName()); + } + else { + assertNotNull(bean.getBeanName()); + } + } + + + private static class CustomDestructionAwareBeanPostProcessor implements DestructionAwareBeanPostProcessor { + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + return bean; + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + return bean; + } + + @Override + public void postProcessBeforeDestruction(Object bean, String beanName) throws BeansException { + } + + @Override + public boolean requiresDestruction(Object bean) { + return true; + } + } + + + @SuppressWarnings("serial") + private static class CustomSerializableDestructionAwareBeanPostProcessor + implements DestructionAwareBeanPostProcessor, Serializable { + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + return bean; + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + return bean; + } + + @Override + public void postProcessBeforeDestruction(Object bean, String beanName) throws BeansException { + if (bean instanceof BeanNameAware) { + ((BeanNameAware) bean).setBeanName(null); + } + } + + @Override + public boolean requiresDestruction(Object bean) { + return true; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/WebApplicationContextScopeTests.java b/spring-web/src/test/java/org/springframework/web/context/request/WebApplicationContextScopeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ff122a286a38466522547b212f1047c5fd59f404 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/WebApplicationContextScopeTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request; + +import javax.servlet.ServletContextEvent; + +import org.junit.Test; + +import org.springframework.beans.factory.support.GenericBeanDefinition; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.tests.sample.beans.DerivedTestBean; +import org.springframework.web.context.ContextCleanupListener; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.GenericWebApplicationContext; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + */ +public class WebApplicationContextScopeTests { + + private static final String NAME = "scoped"; + + + private WebApplicationContext initApplicationContext(String scope) { + MockServletContext sc = new MockServletContext(); + GenericWebApplicationContext ac = new GenericWebApplicationContext(sc); + GenericBeanDefinition bd = new GenericBeanDefinition(); + bd.setBeanClass(DerivedTestBean.class); + bd.setScope(scope); + ac.registerBeanDefinition(NAME, bd); + ac.refresh(); + return ac; + } + + @Test + public void testRequestScope() { + WebApplicationContext ac = initApplicationContext(WebApplicationContext.SCOPE_REQUEST); + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + try { + assertNull(request.getAttribute(NAME)); + DerivedTestBean bean = ac.getBean(NAME, DerivedTestBean.class); + assertSame(bean, request.getAttribute(NAME)); + assertSame(bean, ac.getBean(NAME)); + requestAttributes.requestCompleted(); + assertTrue(bean.wasDestroyed()); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + + @Test + public void testSessionScope() { + WebApplicationContext ac = initApplicationContext(WebApplicationContext.SCOPE_SESSION); + MockHttpServletRequest request = new MockHttpServletRequest(); + ServletRequestAttributes requestAttributes = new ServletRequestAttributes(request); + RequestContextHolder.setRequestAttributes(requestAttributes); + try { + assertNull(request.getSession().getAttribute(NAME)); + DerivedTestBean bean = ac.getBean(NAME, DerivedTestBean.class); + assertSame(bean, request.getSession().getAttribute(NAME)); + assertSame(bean, ac.getBean(NAME)); + request.getSession().invalidate(); + assertTrue(bean.wasDestroyed()); + } + finally { + RequestContextHolder.setRequestAttributes(null); + } + } + + @Test + public void testApplicationScope() { + WebApplicationContext ac = initApplicationContext(WebApplicationContext.SCOPE_APPLICATION); + assertNull(ac.getServletContext().getAttribute(NAME)); + DerivedTestBean bean = ac.getBean(NAME, DerivedTestBean.class); + assertSame(bean, ac.getServletContext().getAttribute(NAME)); + assertSame(bean, ac.getBean(NAME)); + new ContextCleanupListener().contextDestroyed(new ServletContextEvent(ac.getServletContext())); + assertTrue(bean.wasDestroyed()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/DeferredResultTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/DeferredResultTests.java new file mode 100644 index 0000000000000000000000000000000000000000..259eb77584fe9ae1201af7913c58ed8f3f95c378 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/DeferredResultTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.function.Consumer; + +import org.junit.Test; + +import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * DeferredResult tests. + * + * @author Rossen Stoyanchev + */ +public class DeferredResultTests { + + @Test + public void setResult() { + DeferredResultHandler handler = mock(DeferredResultHandler.class); + + DeferredResult result = new DeferredResult<>(); + result.setResultHandler(handler); + + assertTrue(result.setResult("hello")); + verify(handler).handleResult("hello"); + } + + @Test + public void setResultTwice() { + DeferredResultHandler handler = mock(DeferredResultHandler.class); + + DeferredResult result = new DeferredResult<>(); + result.setResultHandler(handler); + + assertTrue(result.setResult("hello")); + assertFalse(result.setResult("hi")); + + verify(handler).handleResult("hello"); + } + + @Test + public void isSetOrExpired() { + DeferredResultHandler handler = mock(DeferredResultHandler.class); + + DeferredResult result = new DeferredResult<>(); + result.setResultHandler(handler); + + assertFalse(result.isSetOrExpired()); + + result.setResult("hello"); + + assertTrue(result.isSetOrExpired()); + + verify(handler).handleResult("hello"); + } + + @Test + public void hasResult() { + DeferredResultHandler handler = mock(DeferredResultHandler.class); + + DeferredResult result = new DeferredResult<>(); + result.setResultHandler(handler); + + assertFalse(result.hasResult()); + assertNull(result.getResult()); + + result.setResult("hello"); + + assertEquals("hello", result.getResult()); + } + + @Test + public void onCompletion() throws Exception { + final StringBuilder sb = new StringBuilder(); + + DeferredResult result = new DeferredResult<>(); + result.onCompletion(new Runnable() { + @Override + public void run() { + sb.append("completion event"); + } + }); + + result.getInterceptor().afterCompletion(null, null); + + assertTrue(result.isSetOrExpired()); + assertEquals("completion event", sb.toString()); + } + + @Test + public void onTimeout() throws Exception { + final StringBuilder sb = new StringBuilder(); + + DeferredResultHandler handler = mock(DeferredResultHandler.class); + + DeferredResult result = new DeferredResult<>(null, "timeout result"); + result.setResultHandler(handler); + result.onTimeout(new Runnable() { + @Override + public void run() { + sb.append("timeout event"); + } + }); + + result.getInterceptor().handleTimeout(null, null); + + assertEquals("timeout event", sb.toString()); + assertFalse("Should not be able to set result a second time", result.setResult("hello")); + verify(handler).handleResult("timeout result"); + } + + @Test + public void onError() throws Exception { + final StringBuilder sb = new StringBuilder(); + + DeferredResultHandler handler = mock(DeferredResultHandler.class); + + DeferredResult result = new DeferredResult<>(null, "error result"); + result.setResultHandler(handler); + Exception e = new Exception(); + result.onError(new Consumer() { + @Override + public void accept(Throwable t) { + sb.append("error event"); + } + }); + + result.getInterceptor().handleError(null, null, e); + + assertEquals("error event", sb.toString()); + assertFalse("Should not be able to set result a second time", result.setResult("hello")); + verify(handler).handleResult(e); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a71c9178b60f2f4371808353ca0690b0dcb6c427 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + + +import java.util.function.Consumer; + +import javax.servlet.AsyncEvent; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.test.MockAsyncContext; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; + +/** + * A test fixture with a {@link StandardServletAsyncWebRequest}. + * @author Rossen Stoyanchev + */ +public class StandardServletAsyncWebRequestTests { + + private StandardServletAsyncWebRequest asyncRequest; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + + @Before + public void setup() { + this.request = new MockHttpServletRequest(); + this.request.setAsyncSupported(true); + this.response = new MockHttpServletResponse(); + this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); + this.asyncRequest.setTimeout(44*1000L); + } + + + @Test + public void isAsyncStarted() throws Exception { + assertFalse(this.asyncRequest.isAsyncStarted()); + this.asyncRequest.startAsync(); + assertTrue(this.asyncRequest.isAsyncStarted()); + } + + @Test + public void startAsync() throws Exception { + this.asyncRequest.startAsync(); + + MockAsyncContext context = (MockAsyncContext) this.request.getAsyncContext(); + assertNotNull(context); + assertEquals("Timeout value not set", 44 * 1000, context.getTimeout()); + assertEquals(1, context.getListeners().size()); + assertSame(this.asyncRequest, context.getListeners().get(0)); + } + + @Test + public void startAsyncMultipleTimes() throws Exception { + this.asyncRequest.startAsync(); + this.asyncRequest.startAsync(); + this.asyncRequest.startAsync(); + this.asyncRequest.startAsync(); // idempotent + + MockAsyncContext context = (MockAsyncContext) this.request.getAsyncContext(); + assertNotNull(context); + assertEquals(1, context.getListeners().size()); + } + + @Test + public void startAsyncNotSupported() throws Exception { + this.request.setAsyncSupported(false); + try { + this.asyncRequest.startAsync(); + fail("expected exception"); + } + catch (IllegalStateException ex) { + assertThat(ex.getMessage(), containsString("Async support must be enabled")); + } + } + + @Test + public void startAsyncAfterCompleted() throws Exception { + this.asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(this.request, this.response))); + try { + this.asyncRequest.startAsync(); + fail("expected exception"); + } + catch (IllegalStateException ex) { + assertEquals("Async processing has already completed", ex.getMessage()); + } + } + + @Test + public void onTimeoutDefaultBehavior() throws Exception { + this.asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(this.request, this.response))); + assertEquals(200, this.response.getStatus()); + } + + @Test + public void onTimeoutHandler() throws Exception { + Runnable timeoutHandler = mock(Runnable.class); + this.asyncRequest.addTimeoutHandler(timeoutHandler); + this.asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(this.request, this.response))); + verify(timeoutHandler).run(); + } + + @SuppressWarnings("unchecked") + @Test + public void onErrorHandler() throws Exception { + Consumer errorHandler = mock(Consumer.class); + this.asyncRequest.addErrorHandler(errorHandler); + Exception e = new Exception(); + this.asyncRequest.onError(new AsyncEvent(new MockAsyncContext(this.request, this.response), e)); + verify(errorHandler).accept(e); + } + + @Test(expected = IllegalStateException.class) + public void setTimeoutDuringConcurrentHandling() { + this.asyncRequest.startAsync(); + this.asyncRequest.setTimeout(25L); + } + + @Test + public void onCompletionHandler() throws Exception { + Runnable handler = mock(Runnable.class); + this.asyncRequest.addCompletionHandler(handler); + + this.asyncRequest.startAsync(); + this.asyncRequest.onComplete(new AsyncEvent(this.request.getAsyncContext())); + + verify(handler).run(); + assertTrue(this.asyncRequest.isAsyncComplete()); + } + + // SPR-13292 + + @SuppressWarnings("unchecked") + @Test + public void onErrorHandlerAfterOnErrorEvent() throws Exception { + Consumer handler = mock(Consumer.class); + this.asyncRequest.addErrorHandler(handler); + + this.asyncRequest.startAsync(); + Exception e = new Exception(); + this.asyncRequest.onError(new AsyncEvent(this.request.getAsyncContext(), e)); + + verify(handler).accept(e); + } + + @Test + public void onCompletionHandlerAfterOnCompleteEvent() throws Exception { + Runnable handler = mock(Runnable.class); + this.asyncRequest.addCompletionHandler(handler); + + this.asyncRequest.startAsync(); + this.asyncRequest.onComplete(new AsyncEvent(this.request.getAsyncContext())); + + verify(handler).run(); + assertTrue(this.asyncRequest.isAsyncComplete()); + } +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..44c345c52061c2f803e324442009adcc2868b495 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java @@ -0,0 +1,284 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.concurrent.Callable; +import java.util.function.Consumer; + +import javax.servlet.AsyncEvent; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.mock.web.test.MockAsyncContext; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.context.request.NativeWebRequest; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; +import static org.springframework.web.context.request.async.CallableProcessingInterceptor.RESULT_NONE; + +/** + * {@link WebAsyncManager} tests where container-triggered error/completion + * events are simulated. + * + * @author Violeta Georgieva + * @since 5.0 + */ +public class WebAsyncManagerErrorTests { + + private WebAsyncManager asyncManager; + + private StandardServletAsyncWebRequest asyncWebRequest; + + private MockHttpServletRequest servletRequest; + + private MockHttpServletResponse servletResponse; + + + @Before + public void setup() { + this.servletRequest = new MockHttpServletRequest("GET", "/test"); + this.servletRequest.setAsyncSupported(true); + this.servletResponse = new MockHttpServletResponse(); + this.asyncWebRequest = new StandardServletAsyncWebRequest(servletRequest, servletResponse); + + AsyncTaskExecutor executor = mock(AsyncTaskExecutor.class); + + this.asyncManager = WebAsyncUtils.getAsyncManager(servletRequest); + this.asyncManager.setTaskExecutor(executor); + this.asyncManager.setAsyncWebRequest(this.asyncWebRequest); + } + + + @Test + public void startCallableProcessingErrorAndComplete() throws Exception { + StubCallable callable = new StubCallable(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + Exception e = new Exception(); + given(interceptor.handleError(this.asyncWebRequest, callable, e)).willReturn(RESULT_NONE); + + this.asyncManager.registerCallableInterceptor("interceptor", interceptor); + this.asyncManager.startCallableProcessing(callable); + + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + this.asyncWebRequest.onComplete(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(e, this.asyncManager.getConcurrentResult()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); + verify(interceptor).afterCompletion(this.asyncWebRequest, callable); + } + + @Test + public void startCallableProcessingErrorAndResumeThroughCallback() throws Exception { + + StubCallable callable = new StubCallable(); + WebAsyncTask webAsyncTask = new WebAsyncTask<>(callable); + webAsyncTask.onError(new Callable() { + @Override + public Object call() throws Exception { + return 7; + } + }); + + this.asyncManager.startCallableProcessing(webAsyncTask); + + Exception e = new Exception(); + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(7, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startCallableProcessingErrorAndResumeThroughInterceptor() throws Exception { + + StubCallable callable = new StubCallable(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + Exception e = new Exception(); + given(interceptor.handleError(this.asyncWebRequest, callable, e)).willReturn(22); + + this.asyncManager.registerCallableInterceptor("errorInterceptor", interceptor); + this.asyncManager.startCallableProcessing(callable); + + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(22, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); + } + + @Test + public void startCallableProcessingAfterException() throws Exception { + + StubCallable callable = new StubCallable(); + Exception exception = new Exception(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + Exception e = new Exception(); + given(interceptor.handleError(this.asyncWebRequest, callable, e)).willThrow(exception); + + this.asyncManager.registerCallableInterceptor("errorInterceptor", interceptor); + this.asyncManager.startCallableProcessing(callable); + + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(exception, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); + } + + @Test + public void startDeferredResultProcessingErrorAndComplete() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(); + + DeferredResultProcessingInterceptor interceptor = mock(DeferredResultProcessingInterceptor.class); + Exception e = new Exception(); + given(interceptor.handleError(this.asyncWebRequest, deferredResult, e)).willReturn(true); + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + this.asyncWebRequest.onComplete(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(e, this.asyncManager.getConcurrentResult()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, deferredResult); + verify(interceptor).preProcess(this.asyncWebRequest, deferredResult); + verify(interceptor).afterCompletion(this.asyncWebRequest, deferredResult); + } + + @Test + public void startDeferredResultProcessingErrorAndResumeWithDefaultResult() throws Exception { + + Exception e = new Exception(); + DeferredResult deferredResult = new DeferredResult<>(null, e); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(e, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startDeferredResultProcessingErrorAndResumeThroughCallback() throws Exception { + + final DeferredResult deferredResult = new DeferredResult<>(); + deferredResult.onError(new Consumer() { + @Override + public void accept(Throwable t) { + deferredResult.setResult(t); + } + }); + + this.asyncManager.startDeferredResultProcessing(deferredResult); + + Exception e = new Exception(); + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(e, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startDeferredResultProcessingErrorAndResumeThroughInterceptor() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(); + + DeferredResultProcessingInterceptor interceptor = new DeferredResultProcessingInterceptor() { + @Override + public boolean handleError(NativeWebRequest request, DeferredResult result, Throwable t) + throws Exception { + result.setErrorResult(t); + return true; + } + }; + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + Exception e = new Exception(); + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(e, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startDeferredResultProcessingAfterException() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(); + final Exception exception = new Exception(); + + DeferredResultProcessingInterceptor interceptor = new DeferredResultProcessingInterceptor() { + @Override + public boolean handleError(NativeWebRequest request, DeferredResult result, Throwable t) + throws Exception { + throw exception; + } + }; + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + Exception e = new Exception(); + AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), e); + this.asyncWebRequest.onError(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(e, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + + private final class StubCallable implements Callable { + @Override + public Object call() throws Exception { + return 21; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c168f5571f920a748caef6973191fdcf344f15ba --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTests.java @@ -0,0 +1,406 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.concurrent.Callable; +import java.util.function.Consumer; + +import javax.servlet.http.HttpServletRequest; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.mock.web.test.MockHttpServletRequest; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Test fixture with an {@link WebAsyncManager} with a mock AsyncWebRequest. + * + * @author Rossen Stoyanchev + */ +public class WebAsyncManagerTests { + + private WebAsyncManager asyncManager; + + private AsyncWebRequest asyncWebRequest; + + private MockHttpServletRequest servletRequest; + + + @Before + public void setup() { + this.servletRequest = new MockHttpServletRequest(); + this.asyncManager = WebAsyncUtils.getAsyncManager(servletRequest); + this.asyncManager.setTaskExecutor(new SyncTaskExecutor()); + this.asyncWebRequest = mock(AsyncWebRequest.class); + this.asyncManager.setAsyncWebRequest(this.asyncWebRequest); + verify(this.asyncWebRequest).addCompletionHandler((Runnable) notNull()); + reset(this.asyncWebRequest); + } + + + @Test + public void startAsyncProcessingWithoutAsyncWebRequest() throws Exception { + WebAsyncManager manager = WebAsyncUtils.getAsyncManager(new MockHttpServletRequest()); + + try { + manager.startCallableProcessing(new StubCallable(1)); + fail("Expected exception"); + } + catch (IllegalStateException ex) { + assertEquals("AsyncWebRequest must not be null", ex.getMessage()); + } + + try { + manager.startDeferredResultProcessing(new DeferredResult()); + fail("Expected exception"); + } + catch (IllegalStateException ex) { + assertEquals("AsyncWebRequest must not be null", ex.getMessage()); + } + } + + @Test + public void isConcurrentHandlingStarted() { + given(this.asyncWebRequest.isAsyncStarted()).willReturn(false); + + assertFalse(this.asyncManager.isConcurrentHandlingStarted()); + + reset(this.asyncWebRequest); + given(this.asyncWebRequest.isAsyncStarted()).willReturn(true); + + assertTrue(this.asyncManager.isConcurrentHandlingStarted()); + } + + @Test(expected = IllegalArgumentException.class) + public void setAsyncWebRequestAfterAsyncStarted() { + this.asyncWebRequest.startAsync(); + this.asyncManager.setAsyncWebRequest(null); + } + + @Test + public void startCallableProcessing() throws Exception { + + int concurrentResult = 21; + Callable task = new StubCallable(concurrentResult); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerCallableInterceptor("interceptor", interceptor); + this.asyncManager.startCallableProcessing(task); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(concurrentResult, this.asyncManager.getConcurrentResult()); + + verifyDefaultAsyncScenario(); + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, task); + verify(interceptor).preProcess(this.asyncWebRequest, task); + verify(interceptor).postProcess(this.asyncWebRequest, task, concurrentResult); + } + + @Test + public void startCallableProcessingCallableException() throws Exception { + + Exception concurrentResult = new Exception(); + Callable task = new StubCallable(concurrentResult); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerCallableInterceptor("interceptor", interceptor); + this.asyncManager.startCallableProcessing(task); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(concurrentResult, this.asyncManager.getConcurrentResult()); + + verifyDefaultAsyncScenario(); + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, task); + verify(interceptor).preProcess(this.asyncWebRequest, task); + verify(interceptor).postProcess(this.asyncWebRequest, task, concurrentResult); + } + + @SuppressWarnings("unchecked") + @Test + public void startCallableProcessingBeforeConcurrentHandlingException() throws Exception { + Callable task = new StubCallable(21); + Exception exception = new Exception(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + willThrow(exception).given(interceptor).beforeConcurrentHandling(this.asyncWebRequest, task); + + this.asyncManager.registerCallableInterceptor("interceptor", interceptor); + + try { + this.asyncManager.startCallableProcessing(task); + fail("Expected Exception"); + } + catch (Exception ex) { + assertEquals(exception, ex); + } + + assertFalse(this.asyncManager.hasConcurrentResult()); + + verify(this.asyncWebRequest).addTimeoutHandler(notNull()); + verify(this.asyncWebRequest).addErrorHandler(notNull()); + verify(this.asyncWebRequest).addCompletionHandler(notNull()); + } + + @Test + public void startCallableProcessingPreProcessException() throws Exception { + Callable task = new StubCallable(21); + Exception exception = new Exception(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + willThrow(exception).given(interceptor).preProcess(this.asyncWebRequest, task); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerCallableInterceptor("interceptor", interceptor); + this.asyncManager.startCallableProcessing(task); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(exception, this.asyncManager.getConcurrentResult()); + + verifyDefaultAsyncScenario(); + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, task); + } + + @Test + public void startCallableProcessingPostProcessException() throws Exception { + Callable task = new StubCallable(21); + Exception exception = new Exception(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + willThrow(exception).given(interceptor).postProcess(this.asyncWebRequest, task, 21); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerCallableInterceptor("interceptor", interceptor); + this.asyncManager.startCallableProcessing(task); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(exception, this.asyncManager.getConcurrentResult()); + + verifyDefaultAsyncScenario(); + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, task); + verify(interceptor).preProcess(this.asyncWebRequest, task); + } + + @Test + public void startCallableProcessingPostProcessContinueAfterException() throws Exception { + Callable task = new StubCallable(21); + Exception exception = new Exception(); + + CallableProcessingInterceptor interceptor1 = mock(CallableProcessingInterceptor.class); + CallableProcessingInterceptor interceptor2 = mock(CallableProcessingInterceptor.class); + willThrow(exception).given(interceptor2).postProcess(this.asyncWebRequest, task, 21); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerCallableInterceptors(interceptor1, interceptor2); + this.asyncManager.startCallableProcessing(task); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(exception, this.asyncManager.getConcurrentResult()); + + verifyDefaultAsyncScenario(); + verify(interceptor1).beforeConcurrentHandling(this.asyncWebRequest, task); + verify(interceptor1).preProcess(this.asyncWebRequest, task); + verify(interceptor1).postProcess(this.asyncWebRequest, task, 21); + verify(interceptor2).beforeConcurrentHandling(this.asyncWebRequest, task); + verify(interceptor2).preProcess(this.asyncWebRequest, task); + } + + @SuppressWarnings("unchecked") + @Test + public void startCallableProcessingWithAsyncTask() throws Exception { + AsyncTaskExecutor executor = mock(AsyncTaskExecutor.class); + given(this.asyncWebRequest.getNativeRequest(HttpServletRequest.class)).willReturn(this.servletRequest); + + WebAsyncTask asyncTask = new WebAsyncTask<>(1000L, executor, mock(Callable.class)); + this.asyncManager.startCallableProcessing(asyncTask); + + verify(executor).submit((Runnable) notNull()); + verify(this.asyncWebRequest).setTimeout(1000L); + verify(this.asyncWebRequest).addTimeoutHandler(any(Runnable.class)); + verify(this.asyncWebRequest).addErrorHandler(any(Consumer.class)); + verify(this.asyncWebRequest).addCompletionHandler(any(Runnable.class)); + verify(this.asyncWebRequest).startAsync(); + } + + @Test + public void startCallableProcessingNullInput() throws Exception { + try { + this.asyncManager.startCallableProcessing((Callable) null); + fail("Expected exception"); + } + catch (IllegalArgumentException ex) { + assertEquals("Callable must not be null", ex.getMessage()); + } + } + + @Test + public void startDeferredResultProcessing() throws Exception { + DeferredResult deferredResult = new DeferredResult<>(1000L); + String concurrentResult = "abc"; + + DeferredResultProcessingInterceptor interceptor = mock(DeferredResultProcessingInterceptor.class); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + deferredResult.setResult(concurrentResult); + + assertEquals(concurrentResult, this.asyncManager.getConcurrentResult()); + verifyDefaultAsyncScenario(); + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, deferredResult); + verify(interceptor).preProcess(this.asyncWebRequest, deferredResult); + verify(interceptor).postProcess(asyncWebRequest, deferredResult, concurrentResult); + verify(this.asyncWebRequest).setTimeout(1000L); + } + + @SuppressWarnings("unchecked") + @Test + public void startDeferredResultProcessingBeforeConcurrentHandlingException() throws Exception { + DeferredResult deferredResult = new DeferredResult<>(); + Exception exception = new Exception(); + + DeferredResultProcessingInterceptor interceptor = mock(DeferredResultProcessingInterceptor.class); + willThrow(exception).given(interceptor).beforeConcurrentHandling(this.asyncWebRequest, deferredResult); + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + + try { + this.asyncManager.startDeferredResultProcessing(deferredResult); + fail("Expected Exception"); + } + catch (Exception success) { + assertEquals(exception, success); + } + + assertFalse(this.asyncManager.hasConcurrentResult()); + + verify(this.asyncWebRequest).addTimeoutHandler(notNull()); + verify(this.asyncWebRequest).addErrorHandler(notNull()); + verify(this.asyncWebRequest).addCompletionHandler(notNull()); + } + + @Test + public void startDeferredResultProcessingPreProcessException() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(); + Exception exception = new Exception(); + + DeferredResultProcessingInterceptor interceptor = mock(DeferredResultProcessingInterceptor.class); + willThrow(exception).given(interceptor).preProcess(this.asyncWebRequest, deferredResult); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + deferredResult.setResult(25); + + assertEquals(exception, this.asyncManager.getConcurrentResult()); + verifyDefaultAsyncScenario(); + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, deferredResult); + } + + @Test + public void startDeferredResultProcessingPostProcessException() throws Exception { + DeferredResult deferredResult = new DeferredResult<>(); + Exception exception = new Exception(); + + DeferredResultProcessingInterceptor interceptor = mock(DeferredResultProcessingInterceptor.class); + willThrow(exception).given(interceptor).postProcess(this.asyncWebRequest, deferredResult, 25); + + setupDefaultAsyncScenario(); + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + deferredResult.setResult(25); + + assertEquals(exception, this.asyncManager.getConcurrentResult()); + verifyDefaultAsyncScenario(); + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, deferredResult); + verify(interceptor).preProcess(this.asyncWebRequest, deferredResult); + } + + @Test + public void startDeferredResultProcessingNullInput() throws Exception { + try { + this.asyncManager.startDeferredResultProcessing(null); + fail("Expected exception"); + } + catch (IllegalArgumentException ex) { + assertEquals("DeferredResult must not be null", ex.getMessage()); + } + } + + private void setupDefaultAsyncScenario() { + given(this.asyncWebRequest.getNativeRequest(HttpServletRequest.class)).willReturn(this.servletRequest); + given(this.asyncWebRequest.isAsyncComplete()).willReturn(false); + } + + @SuppressWarnings("unchecked") + private void verifyDefaultAsyncScenario() { + verify(this.asyncWebRequest).addTimeoutHandler(notNull()); + verify(this.asyncWebRequest).addErrorHandler(notNull()); + verify(this.asyncWebRequest).addCompletionHandler(notNull()); + verify(this.asyncWebRequest).startAsync(); + verify(this.asyncWebRequest).dispatch(); + } + + + private final class StubCallable implements Callable { + + private Object value; + + public StubCallable(Object value) { + this.value = value; + } + + @Override + public Object call() throws Exception { + if (this.value instanceof Exception) { + throw ((Exception) this.value); + } + return this.value; + } + } + + + @SuppressWarnings("serial") + private static class SyncTaskExecutor extends SimpleAsyncTaskExecutor { + + @Override + public void execute(Runnable task, long startTimeout) { + task.run(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTimeoutTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTimeoutTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1769293ceeec23a7dd2806f271e96e1ebec033d5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTimeoutTests.java @@ -0,0 +1,293 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.util.concurrent.Callable; +import java.util.concurrent.Future; + +import javax.servlet.AsyncEvent; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.mock.web.test.MockAsyncContext; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.context.request.NativeWebRequest; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.web.context.request.async.CallableProcessingInterceptor.RESULT_NONE; + +/** + * {@link WebAsyncManager} tests where container-triggered timeout/completion + * events are simulated. + * + * @author Rossen Stoyanchev + */ +public class WebAsyncManagerTimeoutTests { + + private static final AsyncEvent ASYNC_EVENT = null; + + private WebAsyncManager asyncManager; + + private StandardServletAsyncWebRequest asyncWebRequest; + + private MockHttpServletRequest servletRequest; + + private MockHttpServletResponse servletResponse; + + + @Before + public void setup() { + this.servletRequest = new MockHttpServletRequest("GET", "/test"); + this.servletRequest.setAsyncSupported(true); + this.servletResponse = new MockHttpServletResponse(); + this.asyncWebRequest = new StandardServletAsyncWebRequest(servletRequest, servletResponse); + + AsyncTaskExecutor executor = mock(AsyncTaskExecutor.class); + + this.asyncManager = WebAsyncUtils.getAsyncManager(servletRequest); + this.asyncManager.setTaskExecutor(executor); + this.asyncManager.setAsyncWebRequest(this.asyncWebRequest); + } + + + @Test + public void startCallableProcessingTimeoutAndComplete() throws Exception { + StubCallable callable = new StubCallable(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + given(interceptor.handleTimeout(this.asyncWebRequest, callable)).willReturn(RESULT_NONE); + + this.asyncManager.registerCallableInterceptor("interceptor", interceptor); + this.asyncManager.startCallableProcessing(callable); + + this.asyncWebRequest.onTimeout(ASYNC_EVENT); + this.asyncWebRequest.onComplete(ASYNC_EVENT); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(AsyncRequestTimeoutException.class, this.asyncManager.getConcurrentResult().getClass()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); + verify(interceptor).afterCompletion(this.asyncWebRequest, callable); + } + + @Test + public void startCallableProcessingTimeoutAndResumeThroughCallback() throws Exception { + + StubCallable callable = new StubCallable(); + WebAsyncTask webAsyncTask = new WebAsyncTask<>(callable); + webAsyncTask.onTimeout(new Callable() { + @Override + public Object call() throws Exception { + return 7; + } + }); + + this.asyncManager.startCallableProcessing(webAsyncTask); + + this.asyncWebRequest.onTimeout(ASYNC_EVENT); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(7, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startCallableProcessingTimeoutAndResumeThroughInterceptor() throws Exception { + + StubCallable callable = new StubCallable(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + given(interceptor.handleTimeout(this.asyncWebRequest, callable)).willReturn(22); + + this.asyncManager.registerCallableInterceptor("timeoutInterceptor", interceptor); + this.asyncManager.startCallableProcessing(callable); + + this.asyncWebRequest.onTimeout(ASYNC_EVENT); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(22, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); + } + + @Test + public void startCallableProcessingAfterTimeoutException() throws Exception { + + StubCallable callable = new StubCallable(); + Exception exception = new Exception(); + + CallableProcessingInterceptor interceptor = mock(CallableProcessingInterceptor.class); + given(interceptor.handleTimeout(this.asyncWebRequest, callable)).willThrow(exception); + + this.asyncManager.registerCallableInterceptor("timeoutInterceptor", interceptor); + this.asyncManager.startCallableProcessing(callable); + + this.asyncWebRequest.onTimeout(ASYNC_EVENT); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(exception, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); + } + + @SuppressWarnings("unchecked") + @Test + public void startCallableProcessingTimeoutAndCheckThreadInterrupted() throws Exception { + + StubCallable callable = new StubCallable(); + Future future = mock(Future.class); + + AsyncTaskExecutor executor = mock(AsyncTaskExecutor.class); + when(executor.submit(any(Runnable.class))).thenReturn(future); + + this.asyncManager.setTaskExecutor(executor); + this.asyncManager.startCallableProcessing(callable); + + this.asyncWebRequest.onTimeout(ASYNC_EVENT); + + assertTrue(this.asyncManager.hasConcurrentResult()); + + verify(future).cancel(true); + verifyNoMoreInteractions(future); + } + + @Test + public void startDeferredResultProcessingTimeoutAndComplete() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(); + + DeferredResultProcessingInterceptor interceptor = mock(DeferredResultProcessingInterceptor.class); + given(interceptor.handleTimeout(this.asyncWebRequest, deferredResult)).willReturn(true); + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + this.asyncWebRequest.onTimeout(ASYNC_EVENT); + this.asyncWebRequest.onComplete(ASYNC_EVENT); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(AsyncRequestTimeoutException.class, this.asyncManager.getConcurrentResult().getClass()); + + verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, deferredResult); + verify(interceptor).preProcess(this.asyncWebRequest, deferredResult); + verify(interceptor).afterCompletion(this.asyncWebRequest, deferredResult); + } + + @Test + public void startDeferredResultProcessingTimeoutAndResumeWithDefaultResult() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(null, 23); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + AsyncEvent event = null; + this.asyncWebRequest.onTimeout(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(23, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startDeferredResultProcessingTimeoutAndResumeThroughCallback() throws Exception { + + final DeferredResult deferredResult = new DeferredResult<>(); + deferredResult.onTimeout(new Runnable() { + @Override + public void run() { + deferredResult.setResult(23); + } + }); + + this.asyncManager.startDeferredResultProcessing(deferredResult); + + AsyncEvent event = null; + this.asyncWebRequest.onTimeout(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(23, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startDeferredResultProcessingTimeoutAndResumeThroughInterceptor() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(); + + DeferredResultProcessingInterceptor interceptor = new DeferredResultProcessingInterceptor() { + @Override + public boolean handleTimeout(NativeWebRequest request, DeferredResult result) throws Exception { + result.setErrorResult(23); + return true; + } + }; + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + AsyncEvent event = null; + this.asyncWebRequest.onTimeout(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(23, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + @Test + public void startDeferredResultProcessingAfterTimeoutException() throws Exception { + + DeferredResult deferredResult = new DeferredResult<>(); + final Exception exception = new Exception(); + + DeferredResultProcessingInterceptor interceptor = new DeferredResultProcessingInterceptor() { + @Override + public boolean handleTimeout(NativeWebRequest request, DeferredResult result) throws Exception { + throw exception; + } + }; + + this.asyncManager.registerDeferredResultInterceptor("interceptor", interceptor); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + AsyncEvent event = null; + this.asyncWebRequest.onTimeout(event); + + assertTrue(this.asyncManager.hasConcurrentResult()); + assertEquals(exception, this.asyncManager.getConcurrentResult()); + assertEquals("/test", ((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()); + } + + + private final class StubCallable implements Callable { + @Override + public Object call() throws Exception { + return 21; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/support/AnnotationConfigWebApplicationContextTests.java b/spring-web/src/test/java/org/springframework/web/context/support/AnnotationConfigWebApplicationContextTests.java new file mode 100644 index 0000000000000000000000000000000000000000..99bf408e71c74ff0a82afe0f44e6e24943ceb820 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/support/AnnotationConfigWebApplicationContextTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import org.junit.Test; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.context.annotation.AnnotationBeanNameGenerator; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * @author Chris Beams + * @author Juergen Hoeller + */ +public class AnnotationConfigWebApplicationContextTests { + + @Test + @SuppressWarnings("resource") + public void registerSingleClass() { + AnnotationConfigWebApplicationContext ctx = new AnnotationConfigWebApplicationContext(); + ctx.register(Config.class); + ctx.refresh(); + + TestBean bean = ctx.getBean(TestBean.class); + assertNotNull(bean); + } + + @Test + @SuppressWarnings("resource") + public void configLocationWithSingleClass() { + AnnotationConfigWebApplicationContext ctx = new AnnotationConfigWebApplicationContext(); + ctx.setConfigLocation(Config.class.getName()); + ctx.refresh(); + + TestBean bean = ctx.getBean(TestBean.class); + assertNotNull(bean); + } + + @Test + @SuppressWarnings("resource") + public void configLocationWithBasePackage() { + AnnotationConfigWebApplicationContext ctx = new AnnotationConfigWebApplicationContext(); + ctx.setConfigLocation("org.springframework.web.context.support"); + ctx.refresh(); + + TestBean bean = ctx.getBean(TestBean.class); + assertNotNull(bean); + } + + @Test + @SuppressWarnings("resource") + public void withBeanNameGenerator() { + AnnotationConfigWebApplicationContext ctx = new AnnotationConfigWebApplicationContext(); + ctx.setBeanNameGenerator(new AnnotationBeanNameGenerator() { + @Override + public String generateBeanName(BeanDefinition definition, + BeanDefinitionRegistry registry) { + return "custom-" + super.generateBeanName(definition, registry); + } + }); + ctx.setConfigLocation(Config.class.getName()); + ctx.refresh(); + assertThat(ctx.containsBean("custom-myConfig"), is(true)); + } + + + @Configuration("myConfig") + static class Config { + + @Bean + public TestBean myTestBean() { + return new TestBean(); + } + } + + + static class TestBean { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/support/ResourceTests.java b/spring-web/src/test/java/org/springframework/web/context/support/ResourceTests.java new file mode 100644 index 0000000000000000000000000000000000000000..abd3bf9964dda4af1796b8c158d45ed60a75e54f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/support/ResourceTests.java @@ -0,0 +1,48 @@ +package org.springframework.web.context.support; + +import java.io.IOException; + +import org.junit.Test; + +import org.springframework.core.io.Resource; +import org.springframework.mock.web.test.MockServletContext; + +import static org.junit.Assert.*; + +/** + * @author Chris Beams + * @see org.springframework.core.io.ResourceTests + */ +public class ResourceTests { + + @Test + public void testServletContextResource() throws IOException { + MockServletContext sc = new MockServletContext(); + Resource resource = new ServletContextResource(sc, "org/springframework/core/io/Resource.class"); + doTestResource(resource); + assertEquals(resource, new ServletContextResource(sc, "org/springframework/core/../core/io/./Resource.class")); + } + + @Test + public void testServletContextResourceWithRelativePath() throws IOException { + MockServletContext sc = new MockServletContext(); + Resource resource = new ServletContextResource(sc, "dir/"); + Resource relative = resource.createRelative("subdir"); + assertEquals(new ServletContextResource(sc, "dir/subdir"), relative); + } + + private void doTestResource(Resource resource) throws IOException { + assertEquals("Resource.class", resource.getFilename()); + assertTrue(resource.getURL().getFile().endsWith("Resource.class")); + + Resource relative1 = resource.createRelative("ClassPathResource.class"); + assertEquals("ClassPathResource.class", relative1.getFilename()); + assertTrue(relative1.getURL().getFile().endsWith("ClassPathResource.class")); + assertTrue(relative1.exists()); + + Resource relative2 = resource.createRelative("support/ResourcePatternResolver.class"); + assertEquals("ResourcePatternResolver.class", relative2.getFilename()); + assertTrue(relative2.getURL().getFile().endsWith("ResourcePatternResolver.class")); + assertTrue(relative2.exists()); + } +} diff --git a/spring-web/src/test/java/org/springframework/web/context/support/Spr8510Tests.java b/spring-web/src/test/java/org/springframework/web/context/support/Spr8510Tests.java new file mode 100644 index 0000000000000000000000000000000000000000..6a01443834cbf4d6165f7a5a5215e055ac64f7c9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/support/Spr8510Tests.java @@ -0,0 +1,192 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import javax.servlet.ServletContextEvent; + +import org.junit.Test; + +import org.springframework.context.annotation.ClassPathBeanDefinitionScanner; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.web.context.ContextLoader; +import org.springframework.web.context.ContextLoaderListener; + +import static org.junit.Assert.*; + +/** + * Tests the interaction between a WebApplicationContext and ContextLoaderListener with + * regard to config location precedence, overriding and defaulting in programmatic + * configuration use cases, e.g. with Spring 3.1's WebApplicationInitializer. + * + * @author Chris Beams + * @since 3.1 + * @see org.springframework.web.context.ContextLoaderTests + */ +public class Spr8510Tests { + + @Test + public void abstractRefreshableWAC_respectsProgrammaticConfigLocations() { + XmlWebApplicationContext ctx = new XmlWebApplicationContext(); + ctx.setConfigLocation("programmatic.xml"); + ContextLoaderListener cll = new ContextLoaderListener(ctx); + + MockServletContext sc = new MockServletContext(); + + try { + cll.contextInitialized(new ServletContextEvent(sc)); + fail("expected exception"); + } + catch (Throwable t) { + // assert that an attempt was made to load the correct XML + assertTrue(t.getMessage(), t.getMessage().endsWith( + "Could not open ServletContext resource [/programmatic.xml]")); + } + } + + /** + * If a contextConfigLocation init-param has been specified for the ContextLoaderListener, + * then it should take precedence. This is generally not a recommended practice, but + * when it does happen, the init-param should be considered more specific than the + * programmatic configuration, given that it still quite possibly externalized in + * hybrid web.xml + WebApplicationInitializer cases. + */ + @Test + public void abstractRefreshableWAC_respectsInitParam_overProgrammaticConfigLocations() { + XmlWebApplicationContext ctx = new XmlWebApplicationContext(); + ctx.setConfigLocation("programmatic.xml"); + ContextLoaderListener cll = new ContextLoaderListener(ctx); + + MockServletContext sc = new MockServletContext(); + sc.addInitParameter(ContextLoader.CONFIG_LOCATION_PARAM, "from-init-param.xml"); + + try { + cll.contextInitialized(new ServletContextEvent(sc)); + fail("expected exception"); + } + catch (Throwable t) { + // assert that an attempt was made to load the correct XML + assertTrue(t.getMessage(), t.getMessage().endsWith( + "Could not open ServletContext resource [/from-init-param.xml]")); + } + } + + /** + * If setConfigLocation has not been called explicitly against the application context, + * then fall back to the ContextLoaderListener init-param if present. + */ + @Test + public void abstractRefreshableWAC_fallsBackToInitParam() { + XmlWebApplicationContext ctx = new XmlWebApplicationContext(); + //ctx.setConfigLocation("programmatic.xml"); // nothing set programmatically + ContextLoaderListener cll = new ContextLoaderListener(ctx); + + MockServletContext sc = new MockServletContext(); + sc.addInitParameter(ContextLoader.CONFIG_LOCATION_PARAM, "from-init-param.xml"); + + try { + cll.contextInitialized(new ServletContextEvent(sc)); + fail("expected exception"); + } + catch (Throwable t) { + // assert that an attempt was made to load the correct XML + assertTrue(t.getMessage().endsWith( + "Could not open ServletContext resource [/from-init-param.xml]")); + } + } + + /** + * Ensure that any custom default locations are still respected. + */ + @Test + public void customAbstractRefreshableWAC_fallsBackToInitParam() { + XmlWebApplicationContext ctx = new XmlWebApplicationContext() { + @Override + protected String[] getDefaultConfigLocations() { + return new String[] { "/WEB-INF/custom.xml" }; + } + }; + //ctx.setConfigLocation("programmatic.xml"); // nothing set programmatically + ContextLoaderListener cll = new ContextLoaderListener(ctx); + + MockServletContext sc = new MockServletContext(); + sc.addInitParameter(ContextLoader.CONFIG_LOCATION_PARAM, "from-init-param.xml"); + + try { + cll.contextInitialized(new ServletContextEvent(sc)); + fail("expected exception"); + } + catch (Throwable t) { + // assert that an attempt was made to load the correct XML + System.out.println(t.getMessage()); + assertTrue(t.getMessage().endsWith( + "Could not open ServletContext resource [/from-init-param.xml]")); + } + } + + /** + * If context config locations have been specified neither against the application + * context nor the context loader listener, then fall back to default values. + */ + @Test + public void abstractRefreshableWAC_fallsBackToConventionBasedNaming() { + XmlWebApplicationContext ctx = new XmlWebApplicationContext(); + //ctx.setConfigLocation("programmatic.xml"); // nothing set programmatically + ContextLoaderListener cll = new ContextLoaderListener(ctx); + + MockServletContext sc = new MockServletContext(); + // no init-param set + //sc.addInitParameter(ContextLoader.CONFIG_LOCATION_PARAM, "from-init-param.xml"); + + try { + cll.contextInitialized(new ServletContextEvent(sc)); + fail("expected exception"); + } + catch (Throwable t) { + // assert that an attempt was made to load the correct XML + System.out.println(t.getMessage()); + assertTrue(t.getMessage().endsWith( + "Could not open ServletContext resource [/WEB-INF/applicationContext.xml]")); + } + } + + /** + * Ensure that ContextLoaderListener and GenericWebApplicationContext interact nicely. + */ + @Test + public void genericWAC() { + GenericWebApplicationContext ctx = new GenericWebApplicationContext(); + ContextLoaderListener cll = new ContextLoaderListener(ctx); + + ClassPathBeanDefinitionScanner scanner = new ClassPathBeanDefinitionScanner(ctx); + scanner.scan("bogus.pkg"); + + cll.contextInitialized(new ServletContextEvent(new MockServletContext())); + } + + /** + * Ensure that ContextLoaderListener and AnnotationConfigApplicationContext interact nicely. + */ + @Test + public void annotationConfigWAC() { + AnnotationConfigWebApplicationContext ctx = new AnnotationConfigWebApplicationContext(); + + ctx.scan("does.not.matter"); + + ContextLoaderListener cll = new ContextLoaderListener(ctx); + cll.contextInitialized(new ServletContextEvent(new MockServletContext())); + } +} diff --git a/spring-web/src/test/java/org/springframework/web/context/support/SpringBeanAutowiringSupportTests.java b/spring-web/src/test/java/org/springframework/web/context/support/SpringBeanAutowiringSupportTests.java new file mode 100644 index 0000000000000000000000000000000000000000..88d1dcc974d3a39934862089de4233ee31ffbb35 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/support/SpringBeanAutowiringSupportTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import org.junit.Test; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.AnnotationConfigUtils; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.tests.sample.beans.ITestBean; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.web.context.WebApplicationContext; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + */ +public class SpringBeanAutowiringSupportTests { + + @Test + public void testProcessInjectionBasedOnServletContext() { + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + AnnotationConfigUtils.registerAnnotationConfigProcessors(wac); + + MutablePropertyValues pvs = new MutablePropertyValues(); + pvs.add("name", "tb"); + wac.registerSingleton("testBean", TestBean.class, pvs); + + MockServletContext sc = new MockServletContext(); + wac.setServletContext(sc); + wac.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + + InjectionTarget target = new InjectionTarget(); + SpringBeanAutowiringSupport.processInjectionBasedOnServletContext(target, sc); + assertTrue(target.testBean instanceof TestBean); + assertEquals("tb", target.name); + } + + + public static class InjectionTarget { + + @Autowired + public ITestBean testBean; + + @Value("#{testBean.name}") + public String name; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/support/StandardServletEnvironmentTests.java b/spring-web/src/test/java/org/springframework/web/context/support/StandardServletEnvironmentTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5f807d660292df736ce5ee8340fa8e223dd74bb2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/support/StandardServletEnvironmentTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.support; + +import org.junit.Test; + +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.MutablePropertySources; +import org.springframework.core.env.PropertySource; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.tests.mock.jndi.SimpleNamingContextBuilder; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link StandardServletEnvironment}. + * + * @author Chris Beams + * @since 3.1 + */ +public class StandardServletEnvironmentTests { + + @Test + public void propertySourceOrder() throws Exception { + SimpleNamingContextBuilder.emptyActivatedContextBuilder(); + + ConfigurableEnvironment env = new StandardServletEnvironment(); + MutablePropertySources sources = env.getPropertySources(); + + assertThat(sources.precedenceOf(PropertySource.named( + StandardServletEnvironment.SERVLET_CONFIG_PROPERTY_SOURCE_NAME)), equalTo(0)); + assertThat(sources.precedenceOf(PropertySource.named( + StandardServletEnvironment.SERVLET_CONTEXT_PROPERTY_SOURCE_NAME)), equalTo(1)); + assertThat(sources.precedenceOf(PropertySource.named( + StandardServletEnvironment.JNDI_PROPERTY_SOURCE_NAME)), equalTo(2)); + assertThat(sources.precedenceOf(PropertySource.named( + StandardEnvironment.SYSTEM_PROPERTIES_PROPERTY_SOURCE_NAME)), equalTo(3)); + assertThat(sources.precedenceOf(PropertySource.named( + StandardEnvironment.SYSTEM_ENVIRONMENT_PROPERTY_SOURCE_NAME)), equalTo(4)); + assertThat(sources.size(), is(5)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..685890cfb4e50c45c92d642565bdebd8e192cf73 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java @@ -0,0 +1,289 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; + +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link CorsConfiguration}. + * + * @author Sebastien Deleuze + * @author Sam Brannen + */ +public class CorsConfigurationTests { + + @Test + public void setNullValues() { + CorsConfiguration config = new CorsConfiguration(); + config.setAllowedOrigins(null); + assertNull(config.getAllowedOrigins()); + config.setAllowedHeaders(null); + assertNull(config.getAllowedHeaders()); + config.setAllowedMethods(null); + assertNull(config.getAllowedMethods()); + config.setExposedHeaders(null); + assertNull(config.getExposedHeaders()); + config.setAllowCredentials(null); + assertNull(config.getAllowCredentials()); + config.setMaxAge(null); + assertNull(config.getMaxAge()); + } + + @Test + public void setValues() { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOrigin("*"); + assertEquals(Arrays.asList("*"), config.getAllowedOrigins()); + config.addAllowedHeader("*"); + assertEquals(Arrays.asList("*"), config.getAllowedHeaders()); + config.addAllowedMethod("*"); + assertEquals(Arrays.asList("*"), config.getAllowedMethods()); + config.addExposedHeader("*"); + assertEquals(Arrays.asList("*"), config.getAllowedMethods()); + config.setAllowCredentials(true); + assertTrue(config.getAllowCredentials()); + config.setMaxAge(123L); + assertEquals(new Long(123), config.getMaxAge()); + } + + @Test + public void combineWithNull() { + CorsConfiguration config = new CorsConfiguration(); + config.setAllowedOrigins(Arrays.asList("*")); + config.combine(null); + assertEquals(Arrays.asList("*"), config.getAllowedOrigins()); + } + + @Test + public void combineWithNullProperties() { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOrigin("*"); + config.addAllowedHeader("header1"); + config.addExposedHeader("header3"); + config.addAllowedMethod(HttpMethod.GET.name()); + config.setMaxAge(123L); + config.setAllowCredentials(true); + CorsConfiguration other = new CorsConfiguration(); + config = config.combine(other); + assertEquals(Arrays.asList("*"), config.getAllowedOrigins()); + assertEquals(Arrays.asList("header1"), config.getAllowedHeaders()); + assertEquals(Arrays.asList("header3"), config.getExposedHeaders()); + assertEquals(Arrays.asList(HttpMethod.GET.name()), config.getAllowedMethods()); + assertEquals(new Long(123), config.getMaxAge()); + assertTrue(config.getAllowCredentials()); + } + + @Test // SPR-15772 + public void combineWithDefaultPermitValues() { + CorsConfiguration config = new CorsConfiguration().applyPermitDefaultValues(); + CorsConfiguration other = new CorsConfiguration(); + other.addAllowedOrigin("https://domain.com"); + other.addAllowedHeader("header1"); + other.addAllowedMethod(HttpMethod.PUT.name()); + + CorsConfiguration combinedConfig = config.combine(other); + assertEquals(Arrays.asList("https://domain.com"), combinedConfig.getAllowedOrigins()); + assertEquals(Arrays.asList("header1"), combinedConfig.getAllowedHeaders()); + assertEquals(Arrays.asList(HttpMethod.PUT.name()), combinedConfig.getAllowedMethods()); + assertEquals(Collections.emptyList(), combinedConfig.getExposedHeaders()); + + combinedConfig = other.combine(config); + assertEquals(Arrays.asList("https://domain.com"), combinedConfig.getAllowedOrigins()); + assertEquals(Arrays.asList("header1"), combinedConfig.getAllowedHeaders()); + assertEquals(Arrays.asList(HttpMethod.PUT.name()), combinedConfig.getAllowedMethods()); + assertEquals(Collections.emptyList(), combinedConfig.getExposedHeaders()); + + combinedConfig = config.combine(new CorsConfiguration()); + assertEquals(Arrays.asList("*"), config.getAllowedOrigins()); + assertEquals(Arrays.asList("*"), config.getAllowedHeaders()); + assertEquals(Arrays.asList(HttpMethod.GET.name(), HttpMethod.HEAD.name(), + HttpMethod.POST.name()), combinedConfig.getAllowedMethods()); + assertEquals(Collections.emptyList(), combinedConfig.getExposedHeaders()); + + combinedConfig = new CorsConfiguration().combine(config); + assertEquals(Arrays.asList("*"), config.getAllowedOrigins()); + assertEquals(Arrays.asList("*"), config.getAllowedHeaders()); + assertEquals(Arrays.asList(HttpMethod.GET.name(), HttpMethod.HEAD.name(), + HttpMethod.POST.name()), combinedConfig.getAllowedMethods()); + assertEquals(Collections.emptyList(), combinedConfig.getExposedHeaders()); + } + + @Test + public void combineWithAsteriskWildCard() { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOrigin("*"); + config.addAllowedHeader("*"); + config.addExposedHeader("*"); + config.addAllowedMethod("*"); + CorsConfiguration other = new CorsConfiguration(); + other.addAllowedOrigin("https://domain.com"); + other.addAllowedHeader("header1"); + other.addExposedHeader("header2"); + other.addAllowedHeader("anotherHeader1"); + other.addExposedHeader("anotherHeader2"); + other.addAllowedMethod(HttpMethod.PUT.name()); + CorsConfiguration combinedConfig = config.combine(other); + assertEquals(Arrays.asList("*"), combinedConfig.getAllowedOrigins()); + assertEquals(Arrays.asList("*"), combinedConfig.getAllowedHeaders()); + assertEquals(Arrays.asList("*"), combinedConfig.getExposedHeaders()); + assertEquals(Arrays.asList("*"), combinedConfig.getAllowedMethods()); + combinedConfig = other.combine(config); + assertEquals(Arrays.asList("*"), combinedConfig.getAllowedOrigins()); + assertEquals(Arrays.asList("*"), combinedConfig.getAllowedHeaders()); + assertEquals(Arrays.asList("*"), combinedConfig.getExposedHeaders()); + assertEquals(Arrays.asList("*"), combinedConfig.getAllowedMethods()); + } + + @Test // SPR-14792 + public void combineWithDuplicatedElements() { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOrigin("https://domain1.com"); + config.addAllowedOrigin("https://domain2.com"); + config.addAllowedHeader("header1"); + config.addAllowedHeader("header2"); + config.addExposedHeader("header3"); + config.addExposedHeader("header4"); + config.addAllowedMethod(HttpMethod.GET.name()); + config.addAllowedMethod(HttpMethod.PUT.name()); + CorsConfiguration other = new CorsConfiguration(); + other.addAllowedOrigin("https://domain1.com"); + other.addAllowedHeader("header1"); + other.addExposedHeader("header3"); + other.addAllowedMethod(HttpMethod.GET.name()); + CorsConfiguration combinedConfig = config.combine(other); + assertEquals(Arrays.asList("https://domain1.com", "https://domain2.com"), combinedConfig.getAllowedOrigins()); + assertEquals(Arrays.asList("header1", "header2"), combinedConfig.getAllowedHeaders()); + assertEquals(Arrays.asList("header3", "header4"), combinedConfig.getExposedHeaders()); + assertEquals(Arrays.asList(HttpMethod.GET.name(), HttpMethod.PUT.name()), combinedConfig.getAllowedMethods()); + } + + @Test + public void combine() { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOrigin("https://domain1.com"); + config.addAllowedHeader("header1"); + config.addExposedHeader("header3"); + config.addAllowedMethod(HttpMethod.GET.name()); + config.setMaxAge(123L); + config.setAllowCredentials(true); + CorsConfiguration other = new CorsConfiguration(); + other.addAllowedOrigin("https://domain2.com"); + other.addAllowedHeader("header2"); + other.addExposedHeader("header4"); + other.addAllowedMethod(HttpMethod.PUT.name()); + other.setMaxAge(456L); + other.setAllowCredentials(false); + config = config.combine(other); + assertEquals(Arrays.asList("https://domain1.com", "https://domain2.com"), config.getAllowedOrigins()); + assertEquals(Arrays.asList("header1", "header2"), config.getAllowedHeaders()); + assertEquals(Arrays.asList("header3", "header4"), config.getExposedHeaders()); + assertEquals(Arrays.asList(HttpMethod.GET.name(), HttpMethod.PUT.name()), config.getAllowedMethods()); + assertEquals(new Long(456), config.getMaxAge()); + assertFalse(config.getAllowCredentials()); + } + + @Test + public void checkOriginAllowed() { + CorsConfiguration config = new CorsConfiguration(); + config.setAllowedOrigins(Arrays.asList("*")); + assertEquals("*", config.checkOrigin("https://domain.com")); + config.setAllowCredentials(true); + assertEquals("https://domain.com", config.checkOrigin("https://domain.com")); + config.setAllowedOrigins(Arrays.asList("https://domain.com")); + assertEquals("https://domain.com", config.checkOrigin("https://domain.com")); + config.setAllowCredentials(false); + assertEquals("https://domain.com", config.checkOrigin("https://domain.com")); + } + + @Test + public void checkOriginNotAllowed() { + CorsConfiguration config = new CorsConfiguration(); + assertNull(config.checkOrigin(null)); + assertNull(config.checkOrigin("https://domain.com")); + config.addAllowedOrigin("*"); + assertNull(config.checkOrigin(null)); + config.setAllowedOrigins(Arrays.asList("https://domain1.com")); + assertNull(config.checkOrigin("https://domain2.com")); + config.setAllowedOrigins(new ArrayList<>()); + assertNull(config.checkOrigin("https://domain.com")); + } + + @Test + public void checkMethodAllowed() { + CorsConfiguration config = new CorsConfiguration(); + assertEquals(Arrays.asList(HttpMethod.GET, HttpMethod.HEAD), config.checkHttpMethod(HttpMethod.GET)); + config.addAllowedMethod("GET"); + assertEquals(Arrays.asList(HttpMethod.GET), config.checkHttpMethod(HttpMethod.GET)); + config.addAllowedMethod("POST"); + assertEquals(Arrays.asList(HttpMethod.GET, HttpMethod.POST), config.checkHttpMethod(HttpMethod.GET)); + assertEquals(Arrays.asList(HttpMethod.GET, HttpMethod.POST), config.checkHttpMethod(HttpMethod.POST)); + } + + @Test + public void checkMethodNotAllowed() { + CorsConfiguration config = new CorsConfiguration(); + assertNull(config.checkHttpMethod(null)); + assertNull(config.checkHttpMethod(HttpMethod.DELETE)); + config.setAllowedMethods(new ArrayList<>()); + assertNull(config.checkHttpMethod(HttpMethod.POST)); + } + + @Test + public void checkHeadersAllowed() { + CorsConfiguration config = new CorsConfiguration(); + assertEquals(Collections.emptyList(), config.checkHeaders(Collections.emptyList())); + config.addAllowedHeader("header1"); + config.addAllowedHeader("header2"); + assertEquals(Arrays.asList("header1"), config.checkHeaders(Arrays.asList("header1"))); + assertEquals(Arrays.asList("header1", "header2"), + config.checkHeaders(Arrays.asList("header1", "header2"))); + assertEquals(Arrays.asList("header1", "header2"), + config.checkHeaders(Arrays.asList("header1", "header2", "header3"))); + } + + @Test + public void checkHeadersNotAllowed() { + CorsConfiguration config = new CorsConfiguration(); + assertNull(config.checkHeaders(null)); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + config.setAllowedHeaders(Collections.emptyList()); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + config.addAllowedHeader("header2"); + config.addAllowedHeader("header3"); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + } + + @Test // SPR-15772 + public void changePermitDefaultValues() { + CorsConfiguration config = new CorsConfiguration().applyPermitDefaultValues(); + config.addAllowedOrigin("https://domain.com"); + config.addAllowedHeader("header1"); + config.addAllowedMethod("PATCH"); + assertEquals(Arrays.asList("*", "https://domain.com"), config.getAllowedOrigins()); + assertEquals(Arrays.asList("*", "header1"), config.getAllowedHeaders()); + assertEquals(Arrays.asList("GET", "HEAD", "POST", "PATCH"), config.getAllowedMethods()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/CorsUtilsTests.java b/spring-web/src/test/java/org/springframework/web/cors/CorsUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0a923d67ff55e2117470ad1db99424a12a139a73 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/CorsUtilsTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import static org.junit.Assert.*; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.test.MockHttpServletRequest; + +/** + * Test case for {@link CorsUtils}. + * + * @author Sebastien Deleuze + */ +public class CorsUtilsTests { + + @Test + public void isCorsRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.ORIGIN, "https://domain.com"); + assertTrue(CorsUtils.isCorsRequest(request)); + } + + @Test + public void isNotCorsRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + assertFalse(CorsUtils.isCorsRequest(request)); + } + + @Test + public void isPreFlightRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.OPTIONS.name()); + request.addHeader(HttpHeaders.ORIGIN, "https://domain.com"); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + assertTrue(CorsUtils.isPreFlightRequest(request)); + } + + @Test + public void isNotPreFlightRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + assertFalse(CorsUtils.isPreFlightRequest(request)); + + request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.OPTIONS.name()); + request.addHeader(HttpHeaders.ORIGIN, "https://domain.com"); + assertFalse(CorsUtils.isPreFlightRequest(request)); + + request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.OPTIONS.name()); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + assertFalse(CorsUtils.isPreFlightRequest(request)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..0f49e8b5d06b8e7f10c9c1175dc5a03065160b3d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java @@ -0,0 +1,383 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import javax.servlet.http.HttpServletResponse; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.*; + +/** + * Test {@link DefaultCorsProcessor} with simple or preflight CORS request. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class DefaultCorsProcessorTests { + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + private DefaultCorsProcessor processor; + + private CorsConfiguration conf; + + + @Before + public void setup() { + this.request = new MockHttpServletRequest(); + this.request.setRequestURI("/test.html"); + this.request.setRemoteHost("domain1.com"); + this.conf = new CorsConfiguration(); + this.response = new MockHttpServletResponse(); + this.response.setStatus(HttpServletResponse.SC_OK); + this.processor = new DefaultCorsProcessor(); + } + + + @Test + public void actualRequestWithOriginHeader() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); + } + + @Test + public void actualRequestWithOriginHeaderAndNullConfig() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + + this.processor.processRequest(null, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void actualRequestWithOriginHeaderAndAllowedOrigin() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("*", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void actualRequestCredentials() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.conf.addAllowedOrigin("https://domain1.com"); + this.conf.addAllowedOrigin("https://domain2.com"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void actualRequestCredentialsWithOriginWildcard() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.conf.addAllowedOrigin("*"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void actualRequestCaseInsensitiveOriginMatch() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.conf.addAllowedOrigin("https://DOMAIN2.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void actualRequestExposedHeaders() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.conf.addExposedHeader("header1"); + this.conf.addExposedHeader("header2"); + this.conf.addAllowedOrigin("https://domain2.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestAllOriginsAllowed() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestWrongAllowedMethod() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "DELETE"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); + } + + @Test + public void preflightRequestMatchedAllowedMethod() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + assertEquals("GET,HEAD", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + } + + @Test + public void preflightRequestTestWithOriginButWithoutOtherHeaders() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); + } + + @Test + public void preflightRequestWithoutRequestMethod() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); + } + + @Test + public void preflightRequestWithRequestAndMethodHeaderButNoConfig() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); + } + + @Test + public void preflightRequestValidRequestAndConfig() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + this.conf.addAllowedOrigin("*"); + this.conf.addAllowedMethod("GET"); + this.conf.addAllowedMethod("PUT"); + this.conf.addAllowedHeader("header1"); + this.conf.addAllowedHeader("header2"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("*", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertEquals("GET,PUT", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestCredentials() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + this.conf.addAllowedOrigin("https://domain1.com"); + this.conf.addAllowedOrigin("https://domain2.com"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.addAllowedHeader("Header1"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestCredentialsWithOriginWildcard() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + this.conf.addAllowedOrigin("https://domain1.com"); + this.conf.addAllowedOrigin("*"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.addAllowedHeader("Header1"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestAllowedHeaders() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2"); + this.conf.addAllowedHeader("Header1"); + this.conf.addAllowedHeader("Header2"); + this.conf.addAllowedHeader("Header3"); + this.conf.addAllowedOrigin("https://domain2.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestAllowsAllHeaders() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2"); + this.conf.addAllowedHeader("*"); + this.conf.addAllowedOrigin("https://domain2.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestWithEmptyHeaders() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, ""); + this.conf.addAllowedHeader("*"); + this.conf.addAllowedOrigin("https://domain2.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void preflightRequestWithNullConfig() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(null, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/UrlBasedCorsConfigurationSourceTests.java b/spring-web/src/test/java/org/springframework/web/cors/UrlBasedCorsConfigurationSourceTests.java new file mode 100644 index 0000000000000000000000000000000000000000..6c8c91bb90c89f4db02be4b6561e5abeb9d70af6 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/UrlBasedCorsConfigurationSourceTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import static org.junit.Assert.*; +import org.junit.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.test.MockHttpServletRequest; + +/** + * Unit tests for {@link UrlBasedCorsConfigurationSource}. + * @author Sebastien Deleuze + */ +public class UrlBasedCorsConfigurationSourceTests { + + private final UrlBasedCorsConfigurationSource configSource = new UrlBasedCorsConfigurationSource(); + + @Test + public void empty() { + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/bar/test.html"); + assertNull(this.configSource.getCorsConfiguration(request)); + } + + @Test + public void registerAndMatch() { + CorsConfiguration config = new CorsConfiguration(); + this.configSource.registerCorsConfiguration("/bar/**", config); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/foo/test.html"); + assertNull(this.configSource.getCorsConfiguration(request)); + + request.setRequestURI("/bar/test.html"); + assertEquals(config, this.configSource.getCorsConfiguration(request)); + } + + @Test(expected = UnsupportedOperationException.class) + public void unmodifiableConfigurationsMap() { + this.configSource.getCorsConfigurations().put("/**", new CorsConfiguration()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b34ccd4d8d889a2ec77b07e09bb35f9ae2499b55 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.filter.reactive.ForwardedHeaderFilter; + +import static org.junit.Assert.*; +import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.*; + +/** + * Test case for reactive {@link CorsUtils}. + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class CorsUtilsTests { + + @Test + public void isCorsRequest() { + ServerHttpRequest request = get("/").header(HttpHeaders.ORIGIN, "https://domain.com").build(); + assertTrue(CorsUtils.isCorsRequest(request)); + } + + @Test + public void isNotCorsRequest() { + ServerHttpRequest request = get("/").build(); + assertFalse(CorsUtils.isCorsRequest(request)); + } + + @Test + public void isPreFlightRequest() { + ServerHttpRequest request = options("/") + .header(HttpHeaders.ORIGIN, "https://domain.com") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET") + .build(); + assertTrue(CorsUtils.isPreFlightRequest(request)); + } + + @Test + public void isNotPreFlightRequest() { + ServerHttpRequest request = get("/").build(); + assertFalse(CorsUtils.isPreFlightRequest(request)); + + request = options("/").header(HttpHeaders.ORIGIN, "https://domain.com").build(); + assertFalse(CorsUtils.isPreFlightRequest(request)); + + request = options("/").header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET").build(); + assertFalse(CorsUtils.isPreFlightRequest(request)); + } + + @Test // SPR-16262 + public void isSameOriginWithXForwardedHeaders() { + String server = "mydomain1.com"; + testWithXForwardedHeaders(server, -1, "https", null, -1, "https://mydomain1.com"); + testWithXForwardedHeaders(server, 123, "https", null, -1, "https://mydomain1.com"); + testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", -1, "https://mydomain2.com"); + testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", -1, "https://mydomain2.com"); + testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", 456, "https://mydomain2.com:456"); + testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", 456, "https://mydomain2.com:456"); + } + + @Test // SPR-16262 + public void isSameOriginWithForwardedHeader() { + String server = "mydomain1.com"; + testWithForwardedHeader(server, -1, "proto=https", "https://mydomain1.com"); + testWithForwardedHeader(server, 123, "proto=https", "https://mydomain1.com"); + testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com", "https://mydomain2.com"); + testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com", "https://mydomain2.com"); + testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456"); + testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456"); + } + + @Test // SPR-16362 + public void isSameOriginWithDifferentSchemes() { + MockServerHttpRequest request = MockServerHttpRequest + .get("http://mydomain1.com") + .header(HttpHeaders.ORIGIN, "https://mydomain1.com") + .build(); + assertFalse(CorsUtils.isSameOrigin(request)); + } + + private void testWithXForwardedHeaders(String serverName, int port, + String forwardedProto, String forwardedHost, int forwardedPort, String originHeader) { + + String url = "http://" + serverName; + if (port != -1) { + url = url + ":" + port; + } + + MockServerHttpRequest.BaseBuilder builder = get(url).header(HttpHeaders.ORIGIN, originHeader); + if (forwardedProto != null) { + builder.header("X-Forwarded-Proto", forwardedProto); + } + if (forwardedHost != null) { + builder.header("X-Forwarded-Host", forwardedHost); + } + if (forwardedPort != -1) { + builder.header("X-Forwarded-Port", String.valueOf(forwardedPort)); + } + + ServerHttpRequest request = adaptFromForwardedHeaders(builder); + assertTrue(CorsUtils.isSameOrigin(request)); + } + + private void testWithForwardedHeader(String serverName, int port, + String forwardedHeader, String originHeader) { + + String url = "http://" + serverName; + if (port != -1) { + url = url + ":" + port; + } + + MockServerHttpRequest.BaseBuilder builder = get(url) + .header("Forwarded", forwardedHeader) + .header(HttpHeaders.ORIGIN, originHeader); + + ServerHttpRequest request = adaptFromForwardedHeaders(builder); + assertTrue(CorsUtils.isSameOrigin(request)); + } + + // SPR-16668 + private ServerHttpRequest adaptFromForwardedHeaders(MockServerHttpRequest.BaseBuilder builder) { + AtomicReference requestRef = new AtomicReference<>(); + MockServerWebExchange exchange = MockServerWebExchange.from(builder); + new ForwardedHeaderFilter().filter(exchange, exchange2 -> { + requestRef.set(exchange2.getRequest()); + return Mono.empty(); + }).block(); + return requestRef.get(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b1d8df234c6e65f4daf736695acf643b09314d60 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.cors.reactive; + + +import java.io.IOException; +import java.util.Arrays; + +import javax.servlet.ServletException; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.WebFilterChain; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_MAX_AGE; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD; +import static org.springframework.http.HttpHeaders.HOST; +import static org.springframework.http.HttpHeaders.ORIGIN; + +/** + * Unit tests for {@link CorsWebFilter}. + * @author Sebastien Deleuze + */ +public class CorsWebFilterTests { + + private CorsWebFilter filter; + + private final CorsConfiguration config = new CorsConfiguration(); + + @Before + public void setup() throws Exception { + config.setAllowedOrigins(Arrays.asList("https://domain1.com", "https://domain2.com")); + config.setAllowedMethods(Arrays.asList("GET", "POST")); + config.setAllowedHeaders(Arrays.asList("header1", "header2")); + config.setExposedHeaders(Arrays.asList("header3", "header4")); + config.setMaxAge(123L); + config.setAllowCredentials(false); + filter = new CorsWebFilter(r -> config); + } + + @Test + public void validActualRequest() { + WebFilterChain filterChain = (filterExchange) -> { + try { + HttpHeaders headers = filterExchange.getResponse().getHeaders(); + assertEquals("https://domain2.com", headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("header3, header4", headers.getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); + } catch (AssertionError ex) { + return Mono.error(ex); + } + return Mono.empty(); + + }; + MockServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest + .get("https://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "https://domain2.com") + .header("header2", "foo")); + this.filter.filter(exchange, filterChain); + } + + @Test + public void invalidActualRequest() throws ServletException, IOException { + MockServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest + .delete("https://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "https://domain2.com") + .header("header2", "foo")); + + WebFilterChain filterChain = (filterExchange) -> Mono.error( + new AssertionError("Invalid requests must not be forwarded to the filter chain")); + filter.filter(exchange, filterChain); + + assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + @Test + public void validPreFlightRequest() throws ServletException, IOException { + + MockServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest + .options("https://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "https://domain2.com") + .header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.GET.name()) + .header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2") + ); + + WebFilterChain filterChain = (filterExchange) -> Mono.error( + new AssertionError("Preflight requests must not be forwarded to the filter chain")); + filter.filter(exchange, filterChain); + + HttpHeaders headers = exchange.getResponse().getHeaders(); + assertEquals("https://domain2.com", headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("header1, header2", headers.getFirst(ACCESS_CONTROL_ALLOW_HEADERS)); + assertEquals("header3, header4", headers.getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); + assertEquals(123L, Long.parseLong(headers.getFirst(ACCESS_CONTROL_MAX_AGE))); + } + + @Test + public void invalidPreFlightRequest() throws ServletException, IOException { + + MockServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest + .options("https://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "https://domain2.com") + .header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.DELETE.name()) + .header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2")); + + WebFilterChain filterChain = (filterExchange) -> Mono.error( + new AssertionError("Preflight requests must not be forwarded to the filter chain")); + + filter.filter(exchange, filterChain); + + assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..675f23e12c31d0887d3095d6901aef2b5aa5ca60 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java @@ -0,0 +1,413 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS; +import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD; +import static org.springframework.http.HttpHeaders.ORIGIN; +import static org.springframework.http.HttpHeaders.VARY; + +/** + * {@link DefaultCorsProcessor} tests with simple or pre-flight CORS request. + * + * @author Sebastien Deleuze + */ +public class DefaultCorsProcessorTests { + + private DefaultCorsProcessor processor; + + private CorsConfiguration conf; + + + @Before + public void setup() { + this.conf = new CorsConfiguration(); + this.processor = new DefaultCorsProcessor(); + } + + + @Test + public void actualRequestWithOriginHeader() throws Exception { + ServerWebExchange exchange = actualRequest(); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); + } + + @Test + public void actualRequestWithOriginHeaderAndNullConfig() throws Exception { + ServerWebExchange exchange = actualRequest(); + this.processor.process(null, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertNull(response.getStatusCode()); + } + + @Test + public void actualRequestWithOriginHeaderAndAllowedOrigin() throws Exception { + ServerWebExchange exchange = actualRequest(); + this.conf.addAllowedOrigin("*"); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("*", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void actualRequestCredentials() throws Exception { + ServerWebExchange exchange = actualRequest(); + this.conf.addAllowedOrigin("https://domain1.com"); + this.conf.addAllowedOrigin("https://domain2.com"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.setAllowCredentials(true); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void actualRequestCredentialsWithOriginWildcard() throws Exception { + ServerWebExchange exchange = actualRequest(); + this.conf.addAllowedOrigin("*"); + this.conf.setAllowCredentials(true); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void actualRequestCaseInsensitiveOriginMatch() throws Exception { + ServerWebExchange exchange = actualRequest(); + this.conf.addAllowedOrigin("https://DOMAIN2.com"); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void actualRequestExposedHeaders() throws Exception { + ServerWebExchange exchange = actualRequest(); + this.conf.addExposedHeader("header1"); + this.conf.addExposedHeader("header2"); + this.conf.addAllowedOrigin("https://domain2.com"); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); + assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void preflightRequestAllOriginsAllowed() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from( + preFlightRequest().header(ACCESS_CONTROL_REQUEST_METHOD, "GET")); + this.conf.addAllowedOrigin("*"); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + + @Test + public void preflightRequestWrongAllowedMethod() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from( + preFlightRequest().header(ACCESS_CONTROL_REQUEST_METHOD, "DELETE")); + this.conf.addAllowedOrigin("*"); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); + } + + @Test + public void preflightRequestMatchedAllowedMethod() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from( + preFlightRequest().header(ACCESS_CONTROL_REQUEST_METHOD, "GET")); + this.conf.addAllowedOrigin("*"); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertNull(response.getStatusCode()); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals("GET,HEAD", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + } + + @Test + public void preflightRequestTestWithOriginButWithoutOtherHeaders() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest()); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); + } + + @Test + public void preflightRequestWithoutRequestMethod() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from( + preFlightRequest().header(ACCESS_CONTROL_REQUEST_HEADERS, "Header1")); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); + } + + @Test + public void preflightRequestWithRequestAndMethodHeaderButNoConfig() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest() + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "Header1")); + + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); + } + + @Test + public void preflightRequestValidRequestAndConfig() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest() + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "Header1")); + + this.conf.addAllowedOrigin("*"); + this.conf.addAllowedMethod("GET"); + this.conf.addAllowedMethod("PUT"); + this.conf.addAllowedHeader("header1"); + this.conf.addAllowedHeader("header2"); + + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("*", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertEquals("GET,PUT", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void preflightRequestCredentials() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest() + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "Header1")); + + this.conf.addAllowedOrigin("https://domain1.com"); + this.conf.addAllowedOrigin("https://domain2.com"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.addAllowedHeader("Header1"); + this.conf.setAllowCredentials(true); + + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void preflightRequestCredentialsWithOriginWildcard() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest() + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "Header1")); + + this.conf.addAllowedOrigin("https://domain1.com"); + this.conf.addAllowedOrigin("*"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.addAllowedHeader("Header1"); + this.conf.setAllowCredentials(true); + + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("https://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void preflightRequestAllowedHeaders() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest() + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2")); + + this.conf.addAllowedHeader("Header1"); + this.conf.addAllowedHeader("Header2"); + this.conf.addAllowedHeader("Header3"); + this.conf.addAllowedOrigin("https://domain2.com"); + + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void preflightRequestAllowsAllHeaders() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest() + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2")); + + this.conf.addAllowedHeader("*"); + this.conf.addAllowedOrigin("https://domain2.com"); + + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void preflightRequestWithEmptyHeaders() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from(preFlightRequest() + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .header(ACCESS_CONTROL_REQUEST_HEADERS, "")); + + this.conf.addAllowedHeader("*"); + this.conf.addAllowedOrigin("https://domain2.com"); + + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_HEADERS)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void preflightRequestWithNullConfig() throws Exception { + ServerWebExchange exchange = MockServerWebExchange.from( + preFlightRequest().header(ACCESS_CONTROL_REQUEST_METHOD, "GET")); + this.conf.addAllowedOrigin("*"); + this.processor.process(null, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); + } + + + private ServerWebExchange actualRequest() { + return MockServerWebExchange.from(corsRequest(HttpMethod.GET)); + } + + private MockServerHttpRequest.BaseBuilder preFlightRequest() { + return corsRequest(HttpMethod.OPTIONS); + } + + private MockServerHttpRequest.BaseBuilder corsRequest(HttpMethod method) { + return MockServerHttpRequest + .method(method, "http://localhost/test.html") + .header(HttpHeaders.ORIGIN, "https://domain2.com"); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1894608d550617fe02a18cf20a43904660255293 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import org.junit.Test; + +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.cors.CorsConfiguration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** + * Unit tests for {@link UrlBasedCorsConfigurationSource}. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class UrlBasedCorsConfigurationSourceTests { + + private final UrlBasedCorsConfigurationSource configSource + = new UrlBasedCorsConfigurationSource(); + + + @Test + public void empty() { + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/bar/test.html")); + assertNull(this.configSource.getCorsConfiguration(exchange)); + } + + @Test + public void registerAndMatch() { + CorsConfiguration config = new CorsConfiguration(); + this.configSource.registerCorsConfiguration("/bar/**", config); + + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/foo/test.html")); + assertNull(this.configSource.getCorsConfiguration(exchange)); + + exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/bar/test.html")); + assertEquals(config, this.configSource.getCorsConfiguration(exchange)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/CharacterEncodingFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/CharacterEncodingFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..233041a871c6bfe737a21ae0fd2c0c8a1e285b28 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/CharacterEncodingFilterTests.java @@ -0,0 +1,174 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockFilterConfig; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.web.util.WebUtils; + +import static org.mockito.BDDMockito.*; + +/** + * @author Rick Evans + * @author Juergen Hoeller + * @author Vedran Pavic + */ +public class CharacterEncodingFilterTests { + + private static final String FILTER_NAME = "boot"; + + private static final String ENCODING = "UTF-8"; + + + @Test + public void forceEncodingAlwaysSetsEncoding() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + request.setCharacterEncoding(ENCODING); + given(request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)).willReturn(null); + given(request.getAttribute(filteredName(FILTER_NAME))).willReturn(null); + + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain filterChain = mock(FilterChain.class); + + CharacterEncodingFilter filter = new CharacterEncodingFilter(ENCODING, true); + filter.init(new MockFilterConfig(FILTER_NAME)); + filter.doFilter(request, response, filterChain); + + verify(request).setAttribute(filteredName(FILTER_NAME), Boolean.TRUE); + verify(request).removeAttribute(filteredName(FILTER_NAME)); + verify(response).setCharacterEncoding(ENCODING); + verify(filterChain).doFilter(request, response); + } + + @Test + public void encodingIfEmptyAndNotForced() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + given(request.getCharacterEncoding()).willReturn(null); + given(request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)).willReturn(null); + given(request.getAttribute(filteredName(FILTER_NAME))).willReturn(null); + + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = mock(FilterChain.class); + + CharacterEncodingFilter filter = new CharacterEncodingFilter(ENCODING); + filter.init(new MockFilterConfig(FILTER_NAME)); + filter.doFilter(request, response, filterChain); + + verify(request).setCharacterEncoding(ENCODING); + verify(request).setAttribute(filteredName(FILTER_NAME), Boolean.TRUE); + verify(request).removeAttribute(filteredName(FILTER_NAME)); + verify(filterChain).doFilter(request, response); + } + + @Test + public void doesNotIfEncodingIsNotEmptyAndNotForced() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + given(request.getCharacterEncoding()).willReturn(ENCODING); + given(request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)).willReturn(null); + given(request.getAttribute(filteredName(FILTER_NAME))).willReturn(null); + + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = mock(FilterChain.class); + + CharacterEncodingFilter filter = new CharacterEncodingFilter(ENCODING); + filter.init(new MockFilterConfig(FILTER_NAME)); + filter.doFilter(request, response, filterChain); + + verify(request).setAttribute(filteredName(FILTER_NAME), Boolean.TRUE); + verify(request).removeAttribute(filteredName(FILTER_NAME)); + verify(filterChain).doFilter(request, response); + } + + @Test + public void withBeanInitialization() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + given(request.getCharacterEncoding()).willReturn(null); + given(request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)).willReturn(null); + given(request.getAttribute(filteredName(FILTER_NAME))).willReturn(null); + + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = mock(FilterChain.class); + + CharacterEncodingFilter filter = new CharacterEncodingFilter(); + filter.setEncoding(ENCODING); + filter.setBeanName(FILTER_NAME); + filter.setServletContext(new MockServletContext()); + filter.doFilter(request, response, filterChain); + + verify(request).setCharacterEncoding(ENCODING); + verify(request).setAttribute(filteredName(FILTER_NAME), Boolean.TRUE); + verify(request).removeAttribute(filteredName(FILTER_NAME)); + verify(filterChain).doFilter(request, response); + } + + @Test + public void withIncompleteInitialization() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + given(request.getCharacterEncoding()).willReturn(null); + given(request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)).willReturn(null); + given(request.getAttribute(filteredName(CharacterEncodingFilter.class.getName()))).willReturn(null); + + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = mock(FilterChain.class); + + CharacterEncodingFilter filter = new CharacterEncodingFilter(ENCODING); + filter.doFilter(request, response, filterChain); + + verify(request).setCharacterEncoding(ENCODING); + verify(request).setAttribute(filteredName(CharacterEncodingFilter.class.getName()), Boolean.TRUE); + verify(request).removeAttribute(filteredName(CharacterEncodingFilter.class.getName())); + verify(filterChain).doFilter(request, response); + } + + // SPR-14240 + @Test + public void setForceEncodingOnRequestOnly() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + request.setCharacterEncoding(ENCODING); + given(request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)).willReturn(null); + given(request.getAttribute(filteredName(FILTER_NAME))).willReturn(null); + + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain filterChain = mock(FilterChain.class); + + CharacterEncodingFilter filter = new CharacterEncodingFilter(ENCODING, true, false); + filter.init(new MockFilterConfig(FILTER_NAME)); + filter.doFilter(request, response, filterChain); + + verify(request).setAttribute(filteredName(FILTER_NAME), Boolean.TRUE); + verify(request).removeAttribute(filteredName(FILTER_NAME)); + verify(request, times(2)).setCharacterEncoding(ENCODING); + verify(response, never()).setCharacterEncoding(ENCODING); + verify(filterChain).doFilter(request, response); + } + + private String filteredName(String prefix) { + return prefix + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/CompositeFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/CompositeFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e6c28e42aa4fae68a1454bca55a087d3ba8478a1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/CompositeFilterTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.util.Arrays; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockFilterConfig; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.mock.web.test.MockServletContext; + +import static org.junit.Assert.*; + +/** + * @author Dave Syer + */ +public class CompositeFilterTests { + + @Test + public void testCompositeFilter() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + MockFilter targetFilter = new MockFilter(); + MockFilterConfig proxyConfig = new MockFilterConfig(sc); + + CompositeFilter filterProxy = new CompositeFilter(); + filterProxy.setFilters(Arrays.asList(targetFilter)); + filterProxy.init(proxyConfig); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNotNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + + public static class MockFilter implements Filter { + + public FilterConfig filterConfig; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + this.filterConfig = filterConfig; + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) { + request.setAttribute("called", Boolean.TRUE); + } + + @Override + public void destroy() { + this.filterConfig = null; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java new file mode 100644 index 0000000000000000000000000000000000000000..28ba42bdf752a4af2710ed0620c56775ca07c4e1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.filter; + +import java.nio.charset.StandardCharsets; + +import javax.servlet.http.HttpServletResponse; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.FileCopyUtils; +import org.springframework.web.util.ContentCachingResponseWrapper; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@link ContentCachingResponseWrapper}. + * @author Rossen Stoyanchev + */ +public class ContentCachingResponseWrapperTests { + + @Test + public void copyBodyToResponse() throws Exception { + byte[] responseBody = "Hello World".getBytes(StandardCharsets.UTF_8); + MockHttpServletResponse response = new MockHttpServletResponse(); + + ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); + responseWrapper.setStatus(HttpServletResponse.SC_OK); + FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); + responseWrapper.copyBodyToResponse(); + + assertEquals(200, response.getStatus()); + assertTrue(response.getContentLength() > 0); + assertArrayEquals(responseBody, response.getContentAsByteArray()); + } + + @Test + public void copyBodyToResponseWithTransferEncoding() throws Exception { + byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(StandardCharsets.UTF_8); + MockHttpServletResponse response = new MockHttpServletResponse(); + + ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); + responseWrapper.setStatus(HttpServletResponse.SC_OK); + responseWrapper.setHeader(HttpHeaders.TRANSFER_ENCODING, "chunked"); + FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); + responseWrapper.copyBodyToResponse(); + + assertEquals(200, response.getStatus()); + assertEquals("chunked", response.getHeader(HttpHeaders.TRANSFER_ENCODING)); + assertNull(response.getHeader(HttpHeaders.CONTENT_LENGTH)); + assertArrayEquals(responseBody, response.getContentAsByteArray()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/CorsFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/CorsFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..fa1a49eca37ccedd7e7781c0d328a8f79d4f9da3 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/CorsFilterTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.util.Arrays; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; + +import static org.junit.Assert.*; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.cors.CorsConfiguration; + +/** + * Unit tests for {@link CorsFilter}. + * @author Sebastien Deleuze + */ +public class CorsFilterTests { + + private CorsFilter filter; + + private final CorsConfiguration config = new CorsConfiguration(); + + @Before + public void setup() throws Exception { + config.setAllowedOrigins(Arrays.asList("https://domain1.com", "https://domain2.com")); + config.setAllowedMethods(Arrays.asList("GET", "POST")); + config.setAllowedHeaders(Arrays.asList("header1", "header2")); + config.setExposedHeaders(Arrays.asList("header3", "header4")); + config.setMaxAge(123L); + config.setAllowCredentials(false); + filter = new CorsFilter(r -> config); + } + + @Test + public void validActualRequest() throws ServletException, IOException { + + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/test.html"); + request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + request.addHeader("header2", "foo"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("https://domain2.com", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("header3, header4", response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + }; + filter.doFilter(request, response, filterChain); + } + + @Test + public void invalidActualRequest() throws ServletException, IOException { + + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.DELETE.name(), "/test.html"); + request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + request.addHeader("header2", "foo"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + fail("Invalid requests must not be forwarded to the filter chain"); + }; + filter.doFilter(request, response, filterChain); + assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + @Test + public void validPreFlightRequest() throws ServletException, IOException { + + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.OPTIONS.name(), "/test.html"); + request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.GET.name()); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> + fail("Preflight requests must not be forwarded to the filter chain"); + filter.doFilter(request, response, filterChain); + + assertEquals("https://domain2.com", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("header1, header2", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertEquals("header3, header4", response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertEquals(123L, Long.parseLong(response.getHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE))); + } + + @Test + public void invalidPreFlightRequest() throws ServletException, IOException { + + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.OPTIONS.name(), "/test.html"); + request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.DELETE.name()); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> + fail("Preflight requests must not be forwarded to the filter chain"); + filter.doFilter(request, response, filterChain); + + assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/DelegatingFilterProxyTests.java b/spring-web/src/test/java/org/springframework/web/filter/DelegatingFilterProxyTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b8c0bbcaa4471522cf18b2f2438175af3dbfa3ec --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/DelegatingFilterProxyTests.java @@ -0,0 +1,423 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockFilterConfig; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.StaticWebApplicationContext; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @author Chris Beams + * @author Rob Winch + * @since 08.05.2005 + */ +public class DelegatingFilterProxyTests { + + @Test + public void testDelegatingFilterProxy() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + wac.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + MockFilterConfig proxyConfig = new MockFilterConfig(sc); + proxyConfig.addInitParameter("targetBeanName", "targetFilter"); + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(); + filterProxy.init(proxyConfig); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyAndCustomContextAttribute() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + wac.refresh(); + sc.setAttribute("CUSTOM_ATTR", wac); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + MockFilterConfig proxyConfig = new MockFilterConfig(sc); + proxyConfig.addInitParameter("targetBeanName", "targetFilter"); + proxyConfig.addInitParameter("contextAttribute", "CUSTOM_ATTR"); + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(); + filterProxy.init(proxyConfig); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyWithFilterDelegateInstance() throws ServletException, IOException { + MockFilter targetFilter = new MockFilter(); + + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(targetFilter); + filterProxy.init(new MockFilterConfig(new MockServletContext())); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyWithTargetBeanName() throws ServletException, IOException { + MockServletContext sc = new MockServletContext(); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + wac.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy("targetFilter"); + filterProxy.init(new MockFilterConfig(sc)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyWithTargetBeanNameAndNotYetRefreshedApplicationContext() + throws ServletException, IOException { + + MockServletContext sc = new MockServletContext(); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + // wac.refresh(); + // note that the context is not set as the ROOT attribute in the ServletContext! + + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy("targetFilter", wac); + filterProxy.init(new MockFilterConfig(sc)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test(expected = IllegalStateException.class) + public void testDelegatingFilterProxyWithTargetBeanNameAndNoApplicationContext() + throws ServletException, IOException { + + MockServletContext sc = new MockServletContext(); + + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy("targetFilter", null); + filterProxy.init(new MockFilterConfig(sc)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); // throws + } + + @Test + public void testDelegatingFilterProxyWithFilterName() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + wac.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + MockFilterConfig proxyConfig = new MockFilterConfig(sc, "targetFilter"); + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(); + filterProxy.init(proxyConfig); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyWithLazyContextStartup() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + + MockFilterConfig proxyConfig = new MockFilterConfig(sc); + proxyConfig.addInitParameter("targetBeanName", "targetFilter"); + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(); + filterProxy.init(proxyConfig); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + wac.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyWithTargetFilterLifecycle() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + wac.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + MockFilterConfig proxyConfig = new MockFilterConfig(sc); + proxyConfig.addInitParameter("targetBeanName", "targetFilter"); + proxyConfig.addInitParameter("targetFilterLifecycle", "true"); + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(); + filterProxy.init(proxyConfig); + assertEquals(proxyConfig, targetFilter.filterConfig); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertEquals(proxyConfig, targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyWithFrameworkServletContext() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.registerSingleton("targetFilter", MockFilter.class); + wac.refresh(); + sc.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + + MockFilter targetFilter = (MockFilter) wac.getBean("targetFilter"); + + MockFilterConfig proxyConfig = new MockFilterConfig(sc); + proxyConfig.addInitParameter("targetBeanName", "targetFilter"); + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(); + filterProxy.init(proxyConfig); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyInjectedPreferred() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.refresh(); + sc.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + + StaticWebApplicationContext injectedWac = new StaticWebApplicationContext(); + injectedWac.setServletContext(sc); + String beanName = "targetFilter"; + injectedWac.registerSingleton(beanName, MockFilter.class); + injectedWac.refresh(); + + MockFilter targetFilter = (MockFilter) injectedWac.getBean(beanName); + + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(beanName, injectedWac); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyNotInjectedWacServletAttrPreferred() + throws ServletException, IOException { + + ServletContext sc = new MockServletContext(); + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + sc.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + + StaticWebApplicationContext wacToUse = new StaticWebApplicationContext(); + wacToUse.setServletContext(sc); + String beanName = "targetFilter"; + String attrName = "customAttrName"; + wacToUse.registerSingleton(beanName, MockFilter.class); + wacToUse.refresh(); + sc.setAttribute(attrName, wacToUse); + + MockFilter targetFilter = (MockFilter) wacToUse.getBean(beanName); + + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(beanName); + filterProxy.setContextAttribute(attrName); + filterProxy.setServletContext(sc); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + @Test + public void testDelegatingFilterProxyNotInjectedWithRootPreferred() throws ServletException, IOException { + ServletContext sc = new MockServletContext(); + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(sc); + wac.refresh(); + sc.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + sc.setAttribute("another", wac); + + StaticWebApplicationContext wacToUse = new StaticWebApplicationContext(); + wacToUse.setServletContext(sc); + String beanName = "targetFilter"; + wacToUse.registerSingleton(beanName, MockFilter.class); + wacToUse.refresh(); + sc.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wacToUse); + + MockFilter targetFilter = (MockFilter) wacToUse.getBean(beanName); + + DelegatingFilterProxy filterProxy = new DelegatingFilterProxy(beanName); + filterProxy.setServletContext(sc); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterProxy.doFilter(request, response, null); + + assertNull(targetFilter.filterConfig); + assertEquals(Boolean.TRUE, request.getAttribute("called")); + + filterProxy.destroy(); + assertNull(targetFilter.filterConfig); + } + + + public static class MockFilter implements Filter { + + public FilterConfig filterConfig; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + this.filterConfig = filterConfig; + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) + throws IOException, ServletException { + + request.setAttribute("called", Boolean.TRUE); + } + + @Override + public void destroy() { + this.filterConfig = null; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/FormContentFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/FormContentFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..78e030c6ceea9f23619c3a12ca3e1398dcbb1db1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/FormContentFilterTests.java @@ -0,0 +1,219 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.test.MockFilterChain; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static org.junit.Assert.*; + +/** + * Test fixture for {@link FormContentFilter}. + * + * @author Rossen Stoyanchev + */ +public class FormContentFilterTests { + + private final FormContentFilter filter = new FormContentFilter(); + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + private MockFilterChain filterChain; + + + @Before + public void setup() { + this.request = new MockHttpServletRequest("PUT", "/"); + this.request.setContentType("application/x-www-form-urlencoded; charset=ISO-8859-1"); + this.response = new MockHttpServletResponse(); + this.filterChain = new MockFilterChain(); + } + + + @Test + public void wrapPutPatchAndDeleteOnly() throws Exception { + for (HttpMethod method : HttpMethod.values()) { + MockHttpServletRequest request = new MockHttpServletRequest(method.name(), "/"); + request.setContent("foo=bar".getBytes("ISO-8859-1")); + request.setContentType("application/x-www-form-urlencoded; charset=ISO-8859-1"); + this.filterChain = new MockFilterChain(); + this.filter.doFilter(request, this.response, this.filterChain); + if (method == HttpMethod.PUT || method == HttpMethod.PATCH || method == HttpMethod.DELETE) { + assertNotSame(request, this.filterChain.getRequest()); + } + else { + assertSame(request, this.filterChain.getRequest()); + } + } + } + + @Test + public void wrapFormEncodedOnly() throws Exception { + String[] contentTypes = new String[] {"text/plain", "multipart/form-data"}; + for (String contentType : contentTypes) { + MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/"); + request.setContent("".getBytes("ISO-8859-1")); + request.setContentType(contentType); + this.filterChain = new MockFilterChain(); + this.filter.doFilter(request, this.response, this.filterChain); + assertSame(request, this.filterChain.getRequest()); + } + } + + @Test + public void invalidMediaType() throws Exception { + this.request.setContent("".getBytes("ISO-8859-1")); + this.request.setContentType("foo"); + this.filterChain = new MockFilterChain(); + this.filter.doFilter(this.request, this.response, this.filterChain); + assertSame(this.request, this.filterChain.getRequest()); + } + + @Test + public void getParameter() throws Exception { + this.request.setContent("name=value".getBytes("ISO-8859-1")); + this.filter.doFilter(this.request, this.response, this.filterChain); + + assertEquals("value", this.filterChain.getRequest().getParameter("name")); + } + + @Test + public void getParameterFromQueryString() throws Exception { + this.request.addParameter("name", "value1"); + this.request.setContent("name=value2".getBytes("ISO-8859-1")); + this.filter.doFilter(this.request, this.response, this.filterChain); + + assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); + assertEquals("Query string parameters should be listed ahead of form parameters", + "value1", this.filterChain.getRequest().getParameter("name")); + } + + @Test + public void getParameterNullValue() throws Exception { + this.request.setContent("name=value".getBytes("ISO-8859-1")); + this.filter.doFilter(this.request, this.response, this.filterChain); + + assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); + assertNull(this.filterChain.getRequest().getParameter("noSuchParam")); + } + + @Test + public void getParameterNames() throws Exception { + this.request.addParameter("name1", "value1"); + this.request.addParameter("name2", "value2"); + this.request.setContent("name1=value1&name3=value3&name4=value4".getBytes("ISO-8859-1")); + + this.filter.doFilter(this.request, this.response, this.filterChain); + List names = Collections.list(this.filterChain.getRequest().getParameterNames()); + + assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); + assertEquals(Arrays.asList("name1", "name2", "name3", "name4"), names); + } + + @Test + public void getParameterValues() throws Exception { + this.request.setQueryString("name=value1&name=value2"); + this.request.addParameter("name", "value1"); + this.request.addParameter("name", "value2"); + this.request.setContent("name=value3&name=value4".getBytes("ISO-8859-1")); + + this.filter.doFilter(this.request, this.response, this.filterChain); + String[] values = this.filterChain.getRequest().getParameterValues("name"); + + assertNotSame("Request not wrapped", this.request, filterChain.getRequest()); + assertArrayEquals(new String[] {"value1", "value2", "value3", "value4"}, values); + } + + @Test + public void getParameterValuesFromQueryString() throws Exception { + this.request.setQueryString("name=value1&name=value2"); + this.request.addParameter("name", "value1"); + this.request.addParameter("name", "value2"); + this.request.setContent("anotherName=anotherValue".getBytes("ISO-8859-1")); + + this.filter.doFilter(this.request, this.response, this.filterChain); + String[] values = this.filterChain.getRequest().getParameterValues("name"); + + assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); + assertArrayEquals(new String[] {"value1", "value2"}, values); + } + + @Test + public void getParameterValuesFromFormContent() throws Exception { + this.request.addParameter("name", "value1"); + this.request.addParameter("name", "value2"); + this.request.setContent("anotherName=anotherValue".getBytes("ISO-8859-1")); + + this.filter.doFilter(this.request, this.response, this.filterChain); + String[] values = this.filterChain.getRequest().getParameterValues("anotherName"); + + assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); + assertArrayEquals(new String[] {"anotherValue"}, values); + } + + @Test + public void getParameterValuesInvalidName() throws Exception { + this.request.addParameter("name", "value1"); + this.request.addParameter("name", "value2"); + this.request.setContent("anotherName=anotherValue".getBytes("ISO-8859-1")); + + this.filter.doFilter(this.request, this.response, this.filterChain); + String[] values = this.filterChain.getRequest().getParameterValues("noSuchParameter"); + + assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); + assertNull(values); + } + + @Test + public void getParameterMap() throws Exception { + this.request.setQueryString("name=value1&name=value2"); + this.request.addParameter("name", "value1"); + this.request.addParameter("name", "value2"); + this.request.setContent("name=value3&name4=value4".getBytes("ISO-8859-1")); + + this.filter.doFilter(this.request, this.response, this.filterChain); + Map parameters = this.filterChain.getRequest().getParameterMap(); + + assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); + assertEquals(2, parameters.size()); + assertArrayEquals(new String[] {"value1", "value2", "value3"}, parameters.get("name")); + assertArrayEquals(new String[] {"value4"}, parameters.get("name4")); + } + + @Test // SPR-15835 + public void hiddenHttpMethodFilterFollowedByHttpPutFormContentFilter() throws Exception { + this.request.addParameter("_method", "PUT"); + this.request.addParameter("hiddenField", "testHidden"); + this.filter.doFilter(this.request, this.response, this.filterChain); + + assertArrayEquals(new String[] {"testHidden"}, + this.filterChain.getRequest().getParameterValues("hiddenField")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..beae3a30712bf34d70a8f84a753b153092a05195 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -0,0 +1,520 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; +import java.net.URI; +import java.util.Enumeration; + +import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.test.MockFilterChain; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link ForwardedHeaderFilter}. + * + * @author Rossen Stoyanchev + * @author Eddú Meléndez + * @author Rob Winch + */ +public class ForwardedHeaderFilterTests { + + private static final String X_FORWARDED_PROTO = "x-forwarded-proto"; // SPR-14372 (case insensitive) + private static final String X_FORWARDED_HOST = "x-forwarded-host"; + private static final String X_FORWARDED_PORT = "x-forwarded-port"; + private static final String X_FORWARDED_PREFIX = "x-forwarded-prefix"; + private static final String X_FORWARDED_SSL = "x-forwarded-ssl"; + + + private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter(); + + private MockHttpServletRequest request; + + private MockFilterChain filterChain; + + + @Before + @SuppressWarnings("serial") + public void setup() throws Exception { + this.request = new MockHttpServletRequest(); + this.request.setScheme("http"); + this.request.setServerName("localhost"); + this.request.setServerPort(80); + this.filterChain = new MockFilterChain(new HttpServlet() {}); + } + + + @Test + public void contextPathEmpty() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, ""); + assertEquals("", filterAndGetContextPath()); + } + + @Test + public void contextPathWithTrailingSlash() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/foo/bar/"); + assertEquals("/foo/bar", filterAndGetContextPath()); + } + + @Test + public void contextPathWithTrailingSlashes() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/foo/bar/baz///"); + assertEquals("/foo/bar/baz", filterAndGetContextPath()); + } + + @Test + public void contextPathWithForwardedPrefix() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix"); + this.request.setContextPath("/mvc-showcase"); + + String actual = filterAndGetContextPath(); + assertEquals("/prefix", actual); + } + + @Test + public void contextPathWithForwardedPrefixTrailingSlash() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix/"); + this.request.setContextPath("/mvc-showcase"); + + String actual = filterAndGetContextPath(); + assertEquals("/prefix", actual); + } + + private String filterAndGetContextPath() throws ServletException, IOException { + return filterAndGetWrappedRequest().getContextPath(); + } + + private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException { + MockHttpServletResponse response = new MockHttpServletResponse(); + this.filter.doFilterInternal(this.request, response, this.filterChain); + return (HttpServletRequest) this.filterChain.getRequest(); + } + + + @Test + public void contextPathPreserveEncoding() throws Exception { + this.request.setContextPath("/app%20"); + this.request.setRequestURI("/app%20/path/"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("/app%20", actual.getContextPath()); + assertEquals("/app%20/path/", actual.getRequestURI()); + assertEquals("http://localhost/app%20/path/", actual.getRequestURL().toString()); + } + + @Test + public void requestUri() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app/path"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("", actual.getContextPath()); + assertEquals("/path", actual.getRequestURI()); + } + + @Test + public void requestUriWithTrailingSlash() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app/path/"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("", actual.getContextPath()); + assertEquals("/path/", actual.getRequestURI()); + } + + @Test + public void requestUriPreserveEncoding() throws Exception { + this.request.setContextPath("/app"); + this.request.setRequestURI("/app/path%20with%20spaces/"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("/app", actual.getContextPath()); + assertEquals("/app/path%20with%20spaces/", actual.getRequestURI()); + assertEquals("http://localhost/app/path%20with%20spaces/", actual.getRequestURL().toString()); + } + + @Test + public void requestUriEqualsContextPath() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("", actual.getContextPath()); + assertEquals("/", actual.getRequestURI()); + } + + @Test + public void requestUriRootUrl() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app/"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("", actual.getContextPath()); + assertEquals("/", actual.getRequestURI()); + } + + @Test + public void requestUriPreserveSemicolonContent() throws Exception { + this.request.setContextPath(""); + this.request.setRequestURI("/path;a=b/with/semicolon"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("", actual.getContextPath()); + assertEquals("/path;a=b/with/semicolon", actual.getRequestURI()); + assertEquals("http://localhost/path;a=b/with/semicolon", actual.getRequestURL().toString()); + } + + @Test + public void caseInsensitiveForwardedPrefix() throws Exception { + this.request = new MockHttpServletRequest() { + + @Override // SPR-14372: make it case-sensitive + public String getHeader(String header) { + Enumeration names = getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if (name.equals(header)) { + return super.getHeader(header); + } + } + return null; + } + }; + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix"); + this.request.setRequestURI("/path"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("/prefix/path", actual.getRequestURI()); + } + + @Test + public void shouldFilter() { + testShouldFilter("Forwarded"); + testShouldFilter(X_FORWARDED_HOST); + testShouldFilter(X_FORWARDED_PORT); + testShouldFilter(X_FORWARDED_PROTO); + testShouldFilter(X_FORWARDED_SSL); + } + + private void testShouldFilter(String headerName) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(headerName, "1"); + assertFalse(this.filter.shouldNotFilter(request)); + } + + @Test + public void shouldNotFilter() { + assertTrue(this.filter.shouldNotFilter(new MockHttpServletRequest())); + } + + @Test + public void forwardedRequest() throws Exception { + this.request.setRequestURI("/mvc-showcase"); + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.addHeader("foo", "bar"); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); + + assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString()); + assertEquals("https", actual.getScheme()); + assertEquals("84.198.58.199", actual.getServerName()); + assertEquals(443, actual.getServerPort()); + assertTrue(actual.isSecure()); + + assertNull(actual.getHeader(X_FORWARDED_PROTO)); + assertNull(actual.getHeader(X_FORWARDED_HOST)); + assertNull(actual.getHeader(X_FORWARDED_PORT)); + assertEquals("bar", actual.getHeader("foo")); + } + + @Test + public void forwardedRequestInRemoveOnlyMode() throws Exception { + this.request.setRequestURI("/mvc-showcase"); + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.addHeader(X_FORWARDED_SSL, "on"); + this.request.addHeader("foo", "bar"); + + this.filter.setRemoveOnly(true); + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); + + assertEquals("http://localhost/mvc-showcase", actual.getRequestURL().toString()); + assertEquals("http", actual.getScheme()); + assertEquals("localhost", actual.getServerName()); + assertEquals(80, actual.getServerPort()); + assertFalse(actual.isSecure()); + + assertNull(actual.getHeader(X_FORWARDED_PROTO)); + assertNull(actual.getHeader(X_FORWARDED_HOST)); + assertNull(actual.getHeader(X_FORWARDED_PORT)); + assertNull(actual.getHeader(X_FORWARDED_SSL)); + assertEquals("bar", actual.getHeader("foo")); + } + + @Test + public void forwardedRequestWithSsl() throws Exception { + this.request.setRequestURI("/mvc-showcase"); + this.request.addHeader(X_FORWARDED_SSL, "on"); + this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.addHeader("foo", "bar"); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); + + assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString()); + assertEquals("https", actual.getScheme()); + assertEquals("84.198.58.199", actual.getServerName()); + assertEquals(443, actual.getServerPort()); + assertTrue(actual.isSecure()); + + assertNull(actual.getHeader(X_FORWARDED_SSL)); + assertNull(actual.getHeader(X_FORWARDED_HOST)); + assertNull(actual.getHeader(X_FORWARDED_PORT)); + assertEquals("bar", actual.getHeader("foo")); + } + + @Test // SPR-16983 + public void forwardedRequestWithServletForward() throws Exception { + this.request.setRequestURI("/foo"); + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "www.mycompany.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest wrappedRequest = (HttpServletRequest) this.filterChain.getRequest(); + + this.request.setDispatcherType(DispatcherType.FORWARD); + this.request.setRequestURI("/bar"); + this.filterChain.reset(); + + this.filter.doFilter(wrappedRequest, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); + + assertNotNull(actual); + assertEquals("/bar", actual.getRequestURI()); + assertEquals("https://www.mycompany.com/bar", actual.getRequestURL().toString()); + } + + @Test + public void requestUriWithForwardedPrefix() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix"); + this.request.setRequestURI("/mvc-showcase"); + + HttpServletRequest actual = filterAndGetWrappedRequest(); + assertEquals("http://localhost/prefix/mvc-showcase", actual.getRequestURL().toString()); + } + + @Test + public void requestUriWithForwardedPrefixTrailingSlash() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix/"); + this.request.setRequestURI("/mvc-showcase"); + + HttpServletRequest actual = filterAndGetWrappedRequest(); + assertEquals("http://localhost/prefix/mvc-showcase", actual.getRequestURL().toString()); + } + + @Test + public void requestURLNewStringBuffer() throws Exception { + this.request.addHeader(X_FORWARDED_PREFIX, "/prefix/"); + this.request.setRequestURI("/mvc-showcase"); + + HttpServletRequest actual = filterAndGetWrappedRequest(); + actual.getRequestURL().append("?key=value"); + assertEquals("http://localhost/prefix/mvc-showcase", actual.getRequestURL().toString()); + } + + @Test + public void sendRedirectWithAbsolutePath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String redirectedUrl = sendRedirect("/foo/bar"); + assertEquals("https://example.com/foo/bar", redirectedUrl); + } + + @Test // SPR-16506 + public void sendRedirectWithAbsolutePathQueryParamAndFragment() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setQueryString("oldqp=1"); + + String redirectedUrl = sendRedirect("/foo/bar?newqp=2#fragment"); + assertEquals("https://example.com/foo/bar?newqp=2#fragment", redirectedUrl); + } + + @Test + public void sendRedirectWithContextPath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setContextPath("/context"); + + String redirectedUrl = sendRedirect("/context/foo/bar"); + assertEquals("https://example.com/context/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithRelativePath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setRequestURI("/parent/"); + + String redirectedUrl = sendRedirect("foo/bar"); + assertEquals("https://example.com/parent/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithFileInPathAndRelativeRedirect() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setRequestURI("/context/a"); + + String redirectedUrl = sendRedirect("foo/bar"); + assertEquals("https://example.com/context/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithRelativePathIgnoresFile() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.setRequestURI("/parent"); + + String redirectedUrl = sendRedirect("foo/bar"); + assertEquals("https://example.com/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithLocationDotDotPath() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String redirectedUrl = sendRedirect("parent/../foo/bar"); + assertEquals("https://example.com/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithLocationHasScheme() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String location = "https://other.info/foo/bar"; + String redirectedUrl = sendRedirect(location); + assertEquals(location, redirectedUrl); + } + + @Test + public void sendRedirectWithLocationSlashSlash() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String location = "//other.info/foo/bar"; + String redirectedUrl = sendRedirect(location); + assertEquals("https:" + location, redirectedUrl); + } + + @Test + public void sendRedirectWithLocationSlashSlashParentDotDot() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + String location = "//other.info/parent/../foo/bar"; + String redirectedUrl = sendRedirect(location); + assertEquals("https:" + location, redirectedUrl); + } + + @Test + public void sendRedirectWithNoXForwardedAndAbsolutePath() throws Exception { + String redirectedUrl = sendRedirect("/foo/bar"); + assertEquals("/foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWithNoXForwardedAndDotDotPath() throws Exception { + String redirectedUrl = sendRedirect("../foo/bar"); + assertEquals("../foo/bar", redirectedUrl); + } + + @Test + public void sendRedirectWhenRequestOnlyAndXForwardedThenUsesRelativeRedirects() throws Exception { + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "example.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.filter.setRelativeRedirects(true); + String location = sendRedirect("/a"); + + assertEquals("/a", location); + } + + @Test + public void sendRedirectWhenRequestOnlyAndNoXForwardedThenUsesRelativeRedirects() throws Exception { + this.filter.setRelativeRedirects(true); + String location = sendRedirect("/a"); + + assertEquals("/a", location); + } + + private String sendRedirect(final String location) throws ServletException, IOException { + Filter filter = new OncePerRequestFilter() { + @Override + protected void doFilterInternal(HttpServletRequest req, HttpServletResponse res, + FilterChain chain) throws IOException { + + res.sendRedirect(location); + } + }; + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = new MockFilterChain(mock(HttpServlet.class), this.filter, filter); + filterChain.doFilter(request, response); + + return response.getRedirectedUrl(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/HiddenHttpMethodFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/HiddenHttpMethodFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a5650a14bd03ff00dd68f8a66b0c22384904c997 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/HiddenHttpMethodFilterTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static org.junit.Assert.*; + +/** + * Tests for {@link HiddenHttpMethodFilter}. + * + * @author Arjen Poutsma + * @author Brian Clozel + */ +public class HiddenHttpMethodFilterTests { + + private final HiddenHttpMethodFilter filter = new HiddenHttpMethodFilter(); + + @Test + public void filterWithParameter() throws IOException, ServletException { + filterWithParameterForMethod("delete", "DELETE"); + filterWithParameterForMethod("put", "PUT"); + filterWithParameterForMethod("patch", "PATCH"); + } + + @Test + public void filterWithParameterDisallowedMethods() throws IOException, ServletException { + filterWithParameterForMethod("trace", "POST"); + filterWithParameterForMethod("head", "POST"); + filterWithParameterForMethod("options", "POST"); + } + + @Test + public void filterWithNoParameter() throws IOException, ServletException { + filterWithParameterForMethod(null, "POST"); + } + + private void filterWithParameterForMethod(String methodParam, String expectedMethod) + throws IOException, ServletException { + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + if(methodParam != null) { + request.addParameter("_method", methodParam); + } + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = new FilterChain() { + + @Override + public void doFilter(ServletRequest filterRequest, + ServletResponse filterResponse) throws IOException, ServletException { + assertEquals("Invalid method", expectedMethod, + ((HttpServletRequest) filterRequest).getMethod()); + } + }; + this.filter.doFilter(request, response, filterChain); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3898a078463c0592bafd57789c8dd12b2cc74726 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java @@ -0,0 +1,194 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.DispatcherType; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.test.MockFilterChain; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.util.WebUtils; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@link OncePerRequestFilter}. + * @author Rossen Stoyanchev + * @since 5.1.9 + */ +public class OncePerRequestFilterTests { + + private final TestOncePerRequestFilter filter = new TestOncePerRequestFilter(); + + private MockHttpServletRequest request; + + private MockFilterChain filterChain; + + + @Before + @SuppressWarnings("serial") + public void setup() throws Exception { + this.request = new MockHttpServletRequest(); + this.request.setScheme("http"); + this.request.setServerName("localhost"); + this.request.setServerPort(80); + this.filterChain = new MockFilterChain(new HttpServlet() {}); + } + + + @Test + public void filterOnce() throws ServletException, IOException { + + // Already filtered + this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertFalse(this.filter.didFilter); + assertFalse(this.filter.didFilterNestedErrorDispatch); + + // Remove already filtered + this.request.removeAttribute(this.filter.getAlreadyFilteredAttributeName()); + this.filter.reset(); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertTrue(this.filter.didFilter); + assertFalse(this.filter.didFilterNestedErrorDispatch); + } + + @Test + public void shouldNotFilterErrorDispatch() throws ServletException, IOException { + + initErrorDispatch(); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertFalse(this.filter.didFilter); + assertFalse(this.filter.didFilterNestedErrorDispatch); + } + + @Test + public void shouldNotFilterNestedErrorDispatch() throws ServletException, IOException { + + initErrorDispatch(); + this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertFalse(this.filter.didFilter); + assertFalse(this.filter.didFilterNestedErrorDispatch); + } + + @Test // gh-23196 + public void filterNestedErrorDispatch() throws ServletException, IOException { + + // Opt in for ERROR dispatch + this.filter.setShouldNotFilterErrorDispatch(false); + + this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); + initErrorDispatch(); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + assertFalse(this.filter.didFilter); + assertTrue(this.filter.didFilterNestedErrorDispatch); + } + + private void initErrorDispatch() { + this.request.setDispatcherType(DispatcherType.ERROR); + this.request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error"); + } + + + private static class TestOncePerRequestFilter extends OncePerRequestFilter { + + private boolean shouldNotFilter; + + private boolean shouldNotFilterAsyncDispatch = true; + + private boolean shouldNotFilterErrorDispatch = true; + + private boolean didFilter; + + private boolean didFilterNestedErrorDispatch; + + + public void setShouldNotFilter(boolean shouldNotFilter) { + this.shouldNotFilter = shouldNotFilter; + } + + public void setShouldNotFilterAsyncDispatch(boolean shouldNotFilterAsyncDispatch) { + this.shouldNotFilterAsyncDispatch = shouldNotFilterAsyncDispatch; + } + + public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) { + this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch; + } + + + public boolean didFilter() { + return this.didFilter; + } + + public boolean didFilterNestedErrorDispatch() { + return this.didFilterNestedErrorDispatch; + } + + public void reset() { + this.didFilter = false; + this.didFilterNestedErrorDispatch = false; + } + + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) { + return this.shouldNotFilter; + } + + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return this.shouldNotFilterAsyncDispatch; + } + + @Override + protected boolean shouldNotFilterErrorDispatch() { + return this.shouldNotFilterErrorDispatch; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) { + + this.didFilter = true; + } + + @Override + protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + this.didFilterNestedErrorDispatch = true; + super.doFilterNestedErrorDispatch(request, response, filterChain); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/RelativeRedirectFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/RelativeRedirectFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..90a18425d4b423bfcfc31a5bf91405460fad353d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/RelativeRedirectFilterTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.junit.Test; +import org.mockito.InOrder; +import org.mockito.Mockito; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.test.MockFilterChain; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link RelativeRedirectFilter}. + * + * @author Rob Winch + * @author Juergen Hoeller + */ +public class RelativeRedirectFilterTests { + + private RelativeRedirectFilter filter = new RelativeRedirectFilter(); + + private HttpServletResponse response = Mockito.mock(HttpServletResponse.class); + + + @Test(expected = IllegalArgumentException.class) + public void sendRedirectHttpStatusWhenNullThenIllegalArgumentException() { + this.filter.setRedirectStatus(null); + } + + @Test(expected = IllegalArgumentException.class) + public void sendRedirectHttpStatusWhenNot3xxThenIllegalArgumentException() { + this.filter.setRedirectStatus(HttpStatus.OK); + } + + @Test + public void doFilterSendRedirectWhenDefaultsThenLocationAnd303() throws Exception { + String location = "/foo"; + sendRedirect(location); + + InOrder inOrder = Mockito.inOrder(this.response); + inOrder.verify(this.response).setStatus(HttpStatus.SEE_OTHER.value()); + inOrder.verify(this.response).setHeader(HttpHeaders.LOCATION, location); + } + + @Test + public void doFilterSendRedirectWhenCustomSendRedirectHttpStatusThenLocationAnd301() throws Exception { + String location = "/foo"; + HttpStatus status = HttpStatus.MOVED_PERMANENTLY; + this.filter.setRedirectStatus(status); + sendRedirect(location); + + InOrder inOrder = Mockito.inOrder(this.response); + inOrder.verify(this.response).setStatus(status.value()); + inOrder.verify(this.response).setHeader(HttpHeaders.LOCATION, location); + } + + @Test + public void wrapOnceOnly() throws Exception { + HttpServletResponse original = new MockHttpServletResponse(); + + MockFilterChain chain = new MockFilterChain(); + this.filter.doFilterInternal(new MockHttpServletRequest(), original, chain); + + HttpServletResponse wrapped1 = (HttpServletResponse) chain.getResponse(); + assertNotSame(original, wrapped1); + + chain.reset(); + this.filter.doFilterInternal(new MockHttpServletRequest(), wrapped1, chain); + HttpServletResponse current = (HttpServletResponse) chain.getResponse(); + assertSame(wrapped1, current); + + chain.reset(); + HttpServletResponse wrapped2 = new HttpServletResponseWrapper(wrapped1); + this.filter.doFilterInternal(new MockHttpServletRequest(), wrapped2, chain); + current = (HttpServletResponse) chain.getResponse(); + assertSame(wrapped2, current); + } + + + private void sendRedirect(String location) throws Exception { + MockFilterChain chain = new MockFilterChain(); + this.filter.doFilterInternal(new MockHttpServletRequest(), this.response, chain); + + HttpServletResponse wrappedResponse = (HttpServletResponse) chain.getResponse(); + wrappedResponse.sendRedirect(location); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/RequestContextFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/RequestContextFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..076a06c1a1e521e863d4860943fe8dc87cdff0b9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/RequestContextFilterTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockFilterConfig; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; + +import static org.junit.Assert.*; + +/** + * @author Rod Johnson + * @author Juergen Hoeller + */ +public class RequestContextFilterTests { + + @Test + public void happyPath() throws Exception { + testFilterInvocation(null); + } + + @Test + public void withException() throws Exception { + testFilterInvocation(new ServletException()); + } + + private void testFilterInvocation(final ServletException sex) throws Exception { + final MockHttpServletRequest req = new MockHttpServletRequest(); + req.setAttribute("myAttr", "myValue"); + final MockHttpServletResponse resp = new MockHttpServletResponse(); + + // Expect one invocation by the filter being tested + class DummyFilterChain implements FilterChain { + public int invocations = 0; + @Override + public void doFilter(ServletRequest req, ServletResponse resp) throws IOException, ServletException { + ++invocations; + if (invocations == 1) { + assertSame("myValue", + RequestContextHolder.currentRequestAttributes().getAttribute("myAttr", RequestAttributes.SCOPE_REQUEST)); + if (sex != null) { + throw sex; + } + } + else { + throw new IllegalStateException("Too many invocations"); + } + } + } + + DummyFilterChain fc = new DummyFilterChain(); + MockFilterConfig mfc = new MockFilterConfig(new MockServletContext(), "foo"); + + RequestContextFilter rbf = new RequestContextFilter(); + rbf.init(mfc); + + try { + rbf.doFilter(req, resp, fc); + if (sex != null) { + fail(); + } + } + catch (ServletException ex) { + assertNotNull(sex); + } + + try { + RequestContextHolder.currentRequestAttributes(); + fail(); + } + catch (IllegalStateException ex) { + // Ok + } + + assertEquals(1, fc.invocations); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..486e464b64ae8c1156b33c5471deade434a5d78a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java @@ -0,0 +1,213 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.FileCopyUtils; +import org.springframework.web.util.ContentCachingRequestWrapper; +import org.springframework.web.util.WebUtils; + +import static org.junit.Assert.*; + +/** + * Test for {@link AbstractRequestLoggingFilter} and subclasses. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + */ +public class RequestLoggingFilterTests { + + private final MyRequestLoggingFilter filter = new MyRequestLoggingFilter(); + + + @Test + public void uri() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + request.setQueryString("booking=42"); + + FilterChain filterChain = new NoOpFilterChain(); + filter.doFilter(request, response, filterChain); + + assertNotNull(filter.beforeRequestMessage); + assertTrue(filter.beforeRequestMessage.contains("uri=/hotel")); + assertFalse(filter.beforeRequestMessage.contains("booking=42")); + + assertNotNull(filter.afterRequestMessage); + assertTrue(filter.afterRequestMessage.contains("uri=/hotel")); + assertFalse(filter.afterRequestMessage.contains("booking=42")); + } + + @Test + public void queryStringIncluded() throws Exception { + filter.setIncludeQueryString(true); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + request.setQueryString("booking=42"); + + FilterChain filterChain = new NoOpFilterChain(); + filter.doFilter(request, response, filterChain); + + assertNotNull(filter.beforeRequestMessage); + assertTrue(filter.beforeRequestMessage.contains("[uri=/hotels?booking=42]")); + + assertNotNull(filter.afterRequestMessage); + assertTrue(filter.afterRequestMessage.contains("[uri=/hotels?booking=42]")); + } + + @Test + public void noQueryStringAvailable() throws Exception { + filter.setIncludeQueryString(true); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = new NoOpFilterChain(); + filter.doFilter(request, response, filterChain); + + assertNotNull(filter.beforeRequestMessage); + assertTrue(filter.beforeRequestMessage.contains("[uri=/hotels]")); + + assertNotNull(filter.afterRequestMessage); + assertTrue(filter.afterRequestMessage.contains("[uri=/hotels]")); + } + + @Test + public void payloadInputStream() throws Exception { + filter.setIncludePayload(true); + + final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] requestBody = "Hello World".getBytes("UTF-8"); + request.setContent(requestBody); + + FilterChain filterChain = new FilterChain() { + @Override + public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) + throws IOException, ServletException { + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream()); + assertArrayEquals(requestBody, buf); + } + }; + + filter.doFilter(request, response, filterChain); + + assertNotNull(filter.afterRequestMessage); + assertTrue(filter.afterRequestMessage.contains("Hello World")); + } + + @Test + public void payloadReader() throws Exception { + filter.setIncludePayload(true); + + final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final String requestBody = "Hello World"; + request.setContent(requestBody.getBytes("UTF-8")); + + FilterChain filterChain = new FilterChain() { + @Override + public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) + throws IOException, ServletException { + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + String buf = FileCopyUtils.copyToString(filterRequest.getReader()); + assertEquals(requestBody, buf); + } + }; + + filter.doFilter(request, response, filterChain); + + assertNotNull(filter.afterRequestMessage); + assertTrue(filter.afterRequestMessage.contains(requestBody)); + } + + @Test + public void payloadMaxLength() throws Exception { + filter.setIncludePayload(true); + filter.setMaxPayloadLength(3); + + final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] requestBody = "Hello World".getBytes("UTF-8"); + request.setContent(requestBody); + + FilterChain filterChain = new FilterChain() { + @Override + public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) + throws IOException, ServletException { + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream()); + assertArrayEquals(requestBody, buf); + ContentCachingRequestWrapper wrapper = + WebUtils.getNativeRequest(filterRequest, ContentCachingRequestWrapper.class); + assertArrayEquals("Hel".getBytes("UTF-8"), wrapper.getContentAsByteArray()); + } + }; + + filter.doFilter(request, response, filterChain); + + assertNotNull(filter.afterRequestMessage); + assertTrue(filter.afterRequestMessage.contains("Hel")); + assertFalse(filter.afterRequestMessage.contains("Hello World")); + } + + + private static class MyRequestLoggingFilter extends AbstractRequestLoggingFilter { + + private String beforeRequestMessage; + + private String afterRequestMessage; + + @Override + protected void beforeRequest(HttpServletRequest request, String message) { + this.beforeRequestMessage = message; + } + + @Override + protected void afterRequest(HttpServletRequest request, String message) { + this.afterRequestMessage = message; + } + } + + + private static class NoOpFilterChain implements FilterChain { + + @Override + public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..55dbadf0a278428a161bb7d4e59ab00c9f666b04 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java @@ -0,0 +1,265 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.StreamUtils; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Brian Clozel + * @author Juergen Hoeller + */ +public class ShallowEtagHeaderFilterTests { + + private final ShallowEtagHeaderFilter filter = new ShallowEtagHeaderFilter(); + + + @Test + public void isEligibleForEtag() { + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + assertTrue(filter.isEligibleForEtag(request, response, 200, StreamUtils.emptyInput())); + assertFalse(filter.isEligibleForEtag(request, response, 300, StreamUtils.emptyInput())); + + request = new MockHttpServletRequest("HEAD", "/hotels"); + assertFalse(filter.isEligibleForEtag(request, response, 200, StreamUtils.emptyInput())); + + request = new MockHttpServletRequest("POST", "/hotels"); + assertFalse(filter.isEligibleForEtag(request, response, 200, StreamUtils.emptyInput())); + + request = new MockHttpServletRequest("POST", "/hotels"); + request.addHeader("Cache-Control","must-revalidate, no-store"); + assertFalse(filter.isEligibleForEtag(request, response, 200, StreamUtils.emptyInput())); + } + + @Test + public void filterNoMatch() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 200, response.getStatus()); + assertEquals("Invalid ETag header", "\"0b10a8db164e0754105b7a99be72e3fe5\"", response.getHeader("ETag")); + assertTrue("Invalid Content-Length header", response.getContentLength() > 0); + assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); + } + + @Test + public void filterNoMatchWeakETag() throws Exception { + this.filter.setWriteWeakETag(true); + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 200, response.getStatus()); + assertEquals("Invalid ETag header", "W/\"0b10a8db164e0754105b7a99be72e3fe5\"", response.getHeader("ETag")); + assertTrue("Invalid Content-Length header", response.getContentLength() > 0); + assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); + } + + @Test + public void filterMatch() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + String etag = "\"0b10a8db164e0754105b7a99be72e3fe5\""; + request.addHeader("If-None-Match", etag); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + byte[] responseBody = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + filterResponse.setContentLength(responseBody.length); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 304, response.getStatus()); + assertEquals("Invalid ETag header", "\"0b10a8db164e0754105b7a99be72e3fe5\"", response.getHeader("ETag")); + assertFalse("Response has Content-Length header", response.containsHeader("Content-Length")); + assertArrayEquals("Invalid content", new byte[0], response.getContentAsByteArray()); + } + + @Test + public void filterMatchWeakEtag() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + String etag = "\"0b10a8db164e0754105b7a99be72e3fe5\""; + request.addHeader("If-None-Match", "W/" + etag); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + byte[] responseBody = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + filterResponse.setContentLength(responseBody.length); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 304, response.getStatus()); + assertEquals("Invalid ETag header", "\"0b10a8db164e0754105b7a99be72e3fe5\"", response.getHeader("ETag")); + assertFalse("Response has Content-Length header", response.containsHeader("Content-Length")); + assertArrayEquals("Invalid content", new byte[0], response.getContentAsByteArray()); + } + + @Test + public void filterWriter() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + String etag = "\"0b10a8db164e0754105b7a99be72e3fe5\""; + request.addHeader("If-None-Match", etag); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + String responseBody = "Hello World"; + FileCopyUtils.copy(responseBody, filterResponse.getWriter()); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 304, response.getStatus()); + assertEquals("Invalid ETag header", "\"0b10a8db164e0754105b7a99be72e3fe5\"", response.getHeader("ETag")); + assertFalse("Response has Content-Length header", response.containsHeader("Content-Length")); + assertArrayEquals("Invalid content", new byte[0], response.getContentAsByteArray()); + } + + @Test // SPR-12960 + public void filterWriterWithDisabledCaching() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + }; + + ShallowEtagHeaderFilter.disableContentCaching(request); + this.filter.doFilter(request, response, filterChain); + + assertEquals(200, response.getStatus()); + assertNull(response.getHeader("ETag")); + assertArrayEquals(responseBody, response.getContentAsByteArray()); + } + + @Test + public void filterSendError() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + response.setContentLength(100); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 403, response.getStatus()); + assertNull("Invalid ETag header", response.getHeader("ETag")); + assertEquals("Invalid Content-Length header", 100, response.getContentLength()); + assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); + } + + @Test + public void filterSendErrorMessage() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + response.setContentLength(100); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN, "ERROR"); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 403, response.getStatus()); + assertNull("Invalid ETag header", response.getHeader("ETag")); + assertEquals("Invalid Content-Length header", 100, response.getContentLength()); + assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); + assertEquals("Invalid error message", "ERROR", response.getErrorMessage()); + } + + @Test + public void filterSendRedirect() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + response.setContentLength(100); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + ((HttpServletResponse) filterResponse).sendRedirect("https://www.google.com"); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 302, response.getStatus()); + assertNull("Invalid ETag header", response.getHeader("ETag")); + assertEquals("Invalid Content-Length header", 100, response.getContentLength()); + assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); + assertEquals("Invalid redirect URL", "https://www.google.com", response.getRedirectedUrl()); + } + + // SPR-13717 + @Test + public void filterFlushResponse() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + filterResponse.flushBuffer(); + }; + filter.doFilter(request, response, filterChain); + + assertEquals("Invalid status", 200, response.getStatus()); + assertEquals("Invalid ETag header", "\"0b10a8db164e0754105b7a99be72e3fe5\"", response.getHeader("ETag")); + assertTrue("Invalid Content-Length header", response.getContentLength() > 0); + assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c708333748417d519e96e03cabbce39d9411dc40 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilterTests.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter.reactive; + +import java.time.Duration; + +import org.hamcrest.Matchers; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilterChain; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +/** + * Tests for {@link HiddenHttpMethodFilter}. + * @author Greg Turnquist + * @author Rossen Stoyanchev + */ +public class HiddenHttpMethodFilterTests { + + private final HiddenHttpMethodFilter filter = new HiddenHttpMethodFilter(); + + private final TestWebFilterChain filterChain = new TestWebFilterChain(); + + + @Test + public void filterWithParameter() { + postForm("_method=DELETE").block(Duration.ZERO); + assertEquals(HttpMethod.DELETE, this.filterChain.getHttpMethod()); + } + + @Test + public void filterWithParameterMethodNotAllowed() { + postForm("_method=TRACE").block(Duration.ZERO); + assertEquals(HttpMethod.POST, this.filterChain.getHttpMethod()); + } + + @Test + public void filterWithNoParameter() { + postForm("").block(Duration.ZERO); + assertEquals(HttpMethod.POST, this.filterChain.getHttpMethod()); + } + + @Test + public void filterWithEmptyStringParameter() { + postForm("_method=").block(Duration.ZERO); + assertEquals(HttpMethod.POST, this.filterChain.getHttpMethod()); + } + + @Test + public void filterWithDifferentMethodParam() { + this.filter.setMethodParamName("_foo"); + postForm("_foo=DELETE").block(Duration.ZERO); + assertEquals(HttpMethod.DELETE, this.filterChain.getHttpMethod()); + } + + @Test + public void filterWithInvalidMethodValue() { + StepVerifier.create(postForm("_method=INVALID")) + .consumeErrorWith(error -> { + assertThat(error, Matchers.instanceOf(IllegalArgumentException.class)); + assertEquals("HttpMethod 'INVALID' not supported", error.getMessage()); + }) + .verify(); + } + + @Test + public void filterWithHttpPut() { + + ServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest.put("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .body("_method=DELETE")); + + this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + assertEquals(HttpMethod.PUT, this.filterChain.getHttpMethod()); + } + + + private Mono postForm(String body) { + + MockServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest.post("/") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .body(body)); + + return this.filter.filter(exchange, this.filterChain); + } + + + private static class TestWebFilterChain implements WebFilterChain { + + private HttpMethod httpMethod; + + + public HttpMethod getHttpMethod() { + return this.httpMethod; + } + + @Override + public Mono filter(ServerWebExchange exchange) { + this.httpMethod = exchange.getRequest().getMethod(); + return Mono.empty(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/jsf/DelegatingNavigationHandlerTests.java b/spring-web/src/test/java/org/springframework/web/jsf/DelegatingNavigationHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1c230e0cd4244d866b12075c44e4f6006cc4c99a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/jsf/DelegatingNavigationHandlerTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import javax.faces.application.NavigationHandler; +import javax.faces.context.FacesContext; + +import org.junit.Test; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.support.StaticListableBeanFactory; +import org.springframework.lang.Nullable; + +import static org.junit.Assert.*; + +/** + * @author Colin Sampaleanu + * @author Juergen Hoeller + */ +public class DelegatingNavigationHandlerTests { + + private final MockFacesContext facesContext = new MockFacesContext(); + + private final StaticListableBeanFactory beanFactory = new StaticListableBeanFactory(); + + private final TestNavigationHandler origNavHandler = new TestNavigationHandler(); + + private final DelegatingNavigationHandlerProxy delNavHandler = new DelegatingNavigationHandlerProxy(origNavHandler) { + @Override + protected BeanFactory getBeanFactory(FacesContext facesContext) { + return beanFactory; + } + }; + + + @Test + public void handleNavigationWithoutDecoration() { + TestNavigationHandler targetHandler = new TestNavigationHandler(); + beanFactory.addBean("jsfNavigationHandler", targetHandler); + + delNavHandler.handleNavigation(facesContext, "fromAction", "myViewId"); + assertEquals("fromAction", targetHandler.lastFromAction); + assertEquals("myViewId", targetHandler.lastOutcome); + } + + @Test + public void handleNavigationWithDecoration() { + TestDecoratingNavigationHandler targetHandler = new TestDecoratingNavigationHandler(); + beanFactory.addBean("jsfNavigationHandler", targetHandler); + + delNavHandler.handleNavigation(facesContext, "fromAction", "myViewId"); + assertEquals("fromAction", targetHandler.lastFromAction); + assertEquals("myViewId", targetHandler.lastOutcome); + + // Original handler must have been invoked as well... + assertEquals("fromAction", origNavHandler.lastFromAction); + assertEquals("myViewId", origNavHandler.lastOutcome); + } + + + static class TestNavigationHandler extends NavigationHandler { + + private String lastFromAction; + private String lastOutcome; + + @Override + public void handleNavigation(FacesContext facesContext, String fromAction, String outcome) { + lastFromAction = fromAction; + lastOutcome = outcome; + } + } + + + static class TestDecoratingNavigationHandler extends DecoratingNavigationHandler { + + private String lastFromAction; + private String lastOutcome; + + @Override + public void handleNavigation(FacesContext facesContext, @Nullable String fromAction, + @Nullable String outcome, @Nullable NavigationHandler originalNavigationHandler) { + + lastFromAction = fromAction; + lastOutcome = outcome; + if (originalNavigationHandler != null) { + originalNavigationHandler.handleNavigation(facesContext, fromAction, outcome); + } + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/jsf/DelegatingPhaseListenerTests.java b/spring-web/src/test/java/org/springframework/web/jsf/DelegatingPhaseListenerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1c091bc05bc4c030b82f45770d44bb0fa2061ae8 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/jsf/DelegatingPhaseListenerTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import javax.faces.context.FacesContext; +import javax.faces.event.PhaseEvent; +import javax.faces.event.PhaseId; +import javax.faces.event.PhaseListener; + +import org.junit.Test; + +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.support.StaticListableBeanFactory; + +import static org.junit.Assert.*; + +/** + * @author Colin Sampaleanu + * @author Juergen Hoeller + */ +public class DelegatingPhaseListenerTests { + + private final MockFacesContext facesContext = new MockFacesContext(); + + private final StaticListableBeanFactory beanFactory = new StaticListableBeanFactory(); + + @SuppressWarnings("serial") + private final DelegatingPhaseListenerMulticaster delPhaseListener = new DelegatingPhaseListenerMulticaster() { + @Override + protected ListableBeanFactory getBeanFactory(FacesContext facesContext) { + return beanFactory; + } + }; + + @Test + public void beforeAndAfterPhaseWithSingleTarget() { + TestListener target = new TestListener(); + beanFactory.addBean("testListener", target); + + assertEquals(PhaseId.ANY_PHASE, delPhaseListener.getPhaseId()); + PhaseEvent event = new PhaseEvent(facesContext, PhaseId.INVOKE_APPLICATION, new MockLifecycle()); + + delPhaseListener.beforePhase(event); + assertTrue(target.beforeCalled); + + delPhaseListener.afterPhase(event); + assertTrue(target.afterCalled); + } + + @Test + public void beforeAndAfterPhaseWithMultipleTargets() { + TestListener target1 = new TestListener(); + TestListener target2 = new TestListener(); + beanFactory.addBean("testListener1", target1); + beanFactory.addBean("testListener2", target2); + + assertEquals(PhaseId.ANY_PHASE, delPhaseListener.getPhaseId()); + PhaseEvent event = new PhaseEvent(facesContext, PhaseId.INVOKE_APPLICATION, new MockLifecycle()); + + delPhaseListener.beforePhase(event); + assertTrue(target1.beforeCalled); + assertTrue(target2.beforeCalled); + + delPhaseListener.afterPhase(event); + assertTrue(target1.afterCalled); + assertTrue(target2.afterCalled); + } + + + @SuppressWarnings("serial") + public static class TestListener implements PhaseListener { + + boolean beforeCalled = false; + boolean afterCalled = false; + + @Override + public PhaseId getPhaseId() { + return PhaseId.ANY_PHASE; + } + + @Override + public void beforePhase(PhaseEvent arg0) { + beforeCalled = true; + } + + @Override + public void afterPhase(PhaseEvent arg0) { + afterCalled = true; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/jsf/MockFacesContext.java b/spring-web/src/test/java/org/springframework/web/jsf/MockFacesContext.java new file mode 100644 index 0000000000000000000000000000000000000000..77ef2ad1e585359157df5be7ca6bb7836c800f4f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/jsf/MockFacesContext.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import java.util.Iterator; + +import javax.faces.application.Application; +import javax.faces.application.FacesMessage; +import javax.faces.application.FacesMessage.Severity; +import javax.faces.component.UIViewRoot; +import javax.faces.context.ExternalContext; +import javax.faces.context.FacesContext; +import javax.faces.context.ResponseStream; +import javax.faces.context.ResponseWriter; +import javax.faces.render.RenderKit; + +/** + * Mock implementation of the {@code FacesContext} class to facilitate + * standalone Action unit tests. + * + * @author Ulrik Sandberg + * @see javax.faces.context.FacesContext + */ +public class MockFacesContext extends FacesContext { + + private ExternalContext externalContext; + + private Application application; + + private UIViewRoot viewRoot; + + + @Override + public Application getApplication() { + return application; + } + + public void setApplication(Application application) { + this.application = application; + } + + @Override + public Iterator getClientIdsWithMessages() { + return null; + } + + @Override + public ExternalContext getExternalContext() { + return externalContext; + } + + public void setExternalContext(ExternalContext externalContext) { + this.externalContext = externalContext; + } + + @Override + public Severity getMaximumSeverity() { + return null; + } + + @Override + public Iterator getMessages() { + return null; + } + + @Override + public Iterator getMessages(String clientId) { + return null; + } + + @Override + public RenderKit getRenderKit() { + return null; + } + + @Override + public boolean getRenderResponse() { + return false; + } + + @Override + public boolean getResponseComplete() { + return false; + } + + @Override + public ResponseStream getResponseStream() { + return null; + } + + @Override + public void setResponseStream(ResponseStream arg0) { + } + + @Override + public ResponseWriter getResponseWriter() { + return null; + } + + @Override + public void setResponseWriter(ResponseWriter arg0) { + } + + @Override + public UIViewRoot getViewRoot() { + return viewRoot; + } + + @Override + public void setViewRoot(UIViewRoot viewRoot) { + this.viewRoot = viewRoot; + } + + @Override + public void addMessage(String arg0, FacesMessage arg1) { + } + + @Override + public void release() { + } + + @Override + public void renderResponse() { + } + + @Override + public void responseComplete() { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/jsf/MockLifecycle.java b/spring-web/src/test/java/org/springframework/web/jsf/MockLifecycle.java new file mode 100644 index 0000000000000000000000000000000000000000..39685816ff55ddb6777dd5be10de0fbe00e4bfcb --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/jsf/MockLifecycle.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.jsf; + +import javax.faces.FacesException; +import javax.faces.context.FacesContext; +import javax.faces.event.PhaseListener; +import javax.faces.lifecycle.Lifecycle; + +/** + * @author Juergen Hoeller + * @since 29.01.2006 + */ +public class MockLifecycle extends Lifecycle { + + @Override + public void addPhaseListener(PhaseListener phaseListener) { + } + + @Override + public void execute(FacesContext facesContext) throws FacesException { + } + + @Override + public PhaseListener[] getPhaseListeners() { + return null; + } + + @Override + public void removePhaseListener(PhaseListener phaseListener) { + } + + @Override + public void render(FacesContext facesContext) throws FacesException { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/ControllerAdviceBeanTests.java b/spring-web/src/test/java/org/springframework/web/method/ControllerAdviceBeanTests.java new file mode 100644 index 0000000000000000000000000000000000000000..66042e12186fed4f5d96ad98d7fdff0695778296 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/ControllerAdviceBeanTests.java @@ -0,0 +1,160 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import org.junit.Test; + +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.RestController; + +import static org.junit.Assert.*; + +/** + * @author Brian Clozel + */ +public class ControllerAdviceBeanTests { + + @Test + public void shouldMatchAll() { + ControllerAdviceBean bean = new ControllerAdviceBean(new SimpleControllerAdvice()); + assertApplicable("should match all", bean, AnnotatedController.class); + assertApplicable("should match all", bean, ImplementationController.class); + assertApplicable("should match all", bean, InheritanceController.class); + assertApplicable("should match all", bean, String.class); + } + + @Test + public void basePackageSupport() { + ControllerAdviceBean bean = new ControllerAdviceBean(new BasePackageSupport()); + assertApplicable("base package support", bean, AnnotatedController.class); + assertApplicable("base package support", bean, ImplementationController.class); + assertApplicable("base package support", bean, InheritanceController.class); + assertNotApplicable("bean not in package", bean, String.class); + } + + @Test + public void basePackageValueSupport() { + ControllerAdviceBean bean = new ControllerAdviceBean(new BasePackageValueSupport()); + assertApplicable("base package support", bean, AnnotatedController.class); + assertApplicable("base package support", bean, ImplementationController.class); + assertApplicable("base package support", bean, InheritanceController.class); + assertNotApplicable("bean not in package", bean, String.class); + } + + @Test + public void annotationSupport() { + ControllerAdviceBean bean = new ControllerAdviceBean(new AnnotationSupport()); + assertApplicable("annotation support", bean, AnnotatedController.class); + assertNotApplicable("this bean is not annotated", bean, InheritanceController.class); + } + + @Test + public void markerClassSupport() { + ControllerAdviceBean bean = new ControllerAdviceBean(new MarkerClassSupport()); + assertApplicable("base package class support", bean, AnnotatedController.class); + assertApplicable("base package class support", bean, ImplementationController.class); + assertApplicable("base package class support", bean, InheritanceController.class); + assertNotApplicable("bean not in package", bean, String.class); + } + + @Test + public void shouldNotMatch() { + ControllerAdviceBean bean = new ControllerAdviceBean(new ShouldNotMatch()); + assertNotApplicable("should not match", bean, AnnotatedController.class); + assertNotApplicable("should not match", bean, ImplementationController.class); + assertNotApplicable("should not match", bean, InheritanceController.class); + assertNotApplicable("should not match", bean, String.class); + } + + @Test + public void assignableTypesSupport() { + ControllerAdviceBean bean = new ControllerAdviceBean(new AssignableTypesSupport()); + assertApplicable("controller implements assignable", bean, ImplementationController.class); + assertApplicable("controller inherits assignable", bean, InheritanceController.class); + assertNotApplicable("not assignable", bean, AnnotatedController.class); + assertNotApplicable("not assignable", bean, String.class); + } + + @Test + public void multipleMatch() { + ControllerAdviceBean bean = new ControllerAdviceBean(new MultipleSelectorsSupport()); + assertApplicable("controller implements assignable", bean, ImplementationController.class); + assertApplicable("controller is annotated", bean, AnnotatedController.class); + assertNotApplicable("should not match", bean, InheritanceController.class); + } + + private void assertApplicable(String message, ControllerAdviceBean controllerAdvice, Class controllerBeanType) { + assertNotNull(controllerAdvice); + assertTrue(message, controllerAdvice.isApplicableToBeanType(controllerBeanType)); + } + + private void assertNotApplicable(String message, ControllerAdviceBean controllerAdvice, Class controllerBeanType) { + assertNotNull(controllerAdvice); + assertFalse(message, controllerAdvice.isApplicableToBeanType(controllerBeanType)); + } + + + // ControllerAdvice classes + + @ControllerAdvice + static class SimpleControllerAdvice {} + + @ControllerAdvice(annotations = ControllerAnnotation.class) + static class AnnotationSupport {} + + @ControllerAdvice(basePackageClasses = MarkerClass.class) + static class MarkerClassSupport {} + + @ControllerAdvice(assignableTypes = {ControllerInterface.class, + AbstractController.class}) + static class AssignableTypesSupport {} + + @ControllerAdvice(basePackages = "org.springframework.web.method") + static class BasePackageSupport {} + + @ControllerAdvice("org.springframework.web.method") + static class BasePackageValueSupport {} + + @ControllerAdvice(annotations = ControllerAnnotation.class, assignableTypes = ControllerInterface.class) + static class MultipleSelectorsSupport {} + + @ControllerAdvice(basePackages = "java.util", annotations = {RestController.class}) + static class ShouldNotMatch {} + + + // Support classes + + static class MarkerClass {} + + @Retention(RetentionPolicy.RUNTIME) + static @interface ControllerAnnotation {} + + @ControllerAnnotation + public static class AnnotatedController {} + + static interface ControllerInterface {} + + static class ImplementationController implements ControllerInterface {} + + static abstract class AbstractController {} + + static class InheritanceController extends AbstractController {} + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/HandlerTypePredicateTests.java b/spring-web/src/test/java/org/springframework/web/method/HandlerTypePredicateTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a40ed740be7291a7cf35816e09bbd3f27c06bb9d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/HandlerTypePredicateTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.method; + +import java.util.function.Predicate; + +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.web.bind.annotation.RestController; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link HandlerTypePredicate}. + * @author Rossen Stoyanchev + */ +public class HandlerTypePredicateTests { + + @Test + public void forAnnotation() { + + Predicate> predicate = HandlerTypePredicate.forAnnotation(Controller.class); + + assertTrue(predicate.test(HtmlController.class)); + assertTrue(predicate.test(ApiController.class)); + assertTrue(predicate.test(AnotherApiController.class)); + } + + @Test + public void forAnnotationWithException() { + + Predicate> predicate = HandlerTypePredicate.forAnnotation(Controller.class) + .and(HandlerTypePredicate.forAssignableType(Special.class)); + + assertFalse(predicate.test(HtmlController.class)); + assertFalse(predicate.test(ApiController.class)); + assertTrue(predicate.test(AnotherApiController.class)); + } + + + @Controller + private static class HtmlController {} + + @RestController + private static class ApiController {} + + @RestController + private static class AnotherApiController implements Special {} + + interface Special {} + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/MvcAnnotationPredicates.java b/spring-web/src/test/java/org/springframework/web/method/MvcAnnotationPredicates.java new file mode 100644 index 0000000000000000000000000000000000000000..1a95d49a95924002e71bbabb00628273881b2439 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/MvcAnnotationPredicates.java @@ -0,0 +1,385 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.method; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.function.Predicate; + +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.http.HttpStatus; +import org.springframework.web.bind.annotation.MatrixVariable; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.RequestAttribute; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.bind.annotation.ResponseStatus; +import org.springframework.web.bind.annotation.ValueConstants; + +/** + * Predicates for {@code @MVC} annotations. + * + * @author Rossen Stoyanchev + * @since 5.0 + * + * @see ResolvableMethod#annot(Predicate[]) + * @see ResolvableMethod.Builder#annot(Predicate[]) + */ +public class MvcAnnotationPredicates { + + + // Method parameter predicates + + public static ModelAttributePredicate modelAttribute() { + return new ModelAttributePredicate(); + } + + public static RequestBodyPredicate requestBody() { + return new RequestBodyPredicate(); + } + + public static RequestParamPredicate requestParam() { + return new RequestParamPredicate(); + } + + public static RequestPartPredicate requestPart() { + return new RequestPartPredicate(); + } + + public static RequestAttributePredicate requestAttribute() { + return new RequestAttributePredicate(); + } + + public static MatrixVariablePredicate matrixAttribute() { + return new MatrixVariablePredicate(); + } + + // Method predicates + + public static ModelAttributeMethodPredicate modelMethod() { + return new ModelAttributeMethodPredicate(); + } + + public static ResponseStatusPredicate responseStatus() { + return new ResponseStatusPredicate(); + } + + public static ResponseStatusPredicate responseStatus(HttpStatus code) { + return new ResponseStatusPredicate(code); + } + + public static RequestMappingPredicate requestMapping(String... path) { + return new RequestMappingPredicate(path); + } + + public static RequestMappingPredicate getMapping(String... path) { + return new RequestMappingPredicate(path).method(RequestMethod.GET); + } + + public static RequestMappingPredicate postMapping(String... path) { + return new RequestMappingPredicate(path).method(RequestMethod.POST); + } + + public static RequestMappingPredicate putMapping(String... path) { + return new RequestMappingPredicate(path).method(RequestMethod.PUT); + } + + public static RequestMappingPredicate deleteMapping(String... path) { + return new RequestMappingPredicate(path).method(RequestMethod.DELETE); + } + + public static RequestMappingPredicate optionsMapping(String... path) { + return new RequestMappingPredicate(path).method(RequestMethod.OPTIONS); + } + + public static RequestMappingPredicate headMapping(String... path) { + return new RequestMappingPredicate(path).method(RequestMethod.HEAD); + } + + + + public static class ModelAttributePredicate implements Predicate { + + private String name; + + private boolean binding = true; + + + public ModelAttributePredicate name(String name) { + this.name = name; + return this; + } + + public ModelAttributePredicate noName() { + this.name = ""; + return this; + } + + public ModelAttributePredicate noBinding() { + this.binding = false; + return this; + } + + + @Override + public boolean test(MethodParameter parameter) { + ModelAttribute annotation = parameter.getParameterAnnotation(ModelAttribute.class); + return annotation != null && + (this.name == null || annotation.name().equals(this.name)) && + annotation.binding() == this.binding; + } + } + + public static class RequestBodyPredicate implements Predicate { + + private boolean required = true; + + + public RequestBodyPredicate notRequired() { + this.required = false; + return this; + } + + + @Override + public boolean test(MethodParameter parameter) { + RequestBody annotation = parameter.getParameterAnnotation(RequestBody.class); + return annotation != null && annotation.required() == this.required; + } + } + + public static class RequestParamPredicate implements Predicate { + + private String name; + + private boolean required = true; + + private String defaultValue = ValueConstants.DEFAULT_NONE; + + + + public RequestParamPredicate name(String name) { + this.name = name; + return this; + } + + public RequestParamPredicate noName() { + this.name = ""; + return this; + } + + public RequestParamPredicate notRequired() { + this.required = false; + return this; + } + + public RequestParamPredicate notRequired(String defaultValue) { + this.defaultValue = defaultValue; + return this; + } + + + @Override + public boolean test(MethodParameter parameter) { + RequestParam annotation = parameter.getParameterAnnotation(RequestParam.class); + return annotation != null && + (this.name == null || annotation.name().equals(this.name)) && + annotation.required() == this.required && + annotation.defaultValue().equals(this.defaultValue); + } + } + + + public static class RequestPartPredicate implements Predicate { + + private String name; + + private boolean required = true; + + + public RequestPartPredicate name(String name) { + this.name = name; + return this; + } + + public RequestPartPredicate noName() { + this.name = ""; + return this; + } + + public RequestPartPredicate notRequired() { + this.required = false; + return this; + } + + + @Override + public boolean test(MethodParameter parameter) { + RequestPart annotation = parameter.getParameterAnnotation(RequestPart.class); + return annotation != null && + (this.name == null || annotation.name().equals(this.name)) && + annotation.required() == this.required; + } + } + + public static class ModelAttributeMethodPredicate implements Predicate { + + private String name; + + + public ModelAttributeMethodPredicate name(String name) { + this.name = name; + return this; + } + + public ModelAttributeMethodPredicate noName() { + this.name = ""; + return this; + } + + @Override + public boolean test(Method method) { + ModelAttribute annot = AnnotatedElementUtils.findMergedAnnotation(method, ModelAttribute.class); + return annot != null && (this.name == null || annot.name().equals(this.name)); + } + } + + public static class RequestAttributePredicate implements Predicate { + + private String name; + + private boolean required = true; + + + public RequestAttributePredicate name(String name) { + this.name = name; + return this; + } + + public RequestAttributePredicate noName() { + this.name = ""; + return this; + } + + public RequestAttributePredicate notRequired() { + this.required = false; + return this; + } + + + @Override + public boolean test(MethodParameter parameter) { + RequestAttribute annotation = parameter.getParameterAnnotation(RequestAttribute.class); + return annotation != null && + (this.name == null || annotation.name().equals(this.name)) && + annotation.required() == this.required; + } + } + + public static class ResponseStatusPredicate implements Predicate { + + private HttpStatus code = HttpStatus.INTERNAL_SERVER_ERROR; + + + private ResponseStatusPredicate() { + } + + private ResponseStatusPredicate(HttpStatus code) { + this.code = code; + } + + @Override + public boolean test(Method method) { + ResponseStatus annot = AnnotatedElementUtils.findMergedAnnotation(method, ResponseStatus.class); + return annot != null && annot.code().equals(this.code); + } + } + + public static class RequestMappingPredicate implements Predicate { + + private String[] path; + + private RequestMethod[] method = {}; + + private String[] params; + + + private RequestMappingPredicate(String... path) { + this.path = path; + } + + + public RequestMappingPredicate method(RequestMethod... methods) { + this.method = methods; + return this; + } + + public RequestMappingPredicate params(String... params) { + this.params = params; + return this; + } + + @Override + public boolean test(Method method) { + RequestMapping annot = AnnotatedElementUtils.findMergedAnnotation(method, RequestMapping.class); + return annot != null && + Arrays.equals(this.path, annot.path()) && + Arrays.equals(this.method, annot.method()) && + (this.params == null || Arrays.equals(this.params, annot.params())); + } + } + + public static class MatrixVariablePredicate implements Predicate { + + private String name; + + private String pathVar; + + + public MatrixVariablePredicate name(String name) { + this.name = name; + return this; + } + + public MatrixVariablePredicate noName() { + this.name = ""; + return this; + } + + public MatrixVariablePredicate pathVar(String name) { + this.pathVar = name; + return this; + } + + public MatrixVariablePredicate noPathVar() { + this.pathVar = ValueConstants.DEFAULT_NONE; + return this; + } + + @Override + public boolean test(MethodParameter parameter) { + MatrixVariable annotation = parameter.getParameterAnnotation(MatrixVariable.class); + return annotation != null && + (this.name == null || this.name.equalsIgnoreCase(annotation.name())) && + (this.pathVar == null || this.pathVar.equalsIgnoreCase(annotation.pathVar())); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/ResolvableMethod.java b/spring-web/src/test/java/org/springframework/web/method/ResolvableMethod.java new file mode 100644 index 0000000000000000000000000000000000000000..55cd6ccdfae6e5d50bee58f4d4fe32337a8836bd --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/ResolvableMethod.java @@ -0,0 +1,689 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.aopalliance.intercept.MethodInterceptor; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.aop.target.EmptyTargetSource; +import org.springframework.cglib.core.SpringNamingPolicy; +import org.springframework.cglib.proxy.Callback; +import org.springframework.cglib.proxy.Enhancer; +import org.springframework.cglib.proxy.Factory; +import org.springframework.cglib.proxy.MethodProxy; +import org.springframework.core.LocalVariableTableParameterNameDiscoverer; +import org.springframework.core.MethodIntrospector; +import org.springframework.core.MethodParameter; +import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.objenesis.ObjenesisException; +import org.springframework.objenesis.SpringObjenesis; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.util.ReflectionUtils; + +import static java.util.stream.Collectors.*; + +/** + * Convenience class to resolve method parameters from hints. + * + *

Background

+ * + *

When testing annotated methods we create test classes such as + * "TestController" with a diverse range of method signatures representing + * supported annotations and argument types. It becomes challenging to use + * naming strategies to keep track of methods and arguments especially in + * combination with variables for reflection metadata. + * + *

The idea with {@link ResolvableMethod} is NOT to rely on naming techniques + * but to use hints to zero in on method parameters. Such hints can be strongly + * typed and explicit about what is being tested. + * + *

1. Declared Return Type

+ * + * When testing return types it's likely to have many methods with a unique + * return type, possibly with or without an annotation. + * + *
+ * import static org.springframework.web.method.ResolvableMethod.on;
+ * import static org.springframework.web.method.MvcAnnotationPredicates.requestMapping;
+ *
+ * // Return type
+ * on(TestController.class).resolveReturnType(Foo.class);
+ * on(TestController.class).resolveReturnType(List.class, Foo.class);
+ * on(TestController.class).resolveReturnType(Mono.class, responseEntity(Foo.class));
+ *
+ * // Annotation + return type
+ * on(TestController.class).annotPresent(RequestMapping.class).resolveReturnType(Bar.class);
+ *
+ * // Annotation not present
+ * on(TestController.class).annotNotPresent(RequestMapping.class).resolveReturnType();
+ *
+ * // Annotation with attributes
+ * on(TestController.class).annot(requestMapping("/foo").params("p")).resolveReturnType();
+ * 
+ * + *

2. Method Arguments

+ * + * When testing method arguments it's more likely to have one or a small number + * of methods with a wide array of argument types and parameter annotations. + * + *
+ * import static org.springframework.web.method.MvcAnnotationPredicates.requestParam;
+ *
+ * ResolvableMethod testMethod = ResolvableMethod.on(getClass()).named("handle").build();
+ *
+ * testMethod.arg(Foo.class);
+ * testMethod.annotPresent(RequestParam.class).arg(Integer.class);
+ * testMethod.annotNotPresent(RequestParam.class)).arg(Integer.class);
+ * testMethod.annot(requestParam().name("c").notRequired()).arg(Integer.class);
+ * 
+ * + *

3. Mock Handler Method Invocation

+ * + * Locate a method by invoking it through a proxy of the target handler: + * + *
+ * ResolvableMethod.on(TestController.class).mockCall(o -> o.handle(null)).method();
+ * 
+ * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ResolvableMethod { + + private static final Log logger = LogFactory.getLog(ResolvableMethod.class); + + private static final SpringObjenesis objenesis = new SpringObjenesis(); + + private static final ParameterNameDiscoverer nameDiscoverer = new LocalVariableTableParameterNameDiscoverer(); + + // Matches ValueConstants.DEFAULT_NONE (spring-web and spring-messaging) + private static final String DEFAULT_VALUE_NONE = "\n\t\t\n\t\t\n\uE000\uE001\uE002\n\t\t\t\t\n"; + + + private final Method method; + + + private ResolvableMethod(Method method) { + Assert.notNull(method, "'method' is required"); + this.method = method; + } + + + /** + * Return the resolved method. + */ + public Method method() { + return this.method; + } + + /** + * Return the declared return type of the resolved method. + */ + public MethodParameter returnType() { + return new SynthesizingMethodParameter(this.method, -1); + } + + /** + * Find a unique argument matching the given type. + * @param type the expected type + * @param generics optional array of generic types + */ + public MethodParameter arg(Class type, Class... generics) { + return new ArgResolver().arg(type, generics); + } + + /** + * Find a unique argument matching the given type. + * @param type the expected type + * @param generic at least one generic type + * @param generics optional array of generic types + */ + public MethodParameter arg(Class type, ResolvableType generic, ResolvableType... generics) { + return new ArgResolver().arg(type, generic, generics); + } + + /** + * Find a unique argument matching the given type. + * @param type the expected type + */ + public MethodParameter arg(ResolvableType type) { + return new ArgResolver().arg(type); + } + + /** + * Filter on method arguments with annotation. + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final ArgResolver annot(Predicate... filter) { + return new ArgResolver(filter); + } + + @SafeVarargs + public final ArgResolver annotPresent(Class... annotationTypes) { + return new ArgResolver().annotPresent(annotationTypes); + } + + /** + * Filter on method arguments that don't have the given annotation type(s). + * @param annotationTypes the annotation types + */ + @SafeVarargs + public final ArgResolver annotNotPresent(Class... annotationTypes) { + return new ArgResolver().annotNotPresent(annotationTypes); + } + + + @Override + public String toString() { + return "ResolvableMethod=" + formatMethod(); + } + + + private String formatMethod() { + return (method().getName() + + Arrays.stream(this.method.getParameters()) + .map(this::formatParameter) + .collect(joining(",\n\t", "(\n\t", "\n)"))); + } + + private String formatParameter(Parameter param) { + Annotation[] anns = param.getAnnotations(); + return (anns.length > 0 ? + Arrays.stream(anns).map(this::formatAnnotation).collect(joining(",", "[", "]")) + " " + param : + param.toString()); + } + + private String formatAnnotation(Annotation annotation) { + Map map = AnnotationUtils.getAnnotationAttributes(annotation); + map.forEach((key, value) -> { + if (value.equals(DEFAULT_VALUE_NONE)) { + map.put(key, "NONE"); + } + }); + return annotation.annotationType().getName() + map; + } + + private static ResolvableType toResolvableType(Class type, Class... generics) { + return (ObjectUtils.isEmpty(generics) ? ResolvableType.forClass(type) : + ResolvableType.forClassWithGenerics(type, generics)); + } + + private static ResolvableType toResolvableType(Class type, ResolvableType generic, ResolvableType... generics) { + ResolvableType[] genericTypes = new ResolvableType[generics.length + 1]; + genericTypes[0] = generic; + System.arraycopy(generics, 0, genericTypes, 1, generics.length); + return ResolvableType.forClassWithGenerics(type, genericTypes); + } + + + /** + * Create a {@code ResolvableMethod} builder for the given handler class. + */ + public static Builder on(Class objectClass) { + return new Builder<>(objectClass); + } + + + /** + * Builder for {@code ResolvableMethod}. + */ + public static class Builder { + + private final Class objectClass; + + private final List> filters = new ArrayList<>(4); + + + private Builder(Class objectClass) { + Assert.notNull(objectClass, "Class must not be null"); + this.objectClass = objectClass; + } + + + private void addFilter(String message, Predicate filter) { + this.filters.add(new LabeledPredicate<>(message, filter)); + } + + /** + * Filter on methods with the given name. + */ + public Builder named(String methodName) { + addFilter("methodName=" + methodName, method -> method.getName().equals(methodName)); + return this; + } + + /** + * Filter on methods with the given parameter types. + */ + public Builder argTypes(Class... argTypes) { + addFilter("argTypes=" + Arrays.toString(argTypes), method -> + ObjectUtils.isEmpty(argTypes) ? method.getParameterCount() == 0 : + Arrays.equals(method.getParameterTypes(), argTypes)); + return this; + } + + /** + * Filter on annotated methods. + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final Builder annot(Predicate... filters) { + this.filters.addAll(Arrays.asList(filters)); + return this; + } + + /** + * Filter on methods annotated with the given annotation type. + * @see #annot(Predicate[]) + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final Builder annotPresent(Class... annotationTypes) { + String message = "annotationPresent=" + Arrays.toString(annotationTypes); + addFilter(message, method -> + Arrays.stream(annotationTypes).allMatch(annotType -> + AnnotatedElementUtils.findMergedAnnotation(method, annotType) != null)); + return this; + } + + /** + * Filter on methods not annotated with the given annotation type. + */ + @SafeVarargs + public final Builder annotNotPresent(Class... annotationTypes) { + String message = "annotationNotPresent=" + Arrays.toString(annotationTypes); + addFilter(message, method -> { + if (annotationTypes.length != 0) { + return Arrays.stream(annotationTypes).noneMatch(annotType -> + AnnotatedElementUtils.findMergedAnnotation(method, annotType) != null); + } + else { + return method.getAnnotations().length == 0; + } + }); + return this; + } + + /** + * Filter on methods returning the given type. + * @param returnType the return type + * @param generics optional array of generic types + */ + public Builder returning(Class returnType, Class... generics) { + return returning(toResolvableType(returnType, generics)); + } + + /** + * Filter on methods returning the given type with generics. + * @param returnType the return type + * @param generic at least one generic type + * @param generics optional extra generic types + */ + public Builder returning(Class returnType, ResolvableType generic, ResolvableType... generics) { + return returning(toResolvableType(returnType, generic, generics)); + } + + /** + * Filter on methods returning the given type. + * @param returnType the return type + */ + public Builder returning(ResolvableType returnType) { + String expected = returnType.toString(); + String message = "returnType=" + expected; + addFilter(message, m -> expected.equals(ResolvableType.forMethodReturnType(m).toString())); + return this; + } + + /** + * Build a {@code ResolvableMethod} from the provided filters which must + * resolve to a unique, single method. + *

See additional resolveXxx shortcut methods going directly to + * {@link Method} or return type parameter. + * @throws IllegalStateException for no match or multiple matches + */ + public ResolvableMethod build() { + Set methods = MethodIntrospector.selectMethods(this.objectClass, this::isMatch); + Assert.state(!methods.isEmpty(), () -> "No matching method: " + this); + Assert.state(methods.size() == 1, () -> "Multiple matching methods: " + this + formatMethods(methods)); + return new ResolvableMethod(methods.iterator().next()); + } + + private boolean isMatch(Method method) { + return this.filters.stream().allMatch(p -> p.test(method)); + } + + private String formatMethods(Set methods) { + return "\nMatched:\n" + methods.stream() + .map(Method::toGenericString).collect(joining(",\n\t", "[\n\t", "\n]")); + } + + public ResolvableMethod mockCall(Consumer invoker) { + MethodInvocationInterceptor interceptor = new MethodInvocationInterceptor(); + T proxy = initProxy(this.objectClass, interceptor); + invoker.accept(proxy); + Method method = interceptor.getInvokedMethod(); + return new ResolvableMethod(method); + } + + + // Build & resolve shortcuts... + + /** + * Resolve and return the {@code Method} equivalent to: + *

{@code build().method()} + */ + public final Method resolveMethod() { + return build().method(); + } + + /** + * Resolve and return the {@code Method} equivalent to: + *

{@code named(methodName).build().method()} + */ + public Method resolveMethod(String methodName) { + return named(methodName).build().method(); + } + + /** + * Resolve and return the declared return type equivalent to: + *

{@code build().returnType()} + */ + public final MethodParameter resolveReturnType() { + return build().returnType(); + } + + /** + * Shortcut to the unique return type equivalent to: + *

{@code returning(returnType).build().returnType()} + * @param returnType the return type + * @param generics optional array of generic types + */ + public MethodParameter resolveReturnType(Class returnType, Class... generics) { + return returning(returnType, generics).build().returnType(); + } + + /** + * Shortcut to the unique return type equivalent to: + *

{@code returning(returnType).build().returnType()} + * @param returnType the return type + * @param generic at least one generic type + * @param generics optional extra generic types + */ + public MethodParameter resolveReturnType(Class returnType, ResolvableType generic, + ResolvableType... generics) { + + return returning(returnType, generic, generics).build().returnType(); + } + + public MethodParameter resolveReturnType(ResolvableType returnType) { + return returning(returnType).build().returnType(); + } + + + @Override + public String toString() { + return "ResolvableMethod.Builder[\n" + + "\tobjectClass = " + this.objectClass.getName() + ",\n" + + "\tfilters = " + formatFilters() + "\n]"; + } + + private String formatFilters() { + return this.filters.stream().map(Object::toString) + .collect(joining(",\n\t\t", "[\n\t\t", "\n\t]")); + } + } + + + /** + * Predicate with a descriptive label. + */ + private static class LabeledPredicate implements Predicate { + + private final String label; + + private final Predicate delegate; + + + private LabeledPredicate(String label, Predicate delegate) { + this.label = label; + this.delegate = delegate; + } + + + @Override + public boolean test(T method) { + return this.delegate.test(method); + } + + @Override + public Predicate and(Predicate other) { + return this.delegate.and(other); + } + + @Override + public Predicate negate() { + return this.delegate.negate(); + } + + @Override + public Predicate or(Predicate other) { + return this.delegate.or(other); + } + + @Override + public String toString() { + return this.label; + } + } + + + /** + * Resolver for method arguments. + */ + public class ArgResolver { + + private final List> filters = new ArrayList<>(4); + + + @SafeVarargs + private ArgResolver(Predicate... filter) { + this.filters.addAll(Arrays.asList(filter)); + } + + /** + * Filter on method arguments with annotations. + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final ArgResolver annot(Predicate... filters) { + this.filters.addAll(Arrays.asList(filters)); + return this; + } + + /** + * Filter on method arguments that have the given annotations. + * @param annotationTypes the annotation types + * @see #annot(Predicate[]) + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final ArgResolver annotPresent(Class... annotationTypes) { + this.filters.add(param -> Arrays.stream(annotationTypes).allMatch(param::hasParameterAnnotation)); + return this; + } + + /** + * Filter on method arguments that don't have the given annotations. + * @param annotationTypes the annotation types + */ + @SafeVarargs + public final ArgResolver annotNotPresent(Class... annotationTypes) { + this.filters.add(param -> + (annotationTypes.length > 0 ? + Arrays.stream(annotationTypes).noneMatch(param::hasParameterAnnotation) : + param.getParameterAnnotations().length == 0)); + return this; + } + + /** + * Resolve the argument also matching to the given type. + * @param type the expected type + */ + public MethodParameter arg(Class type, Class... generics) { + return arg(toResolvableType(type, generics)); + } + + /** + * Resolve the argument also matching to the given type. + * @param type the expected type + */ + public MethodParameter arg(Class type, ResolvableType generic, ResolvableType... generics) { + return arg(toResolvableType(type, generic, generics)); + } + + /** + * Resolve the argument also matching to the given type. + * @param type the expected type + */ + public MethodParameter arg(ResolvableType type) { + this.filters.add(p -> type.toString().equals(ResolvableType.forMethodParameter(p).toString())); + return arg(); + } + + /** + * Resolve the argument. + */ + public final MethodParameter arg() { + List matches = applyFilters(); + Assert.state(!matches.isEmpty(), () -> + "No matching arg in method\n" + formatMethod()); + Assert.state(matches.size() == 1, () -> + "Multiple matching args in method\n" + formatMethod() + "\nMatches:\n\t" + matches); + return matches.get(0); + } + + + private List applyFilters() { + List matches = new ArrayList<>(); + for (int i = 0; i < method.getParameterCount(); i++) { + MethodParameter param = new SynthesizingMethodParameter(method, i); + param.initParameterNameDiscovery(nameDiscoverer); + if (this.filters.stream().allMatch(p -> p.test(param))) { + matches.add(param); + } + } + return matches; + } + } + + + private static class MethodInvocationInterceptor + implements org.springframework.cglib.proxy.MethodInterceptor, MethodInterceptor { + + private Method invokedMethod; + + + Method getInvokedMethod() { + return this.invokedMethod; + } + + @Override + @Nullable + public Object intercept(Object object, Method method, Object[] args, MethodProxy proxy) { + if (ReflectionUtils.isObjectMethod(method)) { + return ReflectionUtils.invokeMethod(method, object, args); + } + else { + this.invokedMethod = method; + return null; + } + } + + @Override + @Nullable + public Object invoke(org.aopalliance.intercept.MethodInvocation inv) throws Throwable { + return intercept(inv.getThis(), inv.getMethod(), inv.getArguments(), null); + } + } + + @SuppressWarnings("unchecked") + private static T initProxy(Class type, MethodInvocationInterceptor interceptor) { + Assert.notNull(type, "'type' must not be null"); + if (type.isInterface()) { + ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE); + factory.addInterface(type); + factory.addInterface(Supplier.class); + factory.addAdvice(interceptor); + return (T) factory.getProxy(); + } + + else { + Enhancer enhancer = new Enhancer(); + enhancer.setSuperclass(type); + enhancer.setInterfaces(new Class[] {Supplier.class}); + enhancer.setNamingPolicy(SpringNamingPolicy.INSTANCE); + enhancer.setCallbackType(org.springframework.cglib.proxy.MethodInterceptor.class); + + Class proxyClass = enhancer.createClass(); + Object proxy = null; + + if (objenesis.isWorthTrying()) { + try { + proxy = objenesis.newInstance(proxyClass, enhancer.getUseCache()); + } + catch (ObjenesisException ex) { + logger.debug("Objenesis failed, falling back to default constructor", ex); + } + } + + if (proxy == null) { + try { + proxy = ReflectionUtils.accessibleConstructor(proxyClass).newInstance(); + } + catch (Throwable ex) { + throw new IllegalStateException("Unable to instantiate proxy " + + "via both Objenesis and default constructor fails as well", ex); + } + } + + ((Factory) proxy).setCallbacks(new Callback[] {interceptor}); + return (T) proxy; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/CookieValueMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/CookieValueMethodArgumentResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..814a3187fbcc208ac988705a667dace8e8c0e21b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/CookieValueMethodArgumentResolverTests.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; + +import javax.servlet.http.Cookie; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.bind.ServletRequestBindingException; +import org.springframework.web.bind.annotation.CookieValue; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * Test fixture with {@link org.springframework.web.method.annotation.AbstractCookieValueMethodArgumentResolver}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class CookieValueMethodArgumentResolverTests { + + private AbstractCookieValueMethodArgumentResolver resolver; + + private MethodParameter paramNamedCookie; + + private MethodParameter paramNamedDefaultValueString; + + private MethodParameter paramString; + + private ServletWebRequest webRequest; + + private MockHttpServletRequest request; + + + @Before + public void setUp() throws Exception { + resolver = new TestCookieValueMethodArgumentResolver(); + + Method method = getClass().getMethod("params", Cookie.class, String.class, String.class); + paramNamedCookie = new SynthesizingMethodParameter(method, 0); + paramNamedDefaultValueString = new SynthesizingMethodParameter(method, 1); + paramString = new SynthesizingMethodParameter(method, 2); + + request = new MockHttpServletRequest(); + webRequest = new ServletWebRequest(request, new MockHttpServletResponse()); + } + + + @Test + public void supportsParameter() { + assertTrue("Cookie parameter not supported", resolver.supportsParameter(paramNamedCookie)); + assertTrue("Cookie string parameter not supported", resolver.supportsParameter(paramNamedDefaultValueString)); + assertFalse("non-@CookieValue parameter supported", resolver.supportsParameter(paramString)); + } + + @Test + public void resolveCookieDefaultValue() throws Exception { + Object result = resolver.resolveArgument(paramNamedDefaultValueString, null, webRequest, null); + + assertTrue(result instanceof String); + assertEquals("Invalid result", "bar", result); + } + + @Test(expected = ServletRequestBindingException.class) + public void notFound() throws Exception { + resolver.resolveArgument(paramNamedCookie, null, webRequest, null); + fail("Expected exception"); + } + + private static class TestCookieValueMethodArgumentResolver extends AbstractCookieValueMethodArgumentResolver { + + public TestCookieValueMethodArgumentResolver() { + super(null); + } + + @Override + protected Object resolveName(String name, MethodParameter parameter, NativeWebRequest request) throws Exception { + return null; + } + } + + + public void params(@CookieValue("name") Cookie param1, + @CookieValue(name = "name", defaultValue = "bar") String param2, + String param3) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/ErrorsMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/ErrorsMethodArgumentResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..400d30a25fdafb4f8338ee6ae2044fd37d981e30 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/ErrorsMethodArgumentResolverTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.validation.BindingResult; +import org.springframework.validation.Errors; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.support.ModelAndViewContainer; + +import static org.junit.Assert.*; + +/** + * Test fixture with {@link ErrorsMethodArgumentResolver}. + * + * @author Rossen Stoyanchev + */ +public class ErrorsMethodArgumentResolverTests { + + private final ErrorsMethodArgumentResolver resolver = new ErrorsMethodArgumentResolver(); + + private BindingResult bindingResult; + + private MethodParameter paramErrors; + + private NativeWebRequest webRequest; + + + @Before + public void setup() throws Exception { + paramErrors = new MethodParameter(getClass().getDeclaredMethod("handle", Errors.class), 0); + bindingResult = new WebDataBinder(new Object(), "attr").getBindingResult(); + webRequest = new ServletWebRequest(new MockHttpServletRequest()); + } + + + @Test + public void supports() { + resolver.supportsParameter(paramErrors); + } + + @Test + public void bindingResult() throws Exception { + ModelAndViewContainer mavContainer = new ModelAndViewContainer(); + mavContainer.addAttribute("ignore1", "value1"); + mavContainer.addAttribute("ignore2", "value2"); + mavContainer.addAttribute("ignore3", "value3"); + mavContainer.addAttribute("ignore4", "value4"); + mavContainer.addAttribute("ignore5", "value5"); + mavContainer.addAllAttributes(bindingResult.getModel()); + + Object actual = resolver.resolveArgument(paramErrors, mavContainer, webRequest, null); + assertSame(actual, bindingResult); + } + + @Test(expected = IllegalStateException.class) + public void bindingResultNotFound() throws Exception { + ModelAndViewContainer mavContainer = new ModelAndViewContainer(); + mavContainer.addAllAttributes(bindingResult.getModel()); + mavContainer.addAttribute("ignore1", "value1"); + + resolver.resolveArgument(paramErrors, mavContainer, webRequest, null); + } + + @Test(expected = IllegalStateException.class) + public void noBindingResult() throws Exception { + resolver.resolveArgument(paramErrors, new ModelAndViewContainer(), webRequest, null); + } + + + @SuppressWarnings("unused") + private void handle(Errors errors) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b35b724be8204db5545426f4d70eafbe12b0c296 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/ExceptionHandlerMethodResolverTests.java @@ -0,0 +1,152 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.BindException; +import java.net.SocketException; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Test; + +import org.springframework.stereotype.Controller; +import org.springframework.util.ClassUtils; +import org.springframework.web.bind.annotation.ExceptionHandler; + +import static org.junit.Assert.*; + +/** + * Test fixture for {@link ExceptionHandlerMethodResolver} tests. + * + * @author Rossen Stoyanchev + */ +public class ExceptionHandlerMethodResolverTests { + + @Test + public void resolveMethodFromAnnotation() { + ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(ExceptionController.class); + IOException exception = new IOException(); + assertEquals("handleIOException", resolver.resolveMethod(exception).getName()); + } + + @Test + public void resolveMethodFromArgument() { + ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(ExceptionController.class); + IllegalArgumentException exception = new IllegalArgumentException(); + assertEquals("handleIllegalArgumentException", resolver.resolveMethod(exception).getName()); + } + + @Test + public void resolveMethodExceptionSubType() { + ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(ExceptionController.class); + IOException ioException = new FileNotFoundException(); + assertEquals("handleIOException", resolver.resolveMethod(ioException).getName()); + SocketException bindException = new BindException(); + assertEquals("handleSocketException", resolver.resolveMethod(bindException).getName()); + } + + @Test + public void resolveMethodBestMatch() { + ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(ExceptionController.class); + SocketException exception = new SocketException(); + assertEquals("handleSocketException", resolver.resolveMethod(exception).getName()); + } + + @Test + public void resolveMethodNoMatch() { + ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(ExceptionController.class); + Exception exception = new Exception(); + assertNull("1st lookup", resolver.resolveMethod(exception)); + assertNull("2nd lookup from cache", resolver.resolveMethod(exception)); + } + + @Test + public void resolveMethodInherited() { + ExceptionHandlerMethodResolver resolver = new ExceptionHandlerMethodResolver(InheritedController.class); + IOException exception = new IOException(); + assertEquals("handleIOException", resolver.resolveMethod(exception).getName()); + } + + @Test(expected = IllegalStateException.class) + public void ambiguousExceptionMapping() { + new ExceptionHandlerMethodResolver(AmbiguousController.class); + } + + @Test(expected = IllegalStateException.class) + public void noExceptionMapping() { + new ExceptionHandlerMethodResolver(NoExceptionController.class); + } + + + @Controller + static class ExceptionController { + + public void handle() {} + + @ExceptionHandler(IOException.class) + public void handleIOException() { + } + + @ExceptionHandler(SocketException.class) + public void handleSocketException() { + } + + @ExceptionHandler + public void handleIllegalArgumentException(IllegalArgumentException exception) { + } + } + + + @Controller + static class InheritedController extends ExceptionController { + + @Override + public void handleIOException() { + } + } + + + @Controller + static class AmbiguousController { + + public void handle() {} + + @ExceptionHandler({BindException.class, IllegalArgumentException.class}) + public String handle1(Exception ex, HttpServletRequest request, HttpServletResponse response) + throws IOException { + return ClassUtils.getShortName(ex.getClass()); + } + + @ExceptionHandler + public String handle2(IllegalArgumentException ex) { + return ClassUtils.getShortName(ex.getClass()); + } + } + + + @Controller + static class NoExceptionController { + + @ExceptionHandler + public void handle() { + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/ExpressionValueMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/ExpressionValueMethodArgumentResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b566dd232c666161dc18486ddfdc4ae898792199 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/ExpressionValueMethodArgumentResolverTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.MethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.context.support.GenericWebApplicationContext; + +import static org.junit.Assert.*; + +/** + * Test fixture with {@link ExpressionValueMethodArgumentResolver}. + * + * @author Rossen Stoyanchev + */ +public class ExpressionValueMethodArgumentResolverTests { + + private ExpressionValueMethodArgumentResolver resolver; + + private MethodParameter paramSystemProperty; + + private MethodParameter paramContextPath; + + private MethodParameter paramNotSupported; + + private NativeWebRequest webRequest; + + @Before + @SuppressWarnings("resource") + public void setUp() throws Exception { + GenericWebApplicationContext context = new GenericWebApplicationContext(); + context.refresh(); + resolver = new ExpressionValueMethodArgumentResolver(context.getBeanFactory()); + + Method method = getClass().getMethod("params", int.class, String.class, String.class); + paramSystemProperty = new MethodParameter(method, 0); + paramContextPath = new MethodParameter(method, 1); + paramNotSupported = new MethodParameter(method, 2); + + webRequest = new ServletWebRequest(new MockHttpServletRequest(), new MockHttpServletResponse()); + + // Expose request to the current thread (for SpEL expressions) + RequestContextHolder.setRequestAttributes(webRequest); + } + + @After + public void teardown() { + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void supportsParameter() throws Exception { + assertTrue(resolver.supportsParameter(paramSystemProperty)); + assertTrue(resolver.supportsParameter(paramContextPath)); + assertFalse(resolver.supportsParameter(paramNotSupported)); + } + + @Test + public void resolveSystemProperty() throws Exception { + System.setProperty("systemProperty", "22"); + Object value = resolver.resolveArgument(paramSystemProperty, null, webRequest, null); + System.clearProperty("systemProperty"); + + assertEquals("22", value); + } + + @Test + public void resolveContextPath() throws Exception { + webRequest.getNativeRequest(MockHttpServletRequest.class).setContextPath("/contextPath"); + Object value = resolver.resolveArgument(paramContextPath, null, webRequest, null); + + assertEquals("/contextPath", value); + } + + public void params(@Value("#{systemProperties.systemProperty}") int param1, + @Value("#{request.contextPath}") String param2, String notSupported) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/InitBinderDataBinderFactoryTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/InitBinderDataBinderFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f858f90d242468e812000dadf0c3cd4f3dd5f650 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/InitBinderDataBinderFactoryTests.java @@ -0,0 +1,162 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; +import java.util.Collections; + +import org.junit.Test; + +import org.springframework.core.LocalVariableTableParameterNameDiscoverer; +import org.springframework.core.convert.ConversionService; +import org.springframework.format.support.DefaultFormattingConversionService; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.InitBinder; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; +import org.springframework.web.bind.support.DefaultDataBinderFactory; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolverComposite; +import org.springframework.web.method.support.InvocableHandlerMethod; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +/** + * Test fixture with {@link InitBinderDataBinderFactory}. + * + * @author Rossen Stoyanchev + */ +public class InitBinderDataBinderFactoryTests { + + private final ConfigurableWebBindingInitializer bindingInitializer = + new ConfigurableWebBindingInitializer(); + + private final HandlerMethodArgumentResolverComposite argumentResolvers = + new HandlerMethodArgumentResolverComposite(); + + private final NativeWebRequest webRequest = new ServletWebRequest(new MockHttpServletRequest()); + + + @Test + public void createBinder() throws Exception { + WebDataBinderFactory factory = createFactory("initBinder", WebDataBinder.class); + WebDataBinder dataBinder = factory.createBinder(this.webRequest, null, null); + + assertNotNull(dataBinder.getDisallowedFields()); + assertEquals("id", dataBinder.getDisallowedFields()[0]); + } + + @Test + public void createBinderWithGlobalInitialization() throws Exception { + ConversionService conversionService = new DefaultFormattingConversionService(); + bindingInitializer.setConversionService(conversionService); + + WebDataBinderFactory factory = createFactory("initBinder", WebDataBinder.class); + WebDataBinder dataBinder = factory.createBinder(this.webRequest, null, null); + + assertSame(conversionService, dataBinder.getConversionService()); + } + + @Test + public void createBinderWithAttrName() throws Exception { + WebDataBinderFactory factory = createFactory("initBinderWithAttributeName", WebDataBinder.class); + WebDataBinder dataBinder = factory.createBinder(this.webRequest, null, "foo"); + + assertNotNull(dataBinder.getDisallowedFields()); + assertEquals("id", dataBinder.getDisallowedFields()[0]); + } + + @Test + public void createBinderWithAttrNameNoMatch() throws Exception { + WebDataBinderFactory factory = createFactory("initBinderWithAttributeName", WebDataBinder.class); + WebDataBinder dataBinder = factory.createBinder(this.webRequest, null, "invalidName"); + + assertNull(dataBinder.getDisallowedFields()); + } + + @Test + public void createBinderNullAttrName() throws Exception { + WebDataBinderFactory factory = createFactory("initBinderWithAttributeName", WebDataBinder.class); + WebDataBinder dataBinder = factory.createBinder(this.webRequest, null, null); + + assertNull(dataBinder.getDisallowedFields()); + } + + @Test(expected = IllegalStateException.class) + public void returnValueNotExpected() throws Exception { + WebDataBinderFactory factory = createFactory("initBinderReturnValue", WebDataBinder.class); + factory.createBinder(this.webRequest, null, "invalidName"); + } + + @Test + public void createBinderTypeConversion() throws Exception { + this.webRequest.getNativeRequest(MockHttpServletRequest.class).setParameter("requestParam", "22"); + this.argumentResolvers.addResolver(new RequestParamMethodArgumentResolver(null, false)); + + WebDataBinderFactory factory = createFactory("initBinderTypeConversion", WebDataBinder.class, int.class); + WebDataBinder dataBinder = factory.createBinder(this.webRequest, null, "foo"); + + assertNotNull(dataBinder.getDisallowedFields()); + assertEquals("requestParam-22", dataBinder.getDisallowedFields()[0]); + } + + private WebDataBinderFactory createFactory(String methodName, Class... parameterTypes) + throws Exception { + + Object handler = new InitBinderHandler(); + Method method = handler.getClass().getMethod(methodName, parameterTypes); + + InvocableHandlerMethod handlerMethod = new InvocableHandlerMethod(handler, method); + handlerMethod.setHandlerMethodArgumentResolvers(this.argumentResolvers); + handlerMethod.setDataBinderFactory(new DefaultDataBinderFactory(null)); + handlerMethod.setParameterNameDiscoverer(new LocalVariableTableParameterNameDiscoverer()); + + return new InitBinderDataBinderFactory( + Collections.singletonList(handlerMethod), this.bindingInitializer); + } + + + private static class InitBinderHandler { + + @InitBinder + public void initBinder(WebDataBinder dataBinder) { + dataBinder.setDisallowedFields("id"); + } + + @InitBinder(value="foo") + public void initBinderWithAttributeName(WebDataBinder dataBinder) { + dataBinder.setDisallowedFields("id"); + } + + @InitBinder + public String initBinderReturnValue(WebDataBinder dataBinder) { + return "invalid"; + } + + @InitBinder + public void initBinderTypeConversion(WebDataBinder dataBinder, @RequestParam int requestParam) { + dataBinder.setDisallowedFields("requestParam-" + requestParam); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/MapMethodProcessorTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/MapMethodProcessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..933f2012f782ba7bbe8794b24c91f1d9e47bae27 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/MapMethodProcessorTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.ui.ModelMap; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.support.ModelAndViewContainer; + +import static org.junit.Assert.*; + +/** + * Test fixture with {@link org.springframework.web.method.annotation.MapMethodProcessor}. + * + * @author Rossen Stoyanchev + */ +public class MapMethodProcessorTests { + + private MapMethodProcessor processor; + + private ModelAndViewContainer mavContainer; + + private MethodParameter paramMap; + + private MethodParameter returnParamMap; + + private NativeWebRequest webRequest; + + @Before + public void setUp() throws Exception { + processor = new MapMethodProcessor(); + mavContainer = new ModelAndViewContainer(); + + Method method = getClass().getDeclaredMethod("map", Map.class); + paramMap = new MethodParameter(method, 0); + returnParamMap = new MethodParameter(method, 0); + + webRequest = new ServletWebRequest(new MockHttpServletRequest()); + } + + @Test + public void supportsParameter() { + assertTrue(processor.supportsParameter(paramMap)); + } + + @Test + public void supportsReturnType() { + assertTrue(processor.supportsReturnType(returnParamMap)); + } + + @Test + public void resolveArgumentValue() throws Exception { + assertSame(mavContainer.getModel(), processor.resolveArgument(paramMap, mavContainer, webRequest, null)); + } + + @Test + public void handleMapReturnValue() throws Exception { + mavContainer.addAttribute("attr1", "value1"); + Map returnValue = new ModelMap("attr2", "value2"); + + processor.handleReturnValue(returnValue , returnParamMap, mavContainer, webRequest); + + assertEquals("value1", mavContainer.getModel().get("attr1")); + assertEquals("value2", mavContainer.getModel().get("attr2")); + } + + @SuppressWarnings("unused") + private Map map(Map map) { + return null; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/ModelAttributeMethodProcessorTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelAttributeMethodProcessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..13ce6ae9155c570e19c026f81958e2cbb41bd554 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelAttributeMethodProcessorTests.java @@ -0,0 +1,339 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.validation.BindException; +import org.springframework.validation.BindingResult; +import org.springframework.validation.Errors; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.bind.support.WebRequestDataBinder; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.context.request.WebRequest; +import org.springframework.web.method.support.ModelAndViewContainer; + +import static java.lang.annotation.ElementType.*; +import static java.lang.annotation.RetentionPolicy.*; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Test fixture with {@link ModelAttributeMethodProcessor}. + * + * @author Rossen Stoyanchev + */ +public class ModelAttributeMethodProcessorTests { + + private NativeWebRequest request; + + private ModelAndViewContainer container; + + private ModelAttributeMethodProcessor processor; + + private MethodParameter paramNamedValidModelAttr; + private MethodParameter paramErrors; + private MethodParameter paramInt; + private MethodParameter paramModelAttr; + private MethodParameter paramBindingDisabledAttr; + private MethodParameter paramNonSimpleType; + + private MethodParameter returnParamNamedModelAttr; + private MethodParameter returnParamNonSimpleType; + + + @Before + public void setup() throws Exception { + this.request = new ServletWebRequest(new MockHttpServletRequest()); + this.container = new ModelAndViewContainer(); + this.processor = new ModelAttributeMethodProcessor(false); + + Method method = ModelAttributeHandler.class.getDeclaredMethod("modelAttribute", + TestBean.class, Errors.class, int.class, TestBean.class, + TestBean.class, TestBean.class); + + this.paramNamedValidModelAttr = new SynthesizingMethodParameter(method, 0); + this.paramErrors = new SynthesizingMethodParameter(method, 1); + this.paramInt = new SynthesizingMethodParameter(method, 2); + this.paramModelAttr = new SynthesizingMethodParameter(method, 3); + this.paramBindingDisabledAttr = new SynthesizingMethodParameter(method, 4); + this.paramNonSimpleType = new SynthesizingMethodParameter(method, 5); + + method = getClass().getDeclaredMethod("annotatedReturnValue"); + this.returnParamNamedModelAttr = new MethodParameter(method, -1); + + method = getClass().getDeclaredMethod("notAnnotatedReturnValue"); + this.returnParamNonSimpleType = new MethodParameter(method, -1); + } + + + @Test + public void supportedParameters() throws Exception { + assertTrue(this.processor.supportsParameter(this.paramNamedValidModelAttr)); + assertTrue(this.processor.supportsParameter(this.paramModelAttr)); + + assertFalse(this.processor.supportsParameter(this.paramErrors)); + assertFalse(this.processor.supportsParameter(this.paramInt)); + assertFalse(this.processor.supportsParameter(this.paramNonSimpleType)); + } + + @Test + public void supportedParametersInDefaultResolutionMode() throws Exception { + processor = new ModelAttributeMethodProcessor(true); + + // Only non-simple types, even if not annotated + assertTrue(this.processor.supportsParameter(this.paramNamedValidModelAttr)); + assertTrue(this.processor.supportsParameter(this.paramErrors)); + assertTrue(this.processor.supportsParameter(this.paramModelAttr)); + assertTrue(this.processor.supportsParameter(this.paramNonSimpleType)); + + assertFalse(this.processor.supportsParameter(this.paramInt)); + } + + @Test + public void supportedReturnTypes() throws Exception { + processor = new ModelAttributeMethodProcessor(false); + assertTrue(this.processor.supportsReturnType(returnParamNamedModelAttr)); + assertFalse(this.processor.supportsReturnType(returnParamNonSimpleType)); + } + + @Test + public void supportedReturnTypesInDefaultResolutionMode() throws Exception { + processor = new ModelAttributeMethodProcessor(true); + assertTrue(this.processor.supportsReturnType(returnParamNamedModelAttr)); + assertTrue(this.processor.supportsReturnType(returnParamNonSimpleType)); + } + + @Test + public void bindExceptionRequired() throws Exception { + assertTrue(this.processor.isBindExceptionRequired(null, this.paramNonSimpleType)); + assertFalse(this.processor.isBindExceptionRequired(null, this.paramNamedValidModelAttr)); + } + + @Test + public void resolveArgumentFromModel() throws Exception { + testGetAttributeFromModel("attrName", this.paramNamedValidModelAttr); + testGetAttributeFromModel("testBean", this.paramModelAttr); + testGetAttributeFromModel("testBean", this.paramNonSimpleType); + } + + @Test + public void resolveArgumentViaDefaultConstructor() throws Exception { + WebDataBinder dataBinder = new WebRequestDataBinder(null); + WebDataBinderFactory factory = mock(WebDataBinderFactory.class); + given(factory.createBinder(any(), notNull(), eq("attrName"))).willReturn(dataBinder); + + this.processor.resolveArgument(this.paramNamedValidModelAttr, this.container, this.request, factory); + verify(factory).createBinder(any(), notNull(), eq("attrName")); + } + + @Test + public void resolveArgumentValidation() throws Exception { + String name = "attrName"; + Object target = new TestBean(); + this.container.addAttribute(name, target); + + StubRequestDataBinder dataBinder = new StubRequestDataBinder(target, name); + WebDataBinderFactory factory = mock(WebDataBinderFactory.class); + given(factory.createBinder(this.request, target, name)).willReturn(dataBinder); + + this.processor.resolveArgument(this.paramNamedValidModelAttr, this.container, this.request, factory); + + assertTrue(dataBinder.isBindInvoked()); + assertTrue(dataBinder.isValidateInvoked()); + } + + @Test + public void resolveArgumentBindingDisabledPreviously() throws Exception { + String name = "attrName"; + Object target = new TestBean(); + this.container.addAttribute(name, target); + + // Declare binding disabled (e.g. via @ModelAttribute method) + this.container.setBindingDisabled(name); + + StubRequestDataBinder dataBinder = new StubRequestDataBinder(target, name); + WebDataBinderFactory factory = mock(WebDataBinderFactory.class); + given(factory.createBinder(this.request, target, name)).willReturn(dataBinder); + + this.processor.resolveArgument(this.paramNamedValidModelAttr, this.container, this.request, factory); + + assertFalse(dataBinder.isBindInvoked()); + assertTrue(dataBinder.isValidateInvoked()); + } + + @Test + public void resolveArgumentBindingDisabled() throws Exception { + String name = "noBindAttr"; + Object target = new TestBean(); + this.container.addAttribute(name, target); + + StubRequestDataBinder dataBinder = new StubRequestDataBinder(target, name); + WebDataBinderFactory factory = mock(WebDataBinderFactory.class); + given(factory.createBinder(this.request, target, name)).willReturn(dataBinder); + + this.processor.resolveArgument(this.paramBindingDisabledAttr, this.container, this.request, factory); + + assertFalse(dataBinder.isBindInvoked()); + assertTrue(dataBinder.isValidateInvoked()); + } + + @Test(expected = BindException.class) + public void resolveArgumentBindException() throws Exception { + String name = "testBean"; + Object target = new TestBean(); + this.container.getModel().addAttribute(target); + + StubRequestDataBinder dataBinder = new StubRequestDataBinder(target, name); + dataBinder.getBindingResult().reject("error"); + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(this.request, target, name)).willReturn(dataBinder); + + this.processor.resolveArgument(this.paramNonSimpleType, this.container, this.request, binderFactory); + verify(binderFactory).createBinder(this.request, target, name); + } + + @Test // SPR-9378 + public void resolveArgumentOrdering() throws Exception { + String name = "testBean"; + Object testBean = new TestBean(name); + this.container.addAttribute(name, testBean); + this.container.addAttribute(BindingResult.MODEL_KEY_PREFIX + name, testBean); + + Object anotherTestBean = new TestBean(); + this.container.addAttribute("anotherTestBean", anotherTestBean); + + StubRequestDataBinder dataBinder = new StubRequestDataBinder(testBean, name); + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(this.request, testBean, name)).willReturn(dataBinder); + + this.processor.resolveArgument(this.paramModelAttr, this.container, this.request, binderFactory); + + Object[] values = this.container.getModel().values().toArray(); + assertSame("Resolved attribute should be updated to be last", testBean, values[1]); + assertSame("BindingResult of resolved attr should be last", dataBinder.getBindingResult(), values[2]); + } + + @Test + public void handleAnnotatedReturnValue() throws Exception { + this.processor.handleReturnValue("expected", this.returnParamNamedModelAttr, this.container, this.request); + assertEquals("expected", this.container.getModel().get("modelAttrName")); + } + + @Test + public void handleNotAnnotatedReturnValue() throws Exception { + TestBean testBean = new TestBean("expected"); + this.processor.handleReturnValue(testBean, this.returnParamNonSimpleType, this.container, this.request); + assertSame(testBean, this.container.getModel().get("testBean")); + } + + + private void testGetAttributeFromModel(String expectedAttrName, MethodParameter param) throws Exception { + Object target = new TestBean(); + this.container.addAttribute(expectedAttrName, target); + + WebDataBinder dataBinder = new WebRequestDataBinder(target); + WebDataBinderFactory factory = mock(WebDataBinderFactory.class); + given(factory.createBinder(this.request, target, expectedAttrName)).willReturn(dataBinder); + + this.processor.resolveArgument(param, this.container, this.request, factory); + verify(factory).createBinder(this.request, target, expectedAttrName); + } + + + private static class StubRequestDataBinder extends WebRequestDataBinder { + + private boolean bindInvoked; + + private boolean validateInvoked; + + + public StubRequestDataBinder(Object target, String objectName) { + super(target, objectName); + } + + public boolean isBindInvoked() { + return bindInvoked; + } + + public boolean isValidateInvoked() { + return validateInvoked; + } + + @Override + public void bind(WebRequest request) { + bindInvoked = true; + } + + @Override + public void validate() { + validateInvoked = true; + } + + @Override + public void validate(Object... validationHints) { + validateInvoked = true; + } + } + + + @Target({METHOD, FIELD, CONSTRUCTOR, PARAMETER}) + @Retention(RUNTIME) + public @interface Valid { + } + + + @SessionAttributes(types=TestBean.class) + private static class ModelAttributeHandler { + + @SuppressWarnings("unused") + public void modelAttribute( + @ModelAttribute("attrName") @Valid TestBean annotatedAttr, + Errors errors, + int intArg, + @ModelAttribute TestBean defaultNameAttr, + @ModelAttribute(name="noBindAttr", binding=false) @Valid TestBean noBindAttr, + TestBean notAnnotatedAttr) { + } + } + + + @ModelAttribute("modelAttrName") @SuppressWarnings("unused") + private String annotatedReturnValue() { + return null; + } + + + @SuppressWarnings("unused") + private TestBean notAnnotatedReturnValue() { + return null; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryOrderingTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryOrderingTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c5f504c3d71aeaed4c679db0aec7037ac6cd0118 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryOrderingTests.java @@ -0,0 +1,333 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodIntrospector; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.ui.Model; +import org.springframework.util.ReflectionUtils; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.support.DefaultDataBinderFactory; +import org.springframework.web.bind.support.DefaultSessionAttributeStore; +import org.springframework.web.bind.support.SessionAttributeStore; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.method.support.HandlerMethodArgumentResolverComposite; +import org.springframework.web.method.support.InvocableHandlerMethod; +import org.springframework.web.method.support.ModelAndViewContainer; + +import static org.junit.Assert.*; + +/** + * Unit tests verifying {@code @ModelAttribute} method inter-dependencies. + * + * @author Rossen Stoyanchev + */ +public class ModelFactoryOrderingTests { + + private static final Log logger = LogFactory.getLog(ModelFactoryOrderingTests.class); + + private NativeWebRequest webRequest; + + private ModelAndViewContainer mavContainer; + + private SessionAttributeStore sessionAttributeStore; + + + @Before + public void setup() { + this.sessionAttributeStore = new DefaultSessionAttributeStore(); + this.webRequest = new ServletWebRequest(new MockHttpServletRequest(), new MockHttpServletResponse()); + this.mavContainer = new ModelAndViewContainer(); + this.mavContainer.addAttribute("methods", new ArrayList()); + } + + @Test + public void straightLineDependency() throws Exception { + runTest(new StraightLineDependencyController()); + assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4"); + assertInvokedBefore("getB1", "getB2", "getC1", "getC2", "getC3", "getC4"); + assertInvokedBefore("getB2", "getC1", "getC2", "getC3", "getC4"); + assertInvokedBefore("getC1", "getC2", "getC3", "getC4"); + assertInvokedBefore("getC2", "getC3", "getC4"); + assertInvokedBefore("getC3", "getC4"); + } + + @Test + public void treeDependency() throws Exception { + runTest(new TreeDependencyController()); + assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4"); + assertInvokedBefore("getB1", "getC1", "getC2"); + assertInvokedBefore("getB2", "getC3", "getC4"); + } + + @Test + public void InvertedTreeDependency() throws Exception { + runTest(new InvertedTreeDependencyController()); + assertInvokedBefore("getC1", "getA", "getB1"); + assertInvokedBefore("getC2", "getA", "getB1"); + assertInvokedBefore("getC3", "getA", "getB2"); + assertInvokedBefore("getC4", "getA", "getB2"); + assertInvokedBefore("getB1", "getA"); + assertInvokedBefore("getB2", "getA"); + } + + @Test + public void unresolvedDependency() throws Exception { + runTest(new UnresolvedDependencyController()); + assertInvokedBefore("getA", "getC1", "getC2", "getC3", "getC4"); + + // No other order guarantees for methods with unresolvable dependencies (and methods that depend on them), + // Required dependencies will be created via default constructor. + } + + private void runTest(Object controller) throws Exception { + HandlerMethodArgumentResolverComposite resolvers = new HandlerMethodArgumentResolverComposite(); + resolvers.addResolver(new ModelAttributeMethodProcessor(false)); + resolvers.addResolver(new ModelMethodProcessor()); + WebDataBinderFactory dataBinderFactory = new DefaultDataBinderFactory(null); + + Class type = controller.getClass(); + Set methods = MethodIntrospector.selectMethods(type, METHOD_FILTER); + List modelMethods = new ArrayList<>(); + for (Method method : methods) { + InvocableHandlerMethod modelMethod = new InvocableHandlerMethod(controller, method); + modelMethod.setHandlerMethodArgumentResolvers(resolvers); + modelMethod.setDataBinderFactory(dataBinderFactory); + modelMethods.add(modelMethod); + } + Collections.shuffle(modelMethods); + + SessionAttributesHandler sessionHandler = new SessionAttributesHandler(type, this.sessionAttributeStore); + ModelFactory factory = new ModelFactory(modelMethods, dataBinderFactory, sessionHandler); + factory.initModel(this.webRequest, this.mavContainer, new HandlerMethod(controller, "handle")); + if (logger.isDebugEnabled()) { + StringBuilder sb = new StringBuilder(); + for (String name : getInvokedMethods()) { + sb.append(" >> ").append(name); + } + logger.debug(sb); + } + } + + private void assertInvokedBefore(String beforeMethod, String... afterMethods) { + List actual = getInvokedMethods(); + for (String afterMethod : afterMethods) { + assertTrue(beforeMethod + " should be before " + afterMethod + ". Actual order: " + + actual.toString(), actual.indexOf(beforeMethod) < actual.indexOf(afterMethod)); + } + } + + @SuppressWarnings("unchecked") + private List getInvokedMethods() { + return (List) this.mavContainer.getModel().get("methods"); + } + + + private static class AbstractController { + + @RequestMapping + public void handle() { + } + + @SuppressWarnings("unchecked") + T updateAndReturn(Model model, String methodName, T returnValue) throws IOException { + ((List) model.asMap().get("methods")).add(methodName); + return returnValue; + } + } + + private static class StraightLineDependencyController extends AbstractController { + + @ModelAttribute + public A getA(Model model) throws IOException { + return updateAndReturn(model, "getA", new A()); + } + + @ModelAttribute + public B1 getB1(@ModelAttribute A a, Model model) throws IOException { + return updateAndReturn(model, "getB1", new B1()); + } + + @ModelAttribute + public B2 getB2(@ModelAttribute B1 b1, Model model) throws IOException { + return updateAndReturn(model, "getB2", new B2()); + } + + @ModelAttribute + public C1 getC1(@ModelAttribute B2 b2, Model model) throws IOException { + return updateAndReturn(model, "getC1", new C1()); + } + + + @ModelAttribute + public C2 getC2(@ModelAttribute C1 c1, Model model) throws IOException { + return updateAndReturn(model, "getC2", new C2()); + } + + @ModelAttribute + public C3 getC3(@ModelAttribute C2 c2, Model model) throws IOException { + return updateAndReturn(model, "getC3", new C3()); + } + + @ModelAttribute + public C4 getC4(@ModelAttribute C3 c3, Model model) throws IOException { + return updateAndReturn(model, "getC4", new C4()); + } + } + + private static class TreeDependencyController extends AbstractController { + + @ModelAttribute + public A getA(Model model) throws IOException { + return updateAndReturn(model, "getA", new A()); + } + + @ModelAttribute + public B1 getB1(@ModelAttribute A a, Model model) throws IOException { + return updateAndReturn(model, "getB1", new B1()); + } + + @ModelAttribute + public B2 getB2(@ModelAttribute A a, Model model) throws IOException { + return updateAndReturn(model, "getB2", new B2()); + } + + @ModelAttribute + public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException { + return updateAndReturn(model, "getC1", new C1()); + } + + @ModelAttribute + public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException { + return updateAndReturn(model, "getC2", new C2()); + } + + @ModelAttribute + public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException { + return updateAndReturn(model, "getC3", new C3()); + } + + @ModelAttribute + public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException { + return updateAndReturn(model, "getC4", new C4()); + } + } + + private static class InvertedTreeDependencyController extends AbstractController { + + @ModelAttribute + public C1 getC1(Model model) throws IOException { + return updateAndReturn(model, "getC1", new C1()); + } + + @ModelAttribute + public C2 getC2(Model model) throws IOException { + return updateAndReturn(model, "getC2", new C2()); + } + + @ModelAttribute + public C3 getC3(Model model) throws IOException { + return updateAndReturn(model, "getC3", new C3()); + } + + @ModelAttribute + public C4 getC4(Model model) throws IOException { + return updateAndReturn(model, "getC4", new C4()); + } + + @ModelAttribute + public B1 getB1(@ModelAttribute C1 c1, @ModelAttribute C2 c2, Model model) throws IOException { + return updateAndReturn(model, "getB1", new B1()); + } + + @ModelAttribute + public B2 getB2(@ModelAttribute C3 c3, @ModelAttribute C4 c4, Model model) throws IOException { + return updateAndReturn(model, "getB2", new B2()); + } + + @ModelAttribute + public A getA(@ModelAttribute B1 b1, @ModelAttribute B2 b2, Model model) throws IOException { + return updateAndReturn(model, "getA", new A()); + } + + } + + private static class UnresolvedDependencyController extends AbstractController { + + @ModelAttribute + public A getA(Model model) throws IOException { + return updateAndReturn(model, "getA", new A()); + } + + @ModelAttribute + public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException { + return updateAndReturn(model, "getC1", new C1()); + } + + @ModelAttribute + public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException { + return updateAndReturn(model, "getC2", new C2()); + } + + @ModelAttribute + public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException { + return updateAndReturn(model, "getC3", new C3()); + } + + @ModelAttribute + public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException { + return updateAndReturn(model, "getC4", new C4()); + } + } + + private static class A { } + private static class B1 { } + private static class B2 { } + private static class C1 { } + private static class C2 { } + private static class C3 { } + private static class C4 { } + + + private static final ReflectionUtils.MethodFilter METHOD_FILTER = new ReflectionUtils.MethodFilter() { + + @Override + public boolean matches(Method method) { + return ((AnnotationUtils.findAnnotation(method, RequestMapping.class) == null) && + (AnnotationUtils.findAnnotation(method, ModelAttribute.class) != null)); + } + }; + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..340908b49d6a0d437c540eaefac9477c8898d339 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryTests.java @@ -0,0 +1,313 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.LocalVariableTableParameterNameDiscoverer; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.ui.Model; +import org.springframework.ui.ModelMap; +import org.springframework.validation.BindingResult; +import org.springframework.web.HttpSessionRequiredException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.bind.support.DefaultSessionAttributeStore; +import org.springframework.web.bind.support.SessionAttributeStore; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.method.support.HandlerMethodArgumentResolverComposite; +import org.springframework.web.method.support.InvocableHandlerMethod; +import org.springframework.web.method.support.ModelAndViewContainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; + +/** + * Text fixture for {@link ModelFactory} tests. + * + * @author Rossen Stoyanchev + */ +public class ModelFactoryTests { + + private NativeWebRequest webRequest; + + private SessionAttributesHandler attributeHandler; + + private SessionAttributeStore attributeStore; + + private TestController controller = new TestController(); + + private ModelAndViewContainer mavContainer; + + + @Before + public void setUp() throws Exception { + this.webRequest = new ServletWebRequest(new MockHttpServletRequest()); + this.attributeStore = new DefaultSessionAttributeStore(); + this.attributeHandler = new SessionAttributesHandler(TestController.class, this.attributeStore); + this.controller = new TestController(); + this.mavContainer = new ModelAndViewContainer(); + } + + + @Test + public void modelAttributeMethod() throws Exception { + ModelFactory modelFactory = createModelFactory("modelAttr", Model.class); + HandlerMethod handlerMethod = createHandlerMethod("handle"); + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + + assertEquals(Boolean.TRUE, this.mavContainer.getModel().get("modelAttr")); + } + + @Test + public void modelAttributeMethodWithExplicitName() throws Exception { + ModelFactory modelFactory = createModelFactory("modelAttrWithName"); + HandlerMethod handlerMethod = createHandlerMethod("handle"); + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + + assertEquals(Boolean.TRUE, this.mavContainer.getModel().get("name")); + } + + @Test + public void modelAttributeMethodWithNameByConvention() throws Exception { + ModelFactory modelFactory = createModelFactory("modelAttrConvention"); + HandlerMethod handlerMethod = createHandlerMethod("handle"); + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + + assertEquals(Boolean.TRUE, this.mavContainer.getModel().get("boolean")); + } + + @Test + public void modelAttributeMethodWithNullReturnValue() throws Exception { + ModelFactory modelFactory = createModelFactory("nullModelAttr"); + HandlerMethod handlerMethod = createHandlerMethod("handle"); + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + + assertTrue(this.mavContainer.containsAttribute("name")); + assertNull(this.mavContainer.getModel().get("name")); + } + + @Test + public void modelAttributeWithBindingDisabled() throws Exception { + ModelFactory modelFactory = createModelFactory("modelAttrWithBindingDisabled"); + HandlerMethod handlerMethod = createHandlerMethod("handle"); + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + + assertTrue(this.mavContainer.containsAttribute("foo")); + assertTrue(this.mavContainer.isBindingDisabled("foo")); + } + + @Test + public void modelAttributeFromSessionWithBindingDisabled() throws Exception { + Foo foo = new Foo(); + this.attributeStore.storeAttribute(this.webRequest, "foo", foo); + + ModelFactory modelFactory = createModelFactory("modelAttrWithBindingDisabled"); + HandlerMethod handlerMethod = createHandlerMethod("handle"); + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + + assertTrue(this.mavContainer.containsAttribute("foo")); + assertSame(foo, this.mavContainer.getModel().get("foo")); + assertTrue(this.mavContainer.isBindingDisabled("foo")); + } + + @Test + public void sessionAttribute() throws Exception { + this.attributeStore.storeAttribute(this.webRequest, "sessionAttr", "sessionAttrValue"); + + ModelFactory modelFactory = createModelFactory("modelAttr", Model.class); + HandlerMethod handlerMethod = createHandlerMethod("handle"); + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + + assertEquals("sessionAttrValue", this.mavContainer.getModel().get("sessionAttr")); + } + + @Test + public void sessionAttributeNotPresent() throws Exception { + ModelFactory modelFactory = new ModelFactory(null, null, this.attributeHandler); + HandlerMethod handlerMethod = createHandlerMethod("handleSessionAttr", String.class); + try { + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + fail("Expected HttpSessionRequiredException"); + } + catch (HttpSessionRequiredException ex) { + // expected + } + + // Now add attribute and try again + this.attributeStore.storeAttribute(this.webRequest, "sessionAttr", "sessionAttrValue"); + + modelFactory.initModel(this.webRequest, this.mavContainer, handlerMethod); + assertEquals("sessionAttrValue", this.mavContainer.getModel().get("sessionAttr")); + } + + @Test + public void updateModelBindingResult() throws Exception { + String commandName = "attr1"; + Object command = new Object(); + ModelAndViewContainer container = new ModelAndViewContainer(); + container.addAttribute(commandName, command); + + WebDataBinder dataBinder = new WebDataBinder(command, commandName); + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(this.webRequest, command, commandName)).willReturn(dataBinder); + + ModelFactory modelFactory = new ModelFactory(null, binderFactory, this.attributeHandler); + modelFactory.updateModel(this.webRequest, container); + + assertEquals(command, container.getModel().get(commandName)); + String bindingResultKey = BindingResult.MODEL_KEY_PREFIX + commandName; + assertSame(dataBinder.getBindingResult(), container.getModel().get(bindingResultKey)); + assertEquals(2, container.getModel().size()); + } + + @Test + public void updateModelSessionAttributesSaved() throws Exception { + String attributeName = "sessionAttr"; + String attribute = "value"; + ModelAndViewContainer container = new ModelAndViewContainer(); + container.addAttribute(attributeName, attribute); + + WebDataBinder dataBinder = new WebDataBinder(attribute, attributeName); + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(this.webRequest, attribute, attributeName)).willReturn(dataBinder); + + ModelFactory modelFactory = new ModelFactory(null, binderFactory, this.attributeHandler); + modelFactory.updateModel(this.webRequest, container); + + assertEquals(attribute, container.getModel().get(attributeName)); + assertEquals(attribute, this.attributeStore.retrieveAttribute(this.webRequest, attributeName)); + } + + @Test + public void updateModelSessionAttributesRemoved() throws Exception { + String attributeName = "sessionAttr"; + String attribute = "value"; + ModelAndViewContainer container = new ModelAndViewContainer(); + container.addAttribute(attributeName, attribute); + + this.attributeStore.storeAttribute(this.webRequest, attributeName, attribute); + + WebDataBinder dataBinder = new WebDataBinder(attribute, attributeName); + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(this.webRequest, attribute, attributeName)).willReturn(dataBinder); + + container.getSessionStatus().setComplete(); + + ModelFactory modelFactory = new ModelFactory(null, binderFactory, this.attributeHandler); + modelFactory.updateModel(this.webRequest, container); + + assertEquals(attribute, container.getModel().get(attributeName)); + assertNull(this.attributeStore.retrieveAttribute(this.webRequest, attributeName)); + } + + @Test // SPR-12542 + public void updateModelWhenRedirecting() throws Exception { + String attributeName = "sessionAttr"; + String attribute = "value"; + ModelAndViewContainer container = new ModelAndViewContainer(); + container.addAttribute(attributeName, attribute); + + String queryParam = "123"; + String queryParamName = "q"; + container.setRedirectModel(new ModelMap(queryParamName, queryParam)); + container.setRedirectModelScenario(true); + + WebDataBinder dataBinder = new WebDataBinder(attribute, attributeName); + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(this.webRequest, attribute, attributeName)).willReturn(dataBinder); + + ModelFactory modelFactory = new ModelFactory(null, binderFactory, this.attributeHandler); + modelFactory.updateModel(this.webRequest, container); + + assertEquals(queryParam, container.getModel().get(queryParamName)); + assertEquals(1, container.getModel().size()); + assertEquals(attribute, this.attributeStore.retrieveAttribute(this.webRequest, attributeName)); + } + + + private ModelFactory createModelFactory(String methodName, Class... parameterTypes) throws Exception { + HandlerMethodArgumentResolverComposite resolvers = new HandlerMethodArgumentResolverComposite(); + resolvers.addResolver(new ModelMethodProcessor()); + + InvocableHandlerMethod modelMethod = createHandlerMethod(methodName, parameterTypes); + modelMethod.setHandlerMethodArgumentResolvers(resolvers); + modelMethod.setDataBinderFactory(null); + modelMethod.setParameterNameDiscoverer(new LocalVariableTableParameterNameDiscoverer()); + + return new ModelFactory(Collections.singletonList(modelMethod), null, this.attributeHandler); + } + + private InvocableHandlerMethod createHandlerMethod(String methodName, Class... paramTypes) throws Exception { + Method method = this.controller.getClass().getMethod(methodName, paramTypes); + return new InvocableHandlerMethod(this.controller, method); + } + + + @SessionAttributes({"sessionAttr", "foo"}) + static class TestController { + + @ModelAttribute + public void modelAttr(Model model) { + model.addAttribute("modelAttr", Boolean.TRUE); + } + + @ModelAttribute("name") + public Boolean modelAttrWithName() { + return Boolean.TRUE; + } + + @ModelAttribute + public Boolean modelAttrConvention() { + return Boolean.TRUE; + } + + @ModelAttribute("name") + public Boolean nullModelAttr() { + return null; + } + + @ModelAttribute(name="foo", binding=false) + public Foo modelAttrWithBindingDisabled() { + return new Foo(); + } + + public void handle() { + } + + public void handleSessionAttr(@ModelAttribute("sessionAttr") String sessionAttr) { + } + } + + + private static class Foo { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/ModelMethodProcessorTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelMethodProcessorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..2a1662d46c9ddb89578656b9eaa9289d027039d5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/ModelMethodProcessorTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.ui.ExtendedModelMap; +import org.springframework.ui.Model; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.support.ModelAndViewContainer; + +import static org.junit.Assert.*; + +/** + * Test fixture with {@link org.springframework.web.method.annotation.ModelMethodProcessor}. + * + * @author Rossen Stoyanchev + */ +public class ModelMethodProcessorTests { + + private ModelMethodProcessor processor; + + private ModelAndViewContainer mavContainer; + + private MethodParameter paramModel; + + private MethodParameter returnParamModel; + + private NativeWebRequest webRequest; + + @Before + public void setUp() throws Exception { + processor = new ModelMethodProcessor(); + mavContainer = new ModelAndViewContainer(); + + Method method = getClass().getDeclaredMethod("model", Model.class); + paramModel = new MethodParameter(method, 0); + returnParamModel = new MethodParameter(method, -1); + + webRequest = new ServletWebRequest(new MockHttpServletRequest()); + } + + @Test + public void supportsParameter() { + assertTrue(processor.supportsParameter(paramModel)); + } + + @Test + public void supportsReturnType() { + assertTrue(processor.supportsReturnType(returnParamModel)); + } + + @Test + public void resolveArgumentValue() throws Exception { + assertSame(mavContainer.getModel(), processor.resolveArgument(paramModel, mavContainer, webRequest, null)); + } + + @Test + public void handleModelReturnValue() throws Exception { + mavContainer.addAttribute("attr1", "value1"); + Model returnValue = new ExtendedModelMap(); + returnValue.addAttribute("attr2", "value2"); + + processor.handleReturnValue(returnValue , returnParamModel, mavContainer, webRequest); + + assertEquals("value1", mavContainer.getModel().get("attr1")); + assertEquals("value2", mavContainer.getModel().get("attr2")); + } + + @SuppressWarnings("unused") + private Model model(Model model) { + return null; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/RequestHeaderMapMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestHeaderMapMethodArgumentResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8e16ccee015fd3293bd850182bd8e44eea88a2d2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestHeaderMapMethodArgumentResolverTests.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; + +/** + * Text fixture with {@link RequestHeaderMapMethodArgumentResolver}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class RequestHeaderMapMethodArgumentResolverTests { + + private RequestHeaderMapMethodArgumentResolver resolver; + + private MethodParameter paramMap; + + private MethodParameter paramMultiValueMap; + + private MethodParameter paramHttpHeaders; + + private MethodParameter paramUnsupported; + + private NativeWebRequest webRequest; + + private MockHttpServletRequest request; + + + @Before + public void setup() throws Exception { + resolver = new RequestHeaderMapMethodArgumentResolver(); + + Method method = getClass().getMethod("params", Map.class, MultiValueMap.class, HttpHeaders.class, Map.class); + paramMap = new SynthesizingMethodParameter(method, 0); + paramMultiValueMap = new SynthesizingMethodParameter(method, 1); + paramHttpHeaders = new SynthesizingMethodParameter(method, 2); + paramUnsupported = new SynthesizingMethodParameter(method, 3); + + request = new MockHttpServletRequest(); + webRequest = new ServletWebRequest(request, new MockHttpServletResponse()); + } + + + @Test + public void supportsParameter() { + assertTrue("Map parameter not supported", resolver.supportsParameter(paramMap)); + assertTrue("MultiValueMap parameter not supported", resolver.supportsParameter(paramMultiValueMap)); + assertTrue("HttpHeaders parameter not supported", resolver.supportsParameter(paramHttpHeaders)); + assertFalse("non-@RequestParam map supported", resolver.supportsParameter(paramUnsupported)); + } + + @Test + public void resolveMapArgument() throws Exception { + String name = "foo"; + String value = "bar"; + Map expected = Collections.singletonMap(name, value); + request.addHeader(name, value); + + Object result = resolver.resolveArgument(paramMap, null, webRequest, null); + + assertTrue(result instanceof Map); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolveMultiValueMapArgument() throws Exception { + String name = "foo"; + String value1 = "bar"; + String value2 = "baz"; + + request.addHeader(name, value1); + request.addHeader(name, value2); + + MultiValueMap expected = new LinkedMultiValueMap<>(1); + expected.add(name, value1); + expected.add(name, value2); + + Object result = resolver.resolveArgument(paramMultiValueMap, null, webRequest, null); + + assertTrue(result instanceof MultiValueMap); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolveHttpHeadersArgument() throws Exception { + String name = "foo"; + String value1 = "bar"; + String value2 = "baz"; + + request.addHeader(name, value1); + request.addHeader(name, value2); + + HttpHeaders expected = new HttpHeaders(); + expected.add(name, value1); + expected.add(name, value2); + + Object result = resolver.resolveArgument(paramHttpHeaders, null, webRequest, null); + + assertTrue(result instanceof HttpHeaders); + assertEquals("Invalid result", expected, result); + } + + + public void params(@RequestHeader Map param1, + @RequestHeader MultiValueMap param2, @RequestHeader HttpHeaders param3, + Map unsupported) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/RequestHeaderMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestHeaderMethodArgumentResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e21163f60a1c852e5ec59ddc54697d451ef61d4d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestHeaderMethodArgumentResolverTests.java @@ -0,0 +1,237 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.lang.reflect.Method; +import java.time.Instant; +import java.time.format.DateTimeFormatter; +import java.util.Date; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.format.support.DefaultFormattingConversionService; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.ReflectionUtils; +import org.springframework.web.bind.ServletRequestBindingException; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; +import org.springframework.web.bind.support.DefaultDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.context.support.GenericWebApplicationContext; + +import static org.junit.Assert.*; + +/** + * Test fixture with {@link RequestHeaderMethodArgumentResolver}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class RequestHeaderMethodArgumentResolverTests { + + private RequestHeaderMethodArgumentResolver resolver; + + private MethodParameter paramNamedDefaultValueStringHeader; + private MethodParameter paramNamedValueStringArray; + private MethodParameter paramSystemProperty; + private MethodParameter paramContextPath; + private MethodParameter paramResolvedNameWithExpression; + private MethodParameter paramResolvedNameWithPlaceholder; + private MethodParameter paramNamedValueMap; + private MethodParameter paramDate; + private MethodParameter paramInstant; + + private MockHttpServletRequest servletRequest; + + private NativeWebRequest webRequest; + + + @Before + @SuppressWarnings("resource") + public void setup() throws Exception { + GenericWebApplicationContext context = new GenericWebApplicationContext(); + context.refresh(); + resolver = new RequestHeaderMethodArgumentResolver(context.getBeanFactory()); + + Method method = ReflectionUtils.findMethod(getClass(), "params", (Class[]) null); + paramNamedDefaultValueStringHeader = new SynthesizingMethodParameter(method, 0); + paramNamedValueStringArray = new SynthesizingMethodParameter(method, 1); + paramSystemProperty = new SynthesizingMethodParameter(method, 2); + paramContextPath = new SynthesizingMethodParameter(method, 3); + paramResolvedNameWithExpression = new SynthesizingMethodParameter(method, 4); + paramResolvedNameWithPlaceholder = new SynthesizingMethodParameter(method, 5); + paramNamedValueMap = new SynthesizingMethodParameter(method, 6); + paramDate = new SynthesizingMethodParameter(method, 7); + paramInstant = new SynthesizingMethodParameter(method, 8); + + servletRequest = new MockHttpServletRequest(); + webRequest = new ServletWebRequest(servletRequest, new MockHttpServletResponse()); + + // Expose request to the current thread (for SpEL expressions) + RequestContextHolder.setRequestAttributes(webRequest); + } + + @After + public void reset() { + RequestContextHolder.resetRequestAttributes(); + } + + + @Test + public void supportsParameter() { + assertTrue("String parameter not supported", resolver.supportsParameter(paramNamedDefaultValueStringHeader)); + assertTrue("String array parameter not supported", resolver.supportsParameter(paramNamedValueStringArray)); + assertFalse("non-@RequestParam parameter supported", resolver.supportsParameter(paramNamedValueMap)); + } + + @Test + public void resolveStringArgument() throws Exception { + String expected = "foo"; + servletRequest.addHeader("name", expected); + + Object result = resolver.resolveArgument(paramNamedDefaultValueStringHeader, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals(expected, result); + } + + @Test + public void resolveStringArrayArgument() throws Exception { + String[] expected = new String[] {"foo", "bar"}; + servletRequest.addHeader("name", expected); + + Object result = resolver.resolveArgument(paramNamedValueStringArray, null, webRequest, null); + assertTrue(result instanceof String[]); + assertArrayEquals(expected, (String[]) result); + } + + @Test + public void resolveDefaultValue() throws Exception { + Object result = resolver.resolveArgument(paramNamedDefaultValueStringHeader, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals("bar", result); + } + + @Test + public void resolveDefaultValueFromSystemProperty() throws Exception { + System.setProperty("systemProperty", "bar"); + try { + Object result = resolver.resolveArgument(paramSystemProperty, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals("bar", result); + } + finally { + System.clearProperty("systemProperty"); + } + } + + @Test + public void resolveNameFromSystemPropertyThroughExpression() throws Exception { + String expected = "foo"; + servletRequest.addHeader("bar", expected); + + System.setProperty("systemProperty", "bar"); + try { + Object result = resolver.resolveArgument(paramResolvedNameWithExpression, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals(expected, result); + } + finally { + System.clearProperty("systemProperty"); + } + } + + @Test + public void resolveNameFromSystemPropertyThroughPlaceholder() throws Exception { + String expected = "foo"; + servletRequest.addHeader("bar", expected); + + System.setProperty("systemProperty", "bar"); + try { + Object result = resolver.resolveArgument(paramResolvedNameWithPlaceholder, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals(expected, result); + } + finally { + System.clearProperty("systemProperty"); + } + } + + @Test + public void resolveDefaultValueFromRequest() throws Exception { + servletRequest.setContextPath("/bar"); + + Object result = resolver.resolveArgument(paramContextPath, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals("/bar", result); + } + + @Test(expected = ServletRequestBindingException.class) + public void notFound() throws Exception { + resolver.resolveArgument(paramNamedValueStringArray, null, webRequest, null); + } + + @Test + @SuppressWarnings("deprecation") + public void dateConversion() throws Exception { + String rfc1123val = "Thu, 21 Apr 2016 17:11:08 +0100"; + servletRequest.addHeader("name", rfc1123val); + + ConfigurableWebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer(); + bindingInitializer.setConversionService(new DefaultFormattingConversionService()); + Object result = resolver.resolveArgument(paramDate, null, webRequest, + new DefaultDataBinderFactory(bindingInitializer)); + + assertTrue(result instanceof Date); + assertEquals(new Date(rfc1123val), result); + } + + @Test + public void instantConversion() throws Exception { + String rfc1123val = "Thu, 21 Apr 2016 17:11:08 +0100"; + servletRequest.addHeader("name", rfc1123val); + + ConfigurableWebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer(); + bindingInitializer.setConversionService(new DefaultFormattingConversionService()); + Object result = resolver.resolveArgument(paramInstant, null, webRequest, + new DefaultDataBinderFactory(bindingInitializer)); + + assertTrue(result instanceof Instant); + assertEquals(Instant.from(DateTimeFormatter.RFC_1123_DATE_TIME.parse(rfc1123val)), result); + } + + + public void params( + @RequestHeader(name = "name", defaultValue = "bar") String param1, + @RequestHeader("name") String[] param2, + @RequestHeader(name = "name", defaultValue="#{systemProperties.systemProperty}") String param3, + @RequestHeader(name = "name", defaultValue="#{request.contextPath}") String param4, + @RequestHeader("#{systemProperties.systemProperty}") String param5, + @RequestHeader("${systemProperty}") String param6, + @RequestHeader("name") Map unsupported, + @RequestHeader("name") Date dateParam, + @RequestHeader("name") Instant instantParam) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMapMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMapMethodArgumentResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..19ec7996ac5a5c96f171450db6665871c8dc029f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMapMethodArgumentResolverTests.java @@ -0,0 +1,212 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Collections; +import java.util.Map; + +import javax.servlet.http.Part; + +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.mock.web.test.MockMultipartFile; +import org.springframework.mock.web.test.MockMultipartHttpServletRequest; +import org.springframework.mock.web.test.MockPart; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.ResolvableMethod; +import org.springframework.web.multipart.MultipartFile; + +import static org.junit.Assert.*; +import static org.springframework.web.method.MvcAnnotationPredicates.*; + +/** + * Test fixture with {@link RequestParamMapMethodArgumentResolver}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class RequestParamMapMethodArgumentResolverTests { + + private RequestParamMapMethodArgumentResolver resolver = new RequestParamMapMethodArgumentResolver(); + + private MockHttpServletRequest request = new MockHttpServletRequest(); + + private NativeWebRequest webRequest = new ServletWebRequest(request, new MockHttpServletResponse()); + + private ResolvableMethod testMethod = ResolvableMethod.on(getClass()).named("handle").build(); + + + @Test + public void supportsParameter() { + MethodParameter param = this.testMethod.annot(requestParam().noName()).arg(Map.class, String.class, String.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(MultiValueMap.class, String.class, String.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestParam().name("name")).arg(Map.class, String.class, String.class); + assertFalse(resolver.supportsParameter(param)); + + param = this.testMethod.annotNotPresent(RequestParam.class).arg(Map.class, String.class, String.class); + assertFalse(resolver.supportsParameter(param)); + } + + @Test + public void resolveMapOfString() throws Exception { + String name = "foo"; + String value = "bar"; + request.addParameter(name, value); + Map expected = Collections.singletonMap(name, value); + + MethodParameter param = this.testMethod.annot(requestParam().noName()).arg(Map.class, String.class, String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof Map); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolveMultiValueMapOfString() throws Exception { + String name = "foo"; + String value1 = "bar"; + String value2 = "baz"; + request.addParameter(name, value1, value2); + + MultiValueMap expected = new LinkedMultiValueMap<>(1); + expected.add(name, value1); + expected.add(name, value2); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(MultiValueMap.class, String.class, String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof MultiValueMap); + assertEquals("Invalid result", expected, result); + } + + @Test + @SuppressWarnings("unchecked") + public void resolveMapOfMultipartFile() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected1 = new MockMultipartFile("mfile", "Hello World".getBytes()); + MultipartFile expected2 = new MockMultipartFile("other", "Hello World 3".getBytes()); + request.addFile(expected1); + request.addFile(expected2); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annot(requestParam().noName()).arg(Map.class, String.class, MultipartFile.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof Map); + Map resultMap = (Map) result; + assertEquals(2, resultMap.size()); + assertEquals(expected1, resultMap.get("mfile")); + assertEquals(expected2, resultMap.get("other")); + } + + @Test + @SuppressWarnings("unchecked") + public void resolveMultiValueMapOfMultipartFile() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected1 = new MockMultipartFile("mfilelist", "Hello World 1".getBytes()); + MultipartFile expected2 = new MockMultipartFile("mfilelist", "Hello World 2".getBytes()); + MultipartFile expected3 = new MockMultipartFile("other", "Hello World 3".getBytes()); + request.addFile(expected1); + request.addFile(expected2); + request.addFile(expected3); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annot(requestParam().noName()).arg(MultiValueMap.class, String.class, MultipartFile.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof MultiValueMap); + MultiValueMap resultMap = (MultiValueMap) result; + assertEquals(2, resultMap.size()); + assertEquals(2, resultMap.get("mfilelist").size()); + assertEquals(expected1, resultMap.get("mfilelist").get(0)); + assertEquals(expected2, resultMap.get("mfilelist").get(1)); + assertEquals(1, resultMap.get("other").size()); + assertEquals(expected3, resultMap.get("other").get(0)); + } + + @Test + @SuppressWarnings("unchecked") + public void resolveMapOfPart() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContentType("multipart/form-data"); + Part expected1 = new MockPart("mfile", "Hello World".getBytes()); + Part expected2 = new MockPart("other", "Hello World 3".getBytes()); + request.addPart(expected1); + request.addPart(expected2); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annot(requestParam().noName()).arg(Map.class, String.class, Part.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof Map); + Map resultMap = (Map) result; + assertEquals(2, resultMap.size()); + assertEquals(expected1, resultMap.get("mfile")); + assertEquals(expected2, resultMap.get("other")); + } + + @Test + @SuppressWarnings("unchecked") + public void resolveMultiValueMapOfPart() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContentType("multipart/form-data"); + Part expected1 = new MockPart("mfilelist", "Hello World 1".getBytes()); + Part expected2 = new MockPart("mfilelist", "Hello World 2".getBytes()); + Part expected3 = new MockPart("other", "Hello World 3".getBytes()); + request.addPart(expected1); + request.addPart(expected2); + request.addPart(expected3); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annot(requestParam().noName()).arg(MultiValueMap.class, String.class, Part.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof MultiValueMap); + MultiValueMap resultMap = (MultiValueMap) result; + assertEquals(2, resultMap.size()); + assertEquals(2, resultMap.get("mfilelist").size()); + assertEquals(expected1, resultMap.get("mfilelist").get(0)); + assertEquals(expected2, resultMap.get("mfilelist").get(1)); + assertEquals(1, resultMap.get("other").size()); + assertEquals(expected3, resultMap.get("other").get(0)); + } + + + public void handle( + @RequestParam Map param1, + @RequestParam MultiValueMap param2, + @RequestParam Map param3, + @RequestParam MultiValueMap param4, + @RequestParam Map param5, + @RequestParam MultiValueMap param6, + @RequestParam("name") Map param7, + Map param8) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..18af6d462ac1551e479dee60e661e0f79541d2ea --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverTests.java @@ -0,0 +1,613 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.servlet.http.Part; + +import org.junit.Test; + +import org.springframework.beans.propertyeditors.StringTrimmerEditor; +import org.springframework.core.MethodParameter; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.mock.web.test.MockMultipartFile; +import org.springframework.mock.web.test.MockMultipartHttpServletRequest; +import org.springframework.mock.web.test.MockPart; +import org.springframework.web.bind.MissingServletRequestParameterException; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; +import org.springframework.web.bind.support.DefaultDataBinderFactory; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.bind.support.WebRequestDataBinder; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.ResolvableMethod; +import org.springframework.web.multipart.MultipartException; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.support.MissingServletRequestPartException; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; +import static org.springframework.web.method.MvcAnnotationPredicates.*; + +/** + * Test fixture with {@link RequestParamMethodArgumentResolver}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Brian Clozel + */ +public class RequestParamMethodArgumentResolverTests { + + private RequestParamMethodArgumentResolver resolver = new RequestParamMethodArgumentResolver(null, true); + + private MockHttpServletRequest request = new MockHttpServletRequest(); + + private NativeWebRequest webRequest = new ServletWebRequest(request, new MockHttpServletResponse()); + + private ResolvableMethod testMethod = ResolvableMethod.on(getClass()).named("handle").build(); + + + @Test + public void supportsParameter() { + resolver = new RequestParamMethodArgumentResolver(null, true); + + MethodParameter param = this.testMethod.annot(requestParam().notRequired("bar")).arg(String.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(String[].class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestParam().name("name")).arg(Map.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(MultipartFile.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(List.class, MultipartFile.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(MultipartFile[].class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(Part.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(List.class, Part.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(Part[].class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestParam().noName()).arg(Map.class); + assertFalse(resolver.supportsParameter(param)); + + param = this.testMethod.annotNotPresent(RequestParam.class).arg(String.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotNotPresent().arg(MultipartFile.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotNotPresent(RequestParam.class).arg(List.class, MultipartFile.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotNotPresent(RequestParam.class).arg(Part.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestPart()).arg(MultipartFile.class); + assertFalse(resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestParam()).arg(String.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestParam().notRequired()).arg(String.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, Integer.class); + assertTrue(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, MultipartFile.class); + assertTrue(resolver.supportsParameter(param)); + + resolver = new RequestParamMethodArgumentResolver(null, false); + + param = this.testMethod.annotNotPresent(RequestParam.class).arg(String.class); + assertFalse(resolver.supportsParameter(param)); + + param = this.testMethod.annotPresent(RequestPart.class).arg(MultipartFile.class); + assertFalse(resolver.supportsParameter(param)); + } + + @Test + public void resolveString() throws Exception { + String expected = "foo"; + request.addParameter("name", expected); + + MethodParameter param = this.testMethod.annot(requestParam().notRequired("bar")).arg(String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolveStringArray() throws Exception { + String[] expected = new String[] {"foo", "bar"}; + request.addParameter("name", expected); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(String[].class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(result instanceof String[]); + assertArrayEquals("Invalid result", expected, (String[]) result); + } + + @Test + public void resolveMultipartFile() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected = new MockMultipartFile("mfile", "Hello World".getBytes()); + request.addFile(expected); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(MultipartFile.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(result instanceof MultipartFile); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolveMultipartFileList() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected1 = new MockMultipartFile("mfilelist", "Hello World 1".getBytes()); + MultipartFile expected2 = new MockMultipartFile("mfilelist", "Hello World 2".getBytes()); + request.addFile(expected1); + request.addFile(expected2); + request.addFile(new MockMultipartFile("other", "Hello World 3".getBytes())); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(List.class, MultipartFile.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof List); + assertEquals(Arrays.asList(expected1, expected2), result); + } + + @Test + public void resolveMultipartFileArray() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected1 = new MockMultipartFile("mfilearray", "Hello World 1".getBytes()); + MultipartFile expected2 = new MockMultipartFile("mfilearray", "Hello World 2".getBytes()); + request.addFile(expected1); + request.addFile(expected2); + request.addFile(new MockMultipartFile("other", "Hello World 3".getBytes())); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(MultipartFile[].class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof MultipartFile[]); + MultipartFile[] parts = (MultipartFile[]) result; + assertEquals(2, parts.length); + assertEquals(parts[0], expected1); + assertEquals(parts[1], expected2); + } + + @Test + public void resolvePart() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockPart expected = new MockPart("pfile", "Hello World".getBytes()); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + request.addPart(expected); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Part.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof Part); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolvePartList() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + MockPart expected1 = new MockPart("pfilelist", "Hello World 1".getBytes()); + MockPart expected2 = new MockPart("pfilelist", "Hello World 2".getBytes()); + request.addPart(expected1); + request.addPart(expected2); + request.addPart(new MockPart("other", "Hello World 3".getBytes())); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(List.class, Part.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof List); + assertEquals(Arrays.asList(expected1, expected2), result); + } + + @Test + public void resolvePartArray() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockPart expected1 = new MockPart("pfilearray", "Hello World 1".getBytes()); + MockPart expected2 = new MockPart("pfilearray", "Hello World 2".getBytes()); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + request.addPart(expected1); + request.addPart(expected2); + request.addPart(new MockPart("other", "Hello World 3".getBytes())); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Part[].class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof Part[]); + Part[] parts = (Part[]) result; + assertEquals(2, parts.length); + assertEquals(parts[0], expected1); + assertEquals(parts[1], expected2); + } + + @Test + public void resolveMultipartFileNotAnnot() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected = new MockMultipartFile("multipartFileNotAnnot", "Hello World".getBytes()); + request.addFile(expected); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotNotPresent().arg(MultipartFile.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(result instanceof MultipartFile); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolveMultipartFileListNotannot() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected1 = new MockMultipartFile("multipartFileList", "Hello World 1".getBytes()); + MultipartFile expected2 = new MockMultipartFile("multipartFileList", "Hello World 2".getBytes()); + request.addFile(expected1); + request.addFile(expected2); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod + .annotNotPresent(RequestParam.class).arg(List.class, MultipartFile.class); + + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(result instanceof List); + assertEquals(Arrays.asList(expected1, expected2), result); + } + + @Test(expected = MultipartException.class) + public void isMultipartRequest() throws Exception { + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(MultipartFile.class); + resolver.resolveArgument(param, null, webRequest, null); + fail("Expected exception: request is not a multipart request"); + } + + @Test // SPR-9079 + public void isMultipartRequestHttpPut() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected = new MockMultipartFile("multipartFileList", "Hello World".getBytes()); + request.addFile(expected); + request.setMethod("PUT"); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod + .annotNotPresent(RequestParam.class).arg(List.class, MultipartFile.class); + + Object actual = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(actual instanceof List); + assertEquals(expected, ((List) actual).get(0)); + } + + @Test(expected = MultipartException.class) + public void noMultipartContent() throws Exception { + request.setMethod("POST"); + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(MultipartFile.class); + resolver.resolveArgument(param, null, webRequest, null); + fail("Expected exception: no multipart content"); + } + + @Test(expected = MissingServletRequestPartException.class) + public void missingMultipartFile() throws Exception { + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(MultipartFile.class); + resolver.resolveArgument(param, null, webRequest, null); + fail("Expected exception: no such part found"); + } + + @Test + public void resolvePartNotAnnot() throws Exception { + MockPart expected = new MockPart("part", "Hello World".getBytes()); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + request.addPart(expected); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotNotPresent(RequestParam.class).arg(Part.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(result instanceof Part); + assertEquals("Invalid result", expected, result); + } + + @Test + public void resolveDefaultValue() throws Exception { + MethodParameter param = this.testMethod.annot(requestParam().notRequired("bar")).arg(String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertTrue(result instanceof String); + assertEquals("Invalid result", "bar", result); + } + + @Test(expected = MissingServletRequestParameterException.class) + public void missingRequestParam() throws Exception { + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(String[].class); + resolver.resolveArgument(param, null, webRequest, null); + fail("Expected exception"); + } + + @Test // SPR-10578 + public void missingRequestParamEmptyValueConvertedToNull() throws Exception { + WebDataBinder binder = new WebRequestDataBinder(null); + binder.registerCustomEditor(String.class, new StringTrimmerEditor(true)); + + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(webRequest, null, "stringNotAnnot")).willReturn(binder); + + request.addParameter("stringNotAnnot", ""); + + MethodParameter param = this.testMethod.annotNotPresent(RequestParam.class).arg(String.class); + Object arg = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertNull(arg); + } + + @Test + public void missingRequestParamEmptyValueNotRequired() throws Exception { + WebDataBinder binder = new WebRequestDataBinder(null); + binder.registerCustomEditor(String.class, new StringTrimmerEditor(true)); + + WebDataBinderFactory binderFactory = mock(WebDataBinderFactory.class); + given(binderFactory.createBinder(webRequest, null, "name")).willReturn(binder); + + request.addParameter("name", ""); + + MethodParameter param = this.testMethod.annot(requestParam().notRequired()).arg(String.class); + Object arg = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertNull(arg); + } + + @Test + public void resolveSimpleTypeParam() throws Exception { + request.setParameter("stringNotAnnot", "plainValue"); + MethodParameter param = this.testMethod.annotNotPresent(RequestParam.class).arg(String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + + assertTrue(result instanceof String); + assertEquals("plainValue", result); + } + + @Test // SPR-8561 + public void resolveSimpleTypeParamToNull() throws Exception { + MethodParameter param = this.testMethod.annotNotPresent(RequestParam.class).arg(String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertNull(result); + } + + @Test // SPR-10180 + public void resolveEmptyValueToDefault() throws Exception { + request.addParameter("name", ""); + MethodParameter param = this.testMethod.annot(requestParam().notRequired("bar")).arg(String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertEquals("bar", result); + } + + @Test + public void resolveEmptyValueWithoutDefault() throws Exception { + request.addParameter("stringNotAnnot", ""); + MethodParameter param = this.testMethod.annotNotPresent(RequestParam.class).arg(String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertEquals("", result); + } + + @Test + public void resolveEmptyValueRequiredWithoutDefault() throws Exception { + request.addParameter("name", ""); + MethodParameter param = this.testMethod.annot(requestParam().notRequired()).arg(String.class); + Object result = resolver.resolveArgument(param, null, webRequest, null); + assertEquals("", result); + } + + @Test + @SuppressWarnings("rawtypes") + public void resolveOptionalParamValue() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, Integer.class); + Object result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.empty(), result); + + request.addParameter("name", "123"); + result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.class, result.getClass()); + assertEquals(123, ((Optional) result).get()); + } + + @Test + @SuppressWarnings("rawtypes") + public void missingOptionalParamValue() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, Integer.class); + Object result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.empty(), result); + + result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.class, result.getClass()); + assertFalse(((Optional) result).isPresent()); + } + + @Test + @SuppressWarnings("rawtypes") + public void resolveOptionalParamArray() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, Integer[].class); + Object result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.empty(), result); + + request.addParameter("name", "123", "456"); + result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.class, result.getClass()); + assertArrayEquals(new Integer[] {123, 456}, (Integer[]) ((Optional) result).get()); + } + + @Test + @SuppressWarnings("rawtypes") + public void missingOptionalParamArray() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, Integer[].class); + Object result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.empty(), result); + + result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.class, result.getClass()); + assertFalse(((Optional) result).isPresent()); + } + + @Test + @SuppressWarnings("rawtypes") + public void resolveOptionalParamList() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, List.class); + Object result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.empty(), result); + + request.addParameter("name", "123", "456"); + result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.class, result.getClass()); + assertEquals(Arrays.asList("123", "456"), ((Optional) result).get()); + } + + @Test + @SuppressWarnings("rawtypes") + public void missingOptionalParamList() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, List.class); + Object result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.empty(), result); + + result = resolver.resolveArgument(param, null, webRequest, binderFactory); + assertEquals(Optional.class, result.getClass()); + assertFalse(((Optional) result).isPresent()); + } + + @Test + public void resolveOptionalMultipartFile() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected = new MockMultipartFile("mfile", "Hello World".getBytes()); + request.addFile(expected); + webRequest = new ServletWebRequest(request); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, MultipartFile.class); + Object result = resolver.resolveArgument(param, null, webRequest, binderFactory); + + assertTrue(result instanceof Optional); + assertEquals("Invalid result", expected, ((Optional) result).get()); + } + + @Test + public void missingOptionalMultipartFile() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, MultipartFile.class); + Object actual = resolver.resolveArgument(param, null, webRequest, binderFactory); + + assertEquals(Optional.empty(), actual); + } + + @Test + public void optionalMultipartFileWithoutMultipartRequest() throws Exception { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(new DefaultConversionService()); + WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer); + + MethodParameter param = this.testMethod.annotPresent(RequestParam.class).arg(Optional.class, MultipartFile.class); + Object actual = resolver.resolveArgument(param, null, webRequest, binderFactory); + + assertEquals(Optional.empty(), actual); + } + + + @SuppressWarnings({"unused", "OptionalUsedAsFieldOrParameterType"}) + public void handle( + @RequestParam(name = "name", defaultValue = "bar") String param1, + @RequestParam("name") String[] param2, + @RequestParam("name") Map param3, + @RequestParam("mfile") MultipartFile param4, + @RequestParam("mfilelist") List param5, + @RequestParam("mfilearray") MultipartFile[] param6, + @RequestParam("pfile") Part param7, + @RequestParam("pfilelist") List param8, + @RequestParam("pfilearray") Part[] param9, + @RequestParam Map param10, + String stringNotAnnot, + MultipartFile multipartFileNotAnnot, + List multipartFileList, + Part part, + @RequestPart MultipartFile requestPartAnnot, + @RequestParam("name") String paramRequired, + @RequestParam(name = "name", required = false) String paramNotRequired, + @RequestParam("name") Optional paramOptional, + @RequestParam("name") Optional paramOptionalArray, + @RequestParam("name") Optional paramOptionalList, + @RequestParam("mfile") Optional multipartFileOptional) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/SessionAttributesHandlerTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/SessionAttributesHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..413d11da1d81bc28627b193026aec3560a79356a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/SessionAttributesHandlerTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + + +import java.util.HashSet; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.ui.ModelMap; +import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.bind.support.DefaultSessionAttributeStore; +import org.springframework.web.bind.support.SessionAttributeStore; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Test fixture with {@link SessionAttributesHandler}. + * @author Rossen Stoyanchev + */ +public class SessionAttributesHandlerTests { + + private final SessionAttributeStore sessionAttributeStore = new DefaultSessionAttributeStore(); + + private final SessionAttributesHandler sessionAttributesHandler = new SessionAttributesHandler( + SessionAttributeHandler.class, sessionAttributeStore); + + private final NativeWebRequest request = new ServletWebRequest(new MockHttpServletRequest()); + + + @Test + public void isSessionAttribute() throws Exception { + assertTrue(sessionAttributesHandler.isHandlerSessionAttribute("attr1", String.class)); + assertTrue(sessionAttributesHandler.isHandlerSessionAttribute("attr2", String.class)); + assertTrue(sessionAttributesHandler.isHandlerSessionAttribute("simple", TestBean.class)); + assertFalse(sessionAttributesHandler.isHandlerSessionAttribute("simple", String.class)); + } + + @Test + public void retrieveAttributes() throws Exception { + sessionAttributeStore.storeAttribute(request, "attr1", "value1"); + sessionAttributeStore.storeAttribute(request, "attr2", "value2"); + sessionAttributeStore.storeAttribute(request, "attr3", new TestBean()); + sessionAttributeStore.storeAttribute(request, "attr4", new TestBean()); + + assertEquals("Named attributes (attr1, attr2) should be 'known' right away", + new HashSet<>(asList("attr1", "attr2")), + sessionAttributesHandler.retrieveAttributes(request).keySet()); + + // Resolve 'attr3' by type + sessionAttributesHandler.isHandlerSessionAttribute("attr3", TestBean.class); + + assertEquals("Named attributes (attr1, attr2) and resolved attribute (att3) should be 'known'", + new HashSet<>(asList("attr1", "attr2", "attr3")), + sessionAttributesHandler.retrieveAttributes(request).keySet()); + } + + @Test + public void cleanupAttributes() throws Exception { + sessionAttributeStore.storeAttribute(request, "attr1", "value1"); + sessionAttributeStore.storeAttribute(request, "attr2", "value2"); + sessionAttributeStore.storeAttribute(request, "attr3", new TestBean()); + + sessionAttributesHandler.cleanupAttributes(request); + + assertNull(sessionAttributeStore.retrieveAttribute(request, "attr1")); + assertNull(sessionAttributeStore.retrieveAttribute(request, "attr2")); + assertNotNull(sessionAttributeStore.retrieveAttribute(request, "attr3")); + + // Resolve 'attr3' by type + sessionAttributesHandler.isHandlerSessionAttribute("attr3", TestBean.class); + sessionAttributesHandler.cleanupAttributes(request); + + assertNull(sessionAttributeStore.retrieveAttribute(request, "attr3")); + } + + @Test + public void storeAttributes() throws Exception { + ModelMap model = new ModelMap(); + model.put("attr1", "value1"); + model.put("attr2", "value2"); + model.put("attr3", new TestBean()); + + sessionAttributesHandler.storeAttributes(request, model); + + assertEquals("value1", sessionAttributeStore.retrieveAttribute(request, "attr1")); + assertEquals("value2", sessionAttributeStore.retrieveAttribute(request, "attr2")); + assertTrue(sessionAttributeStore.retrieveAttribute(request, "attr3") instanceof TestBean); + } + + + @SessionAttributes(names = { "attr1", "attr2" }, types = { TestBean.class }) + private static class SessionAttributeHandler { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/WebArgumentResolverAdapterTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/WebArgumentResolverAdapterTests.java new file mode 100644 index 0000000000000000000000000000000000000000..02e40509baf68d776fdf8e9addb918cd313a95c6 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/WebArgumentResolverAdapterTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.web.bind.support.WebArgumentResolver; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletWebRequest; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * Test fixture with {@link WebArgumentResolverAdapterTests}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class WebArgumentResolverAdapterTests { + + private TestWebArgumentResolverAdapter adapter; + + private WebArgumentResolver adaptee; + + private MethodParameter parameter; + + private NativeWebRequest webRequest; + + @Before + public void setUp() throws Exception { + adaptee = mock(WebArgumentResolver.class); + adapter = new TestWebArgumentResolverAdapter(adaptee); + parameter = new MethodParameter(getClass().getMethod("handle", Integer.TYPE), 0); + webRequest = new ServletWebRequest(new MockHttpServletRequest()); + + // Expose request to the current thread (for SpEL expressions) + RequestContextHolder.setRequestAttributes(webRequest); + } + + @After + public void resetRequestContextHolder() { + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void supportsParameter() throws Exception { + given(adaptee.resolveArgument(parameter, webRequest)).willReturn(42); + + assertTrue("Parameter not supported", adapter.supportsParameter(parameter)); + + verify(adaptee).resolveArgument(parameter, webRequest); + } + + @Test + public void supportsParameterUnresolved() throws Exception { + given(adaptee.resolveArgument(parameter, webRequest)).willReturn(WebArgumentResolver.UNRESOLVED); + + assertFalse("Parameter supported", adapter.supportsParameter(parameter)); + + verify(adaptee).resolveArgument(parameter, webRequest); + } + + @Test + public void supportsParameterWrongType() throws Exception { + given(adaptee.resolveArgument(parameter, webRequest)).willReturn("Foo"); + + assertFalse("Parameter supported", adapter.supportsParameter(parameter)); + + verify(adaptee).resolveArgument(parameter, webRequest); + } + + @Test + public void supportsParameterThrowsException() throws Exception { + given(adaptee.resolveArgument(parameter, webRequest)).willThrow(new Exception()); + + assertFalse("Parameter supported", adapter.supportsParameter(parameter)); + + verify(adaptee).resolveArgument(parameter, webRequest); + } + + @Test + public void resolveArgument() throws Exception { + int expected = 42; + given(adaptee.resolveArgument(parameter, webRequest)).willReturn(expected); + + Object result = adapter.resolveArgument(parameter, null, webRequest, null); + assertEquals("Invalid result", expected, result); + } + + @Test(expected = IllegalStateException.class) + public void resolveArgumentUnresolved() throws Exception { + given(adaptee.resolveArgument(parameter, webRequest)).willReturn(WebArgumentResolver.UNRESOLVED); + + adapter.resolveArgument(parameter, null, webRequest, null); + } + + @Test(expected = IllegalStateException.class) + public void resolveArgumentWrongType() throws Exception { + given(adaptee.resolveArgument(parameter, webRequest)).willReturn("Foo"); + + adapter.resolveArgument(parameter, null, webRequest, null); + } + + @Test(expected = Exception.class) + public void resolveArgumentThrowsException() throws Exception { + given(adaptee.resolveArgument(parameter, webRequest)).willThrow(new Exception()); + + adapter.resolveArgument(parameter, null, webRequest, null); + } + + public void handle(int param) { + } + + private class TestWebArgumentResolverAdapter extends AbstractWebArgumentResolverAdapter { + + public TestWebArgumentResolverAdapter(WebArgumentResolver adaptee) { + super(adaptee); + } + + @Override + protected NativeWebRequest getWebRequest() { + return WebArgumentResolverAdapterTests.this.webRequest; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/support/CompositeUriComponentsContributorTests.java b/spring-web/src/test/java/org/springframework/web/method/support/CompositeUriComponentsContributorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..df37493f1d688e22806d904249acfb20c6a14872 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/support/CompositeUriComponentsContributorTests.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.util.ClassUtils; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.method.annotation.RequestHeaderMethodArgumentResolver; +import org.springframework.web.method.annotation.RequestParamMethodArgumentResolver; + +import static org.junit.Assert.*; + +/** + * Unit tests for + * {@link org.springframework.web.method.support.CompositeUriComponentsContributor}. + * + * @author Rossen Stoyanchev + */ +public class CompositeUriComponentsContributorTests { + + + @Test + public void supportsParameter() { + + List resolvers = new ArrayList<>(); + resolvers.add(new RequestParamMethodArgumentResolver(false)); + resolvers.add(new RequestHeaderMethodArgumentResolver(null)); + resolvers.add(new RequestParamMethodArgumentResolver(true)); + + Method method = ClassUtils.getMethod(this.getClass(), "handleRequest", String.class, String.class, String.class); + + CompositeUriComponentsContributor contributor = new CompositeUriComponentsContributor(resolvers); + assertTrue(contributor.supportsParameter(new MethodParameter(method, 0))); + assertTrue(contributor.supportsParameter(new MethodParameter(method, 1))); + assertFalse(contributor.supportsParameter(new MethodParameter(method, 2))); + } + + + public void handleRequest(@RequestParam String p1, String p2, @RequestHeader String h) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/support/HandlerMethodArgumentResolverCompositeTests.java b/spring-web/src/test/java/org/springframework/web/method/support/HandlerMethodArgumentResolverCompositeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..81c35a483a4789d8f6c178dfe450f670cc401464 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/support/HandlerMethodArgumentResolverCompositeTests.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; + +import static org.junit.Assert.*; + +/** + * Test fixture with {@link HandlerMethodArgumentResolverComposite}. + * + * @author Rossen Stoyanchev + */ +public class HandlerMethodArgumentResolverCompositeTests { + + private HandlerMethodArgumentResolverComposite resolverComposite; + + private MethodParameter paramInt; + + private MethodParameter paramStr; + + + @Before + public void setup() throws Exception { + this.resolverComposite = new HandlerMethodArgumentResolverComposite(); + + Method method = getClass().getDeclaredMethod("handle", Integer.class, String.class); + paramInt = new MethodParameter(method, 0); + paramStr = new MethodParameter(method, 1); + } + + + @Test + public void supportsParameter() throws Exception { + this.resolverComposite.addResolver(new StubArgumentResolver(Integer.class)); + + assertTrue(this.resolverComposite.supportsParameter(paramInt)); + assertFalse(this.resolverComposite.supportsParameter(paramStr)); + } + + @Test + public void resolveArgument() throws Exception { + this.resolverComposite.addResolver(new StubArgumentResolver(55)); + Object resolvedValue = this.resolverComposite.resolveArgument(paramInt, null, null, null); + + assertEquals(55, resolvedValue); + } + + @Test + public void checkArgumentResolverOrder() throws Exception { + this.resolverComposite.addResolver(new StubArgumentResolver(1)); + this.resolverComposite.addResolver(new StubArgumentResolver(2)); + Object resolvedValue = this.resolverComposite.resolveArgument(paramInt, null, null, null); + + assertEquals("Didn't use the first registered resolver", 1, resolvedValue); + } + + @Test(expected = IllegalArgumentException.class) + public void noSuitableArgumentResolver() throws Exception { + this.resolverComposite.resolveArgument(paramStr, null, null, null); + } + + + @SuppressWarnings("unused") + private void handle(Integer arg1, String arg2) { + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/support/HandlerMethodReturnValueHandlerCompositeTests.java b/spring-web/src/test/java/org/springframework/web/method/support/HandlerMethodReturnValueHandlerCompositeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..d45db77e1389385ac43eb5fdee04e9ec1f02990c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/support/HandlerMethodReturnValueHandlerCompositeTests.java @@ -0,0 +1,131 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Test fixture with {@link HandlerMethodReturnValueHandlerComposite}. + * + * @author Rossen Stoyanchev + */ +@SuppressWarnings("unused") +public class HandlerMethodReturnValueHandlerCompositeTests { + + private HandlerMethodReturnValueHandlerComposite handlers; + + private HandlerMethodReturnValueHandler integerHandler; + + ModelAndViewContainer mavContainer; + + private MethodParameter integerType; + + private MethodParameter stringType; + + + @Before + public void setup() throws Exception { + this.integerType = new MethodParameter(getClass().getDeclaredMethod("handleInteger"), -1); + this.stringType = new MethodParameter(getClass().getDeclaredMethod("handleString"), -1); + + this.integerHandler = mock(HandlerMethodReturnValueHandler.class); + when(this.integerHandler.supportsReturnType(this.integerType)).thenReturn(true); + + this.handlers = new HandlerMethodReturnValueHandlerComposite(); + this.handlers.addHandler(this.integerHandler); + + mavContainer = new ModelAndViewContainer(); + } + + + @Test + public void supportsReturnType() throws Exception { + assertTrue(this.handlers.supportsReturnType(this.integerType)); + assertFalse(this.handlers.supportsReturnType(this.stringType)); + } + + @Test + public void handleReturnValue() throws Exception { + this.handlers.handleReturnValue(55, this.integerType, this.mavContainer, null); + verify(this.integerHandler).handleReturnValue(55, this.integerType, this.mavContainer, null); + } + + @Test + public void handleReturnValueWithMultipleHandlers() throws Exception { + HandlerMethodReturnValueHandler anotherIntegerHandler = mock(HandlerMethodReturnValueHandler.class); + when(anotherIntegerHandler.supportsReturnType(this.integerType)).thenReturn(true); + + this.handlers.handleReturnValue(55, this.integerType, this.mavContainer, null); + + verify(this.integerHandler).handleReturnValue(55, this.integerType, this.mavContainer, null); + verifyNoMoreInteractions(anotherIntegerHandler); + } + + @Test // SPR-13083 + public void handleReturnValueWithAsyncHandler() throws Exception { + Promise promise = new Promise<>(); + MethodParameter promiseType = new MethodParameter(getClass().getDeclaredMethod("handlePromise"), -1); + + HandlerMethodReturnValueHandler responseBodyHandler = mock(HandlerMethodReturnValueHandler.class); + when(responseBodyHandler.supportsReturnType(promiseType)).thenReturn(true); + this.handlers.addHandler(responseBodyHandler); + + AsyncHandlerMethodReturnValueHandler promiseHandler = mock(AsyncHandlerMethodReturnValueHandler.class); + when(promiseHandler.supportsReturnType(promiseType)).thenReturn(true); + when(promiseHandler.isAsyncReturnValue(promise, promiseType)).thenReturn(true); + this.handlers.addHandler(promiseHandler); + + this.handlers.handleReturnValue(promise, promiseType, this.mavContainer, null); + + verify(promiseHandler).isAsyncReturnValue(promise, promiseType); + verify(promiseHandler).supportsReturnType(promiseType); + verify(promiseHandler).handleReturnValue(promise, promiseType, this.mavContainer, null); + verifyNoMoreInteractions(promiseHandler); + verifyNoMoreInteractions(responseBodyHandler); + } + + @Test(expected = IllegalArgumentException.class) + public void noSuitableReturnValueHandler() throws Exception { + this.handlers.handleReturnValue("value", this.stringType, null, null); + } + + + private Integer handleInteger() { + return null; + } + + private String handleString() { + return null; + } + + private Promise handlePromise() { + return null; + } + + private static class Promise {} + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/support/InvocableHandlerMethodTests.java b/spring-web/src/test/java/org/springframework/web/method/support/InvocableHandlerMethodTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5a2424c1681e1c8db5e79471f5556d6418182725 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/support/InvocableHandlerMethodTests.java @@ -0,0 +1,236 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.MethodParameter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.method.ResolvableMethod; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link InvocableHandlerMethod}. + * + * @author Rossen Stoyanchev + */ +public class InvocableHandlerMethodTests { + + private NativeWebRequest request; + + private final HandlerMethodArgumentResolverComposite composite = new HandlerMethodArgumentResolverComposite(); + + + @Before + public void setUp() throws Exception { + this.request = new ServletWebRequest(new MockHttpServletRequest(), new MockHttpServletResponse()); + } + + + @Test + public void resolveArg() throws Exception { + this.composite.addResolver(new StubArgumentResolver(99)); + this.composite.addResolver(new StubArgumentResolver("value")); + + Object value = getInvocable(Integer.class, String.class).invokeForRequest(request, null); + + assertEquals(1, getStubResolver(0).getResolvedParameters().size()); + assertEquals(1, getStubResolver(1).getResolvedParameters().size()); + assertEquals("99-value", value); + assertEquals("intArg", getStubResolver(0).getResolvedParameters().get(0).getParameterName()); + assertEquals("stringArg", getStubResolver(1).getResolvedParameters().get(0).getParameterName()); + } + + @Test + public void resolveNoArgValue() throws Exception { + this.composite.addResolver(new StubArgumentResolver(Integer.class)); + this.composite.addResolver(new StubArgumentResolver(String.class)); + + Object returnValue = getInvocable(Integer.class, String.class).invokeForRequest(request, null); + + assertEquals(1, getStubResolver(0).getResolvedParameters().size()); + assertEquals(1, getStubResolver(1).getResolvedParameters().size()); + assertEquals("null-null", returnValue); + } + + @Test + public void cannotResolveArg() throws Exception { + try { + getInvocable(Integer.class, String.class).invokeForRequest(request, null); + fail("Expected exception"); + } + catch (IllegalStateException ex) { + assertTrue(ex.getMessage().contains("Could not resolve parameter [0]")); + } + } + + @Test + public void resolveProvidedArg() throws Exception { + Object value = getInvocable(Integer.class, String.class).invokeForRequest(request, null, 99, "value"); + + assertNotNull(value); + assertEquals(String.class, value.getClass()); + assertEquals("99-value", value); + } + + @Test + public void resolveProvidedArgFirst() throws Exception { + this.composite.addResolver(new StubArgumentResolver(1)); + this.composite.addResolver(new StubArgumentResolver("value1")); + Object value = getInvocable(Integer.class, String.class).invokeForRequest(request, null, 2, "value2"); + + assertEquals("2-value2", value); + } + + @Test + public void exceptionInResolvingArg() throws Exception { + this.composite.addResolver(new ExceptionRaisingArgumentResolver()); + try { + getInvocable(Integer.class, String.class).invokeForRequest(request, null); + fail("Expected exception"); + } + catch (IllegalArgumentException ex) { + // expected - allow HandlerMethodArgumentResolver exceptions to propagate + } + } + + @Test + public void illegalArgumentException() throws Exception { + this.composite.addResolver(new StubArgumentResolver(Integer.class, "__not_an_int__")); + this.composite.addResolver(new StubArgumentResolver("value")); + try { + getInvocable(Integer.class, String.class).invokeForRequest(request, null); + fail("Expected exception"); + } + catch (IllegalStateException ex) { + assertNotNull("Exception not wrapped", ex.getCause()); + assertTrue(ex.getCause() instanceof IllegalArgumentException); + assertTrue(ex.getMessage().contains("Controller [")); + assertTrue(ex.getMessage().contains("Method [")); + assertTrue(ex.getMessage().contains("with argument values:")); + assertTrue(ex.getMessage().contains("[0] [type=java.lang.String] [value=__not_an_int__]")); + assertTrue(ex.getMessage().contains("[1] [type=java.lang.String] [value=value")); + } + } + + @Test + public void invocationTargetException() throws Exception { + Throwable expected = new RuntimeException("error"); + try { + getInvocable(Throwable.class).invokeForRequest(this.request, null, expected); + fail("Expected exception"); + } + catch (RuntimeException actual) { + assertSame(expected, actual); + } + + expected = new Error("error"); + try { + getInvocable(Throwable.class).invokeForRequest(this.request, null, expected); + fail("Expected exception"); + } + catch (Error actual) { + assertSame(expected, actual); + } + + expected = new Exception("error"); + try { + getInvocable(Throwable.class).invokeForRequest(this.request, null, expected); + fail("Expected exception"); + } + catch (Exception actual) { + assertSame(expected, actual); + } + + expected = new Throwable("error"); + try { + getInvocable(Throwable.class).invokeForRequest(this.request, null, expected); + fail("Expected exception"); + } + catch (IllegalStateException actual) { + assertNotNull(actual.getCause()); + assertSame(expected, actual.getCause()); + assertTrue(actual.getMessage().contains("Invocation failure")); + } + } + + @Test // SPR-13917 + public void invocationErrorMessage() throws Exception { + this.composite.addResolver(new StubArgumentResolver(double.class)); + try { + getInvocable(double.class).invokeForRequest(this.request, null); + fail(); + } + catch (IllegalStateException ex) { + assertThat(ex.getMessage(), containsString("Illegal argument")); + } + } + + private InvocableHandlerMethod getInvocable(Class... argTypes) { + Method method = ResolvableMethod.on(Handler.class).argTypes(argTypes).resolveMethod(); + InvocableHandlerMethod handlerMethod = new InvocableHandlerMethod(new Handler(), method); + handlerMethod.setHandlerMethodArgumentResolvers(this.composite); + return handlerMethod; + } + + private StubArgumentResolver getStubResolver(int index) { + return (StubArgumentResolver) this.composite.getResolvers().get(index); + } + + + + @SuppressWarnings("unused") + private static class Handler { + + public String handle(Integer intArg, String stringArg) { + return intArg + "-" + stringArg; + } + + public void handle(double amount) { + } + + public void handleWithException(Throwable ex) throws Throwable { + throw ex; + } + } + + + private static class ExceptionRaisingArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return true; + } + + @Override + public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, WebDataBinderFactory binderFactory) { + + throw new IllegalArgumentException("oops, can't read"); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/support/ModelAndViewContainerTests.java b/spring-web/src/test/java/org/springframework/web/method/support/ModelAndViewContainerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..f0713b67a423a0395cb3942ad1923e110b8ce614 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/support/ModelAndViewContainerTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.ui.ModelMap; + +import static org.junit.Assert.*; + +/** + * Test fixture for {@link ModelAndViewContainer}. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class ModelAndViewContainerTests { + + private ModelAndViewContainer mavContainer; + + + @Before + public void setup() { + this.mavContainer = new ModelAndViewContainer(); + } + + + @Test + public void getModel() { + this.mavContainer.addAttribute("name", "value"); + assertEquals(1, this.mavContainer.getModel().size()); + assertEquals("value", this.mavContainer.getModel().get("name")); + } + + @Test + public void redirectScenarioWithRedirectModel() { + this.mavContainer.addAttribute("name1", "value1"); + this.mavContainer.setRedirectModel(new ModelMap("name2", "value2")); + this.mavContainer.setRedirectModelScenario(true); + + assertEquals(1, this.mavContainer.getModel().size()); + assertEquals("value2", this.mavContainer.getModel().get("name2")); + } + + @Test + public void redirectScenarioWithoutRedirectModel() { + this.mavContainer.addAttribute("name", "value"); + this.mavContainer.setRedirectModelScenario(true); + + assertEquals(1, this.mavContainer.getModel().size()); + assertEquals("value", this.mavContainer.getModel().get("name")); + } + + @Test + public void ignoreDefaultModel() { + this.mavContainer.setIgnoreDefaultModelOnRedirect(true); + this.mavContainer.addAttribute("name", "value"); + this.mavContainer.setRedirectModelScenario(true); + + assertTrue(this.mavContainer.getModel().isEmpty()); + } + + @Test // SPR-14045 + public void ignoreDefaultModelAndWithoutRedirectModel() { + this.mavContainer.setIgnoreDefaultModelOnRedirect(true); + this.mavContainer.setRedirectModelScenario(true); + this.mavContainer.addAttribute("name", "value"); + + assertEquals(1, this.mavContainer.getModel().size()); + assertEquals("value", this.mavContainer.getModel().get("name")); + } + + +} diff --git a/spring-web/src/test/java/org/springframework/web/method/support/StubArgumentResolver.java b/spring-web/src/test/java/org/springframework/web/method/support/StubArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..a76c4376c75c2e3c8a60f543decfa4d605ca1dee --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/method/support/StubArgumentResolver.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.support; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; + +/** + * Stub resolver for a fixed value type and/or value. + * + * @author Rossen Stoyanchev + */ +public class StubArgumentResolver implements HandlerMethodArgumentResolver { + + private final Class valueType; + + @Nullable + private final Object value; + + private List resolvedParameters = new ArrayList<>(); + + + public StubArgumentResolver(Object value) { + this(value.getClass(), value); + } + + public StubArgumentResolver(Class valueType) { + this(valueType, null); + } + + public StubArgumentResolver(Class valueType, Object value) { + this.valueType = valueType; + this.value = value; + } + + + public List getResolvedParameters() { + return resolvedParameters; + } + + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return parameter.getParameterType().equals(this.valueType); + } + + @Override + public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer, + NativeWebRequest webRequest, WebDataBinderFactory binderFactory) { + + this.resolvedParameters.add(parameter); + return this.value; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/multipart/commons/CommonsMultipartResolverTests.java b/spring-web/src/test/java/org/springframework/web/multipart/commons/CommonsMultipartResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..ba0df2baf526eb2a70a9d8a1978b7b6c4f05097e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/multipart/commons/CommonsMultipartResolverTests.java @@ -0,0 +1,557 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.commons; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.fileupload.FileItem; +import org.apache.commons.fileupload.FileItemFactory; +import org.apache.commons.fileupload.FileItemHeaders; +import org.apache.commons.fileupload.FileUpload; +import org.apache.commons.fileupload.servlet.ServletFileUpload; +import org.junit.Test; + +import org.springframework.beans.MutablePropertyValues; +import org.springframework.mock.web.test.MockFilterConfig; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.mock.web.test.MockServletContext; +import org.springframework.mock.web.test.PassThroughFilterChain; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.ServletRequestDataBinder; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.StaticWebApplicationContext; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.support.ByteArrayMultipartFileEditor; +import org.springframework.web.multipart.support.MultipartFilter; +import org.springframework.web.multipart.support.StringMultipartFileEditor; +import org.springframework.web.util.WebUtils; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @author Arjen Poutsma + * @since 08.10.2003 + */ +public class CommonsMultipartResolverTests { + + @Test + public void withApplicationContext() throws Exception { + doTestWithApplicationContext(false); + } + + @Test + public void withApplicationContextAndLazyResolution() throws Exception { + doTestWithApplicationContext(true); + } + + private void doTestWithApplicationContext(boolean lazy) throws Exception { + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(new MockServletContext()); + wac.getServletContext().setAttribute(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File("mytemp")); + wac.refresh(); + MockCommonsMultipartResolver resolver = new MockCommonsMultipartResolver(); + resolver.setMaxUploadSize(1000); + resolver.setMaxInMemorySize(100); + resolver.setDefaultEncoding("enc"); + if (lazy) { + resolver.setResolveLazily(false); + } + resolver.setServletContext(wac.getServletContext()); + assertEquals(1000, resolver.getFileUpload().getSizeMax()); + assertEquals(100, resolver.getFileItemFactory().getSizeThreshold()); + assertEquals("enc", resolver.getFileUpload().getHeaderEncoding()); + assertTrue(resolver.getFileItemFactory().getRepository().getAbsolutePath().endsWith("mytemp")); + + MockHttpServletRequest originalRequest = new MockHttpServletRequest(); + originalRequest.setMethod("POST"); + originalRequest.setContentType("multipart/form-data"); + originalRequest.addHeader("Content-type", "multipart/form-data"); + originalRequest.addParameter("getField", "getValue"); + assertTrue(resolver.isMultipart(originalRequest)); + MultipartHttpServletRequest request = resolver.resolveMultipart(originalRequest); + + doTestParameters(request); + + doTestFiles(request); + + doTestBinding(resolver, originalRequest, request); + + wac.close(); + } + + private void doTestParameters(MultipartHttpServletRequest request) { + Set parameterNames = new HashSet<>(); + Enumeration parameterEnum = request.getParameterNames(); + while (parameterEnum.hasMoreElements()) { + parameterNames.add(parameterEnum.nextElement()); + } + assertEquals(3, parameterNames.size()); + assertTrue(parameterNames.contains("field3")); + assertTrue(parameterNames.contains("field4")); + assertTrue(parameterNames.contains("getField")); + assertEquals("value3", request.getParameter("field3")); + List parameterValues = Arrays.asList(request.getParameterValues("field3")); + assertEquals(1, parameterValues.size()); + assertTrue(parameterValues.contains("value3")); + assertEquals("value4", request.getParameter("field4")); + parameterValues = Arrays.asList(request.getParameterValues("field4")); + assertEquals(2, parameterValues.size()); + assertTrue(parameterValues.contains("value4")); + assertTrue(parameterValues.contains("value5")); + assertEquals("value4", request.getParameter("field4")); + assertEquals("getValue", request.getParameter("getField")); + + List parameterMapKeys = new ArrayList<>(); + List parameterMapValues = new ArrayList<>(); + for (Object o : request.getParameterMap().keySet()) { + String key = (String) o; + parameterMapKeys.add(key); + parameterMapValues.add(request.getParameterMap().get(key)); + } + assertEquals(3, parameterMapKeys.size()); + assertEquals(3, parameterMapValues.size()); + int field3Index = parameterMapKeys.indexOf("field3"); + int field4Index = parameterMapKeys.indexOf("field4"); + int getFieldIndex = parameterMapKeys.indexOf("getField"); + assertTrue(field3Index != -1); + assertTrue(field4Index != -1); + assertTrue(getFieldIndex != -1); + parameterValues = Arrays.asList((String[]) parameterMapValues.get(field3Index)); + assertEquals(1, parameterValues.size()); + assertTrue(parameterValues.contains("value3")); + parameterValues = Arrays.asList((String[]) parameterMapValues.get(field4Index)); + assertEquals(2, parameterValues.size()); + assertTrue(parameterValues.contains("value4")); + assertTrue(parameterValues.contains("value5")); + parameterValues = Arrays.asList((String[]) parameterMapValues.get(getFieldIndex)); + assertEquals(1, parameterValues.size()); + assertTrue(parameterValues.contains("getValue")); + } + + private void doTestFiles(MultipartHttpServletRequest request) throws IOException { + Set fileNames = new HashSet<>(); + Iterator fileIter = request.getFileNames(); + while (fileIter.hasNext()) { + fileNames.add(fileIter.next()); + } + assertEquals(3, fileNames.size()); + assertTrue(fileNames.contains("field1")); + assertTrue(fileNames.contains("field2")); + assertTrue(fileNames.contains("field2x")); + CommonsMultipartFile file1 = (CommonsMultipartFile) request.getFile("field1"); + CommonsMultipartFile file2 = (CommonsMultipartFile) request.getFile("field2"); + CommonsMultipartFile file2x = (CommonsMultipartFile) request.getFile("field2x"); + + Map fileMap = request.getFileMap(); + assertEquals(3, fileMap.size()); + assertTrue(fileMap.containsKey("field1")); + assertTrue(fileMap.containsKey("field2")); + assertTrue(fileMap.containsKey("field2x")); + assertEquals(file1, fileMap.get("field1")); + assertEquals(file2, fileMap.get("field2")); + assertEquals(file2x, fileMap.get("field2x")); + + MultiValueMap multiFileMap = request.getMultiFileMap(); + assertEquals(3, multiFileMap.size()); + assertTrue(multiFileMap.containsKey("field1")); + assertTrue(multiFileMap.containsKey("field2")); + assertTrue(multiFileMap.containsKey("field2x")); + List field1Files = multiFileMap.get("field1"); + assertEquals(2, field1Files.size()); + assertTrue(field1Files.contains(file1)); + assertEquals(file1, multiFileMap.getFirst("field1")); + assertEquals(file2, multiFileMap.getFirst("field2")); + assertEquals(file2x, multiFileMap.getFirst("field2x")); + + assertEquals("type1", file1.getContentType()); + assertEquals("type2", file2.getContentType()); + assertEquals("type2", file2x.getContentType()); + assertEquals("field1.txt", file1.getOriginalFilename()); + assertEquals("field2.txt", file2.getOriginalFilename()); + assertEquals("field2x.txt", file2x.getOriginalFilename()); + assertEquals("text1", new String(file1.getBytes())); + assertEquals("text2", new String(file2.getBytes())); + assertEquals(5, file1.getSize()); + assertEquals(5, file2.getSize()); + assertTrue(file1.getInputStream() instanceof ByteArrayInputStream); + assertTrue(file2.getInputStream() instanceof ByteArrayInputStream); + File transfer1 = new File("C:/transfer1"); + file1.transferTo(transfer1); + File transfer2 = new File("C:/transfer2"); + file2.transferTo(transfer2); + assertEquals(transfer1, ((MockFileItem) file1.getFileItem()).writtenFile); + assertEquals(transfer2, ((MockFileItem) file2.getFileItem()).writtenFile); + + } + + private void doTestBinding(MockCommonsMultipartResolver resolver, MockHttpServletRequest originalRequest, + MultipartHttpServletRequest request) throws UnsupportedEncodingException { + + MultipartTestBean1 mtb1 = new MultipartTestBean1(); + assertArrayEquals(null, mtb1.getField1()); + assertEquals(null, mtb1.getField2()); + ServletRequestDataBinder binder = new ServletRequestDataBinder(mtb1, "mybean"); + binder.registerCustomEditor(byte[].class, new ByteArrayMultipartFileEditor()); + binder.bind(request); + List file1List = request.getFiles("field1"); + CommonsMultipartFile file1a = (CommonsMultipartFile) file1List.get(0); + CommonsMultipartFile file1b = (CommonsMultipartFile) file1List.get(1); + CommonsMultipartFile file2 = (CommonsMultipartFile) request.getFile("field2"); + assertEquals(file1a, mtb1.getField1()[0]); + assertEquals(file1b, mtb1.getField1()[1]); + assertEquals(new String(file2.getBytes()), new String(mtb1.getField2())); + + MultipartTestBean2 mtb2 = new MultipartTestBean2(); + assertArrayEquals(null, mtb2.getField1()); + assertEquals(null, mtb2.getField2()); + binder = new ServletRequestDataBinder(mtb2, "mybean"); + binder.registerCustomEditor(String.class, "field1", new StringMultipartFileEditor()); + binder.registerCustomEditor(String.class, "field2", new StringMultipartFileEditor("UTF-16")); + binder.bind(request); + assertEquals(new String(file1a.getBytes()), mtb2.getField1()[0]); + assertEquals(new String(file1b.getBytes()), mtb2.getField1()[1]); + assertEquals(new String(file2.getBytes(), "UTF-16"), mtb2.getField2()); + + resolver.cleanupMultipart(request); + assertTrue(((MockFileItem) file1a.getFileItem()).deleted); + assertTrue(((MockFileItem) file1b.getFileItem()).deleted); + assertTrue(((MockFileItem) file2.getFileItem()).deleted); + + resolver.setEmpty(true); + request = resolver.resolveMultipart(originalRequest); + binder.setBindEmptyMultipartFiles(false); + String firstBound = mtb2.getField2(); + binder.bind(request); + assertFalse(mtb2.getField2().isEmpty()); + assertEquals(firstBound, mtb2.getField2()); + + request = resolver.resolveMultipart(originalRequest); + binder.setBindEmptyMultipartFiles(true); + binder.bind(request); + assertTrue(mtb2.getField2().isEmpty()); + } + + @Test + public void withServletContextAndFilter() throws Exception { + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(new MockServletContext()); + wac.registerSingleton("filterMultipartResolver", MockCommonsMultipartResolver.class, new MutablePropertyValues()); + wac.getServletContext().setAttribute(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File("mytemp")); + wac.refresh(); + wac.getServletContext().setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + CommonsMultipartResolver resolver = new CommonsMultipartResolver(wac.getServletContext()); + assertTrue(resolver.getFileItemFactory().getRepository().getAbsolutePath().endsWith("mytemp")); + + MockFilterConfig filterConfig = new MockFilterConfig(wac.getServletContext(), "filter"); + filterConfig.addInitParameter("class", "notWritable"); + filterConfig.addInitParameter("unknownParam", "someValue"); + final MultipartFilter filter = new MultipartFilter(); + filter.init(filterConfig); + + final List files = new ArrayList<>(); + final FilterChain filterChain = new FilterChain() { + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) { + MultipartHttpServletRequest request = (MultipartHttpServletRequest) servletRequest; + files.addAll(request.getFileMap().values()); + } + }; + + FilterChain filterChain2 = new PassThroughFilterChain(filter, filterChain); + + MockHttpServletRequest originalRequest = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + originalRequest.setMethod("POST"); + originalRequest.setContentType("multipart/form-data"); + originalRequest.addHeader("Content-type", "multipart/form-data"); + filter.doFilter(originalRequest, response, filterChain2); + + CommonsMultipartFile file1 = (CommonsMultipartFile) files.get(0); + CommonsMultipartFile file2 = (CommonsMultipartFile) files.get(1); + assertTrue(((MockFileItem) file1.getFileItem()).deleted); + assertTrue(((MockFileItem) file2.getFileItem()).deleted); + } + + @Test + public void withServletContextAndFilterWithCustomBeanName() throws Exception { + StaticWebApplicationContext wac = new StaticWebApplicationContext(); + wac.setServletContext(new MockServletContext()); + wac.refresh(); + wac.registerSingleton("myMultipartResolver", MockCommonsMultipartResolver.class, new MutablePropertyValues()); + wac.getServletContext().setAttribute(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File("mytemp")); + wac.getServletContext().setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + CommonsMultipartResolver resolver = new CommonsMultipartResolver(wac.getServletContext()); + assertTrue(resolver.getFileItemFactory().getRepository().getAbsolutePath().endsWith("mytemp")); + + MockFilterConfig filterConfig = new MockFilterConfig(wac.getServletContext(), "filter"); + filterConfig.addInitParameter("multipartResolverBeanName", "myMultipartResolver"); + + final List files = new ArrayList<>(); + FilterChain filterChain = new FilterChain() { + @Override + public void doFilter(ServletRequest originalRequest, ServletResponse response) { + if (originalRequest instanceof MultipartHttpServletRequest) { + MultipartHttpServletRequest request = (MultipartHttpServletRequest) originalRequest; + files.addAll(request.getFileMap().values()); + } + } + }; + + MultipartFilter filter = new MultipartFilter() { + private boolean invoked = false; + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + super.doFilterInternal(request, response, filterChain); + super.doFilterInternal(request, response, filterChain); + if (invoked) { + throw new ServletException("Should not have been invoked twice"); + } + invoked = true; + } + }; + filter.init(filterConfig); + + MockHttpServletRequest originalRequest = new MockHttpServletRequest(); + originalRequest.setMethod("POST"); + originalRequest.setContentType("multipart/form-data"); + originalRequest.addHeader("Content-type", "multipart/form-data"); + HttpServletResponse response = new MockHttpServletResponse(); + filter.doFilter(originalRequest, response, filterChain); + CommonsMultipartFile file1 = (CommonsMultipartFile) files.get(0); + CommonsMultipartFile file2 = (CommonsMultipartFile) files.get(1); + assertTrue(((MockFileItem) file1.getFileItem()).deleted); + assertTrue(((MockFileItem) file2.getFileItem()).deleted); + } + + + public static class MockCommonsMultipartResolver extends CommonsMultipartResolver { + + private boolean empty; + + protected void setEmpty(boolean empty) { + this.empty = empty; + } + + @Override + protected FileUpload newFileUpload(FileItemFactory fileItemFactory) { + return new ServletFileUpload() { + @Override + public List parseRequest(HttpServletRequest request) { + if (request instanceof MultipartHttpServletRequest) { + throw new IllegalStateException("Already a multipart request"); + } + List fileItems = new ArrayList<>(); + MockFileItem fileItem1 = new MockFileItem( + "field1", "type1", empty ? "" : "field1.txt", empty ? "" : "text1"); + MockFileItem fileItem1x = new MockFileItem( + "field1", "type1", empty ? "" : "field1.txt", empty ? "" : "text1"); + MockFileItem fileItem2 = new MockFileItem( + "field2", "type2", empty ? "" : "C:\\mypath/field2.txt", empty ? "" : "text2"); + MockFileItem fileItem2x = new MockFileItem( + "field2x", "type2", empty ? "" : "C:/mypath\\field2x.txt", empty ? "" : "text2"); + MockFileItem fileItem3 = new MockFileItem("field3", null, null, "value3"); + MockFileItem fileItem4 = new MockFileItem("field4", "text/html; charset=iso-8859-1", null, "value4"); + MockFileItem fileItem5 = new MockFileItem("field4", null, null, "value5"); + fileItems.add(fileItem1); + fileItems.add(fileItem1x); + fileItems.add(fileItem2); + fileItems.add(fileItem2x); + fileItems.add(fileItem3); + fileItems.add(fileItem4); + fileItems.add(fileItem5); + return fileItems; + } + }; + } + } + + + @SuppressWarnings("serial") + private static class MockFileItem implements FileItem { + + private String fieldName; + private String contentType; + private String name; + private String value; + + private File writtenFile; + private boolean deleted; + + public MockFileItem(String fieldName, String contentType, String name, String value) { + this.fieldName = fieldName; + this.contentType = contentType; + this.name = name; + this.value = value; + } + + @Override + public InputStream getInputStream() throws IOException { + return new ByteArrayInputStream(value.getBytes()); + } + + @Override + public String getContentType() { + return contentType; + } + + @Override + public String getName() { + return name; + } + + @Override + public boolean isInMemory() { + return true; + } + + @Override + public long getSize() { + return value.length(); + } + + @Override + public byte[] get() { + return value.getBytes(); + } + + @Override + public String getString(String encoding) throws UnsupportedEncodingException { + return new String(get(), encoding); + } + + @Override + public String getString() { + return value; + } + + @Override + public void write(File file) throws Exception { + this.writtenFile = file; + } + + @Override + public void delete() { + this.deleted = true; + } + + @Override + public String getFieldName() { + return fieldName; + } + + @Override + public void setFieldName(String s) { + this.fieldName = s; + } + + @Override + public boolean isFormField() { + return (this.name == null); + } + + @Override + public void setFormField(boolean b) { + throw new UnsupportedOperationException(); + } + + @Override + public OutputStream getOutputStream() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public FileItemHeaders getHeaders() { + throw new UnsupportedOperationException(); + } + + @Override + public void setHeaders(FileItemHeaders headers) { + throw new UnsupportedOperationException(); + } + } + + + public class MultipartTestBean1 { + + private MultipartFile[] field1; + private byte[] field2; + + public void setField1(MultipartFile[] field1) { + this.field1 = field1; + } + + public MultipartFile[] getField1() { + return field1; + } + + public void setField2(byte[] field2) { + this.field2 = field2; + } + + public byte[] getField2() { + return field2; + } + } + + + public class MultipartTestBean2 { + + private String[] field1; + private String field2; + + public void setField1(String[] field1) { + this.field1 = field1; + } + + public String[] getField1() { + return field1; + } + + public void setField2(String field2) { + this.field2 = field2; + } + + public String getField2() { + return field2; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/ByteArrayMultipartFileEditorTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/ByteArrayMultipartFileEditorTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e8c19ac30fdc4e3402de65d86678095c4151270f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/ByteArrayMultipartFileEditorTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.io.IOException; + +import org.junit.Test; + +import org.springframework.web.multipart.MultipartFile; + +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; + +/** + * @author Rick Evans + * @author Sam Brannen + */ +public class ByteArrayMultipartFileEditorTests { + + private final ByteArrayMultipartFileEditor editor = new ByteArrayMultipartFileEditor(); + + @Test + public void setValueAsByteArray() throws Exception { + String expectedValue = "Shumwere, shumhow, a shuck ish washing you. - Drunken Far Side"; + editor.setValue(expectedValue.getBytes()); + assertEquals(expectedValue, editor.getAsText()); + } + + @Test + public void setValueAsString() throws Exception { + String expectedValue = "'Green Wing' - classic British comedy"; + editor.setValue(expectedValue); + assertEquals(expectedValue, editor.getAsText()); + } + + @Test + public void setValueAsCustomObjectInvokesToString() throws Exception { + final String expectedValue = "'Green Wing' - classic British comedy"; + Object object = new Object() { + @Override + public String toString() { + return expectedValue; + } + }; + + editor.setValue(object); + assertEquals(expectedValue, editor.getAsText()); + } + + @Test + public void setValueAsNullGetsBackEmptyString() throws Exception { + editor.setValue(null); + assertEquals("", editor.getAsText()); + } + + @Test + public void setValueAsMultipartFile() throws Exception { + String expectedValue = "That is comforting to know"; + MultipartFile file = mock(MultipartFile.class); + given(file.getBytes()).willReturn(expectedValue.getBytes()); + editor.setValue(file); + assertEquals(expectedValue, editor.getAsText()); + } + + @Test(expected = IllegalArgumentException.class) + public void setValueAsMultipartFileWithBadBytes() throws Exception { + MultipartFile file = mock(MultipartFile.class); + given(file.getBytes()).willThrow(new IOException()); + editor.setValue(file); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequestTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e816974801373ceaa05f3f5248bd378da2e24833 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequestTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.multipart.support; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link DefaultMultipartHttpServletRequest}. + * + * @author Rossen Stoyanchev + */ +public class DefaultMultipartHttpServletRequestTests { + + private final MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + + private final Map multipartParams = new LinkedHashMap<>(); + + private final MultiValueMap queryParams = new LinkedMultiValueMap<>(); + + + @Test // SPR-16590 + public void parameterValues() { + + this.multipartParams.put("key", new String[] {"p"}); + this.queryParams.add("key", "q"); + + String[] values = createMultipartRequest().getParameterValues("key"); + + assertArrayEquals(new String[] {"p", "q"}, values); + } + + @Test // SPR-16590 + public void parameterMap() { + + this.multipartParams.put("key1", new String[] {"p1"}); + this.multipartParams.put("key2", new String[] {"p2"}); + + this.queryParams.add("key1", "q1"); + this.queryParams.add("key3", "q3"); + + Map map = createMultipartRequest().getParameterMap(); + + assertEquals(3, map.size()); + assertArrayEquals(new String[] {"p1", "q1"}, map.get("key1")); + assertArrayEquals(new String[] {"p2"}, map.get("key2")); + assertArrayEquals(new String[] {"q3"}, map.get("key3")); + } + + private DefaultMultipartHttpServletRequest createMultipartRequest() { + insertQueryParams(); + return new DefaultMultipartHttpServletRequest(this.servletRequest, new LinkedMultiValueMap<>(), + this.multipartParams, new HashMap<>()); + } + + private void insertQueryParams() { + StringBuilder query = new StringBuilder(); + for (String key : this.queryParams.keySet()) { + for (String value : this.queryParams.get(key)) { + this.servletRequest.addParameter(key, value); + query.append(query.length() > 0 ? "&" : "").append(key).append("=").append(value); + } + } + this.servletRequest.setQueryString(query.toString()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java new file mode 100644 index 0000000000000000000000000000000000000000..87082cc8bcf32d4f43026f59aad231d5abd38785 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.mock.web.test.MockMultipartFile; +import org.springframework.mock.web.test.MockMultipartHttpServletRequest; +import org.springframework.util.FileCopyUtils; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.mock.web.test.MockPart; + +import static org.junit.Assert.*; + +/** + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class RequestPartServletServerHttpRequestTests { + + private final MockMultipartHttpServletRequest mockRequest = new MockMultipartHttpServletRequest(); + + + @Test + public void getMethod() throws Exception { + this.mockRequest.addFile(new MockMultipartFile("part", "", "", "content".getBytes("UTF-8"))); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); + this.mockRequest.setMethod("POST"); + + assertEquals(HttpMethod.POST, request.getMethod()); + } + + @Test + public void getURI() throws Exception { + this.mockRequest.addFile(new MockMultipartFile("part", "", "application/json", "content".getBytes("UTF-8"))); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); + + URI uri = new URI("https://example.com/path?query"); + this.mockRequest.setScheme("https"); + this.mockRequest.setServerName(uri.getHost()); + this.mockRequest.setServerPort(uri.getPort()); + this.mockRequest.setRequestURI(uri.getPath()); + this.mockRequest.setQueryString(uri.getQuery()); + assertEquals(uri, request.getURI()); + } + + @Test + public void getContentType() throws Exception { + MultipartFile part = new MockMultipartFile("part", "", "application/json", "content".getBytes("UTF-8")); + this.mockRequest.addFile(part); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); + + HttpHeaders headers = request.getHeaders(); + assertNotNull(headers); + assertEquals(MediaType.APPLICATION_JSON, headers.getContentType()); + } + + @Test + public void getBody() throws Exception { + byte[] bytes = "content".getBytes("UTF-8"); + MultipartFile part = new MockMultipartFile("part", "", "application/json", bytes); + this.mockRequest.addFile(part); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); + + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertArrayEquals(bytes, result); + } + + @Test // SPR-13317 + public void getBodyWithWrappedRequest() throws Exception { + byte[] bytes = "content".getBytes("UTF-8"); + MultipartFile part = new MockMultipartFile("part", "", "application/json", bytes); + this.mockRequest.addFile(part); + HttpServletRequest wrapped = new HttpServletRequestWrapper(this.mockRequest); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(wrapped, "part"); + + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertArrayEquals(bytes, result); + } + + @Test // SPR-13096 + public void getBodyViaRequestParameter() throws Exception { + MockMultipartHttpServletRequest mockRequest = new MockMultipartHttpServletRequest() { + @Override + public HttpHeaders getMultipartHeaders(String paramOrFileName) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(new MediaType("application", "octet-stream", StandardCharsets.ISO_8859_1)); + return headers; + } + }; + + byte[] bytes = {(byte) 0xC4}; + mockRequest.setParameter("part", new String(bytes, StandardCharsets.ISO_8859_1)); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(mockRequest, "part"); + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertArrayEquals(bytes, result); + } + + @Test + public void getBodyViaRequestParameterWithRequestEncoding() throws Exception { + MockMultipartHttpServletRequest mockRequest = new MockMultipartHttpServletRequest() { + @Override + public HttpHeaders getMultipartHeaders(String paramOrFileName) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_OCTET_STREAM); + return headers; + } + }; + + byte[] bytes = {(byte) 0xC4}; + mockRequest.setParameter("part", new String(bytes, StandardCharsets.ISO_8859_1)); + mockRequest.setCharacterEncoding("iso-8859-1"); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(mockRequest, "part"); + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertArrayEquals(bytes, result); + } + + @Test // gh-25829 + public void getBodyViaRequestPart() throws Exception { + byte[] bytes = "content".getBytes("UTF-8"); + MockPart mockPart = new MockPart("part", bytes); + mockPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); + this.mockRequest.addPart(mockPart); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); + + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertArrayEquals(bytes, result); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java new file mode 100644 index 0000000000000000000000000000000000000000..b0d06315568f20fd7e63877221bbd471f566ec72 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/StandardMultipartHttpServletRequestTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import org.junit.Test; + +import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockPart; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartFile; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link StandardMultipartHttpServletRequest}. + * + * @author Rossen Stoyanchev + */ +public class StandardMultipartHttpServletRequestTests { + + @Test + public void filename() throws Exception { + String disposition = "form-data; name=\"file\"; filename=\"myFile.txt\""; + StandardMultipartHttpServletRequest request = requestWithPart("file", disposition, ""); + + MultipartFile multipartFile = request.getFile("file"); + assertNotNull(multipartFile); + assertEquals("myFile.txt", multipartFile.getOriginalFilename()); + } + + @Test // SPR-13319 + public void filenameRfc5987() throws Exception { + String disposition = "form-data; name=\"file\"; filename*=\"UTF-8''foo-%c3%a4-%e2%82%ac.html\""; + StandardMultipartHttpServletRequest request = requestWithPart("file", disposition, ""); + + MultipartFile multipartFile = request.getFile("file"); + assertNotNull(multipartFile); + assertEquals("foo-ä-€.html", multipartFile.getOriginalFilename()); + } + + @Test // SPR-15205 + public void filenameRfc2047() throws Exception { + String disposition = "form-data; name=\"file\"; filename=\"=?UTF-8?Q?Declara=C3=A7=C3=A3o.pdf?=\""; + StandardMultipartHttpServletRequest request = requestWithPart("file", disposition, ""); + + MultipartFile multipartFile = request.getFile("file"); + assertNotNull(multipartFile); + assertEquals("Declaração.pdf", multipartFile.getOriginalFilename()); + } + + @Test + public void multipartFileResource() throws IOException { + String name = "file"; + String disposition = "form-data; name=\"" + name + "\"; filename=\"myFile.txt\""; + StandardMultipartHttpServletRequest request = requestWithPart(name, disposition, "myBody"); + MultipartFile multipartFile = request.getFile(name); + + assertNotNull(multipartFile); + + MultiValueMap map = new LinkedMultiValueMap<>(); + map.add(name, multipartFile.getResource()); + + MockHttpOutputMessage output = new MockHttpOutputMessage(); + new FormHttpMessageConverter().write(map, null, output); + + assertThat(output.getBodyAsString(StandardCharsets.UTF_8), containsString( + "Content-Disposition: form-data; name=\"file\"; filename=\"myFile.txt\"\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "myBody\r\n")); + } + + + private StandardMultipartHttpServletRequest requestWithPart(String name, String disposition, String content) { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockPart part = new MockPart(name, null, content.getBytes(StandardCharsets.UTF_8)); + part.getHeaders().set("Content-Disposition", disposition); + request.addPart(part); + return new StandardMultipartHttpServletRequest(request); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/adapter/DefaultServerWebExchangeCheckNotModifiedTests.java b/spring-web/src/test/java/org/springframework/web/server/adapter/DefaultServerWebExchangeCheckNotModifiedTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5c0036393540f1e2d46906ac1f5795031d34d396 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/adapter/DefaultServerWebExchangeCheckNotModifiedTests.java @@ -0,0 +1,336 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import java.net.URISyntaxException; +import java.text.SimpleDateFormat; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Locale; +import java.util.TimeZone; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.get; + +/** + * "checkNotModified" unit tests for {@link DefaultServerWebExchange}. + * + * @author Rossen Stoyanchev + */ +@RunWith(Parameterized.class) +public class DefaultServerWebExchangeCheckNotModifiedTests { + + private static final String CURRENT_TIME = "Wed, 09 Apr 2014 09:57:42 GMT"; + + + private SimpleDateFormat dateFormat; + + private Instant currentDate; + + @Parameter + public HttpMethod method; + + @Parameters(name = "{0}") + static public Iterable safeMethods() { + return Arrays.asList(new Object[][] { + {HttpMethod.GET}, + {HttpMethod.HEAD} + }); + } + + + @Before + public void setup() throws URISyntaxException { + this.currentDate = Instant.now().truncatedTo(ChronoUnit.SECONDS); + this.dateFormat = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US); + this.dateFormat.setTimeZone(TimeZone.getTimeZone("GMT")); + } + + + @Test + public void checkNotModifiedNon2xxStatus() { + MockServerHttpRequest request = get("/").ifModifiedSince(this.currentDate.toEpochMilli()).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + exchange.getResponse().setStatusCode(HttpStatus.NOT_MODIFIED); + + assertFalse(exchange.checkNotModified(this.currentDate)); + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(-1, exchange.getResponse().getHeaders().getLastModified()); + } + + @Test // SPR-14559 + public void checkNotModifiedInvalidIfNoneMatchHeader() { + String eTag = "\"etagvalue\""; + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch("missingquotes")); + assertFalse(exchange.checkNotModified(eTag)); + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedHeaderAlreadySet() { + MockServerHttpRequest request = get("/").ifModifiedSince(currentDate.toEpochMilli()).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + exchange.getResponse().getHeaders().add("Last-Modified", CURRENT_TIME); + + assertTrue(exchange.checkNotModified(currentDate)); + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(1, exchange.getResponse().getHeaders().get("Last-Modified").size()); + assertEquals(CURRENT_TIME, exchange.getResponse().getHeaders().getFirst("Last-Modified")); + } + + @Test + public void checkNotModifiedTimestamp() throws Exception { + MockServerHttpRequest request = get("/").ifModifiedSince(currentDate.toEpochMilli()).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertTrue(exchange.checkNotModified(currentDate)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(currentDate.toEpochMilli(), exchange.getResponse().getHeaders().getLastModified()); + } + + @Test + public void checkModifiedTimestamp() { + Instant oneMinuteAgo = currentDate.minusSeconds(60); + MockServerHttpRequest request = get("/").ifModifiedSince(oneMinuteAgo.toEpochMilli()).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertFalse(exchange.checkNotModified(currentDate)); + + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(currentDate.toEpochMilli(), exchange.getResponse().getHeaders().getLastModified()); + } + + @Test + public void checkNotModifiedETag() { + String eTag = "\"Foo\""; + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch(eTag)); + + assertTrue(exchange.checkNotModified(eTag)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedETagWithSeparatorChars() { + String eTag = "\"Foo, Bar\""; + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch(eTag)); + + assertTrue(exchange.checkNotModified(eTag)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + } + + + @Test + public void checkModifiedETag() { + String currentETag = "\"Foo\""; + String oldEtag = "Bar"; + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch(oldEtag)); + + assertFalse(exchange.checkNotModified(currentETag)); + + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(currentETag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedUnpaddedETag() { + String eTag = "Foo"; + String paddedEtag = String.format("\"%s\"", eTag); + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch(paddedEtag)); + + assertTrue(exchange.checkNotModified(eTag)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(paddedEtag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkModifiedUnpaddedETag() { + String currentETag = "Foo"; + String oldEtag = "Bar"; + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch(oldEtag)); + + assertFalse(exchange.checkNotModified(currentETag)); + + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(String.format("\"%s\"", currentETag), exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedWildcardIsIgnored() { + String eTag = "\"Foo\""; + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch("*")); + assertFalse(exchange.checkNotModified(eTag)); + + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedETagAndTimestamp() { + String eTag = "\"Foo\""; + long time = currentDate.toEpochMilli(); + MockServerHttpRequest request = get("/").ifNoneMatch(eTag).ifModifiedSince(time).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertTrue(exchange.checkNotModified(eTag, currentDate)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + assertEquals(time, exchange.getResponse().getHeaders().getLastModified()); + } + + // SPR-14224 + @Test + public void checkNotModifiedETagAndModifiedTimestamp() { + String eTag = "\"Foo\""; + Instant oneMinuteAgo = currentDate.minusSeconds(60); + MockServerWebExchange exchange = MockServerWebExchange.from(get("/") + .ifNoneMatch(eTag) + .ifModifiedSince(oneMinuteAgo.toEpochMilli()) + ); + + assertTrue(exchange.checkNotModified(eTag, currentDate)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + assertEquals(currentDate.toEpochMilli(), exchange.getResponse().getHeaders().getLastModified()); + } + + @Test + public void checkModifiedETagAndNotModifiedTimestamp() throws Exception { + String currentETag = "\"Foo\""; + String oldEtag = "\"Bar\""; + long time = currentDate.toEpochMilli(); + MockServerHttpRequest request = get("/").ifNoneMatch(oldEtag).ifModifiedSince(time).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertFalse(exchange.checkNotModified(currentETag, currentDate)); + + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(currentETag, exchange.getResponse().getHeaders().getETag()); + assertEquals(time, exchange.getResponse().getHeaders().getLastModified()); + } + + @Test + public void checkNotModifiedETagWeakStrong() { + String eTag = "\"Foo\""; + String weakEtag = String.format("W/%s", eTag); + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch(eTag)); + + assertTrue(exchange.checkNotModified(weakEtag)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(weakEtag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedETagStrongWeak() { + String eTag = "\"Foo\""; + MockServerHttpRequest request = get("/").ifNoneMatch(String.format("W/%s", eTag)).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertTrue(exchange.checkNotModified(eTag)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedMultipleETags() { + String eTag = "\"Bar\""; + String multipleETags = String.format("\"Foo\", %s", eTag); + MockServerWebExchange exchange = MockServerWebExchange.from(get("/").ifNoneMatch(multipleETags)); + + assertTrue(exchange.checkNotModified(eTag)); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(eTag, exchange.getResponse().getHeaders().getETag()); + } + + @Test + public void checkNotModifiedTimestampWithLengthPart() throws Exception { + long epochTime = dateFormat.parse(CURRENT_TIME).getTime(); + String header = "Wed, 09 Apr 2014 09:57:42 GMT; length=13774"; + MockServerHttpRequest request = get("/").header("If-Modified-Since", header).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertTrue(exchange.checkNotModified(Instant.ofEpochMilli(epochTime))); + + assertEquals(304, exchange.getResponse().getStatusCode().value()); + assertEquals(epochTime, exchange.getResponse().getHeaders().getLastModified()); + } + + @Test + public void checkModifiedTimestampWithLengthPart() throws Exception { + long epochTime = dateFormat.parse(CURRENT_TIME).getTime(); + String header = "Tue, 08 Apr 2014 09:57:42 GMT; length=13774"; + MockServerHttpRequest request = get("/").header("If-Modified-Since", header).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertFalse(exchange.checkNotModified(Instant.ofEpochMilli(epochTime))); + + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(epochTime, exchange.getResponse().getHeaders().getLastModified()); + } + + @Test + public void checkNotModifiedTimestampConditionalPut() throws Exception { + Instant oneMinuteAgo = currentDate.minusSeconds(60); + long millis = currentDate.toEpochMilli(); + MockServerHttpRequest request = MockServerHttpRequest.put("/").ifUnmodifiedSince(millis).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertFalse(exchange.checkNotModified(oneMinuteAgo)); + assertNull(exchange.getResponse().getStatusCode()); + assertEquals(-1, exchange.getResponse().getHeaders().getLastModified()); + } + + @Test + public void checkNotModifiedTimestampConditionalPutConflict() throws Exception { + Instant oneMinuteAgo = currentDate.minusSeconds(60); + long millis = oneMinuteAgo.toEpochMilli(); + MockServerHttpRequest request = MockServerHttpRequest.put("/").ifUnmodifiedSince(millis).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + assertTrue(exchange.checkNotModified(currentDate)); + assertEquals(412, exchange.getResponse().getStatusCode().value()); + assertEquals(-1, exchange.getResponse().getHeaders().getLastModified()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/adapter/DefaultServerWebExchangeTests.java b/spring-web/src/test/java/org/springframework/web/server/adapter/DefaultServerWebExchangeTests.java new file mode 100644 index 0000000000000000000000000000000000000000..e3f7151a86c33398459c8f4890316aafd60bd7de --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/adapter/DefaultServerWebExchangeTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import org.junit.Test; + +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.session.DefaultWebSessionManager; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link DefaultServerWebExchange}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + */ +public class DefaultServerWebExchangeTests { + + @Test + public void transformUrlDefault() { + ServerWebExchange exchange = createExchange(); + assertEquals("/foo", exchange.transformUrl("/foo")); + } + + @Test + public void transformUrlWithEncoder() { + ServerWebExchange exchange = createExchange(); + exchange.addUrlTransformer(s -> s + "?nonce=123"); + assertEquals("/foo?nonce=123", exchange.transformUrl("/foo")); + } + + @Test + public void transformUrlWithMultipleEncoders() { + ServerWebExchange exchange = createExchange(); + exchange.addUrlTransformer(s -> s + ";p=abc"); + exchange.addUrlTransformer(s -> s + "?q=123"); + assertEquals("/foo;p=abc?q=123", exchange.transformUrl("/foo")); + } + + + private DefaultServerWebExchange createExchange() { + MockServerHttpRequest request = MockServerHttpRequest.get("https://example.com").build(); + return createExchange(request); + } + + private DefaultServerWebExchange createExchange(MockServerHttpRequest request) { + return new DefaultServerWebExchange(request, new MockServerHttpResponse(), + new DefaultWebSessionManager(), ServerCodecConfigurer.create(), + new AcceptHeaderLocaleContextResolver()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/adapter/ForwardedHeaderTransformerTests.java b/spring-web/src/test/java/org/springframework/web/server/adapter/ForwardedHeaderTransformerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..585cfb73d0cd4fb544394394ded858ec2882de93 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/adapter/ForwardedHeaderTransformerTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import java.net.URI; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link ForwardedHeaderTransformer}. + * @author Rossen Stoyanchev + */ +public class ForwardedHeaderTransformerTests { + + private static final String BASE_URL = "https://example.com/path"; + + + private final ForwardedHeaderTransformer requestMutator = new ForwardedHeaderTransformer(); + + + @Test + public void removeOnly() { + + this.requestMutator.setRemoveOnly(true); + + HttpHeaders headers = new HttpHeaders(); + headers.add("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43"); + headers.add("X-Forwarded-Host", "example.com"); + headers.add("X-Forwarded-Port", "8080"); + headers.add("X-Forwarded-Proto", "http"); + headers.add("X-Forwarded-Prefix", "prefix"); + headers.add("X-Forwarded-Ssl", "on"); + ServerHttpRequest request = this.requestMutator.apply(getRequest(headers)); + + assertForwardedHeadersRemoved(request); + } + + @Test + public void xForwardedHeaders() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Forwarded-Host", "84.198.58.199"); + headers.add("X-Forwarded-Port", "443"); + headers.add("X-Forwarded-Proto", "https"); + headers.add("foo", "bar"); + ServerHttpRequest request = this.requestMutator.apply(getRequest(headers)); + + assertEquals(new URI("https://84.198.58.199/path"), request.getURI()); + assertForwardedHeadersRemoved(request); + } + + @Test + public void forwardedHeader() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("Forwarded", "host=84.198.58.199;proto=https"); + ServerHttpRequest request = this.requestMutator.apply(getRequest(headers)); + + assertEquals(new URI("https://84.198.58.199/path"), request.getURI()); + assertForwardedHeadersRemoved(request); + } + + @Test + public void xForwardedPrefix() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Forwarded-Prefix", "/prefix"); + ServerHttpRequest request = this.requestMutator.apply(getRequest(headers)); + + assertEquals(new URI("https://example.com/prefix/path"), request.getURI()); + assertEquals("/prefix/path", request.getPath().value()); + assertForwardedHeadersRemoved(request); + } + + @Test // gh-23305 + public void xForwardedPrefixShouldNotLeadToDecodedPath() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Forwarded-Prefix", "/prefix"); + ServerHttpRequest request = MockServerHttpRequest + .method(HttpMethod.GET, new URI("https://example.com/a%20b?q=a%2Bb")) + .headers(headers) + .build(); + + request = this.requestMutator.apply(request); + + assertEquals(new URI("https://example.com/prefix/a%20b?q=a%2Bb"), request.getURI()); + assertEquals("/prefix/a%20b", request.getPath().value()); + assertForwardedHeadersRemoved(request); + } + + @Test + public void xForwardedPrefixTrailingSlash() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Forwarded-Prefix", "/prefix////"); + ServerHttpRequest request = this.requestMutator.apply(getRequest(headers)); + + assertEquals(new URI("https://example.com/prefix/path"), request.getURI()); + assertEquals("/prefix/path", request.getPath().value()); + assertForwardedHeadersRemoved(request); + } + + @Test // SPR-17525 + public void shouldNotDoubleEncode() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("Forwarded", "host=84.198.58.199;proto=https"); + + ServerHttpRequest request = MockServerHttpRequest + .method(HttpMethod.GET, new URI("https://example.com/a%20b?q=a%2Bb")) + .headers(headers) + .build(); + + request = this.requestMutator.apply(request); + + assertEquals(new URI("https://84.198.58.199/a%20b?q=a%2Bb"), request.getURI()); + assertForwardedHeadersRemoved(request); + } + + + private MockServerHttpRequest getRequest(HttpHeaders headers) { + return MockServerHttpRequest.get(BASE_URL).headers(headers).build(); + } + + private void assertForwardedHeadersRemoved(ServerHttpRequest request) { + ForwardedHeaderTransformer.FORWARDED_HEADER_NAMES + .forEach(name -> assertFalse(request.getHeaders().containsKey(name))); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/adapter/WebHttpHandlerBuilderTests.java b/spring-web/src/test/java/org/springframework/web/server/adapter/WebHttpHandlerBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a3232dc8d215a7fbd6b18fd528e09a3c312af19f --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/adapter/WebHttpHandlerBuilderTests.java @@ -0,0 +1,207 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.adapter; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.annotation.Order; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.filter.reactive.ForwardedHeaderFilter; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebHandler; + +import static java.time.Duration.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link WebHttpHandlerBuilder}. + * @author Rossen Stoyanchev + */ +public class WebHttpHandlerBuilderTests { + + @Test // SPR-15074 + public void orderedWebFilterBeans() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(OrderedWebFilterBeanConfig.class); + context.refresh(); + + HttpHandler httpHandler = WebHttpHandlerBuilder.applicationContext(context).build(); + assertTrue(httpHandler instanceof HttpWebHandlerAdapter); + assertSame(context, ((HttpWebHandlerAdapter) httpHandler).getApplicationContext()); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + httpHandler.handle(request, response).block(ofMillis(5000)); + + assertEquals("FilterB::FilterA", response.getBodyAsString().block(ofMillis(5000))); + } + + @Test + public void forwardedHeaderFilter() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(ForwardedHeaderFilterConfig.class); + context.refresh(); + + WebHttpHandlerBuilder builder = WebHttpHandlerBuilder.applicationContext(context); + builder.filters(filters -> assertEquals(Collections.emptyList(), filters)); + assertTrue(builder.hasForwardedHeaderTransformer()); + } + + @Test // SPR-15074 + public void orderedWebExceptionHandlerBeans() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(OrderedExceptionHandlerBeanConfig.class); + context.refresh(); + + HttpHandler httpHandler = WebHttpHandlerBuilder.applicationContext(context).build(); + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + httpHandler.handle(request, response).block(ofMillis(5000)); + + assertEquals("ExceptionHandlerB", response.getBodyAsString().block(ofMillis(5000))); + } + + @Test + public void configWithoutFilters() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(NoFilterConfig.class); + context.refresh(); + + HttpHandler httpHandler = WebHttpHandlerBuilder.applicationContext(context).build(); + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + httpHandler.handle(request, response).block(ofMillis(5000)); + + assertEquals("handled", response.getBodyAsString().block(ofMillis(5000))); + } + + @Test // SPR-16972 + public void cloneWithApplicationContext() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(NoFilterConfig.class); + context.refresh(); + + WebHttpHandlerBuilder builder = WebHttpHandlerBuilder.applicationContext(context); + assertSame(context, ((HttpWebHandlerAdapter) builder.build()).getApplicationContext()); + assertSame(context, ((HttpWebHandlerAdapter) builder.clone().build()).getApplicationContext()); + } + + + private static Mono writeToResponse(ServerWebExchange exchange, String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = new DefaultDataBufferFactory().wrap(bytes); + return exchange.getResponse().writeWith(Flux.just(buffer)); + } + + + @Configuration + @SuppressWarnings("unused") + static class OrderedWebFilterBeanConfig { + + private static final String ATTRIBUTE = "attr"; + + @Bean @Order(2) + public WebFilter filterA() { + return createFilter("FilterA"); + } + + @Bean @Order(1) + public WebFilter filterB() { + return createFilter("FilterB"); + } + + private WebFilter createFilter(String name) { + return (exchange, chain) -> { + String value = exchange.getAttribute(ATTRIBUTE); + value = (value != null ? value + "::" + name : name); + exchange.getAttributes().put(ATTRIBUTE, value); + return chain.filter(exchange); + }; + } + + @Bean + public WebHandler webHandler() { + return exchange -> { + String value = exchange.getAttributeOrDefault(ATTRIBUTE, "none"); + return writeToResponse(exchange, value); + }; + } + } + + + @Configuration + @SuppressWarnings("unused") + static class OrderedExceptionHandlerBeanConfig { + + @Bean + @Order(2) + public WebExceptionHandler exceptionHandlerA() { + return (exchange, ex) -> writeToResponse(exchange, "ExceptionHandlerA"); + } + + @Bean + @Order(1) + public WebExceptionHandler exceptionHandlerB() { + return (exchange, ex) -> writeToResponse(exchange, "ExceptionHandlerB"); + } + + @Bean + public WebHandler webHandler() { + return exchange -> Mono.error(new Exception()); + } + } + + @Configuration + @SuppressWarnings({"unused", "deprecation"}) + static class ForwardedHeaderFilterConfig { + + @Bean + public ForwardedHeaderFilter forwardedHeaderFilter() { + return new ForwardedHeaderFilter(); + } + + @Bean + public WebHandler webHandler() { + return exchange -> Mono.error(new Exception()); + } + } + + @Configuration + @SuppressWarnings("unused") + static class NoFilterConfig { + + @Bean + public WebHandler webHandler() { + return exchange -> writeToResponse(exchange, "handled"); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/handler/ExceptionHandlingWebHandlerTests.java b/spring-web/src/test/java/org/springframework/web/server/handler/ExceptionHandlingWebHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..04522b0cda04781e89a90794c3593afeeb332bbc --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/handler/ExceptionHandlingWebHandlerTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import java.util.Arrays; + +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.adapter.HttpWebHandlerAdapter; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** + * Unit tests for {@link ExceptionHandlingWebHandler}. + * @author Rossen Stoyanchev + */ +public class ExceptionHandlingWebHandlerTests { + + private final WebHandler targetHandler = new StubWebHandler(new IllegalStateException("boo")); + + private final ServerWebExchange exchange = + MockServerWebExchange.from(MockServerHttpRequest.get("http://localhost:8080")); + + + @Test + public void handleErrorSignal() throws Exception { + createWebHandler(new BadRequestExceptionHandler()).handle(this.exchange).block(); + assertEquals(HttpStatus.BAD_REQUEST, this.exchange.getResponse().getStatusCode()); + } + + @Test + public void handleErrorSignalWithMultipleHttpErrorHandlers() throws Exception { + createWebHandler( + new UnresolvedExceptionHandler(), + new UnresolvedExceptionHandler(), + new BadRequestExceptionHandler(), + new UnresolvedExceptionHandler()).handle(this.exchange).block(); + + assertEquals(HttpStatus.BAD_REQUEST, this.exchange.getResponse().getStatusCode()); + } + + @Test + public void unresolvedException() throws Exception { + Mono mono = createWebHandler(new UnresolvedExceptionHandler()).handle(this.exchange); + StepVerifier.create(mono).expectErrorMessage("boo").verify(); + assertNull(this.exchange.getResponse().getStatusCode()); + } + + @Test + public void unresolvedExceptionWithWebHttpHandlerAdapter() throws Exception { + + // HttpWebHandlerAdapter handles unresolved errors + + new HttpWebHandlerAdapter(createWebHandler(new UnresolvedExceptionHandler())) + .handle(this.exchange.getRequest(), this.exchange.getResponse()).block(); + + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, this.exchange.getResponse().getStatusCode()); + } + + @Test + public void thrownExceptionBecomesErrorSignal() throws Exception { + createWebHandler(new BadRequestExceptionHandler()).handle(this.exchange).block(); + assertEquals(HttpStatus.BAD_REQUEST, this.exchange.getResponse().getStatusCode()); + } + + private WebHandler createWebHandler(WebExceptionHandler... handlers) { + return new ExceptionHandlingWebHandler(this.targetHandler, Arrays.asList(handlers)); + } + + + private static class StubWebHandler implements WebHandler { + + private final RuntimeException exception; + + private final boolean raise; + + + StubWebHandler(RuntimeException exception) { + this(exception, false); + } + + StubWebHandler(RuntimeException exception, boolean raise) { + this.exception = exception; + this.raise = raise; + } + + @Override + public Mono handle(ServerWebExchange exchange) { + if (this.raise) { + throw this.exception; + } + return Mono.error(this.exception); + } + } + + private static class BadRequestExceptionHandler implements WebExceptionHandler { + + @Override + public Mono handle(ServerWebExchange exchange, Throwable ex) { + exchange.getResponse().setStatusCode(HttpStatus.BAD_REQUEST); + return Mono.empty(); + } + } + + /** Leave the exception unresolved. */ + private static class UnresolvedExceptionHandler implements WebExceptionHandler { + + @Override + public Mono handle(ServerWebExchange exchange, Throwable ex) { + return Mono.error(ex); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/handler/FilteringWebHandlerTests.java b/spring-web/src/test/java/org/springframework/web/server/handler/FilteringWebHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..337f12585f1c5828cb4b6ae4f5a8efd9dff25c1c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/handler/FilteringWebHandlerTests.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@link FilteringWebHandler}. + * @author Rossen Stoyanchev + */ +public class FilteringWebHandlerTests { + + private static Log logger = LogFactory.getLog(FilteringWebHandlerTests.class); + + + @Test + public void multipleFilters() throws Exception { + + TestFilter filter1 = new TestFilter(); + TestFilter filter2 = new TestFilter(); + TestFilter filter3 = new TestFilter(); + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3)) + .handle(MockServerWebExchange.from(MockServerHttpRequest.get("/"))) + .block(Duration.ZERO); + + assertTrue(filter1.invoked()); + assertTrue(filter2.invoked()); + assertTrue(filter3.invoked()); + assertTrue(targetHandler.invoked()); + } + + @Test + public void zeroFilters() throws Exception { + + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Collections.emptyList()) + .handle(MockServerWebExchange.from(MockServerHttpRequest.get("/"))) + .block(Duration.ZERO); + + assertTrue(targetHandler.invoked()); + } + + @Test + public void shortcircuitFilter() throws Exception { + + TestFilter filter1 = new TestFilter(); + ShortcircuitingFilter filter2 = new ShortcircuitingFilter(); + TestFilter filter3 = new TestFilter(); + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3)) + .handle(MockServerWebExchange.from(MockServerHttpRequest.get("/"))) + .block(Duration.ZERO); + + assertTrue(filter1.invoked()); + assertTrue(filter2.invoked()); + assertFalse(filter3.invoked()); + assertFalse(targetHandler.invoked()); + } + + @Test + public void asyncFilter() throws Exception { + + AsyncFilter filter = new AsyncFilter(); + StubWebHandler targetHandler = new StubWebHandler(); + + new FilteringWebHandler(targetHandler, Collections.singletonList(filter)) + .handle(MockServerWebExchange.from(MockServerHttpRequest.get("/"))) + .block(Duration.ofSeconds(5)); + + assertTrue(filter.invoked()); + assertTrue(targetHandler.invoked()); + } + + @Test + public void handleErrorFromFilter() throws Exception { + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + + TestExceptionHandler exceptionHandler = new TestExceptionHandler(); + + WebHttpHandlerBuilder.webHandler(new StubWebHandler()) + .filter(new ExceptionFilter()) + .exceptionHandler(exceptionHandler).build() + .handle(request, response) + .block(); + + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode()); + assertNotNull(exceptionHandler.ex); + assertEquals("boo", exceptionHandler.ex.getMessage()); + } + + + private static class TestFilter implements WebFilter { + + private volatile boolean invoked; + + public boolean invoked() { + return this.invoked; + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + this.invoked = true; + return doFilter(exchange, chain); + } + + public Mono doFilter(ServerWebExchange exchange, WebFilterChain chain) { + return chain.filter(exchange); + } + } + + + private static class ShortcircuitingFilter extends TestFilter { + + @Override + public Mono doFilter(ServerWebExchange exchange, WebFilterChain chain) { + return Mono.empty(); + } + } + + + private static class AsyncFilter extends TestFilter { + + @Override + public Mono doFilter(ServerWebExchange exchange, WebFilterChain chain) { + return doAsyncWork().flatMap(asyncResult -> { + logger.debug("Async result: " + asyncResult); + return chain.filter(exchange); + }); + } + + private Mono doAsyncWork() { + return Mono.delay(Duration.ofMillis(100L)).map(l -> "123"); + } + } + + + private static class ExceptionFilter implements WebFilter { + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return Mono.error(new IllegalStateException("boo")); + } + } + + + private static class TestExceptionHandler implements WebExceptionHandler { + + private Throwable ex; + + @Override + public Mono handle(ServerWebExchange exchange, Throwable ex) { + this.ex = ex; + return Mono.error(ex); + } + } + + + private static class StubWebHandler implements WebHandler { + + private volatile boolean invoked; + + public boolean invoked() { + return this.invoked; + } + + @Override + public Mono handle(ServerWebExchange exchange) { + logger.trace("StubHandler invoked."); + this.invoked = true; + return Mono.empty(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/handler/ResponseStatusExceptionHandlerTests.java b/spring-web/src/test/java/org/springframework/web/server/handler/ResponseStatusExceptionHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c7528ef2a422a89ad7d20767bbb7d6637be201b1 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/handler/ResponseStatusExceptionHandlerTests.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.handler; + +import java.time.Duration; +import java.util.Arrays; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.server.MethodNotAllowedException; +import org.springframework.web.server.NotAcceptableStatusException; +import org.springframework.web.server.ResponseStatusException; + +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; + +/** + * Unit tests for {@link ResponseStatusExceptionHandler}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class ResponseStatusExceptionHandlerTests { + + protected final MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); + + protected ResponseStatusExceptionHandler handler; + + + @Before + public void setup() { + this.handler = createResponseStatusExceptionHandler(); + } + + protected ResponseStatusExceptionHandler createResponseStatusExceptionHandler() { + return new ResponseStatusExceptionHandler(); + } + + + @Test + public void handleResponseStatusException() { + Throwable ex = new ResponseStatusException(HttpStatus.BAD_REQUEST, ""); + this.handler.handle(this.exchange, ex).block(Duration.ofSeconds(5)); + assertEquals(HttpStatus.BAD_REQUEST, this.exchange.getResponse().getStatusCode()); + } + + @Test + public void handleNestedResponseStatusException() { + Throwable ex = new Exception(new ResponseStatusException(HttpStatus.BAD_REQUEST, "")); + this.handler.handle(this.exchange, ex).block(Duration.ofSeconds(5)); + assertEquals(HttpStatus.BAD_REQUEST, this.exchange.getResponse().getStatusCode()); + } + + @Test // gh-23741 + public void handleMethodNotAllowed() { + Throwable ex = new MethodNotAllowedException(HttpMethod.PATCH, Arrays.asList(HttpMethod.POST, HttpMethod.PUT)); + this.handler.handle(this.exchange, ex).block(Duration.ofSeconds(5)); + + MockServerHttpResponse response = this.exchange.getResponse(); + assertEquals(HttpStatus.METHOD_NOT_ALLOWED, response.getStatusCode()); + assertThat(response.getHeaders().getAllow(), contains(HttpMethod.POST, HttpMethod.PUT)); + } + + @Test // gh-23741 + public void handleResponseStatusExceptionWithHeaders() { + Throwable ex = new NotAcceptableStatusException(Arrays.asList(MediaType.TEXT_PLAIN, MediaType.TEXT_HTML)); + this.handler.handle(this.exchange, ex).block(Duration.ofSeconds(5)); + + MockServerHttpResponse response = this.exchange.getResponse(); + assertEquals(HttpStatus.NOT_ACCEPTABLE, response.getStatusCode()); + assertThat(response.getHeaders().getAccept(), contains(MediaType.TEXT_PLAIN, MediaType.TEXT_HTML)); + } + + @Test + public void unresolvedException() { + Throwable expected = new IllegalStateException(); + Mono mono = this.handler.handle(this.exchange, expected); + StepVerifier.create(mono).consumeErrorWith(actual -> assertSame(expected, actual)).verify(); + } + + @Test // SPR-16231 + public void responseCommitted() { + Throwable ex = new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Oops"); + this.exchange.getResponse().setStatusCode(HttpStatus.CREATED); + Mono mono = this.exchange.getResponse().setComplete() + .then(Mono.defer(() -> this.handler.handle(this.exchange, ex))); + StepVerifier.create(mono).consumeErrorWith(actual -> assertSame(ex, actual)).verify(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/i18n/AcceptHeaderLocaleContextResolverTests.java b/spring-web/src/test/java/org/springframework/web/server/i18n/AcceptHeaderLocaleContextResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..bf859cab728250a7c85973cf7ab785c64cac0bba --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/i18n/AcceptHeaderLocaleContextResolverTests.java @@ -0,0 +1,152 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.i18n; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Locale; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; + +import static java.util.Locale.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link AcceptHeaderLocaleContextResolver}. + * + * @author Sebastien Deleuze + * @author Juergen Hoeller + */ +public class AcceptHeaderLocaleContextResolverTests { + + private final AcceptHeaderLocaleContextResolver resolver = new AcceptHeaderLocaleContextResolver(); + + + @Test + public void resolve() { + assertEquals(CANADA, this.resolver.resolveLocaleContext(exchange(CANADA)).getLocale()); + assertEquals(US, this.resolver.resolveLocaleContext(exchange(US, CANADA)).getLocale()); + } + + @Test + public void resolvePreferredSupported() { + this.resolver.setSupportedLocales(Collections.singletonList(CANADA)); + assertEquals(CANADA, this.resolver.resolveLocaleContext(exchange(US, CANADA)).getLocale()); + } + + @Test + public void resolvePreferredNotSupported() { + this.resolver.setSupportedLocales(Collections.singletonList(CANADA)); + assertEquals(US, this.resolver.resolveLocaleContext(exchange(US, UK)).getLocale()); + } + + @Test + public void resolvePreferredNotSupportedWithDefault() { + this.resolver.setSupportedLocales(Arrays.asList(US, JAPAN)); + this.resolver.setDefaultLocale(JAPAN); + assertEquals(JAPAN, this.resolver.resolveLocaleContext(exchange(KOREA)).getLocale()); + } + + @Test + public void resolvePreferredAgainstLanguageOnly() { + this.resolver.setSupportedLocales(Collections.singletonList(ENGLISH)); + assertEquals(ENGLISH, this.resolver.resolveLocaleContext(exchange(GERMANY, US, UK)).getLocale()); + } + + @Test + public void resolvePreferredAgainstCountryIfPossible() { + this.resolver.setSupportedLocales(Arrays.asList(ENGLISH, UK)); + assertEquals(UK, this.resolver.resolveLocaleContext(exchange(GERMANY, US, UK)).getLocale()); + } + + @Test + public void resolvePreferredAgainstLanguageWithMultipleSupportedLocales() { + this.resolver.setSupportedLocales(Arrays.asList(GERMAN, US)); + assertEquals(GERMAN, this.resolver.resolveLocaleContext(exchange(GERMANY, US, UK)).getLocale()); + } + + @Test + public void resolveMissingAcceptLanguageHeader() { + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + assertNull(this.resolver.resolveLocaleContext(exchange).getLocale()); + } + + @Test + public void resolveMissingAcceptLanguageHeaderWithDefault() { + this.resolver.setDefaultLocale(US); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + assertEquals(US, this.resolver.resolveLocaleContext(exchange).getLocale()); + } + + @Test + public void resolveEmptyAcceptLanguageHeader() { + MockServerHttpRequest request = MockServerHttpRequest.get("/").header(HttpHeaders.ACCEPT_LANGUAGE, "").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + assertNull(this.resolver.resolveLocaleContext(exchange).getLocale()); + } + + @Test + public void resolveEmptyAcceptLanguageHeaderWithDefault() { + this.resolver.setDefaultLocale(US); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").header(HttpHeaders.ACCEPT_LANGUAGE, "").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + assertEquals(US, this.resolver.resolveLocaleContext(exchange).getLocale()); + } + + @Test + public void resolveInvalidAcceptLanguageHeader() { + MockServerHttpRequest request = MockServerHttpRequest.get("/").header(HttpHeaders.ACCEPT_LANGUAGE, "en_US").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + assertNull(this.resolver.resolveLocaleContext(exchange).getLocale()); + } + + @Test + public void resolveInvalidAcceptLanguageHeaderWithDefault() { + this.resolver.setDefaultLocale(US); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").header(HttpHeaders.ACCEPT_LANGUAGE, "en_US").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + assertEquals(US, this.resolver.resolveLocaleContext(exchange).getLocale()); + } + + @Test + public void defaultLocale() { + this.resolver.setDefaultLocale(JAPANESE); + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + assertEquals(JAPANESE, this.resolver.resolveLocaleContext(exchange).getLocale()); + + request = MockServerHttpRequest.get("/").acceptLanguageAsLocales(US).build(); + exchange = MockServerWebExchange.from(request); + assertEquals(US, this.resolver.resolveLocaleContext(exchange).getLocale()); + } + + + private ServerWebExchange exchange(Locale... locales) { + return MockServerWebExchange.from(MockServerHttpRequest.get("").acceptLanguageAsLocales(locales)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/i18n/FixedLocaleContextResolverTests.java b/spring-web/src/test/java/org/springframework/web/server/i18n/FixedLocaleContextResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..76986d3867bd1f516cec22e264f897844b29bffc --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/i18n/FixedLocaleContextResolverTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.i18n; + +import java.time.ZoneId; +import java.util.Locale; +import java.util.TimeZone; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.context.i18n.TimeZoneAwareLocaleContext; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; + +import static java.util.Locale.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link FixedLocaleContextResolver}. + * + * @author Sebastien Deleuze + */ +public class FixedLocaleContextResolverTests { + + @Before + public void setup() { + Locale.setDefault(US); + } + + @Test + public void resolveDefaultLocale() { + FixedLocaleContextResolver resolver = new FixedLocaleContextResolver(); + assertEquals(US, resolver.resolveLocaleContext(exchange()).getLocale()); + assertEquals(US, resolver.resolveLocaleContext(exchange(CANADA)).getLocale()); + } + + @Test + public void resolveCustomizedLocale() { + FixedLocaleContextResolver resolver = new FixedLocaleContextResolver(FRANCE); + assertEquals(FRANCE, resolver.resolveLocaleContext(exchange()).getLocale()); + assertEquals(FRANCE, resolver.resolveLocaleContext(exchange(CANADA)).getLocale()); + } + + @Test + public void resolveCustomizedAndTimeZoneLocale() { + TimeZone timeZone = TimeZone.getTimeZone(ZoneId.of("UTC")); + FixedLocaleContextResolver resolver = new FixedLocaleContextResolver(FRANCE, timeZone); + TimeZoneAwareLocaleContext context = (TimeZoneAwareLocaleContext) resolver.resolveLocaleContext(exchange()); + assertEquals(FRANCE, context.getLocale()); + assertEquals(timeZone, context.getTimeZone()); + } + + private ServerWebExchange exchange(Locale... locales) { + return MockServerWebExchange.from(MockServerHttpRequest.get("").acceptLanguageAsLocales(locales)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/session/CookieWebSessionIdResolverTests.java b/spring-web/src/test/java/org/springframework/web/server/session/CookieWebSessionIdResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..61e0a0abf88e1cd6259d414dbb129a29b4fb25a5 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/session/CookieWebSessionIdResolverTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server.session; + +import org.junit.Test; + +import org.springframework.http.ResponseCookie; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * Unit tests for {@link CookieWebSessionIdResolver}. + * @author Rossen Stoyanchev + */ +public class CookieWebSessionIdResolverTests { + + private final CookieWebSessionIdResolver resolver = new CookieWebSessionIdResolver(); + + + @Test + public void setSessionId() { + MockServerHttpRequest request = MockServerHttpRequest.get("https://example.org/path").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + this.resolver.setSessionId(exchange, "123"); + + MultiValueMap cookies = exchange.getResponse().getCookies(); + assertEquals(1, cookies.size()); + ResponseCookie cookie = cookies.getFirst(this.resolver.getCookieName()); + assertNotNull(cookie); + assertEquals("SESSION=123; Path=/; Secure; HttpOnly; SameSite=Lax", cookie.toString()); + } + + @Test + public void cookieInitializer() { + this.resolver.addCookieInitializer(builder -> builder.domain("example.org")); + this.resolver.addCookieInitializer(builder -> builder.sameSite("Strict")); + this.resolver.addCookieInitializer(builder -> builder.secure(false)); + + MockServerHttpRequest request = MockServerHttpRequest.get("https://example.org/path").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + this.resolver.setSessionId(exchange, "123"); + + MultiValueMap cookies = exchange.getResponse().getCookies(); + assertEquals(1, cookies.size()); + ResponseCookie cookie = cookies.getFirst(this.resolver.getCookieName()); + assertNotNull(cookie); + assertEquals("SESSION=123; Path=/; Domain=example.org; HttpOnly; SameSite=Strict", cookie.toString()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..1330820613c9fa10bca19922cae2ea465e54ca91 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java @@ -0,0 +1,140 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server.session; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link DefaultWebSessionManager}. + * @author Rossen Stoyanchev + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class DefaultWebSessionManagerTests { + + private DefaultWebSessionManager sessionManager; + + private ServerWebExchange exchange; + + @Mock + private WebSessionIdResolver sessionIdResolver; + + @Mock + private WebSessionStore sessionStore; + + @Mock + private WebSession createSession; + + @Mock + private WebSession updateSession; + + + @Before + public void setUp() throws Exception { + + when(this.createSession.save()).thenReturn(Mono.empty()); + when(this.createSession.getId()).thenReturn("create-session-id"); + when(this.updateSession.getId()).thenReturn("update-session-id"); + + when(this.sessionStore.createWebSession()).thenReturn(Mono.just(this.createSession)); + when(this.sessionStore.retrieveSession(this.updateSession.getId())).thenReturn(Mono.just(this.updateSession)); + + this.sessionManager = new DefaultWebSessionManager(); + this.sessionManager.setSessionIdResolver(this.sessionIdResolver); + this.sessionManager.setSessionStore(this.sessionStore); + + MockServerHttpRequest request = MockServerHttpRequest.get("/path").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + this.exchange = new DefaultServerWebExchange(request, response, this.sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + } + + @Test + public void getSessionSaveWhenCreatedAndNotStartedThenNotSaved() { + + when(this.sessionIdResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList()); + WebSession session = this.sessionManager.getSession(this.exchange).block(); + this.exchange.getResponse().setComplete().block(); + + assertSame(this.createSession, session); + assertFalse(session.isStarted()); + assertFalse(session.isExpired()); + verify(this.createSession, never()).save(); + verify(this.sessionIdResolver, never()).setSessionId(any(), any()); + } + + @Test + public void getSessionSaveWhenCreatedAndStartedThenSavesAndSetsId() { + + when(this.sessionIdResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList()); + WebSession session = this.sessionManager.getSession(this.exchange).block(); + assertSame(this.createSession, session); + String sessionId = this.createSession.getId(); + + when(this.createSession.isStarted()).thenReturn(true); + this.exchange.getResponse().setComplete().block(); + + verify(this.sessionStore).createWebSession(); + verify(this.sessionIdResolver).setSessionId(any(), eq(sessionId)); + verify(this.createSession).save(); + } + + @Test + public void existingSession() { + + String sessionId = this.updateSession.getId(); + when(this.sessionIdResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(sessionId)); + + WebSession actual = this.sessionManager.getSession(this.exchange).block(); + assertNotNull(actual); + assertEquals(sessionId, actual.getId()); + } + + @Test + public void multipleSessionIds() { + + List ids = Arrays.asList("not-this", "not-that", this.updateSession.getId()); + when(this.sessionStore.retrieveSession("not-this")).thenReturn(Mono.empty()); + when(this.sessionStore.retrieveSession("not-that")).thenReturn(Mono.empty()); + when(this.sessionIdResolver.resolveSessionIds(this.exchange)).thenReturn(ids); + WebSession actual = this.sessionManager.getSession(this.exchange).block(); + + assertNotNull(actual); + assertEquals(this.updateSession.getId(), actual.getId()); + } +} diff --git a/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java b/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java new file mode 100644 index 0000000000000000000000000000000000000000..431a113a1fc8a1c73914264e21fceccd817074ec --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server.session; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests using {@link HeaderWebSessionIdResolver}. + * + * @author Greg Turnquist + * @author Rob Winch + */ +public class HeaderWebSessionIdResolverTests { + private HeaderWebSessionIdResolver idResolver; + + private ServerWebExchange exchange; + + @Before + public void setUp() { + this.idResolver = new HeaderWebSessionIdResolver(); + this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/path")); + } + + @Test + public void expireWhenValidThenSetsEmptyHeader() { + this.idResolver.expireSession(this.exchange); + + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); + } + + @Test + public void expireWhenMultipleInvocationThenSetsSingleEmptyHeader() { + this.idResolver.expireSession(this.exchange); + + this.idResolver.expireSession(this.exchange); + + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); + } + + @Test + public void expireWhenAfterSetSessionIdThenSetsEmptyHeader() { + this.idResolver.setSessionId(this.exchange, "123"); + + this.idResolver.expireSession(this.exchange); + + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); + } + + @Test + public void setSessionIdWhenValidThenSetsHeader() { + String id = "123"; + + this.idResolver.setSessionId(this.exchange, id); + + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); + } + + @Test + public void setSessionIdWhenMultipleThenSetsSingleHeader() { + String id = "123"; + this.idResolver.setSessionId(this.exchange, "overriddenByNextInvocation"); + + this.idResolver.setSessionId(this.exchange, id); + + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); + } + + @Test + public void setSessionIdWhenCustomHeaderNameThenSetsHeader() { + String headerName = "x-auth"; + String id = "123"; + this.idResolver.setHeaderName(headerName); + + this.idResolver.setSessionId(this.exchange, id); + + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(headerName)); + } + + @Test(expected = IllegalArgumentException.class) + public void setSessionIdWhenNullIdThenIllegalArgumentException() { + String id = null; + + this.idResolver.setSessionId(this.exchange, id); + } + + @Test + public void resolveSessionIdsWhenNoIdsThenEmpty() { + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertTrue(ids.isEmpty()); + } + + @Test + public void resolveSessionIdsWhenIdThenIdFound() { + String id = "123"; + this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/path") + .header(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME, id)); + + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertEquals(Arrays.asList(id), ids); + } + + @Test + public void resolveSessionIdsWhenMultipleIdsThenIdsFound() { + String id1 = "123"; + String id2 = "abc"; + this.exchange = MockServerWebExchange.from( + MockServerHttpRequest.get("/path") + .header(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME, id1, id2)); + + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertEquals(Arrays.asList(id1, id2), ids); + } +} diff --git a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java new file mode 100644 index 0000000000000000000000000000000000000000..51cbf59f2a7cb6ad4fabdcb0f79b2727d7a807dd --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server.session; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.stream.IntStream; + +import org.junit.Test; + +import org.springframework.beans.DirectFieldAccessor; +import org.springframework.web.server.WebSession; + +import static junit.framework.TestCase.assertSame; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Unit tests for {@link InMemoryWebSessionStore}. + * @author Rob Winch + */ +public class InMemoryWebSessionStoreTests { + + private InMemoryWebSessionStore store = new InMemoryWebSessionStore(); + + + @Test + public void startsSessionExplicitly() { + WebSession session = this.store.createWebSession().block(); + assertNotNull(session); + session.start(); + assertTrue(session.isStarted()); + } + + @Test + public void startsSessionImplicitly() { + WebSession session = this.store.createWebSession().block(); + assertNotNull(session); + session.start(); + session.getAttributes().put("foo", "bar"); + assertTrue(session.isStarted()); + } + + @Test + public void retrieveExpiredSession() { + WebSession session = this.store.createWebSession().block(); + assertNotNull(session); + session.getAttributes().put("foo", "bar"); + session.save().block(); + + String id = session.getId(); + WebSession retrieved = this.store.retrieveSession(id).block(); + assertNotNull(retrieved); + assertSame(session, retrieved); + + // Fast-forward 31 minutes + this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); + WebSession retrievedAgain = this.store.retrieveSession(id).block(); + assertNull(retrievedAgain); + } + + @Test + public void lastAccessTimeIsUpdatedOnRetrieve() { + WebSession session1 = this.store.createWebSession().block(); + assertNotNull(session1); + String id = session1.getId(); + Instant time1 = session1.getLastAccessTime(); + session1.start(); + session1.save().block(); + + // Fast-forward a few seconds + this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofSeconds(5))); + + WebSession session2 = this.store.retrieveSession(id).block(); + assertNotNull(session2); + assertSame(session1, session2); + Instant time2 = session2.getLastAccessTime(); + assertTrue(time1.isBefore(time2)); + } + + @Test // SPR-17051 + public void sessionInvalidatedBeforeSave() { + // Request 1 creates session + WebSession session1 = this.store.createWebSession().block(); + assertNotNull(session1); + String id = session1.getId(); + session1.start(); + session1.save().block(); + + // Request 2 retrieves session + WebSession session2 = this.store.retrieveSession(id).block(); + assertNotNull(session2); + assertSame(session1, session2); + + // Request 3 retrieves and invalidates + WebSession session3 = this.store.retrieveSession(id).block(); + assertNotNull(session3); + assertSame(session1, session3); + session3.invalidate().block(); + + // Request 2 saves session after invalidated + session2.save().block(); + + // Session should not be present + WebSession session4 = this.store.retrieveSession(id).block(); + assertNull(session4); + } + + @Test + public void expirationCheckPeriod() { + + DirectFieldAccessor accessor = new DirectFieldAccessor(this.store); + Map sessions = (Map) accessor.getPropertyValue("sessions"); + assertNotNull(sessions); + + // Create 100 sessions + IntStream.range(0, 100).forEach(i -> insertSession()); + assertEquals(100, sessions.size()); + + // Force a new clock (31 min later), don't use setter which would clean expired sessions + accessor.setPropertyValue("clock", Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); + assertEquals(100, sessions.size()); + + // Create 1 more which forces a time-based check (clock moved forward) + insertSession(); + assertEquals(1, sessions.size()); + } + + @Test + public void maxSessions() { + + IntStream.range(0, 10000).forEach(i -> insertSession()); + + try { + insertSession(); + fail(); + } + catch (IllegalStateException ex) { + assertEquals("Max sessions limit reached: 10000", ex.getMessage()); + } + } + + private WebSession insertSession() { + WebSession session = this.store.createWebSession().block(); + assertNotNull(session); + session.start(); + session.save().block(); + return session; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c935cccb7214f6656f1a9e4f148760ee152150bd --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java @@ -0,0 +1,244 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.server.session; + +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +import static org.junit.Assert.*; + +/** + * Integration tests for with a server-side session. + * + * @author Rossen Stoyanchev + */ +public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private final RestTemplate restTemplate = new RestTemplate(); + + private DefaultWebSessionManager sessionManager; + + private TestWebHandler handler; + + + @Override + protected HttpHandler createHttpHandler() { + this.sessionManager = new DefaultWebSessionManager(); + this.handler = new TestWebHandler(); + return WebHttpHandlerBuilder.webHandler(this.handler).sessionManager(this.sessionManager).build(); + } + + + @Test + public void createSession() throws Exception { + RequestEntity request = RequestEntity.get(createUri()).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String id = extractSessionId(response.getHeaders()); + assertNotNull(id); + assertEquals(1, this.handler.getSessionRequestCount()); + + request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertNull(response.getHeaders().get("Set-Cookie")); + assertEquals(2, this.handler.getSessionRequestCount()); + } + + @Test + public void expiredSessionIsRecreated() throws Exception { + + // First request: no session yet, new session created + RequestEntity request = RequestEntity.get(createUri()).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String id = extractSessionId(response.getHeaders()); + assertNotNull(id); + assertEquals(1, this.handler.getSessionRequestCount()); + + // Second request: same session + request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + assertNull(response.getHeaders().get("Set-Cookie")); + assertEquals(2, this.handler.getSessionRequestCount()); + + // Now fast-forward by 31 minutes + InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore(); + WebSession session = store.retrieveSession(id).block(); + assertNotNull(session); + store.setClock(Clock.offset(store.getClock(), Duration.ofMinutes(31))); + + // Third request: expired session, new session created + request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + id = extractSessionId(response.getHeaders()); + assertNotNull("Expected new session id", id); + assertEquals(1, this.handler.getSessionRequestCount()); + } + + @Test + public void expiredSessionEnds() throws Exception { + + // First request: no session yet, new session created + RequestEntity request = RequestEntity.get(createUri()).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String id = extractSessionId(response.getHeaders()); + assertNotNull(id); + + // Now fast-forward by 31 minutes + InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore(); + store.setClock(Clock.offset(store.getClock(), Duration.ofMinutes(31))); + + // Second request: session expires + URI uri = new URI("http://localhost:" + this.port + "/?expire"); + request = RequestEntity.get(uri).header("Cookie", "SESSION=" + id).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String value = response.getHeaders().getFirst("Set-Cookie"); + assertNotNull(value); + assertTrue("Actual value: " + value, value.contains("Max-Age=0")); + } + + @Test + public void changeSessionId() throws Exception { + + // First request: no session yet, new session created + RequestEntity request = RequestEntity.get(createUri()).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String oldId = extractSessionId(response.getHeaders()); + assertNotNull(oldId); + assertEquals(1, this.handler.getSessionRequestCount()); + + // Second request: session id changes + URI uri = new URI("http://localhost:" + this.port + "/?changeId"); + request = RequestEntity.get(uri).header("Cookie", "SESSION=" + oldId).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String newId = extractSessionId(response.getHeaders()); + assertNotNull("Expected new session id", newId); + assertNotEquals(oldId, newId); + assertEquals(2, this.handler.getSessionRequestCount()); + } + + @Test + public void invalidate() throws Exception { + + // First request: no session yet, new session created + RequestEntity request = RequestEntity.get(createUri()).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String id = extractSessionId(response.getHeaders()); + assertNotNull(id); + + // Second request: invalidates session + URI uri = new URI("http://localhost:" + this.port + "/?invalidate"); + request = RequestEntity.get(uri).header("Cookie", "SESSION=" + id).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String value = response.getHeaders().getFirst("Set-Cookie"); + assertNotNull(value); + assertTrue("Actual value: " + value, value.contains("Max-Age=0")); + } + + private String extractSessionId(HttpHeaders headers) { + List headerValues = headers.get("Set-Cookie"); + assertNotNull(headerValues); + assertEquals(1, headerValues.size()); + + for (String s : headerValues.get(0).split(";")){ + if (s.startsWith("SESSION=")) { + return s.substring("SESSION=".length()); + } + } + return null; + } + + private URI createUri() throws URISyntaxException { + return new URI("http://localhost:" + this.port + "/"); + } + + + private static class TestWebHandler implements WebHandler { + + private AtomicInteger currentValue = new AtomicInteger(); + + + public int getSessionRequestCount() { + return this.currentValue.get(); + } + + @Override + public Mono handle(ServerWebExchange exchange) { + if (exchange.getRequest().getQueryParams().containsKey("expire")) { + return exchange.getSession().doOnNext(session -> { + // Don't do anything, leave it expired... + }).then(); + } + else if (exchange.getRequest().getQueryParams().containsKey("changeId")) { + return exchange.getSession().flatMap(session -> + session.changeSessionId().doOnSuccess(aVoid -> updateSessionAttribute(session))); + } + else if (exchange.getRequest().getQueryParams().containsKey("invalidate")) { + return exchange.getSession().doOnNext(WebSession::invalidate).then(); + } + else { + return exchange.getSession().doOnSuccess(this::updateSessionAttribute).then(); + } + } + + private void updateSessionAttribute(WebSession session) { + int value = session.getAttributeOrDefault("counter", 0); + session.getAttributes().put("counter", ++value); + this.currentValue.set(value); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/ContentCachingRequestWrapperTests.java b/spring-web/src/test/java/org/springframework/web/util/ContentCachingRequestWrapperTests.java new file mode 100644 index 0000000000000000000000000000000000000000..81c5212b4fae26e636f215c8968713c116bfbe0e --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/ContentCachingRequestWrapperTests.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.util.FileCopyUtils; + +import static org.junit.Assert.*; + +/** + * @author Brian Clozel + */ +public class ContentCachingRequestWrapperTests { + + protected static final String FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"; + + protected static final String CHARSET = "UTF-8"; + + private final MockHttpServletRequest request = new MockHttpServletRequest(); + + + @Test + public void cachedContent() throws Exception { + this.request.setMethod("GET"); + this.request.setCharacterEncoding(CHARSET); + this.request.setContent("Hello World".getBytes(CHARSET)); + + ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request); + byte[] response = FileCopyUtils.copyToByteArray(wrapper.getInputStream()); + assertArrayEquals(response, wrapper.getContentAsByteArray()); + } + + @Test + public void cachedContentWithLimit() throws Exception { + this.request.setMethod("GET"); + this.request.setCharacterEncoding(CHARSET); + this.request.setContent("Hello World".getBytes(CHARSET)); + + ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request, 3); + byte[] response = FileCopyUtils.copyToByteArray(wrapper.getInputStream()); + assertArrayEquals("Hello World".getBytes(CHARSET), response); + assertArrayEquals("Hel".getBytes(CHARSET), wrapper.getContentAsByteArray()); + } + + @Test + public void cachedContentWithOverflow() throws Exception { + this.request.setMethod("GET"); + this.request.setCharacterEncoding(CHARSET); + this.request.setContent("Hello World".getBytes(CHARSET)); + + ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request, 3) { + @Override + protected void handleContentOverflow(int contentCacheLimit) { + throw new IllegalStateException(String.valueOf(contentCacheLimit)); + } + }; + + try { + FileCopyUtils.copyToByteArray(wrapper.getInputStream()); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + assertEquals("3", ex.getMessage()); + } + } + + @Test + public void requestParams() throws Exception { + this.request.setMethod("POST"); + this.request.setContentType(FORM_CONTENT_TYPE); + this.request.setCharacterEncoding(CHARSET); + this.request.setParameter("first", "value"); + this.request.setParameter("second", "foo", "bar"); + + ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request); + // getting request parameters will consume the request body + assertFalse(wrapper.getParameterMap().isEmpty()); + assertEquals("first=value&second=foo&second=bar", new String(wrapper.getContentAsByteArray())); + // SPR-12810 : inputstream body should be consumed + assertEquals("", new String(FileCopyUtils.copyToByteArray(wrapper.getInputStream()))); + } + + @Test // SPR-12810 + public void inputStreamFormPostRequest() throws Exception { + this.request.setMethod("POST"); + this.request.setContentType(FORM_CONTENT_TYPE); + this.request.setCharacterEncoding(CHARSET); + this.request.setParameter("first", "value"); + this.request.setParameter("second", "foo", "bar"); + + ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request); + + byte[] response = FileCopyUtils.copyToByteArray(wrapper.getInputStream()); + assertArrayEquals(response, wrapper.getContentAsByteArray()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/DefaultUriBuilderFactoryTests.java b/spring-web/src/test/java/org/springframework/web/util/DefaultUriBuilderFactoryTests.java new file mode 100644 index 0000000000000000000000000000000000000000..4d6ed8e87b8afd6d1e94f5d16c5f03c88bb82c73 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/DefaultUriBuilderFactoryTests.java @@ -0,0 +1,191 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.web.util.DefaultUriBuilderFactory.EncodingMode; + +import static java.util.Collections.singletonMap; +import static junit.framework.TestCase.assertEquals; + +/** + * Unit tests for {@link DefaultUriBuilderFactory}. + * @author Rossen Stoyanchev + */ +public class DefaultUriBuilderFactoryTests { + + @Test + public void defaultSettings() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + URI uri = factory.uriString("/foo/{id}").build("a/b"); + assertEquals("/foo/a%2Fb", uri.toString()); + } + + @Test // SPR-17465 + public void defaultSettingsWithBuilder() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + URI uri = factory.builder().path("/foo/{id}").build("a/b"); + assertEquals("/foo/a%2Fb", uri.toString()); + } + + @Test + public void baseUri() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("http://example.com/v1?id=123"); + URI uri = factory.uriString("/bar").port(8080).build(); + assertEquals("http://example.com:8080/v1/bar?id=123", uri.toString()); + } + + @Test + public void baseUriWithFullOverride() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("http://example.com/v1?id=123"); + URI uri = factory.uriString("https://example.com/1/2").build(); + assertEquals("Use of host should case baseUri to be completely ignored", + "https://example.com/1/2", uri.toString()); + } + + @Test + public void baseUriWithPathOverride() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("http://example.com/v1"); + URI uri = factory.builder().replacePath("/baz").build(); + assertEquals("http://example.com/baz", uri.toString()); + } + + @Test + public void defaultUriVars() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("http://{host}/v1"); + factory.setDefaultUriVariables(singletonMap("host", "example.com")); + URI uri = factory.uriString("/{id}").build(singletonMap("id", "123")); + assertEquals("http://example.com/v1/123", uri.toString()); + } + + @Test + public void defaultUriVarsWithOverride() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("http://{host}/v1"); + factory.setDefaultUriVariables(singletonMap("host", "spring.io")); + URI uri = factory.uriString("/bar").build(singletonMap("host", "docs.spring.io")); + assertEquals("http://docs.spring.io/v1/bar", uri.toString()); + } + + @Test + public void defaultUriVarsWithEmptyVarArg() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("http://{host}/v1"); + factory.setDefaultUriVariables(singletonMap("host", "example.com")); + URI uri = factory.uriString("/bar").build(); + assertEquals("Expected delegation to build(Map) method", "http://example.com/v1/bar", uri.toString()); + } + + @Test + public void defaultUriVarsSpr14147() { + Map defaultUriVars = new HashMap<>(2); + defaultUriVars.put("host", "api.example.com"); + defaultUriVars.put("port", "443"); + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + factory.setDefaultUriVariables(defaultUriVars); + + URI uri = factory.expand("https://{host}:{port}/v42/customers/{id}", singletonMap("id", 123L)); + assertEquals("https://api.example.com:443/v42/customers/123", uri.toString()); + } + + @Test + public void encodeTemplateAndValues() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + factory.setEncodingMode(EncodingMode.TEMPLATE_AND_VALUES); + UriBuilder uriBuilder = factory.uriString("/hotel list/{city} specials?q={value}"); + + String expected = "/hotel%20list/Z%C3%BCrich%20specials?q=a%2Bb"; + + Map vars = new HashMap<>(); + vars.put("city", "Z\u00fcrich"); + vars.put("value", "a+b"); + + assertEquals(expected, uriBuilder.build("Z\u00fcrich", "a+b").toString()); + assertEquals(expected, uriBuilder.build(vars).toString()); + } + + @Test + public void encodingValuesOnly() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + factory.setEncodingMode(EncodingMode.VALUES_ONLY); + UriBuilder uriBuilder = factory.uriString("/foo/a%2Fb/{id}"); + + String id = "c/d"; + String expected = "/foo/a%2Fb/c%2Fd"; + + assertEquals(expected, uriBuilder.build(id).toString()); + assertEquals(expected, uriBuilder.build(singletonMap("id", id)).toString()); + } + + @Test + public void encodingValuesOnlySpr14147() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + factory.setEncodingMode(EncodingMode.VALUES_ONLY); + factory.setDefaultUriVariables(singletonMap("host", "example.com")); + UriBuilder uriBuilder = factory.uriString("http://{host}/user/{userId}/dashboard"); + + assertEquals("http://example.com/user/john%3Bdoe/dashboard", + uriBuilder.build(singletonMap("userId", "john;doe")).toString()); + } + + @Test + public void encodingNone() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + factory.setEncodingMode(EncodingMode.NONE); + UriBuilder uriBuilder = factory.uriString("/foo/a%2Fb/{id}"); + + String id = "c%2Fd"; + String expected = "/foo/a%2Fb/c%2Fd"; + + assertEquals(expected, uriBuilder.build(id).toString()); + assertEquals(expected, uriBuilder.build(singletonMap("id", id)).toString()); + } + + @Test + public void parsePathWithDefaultSettings() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("/foo/{bar}"); + URI uri = factory.uriString("/baz/{id}").build("a/b", "c/d"); + assertEquals("/foo/a%2Fb/baz/c%2Fd", uri.toString()); + } + + @Test + public void parsePathIsTurnedOff() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory("/foo/{bar}"); + factory.setEncodingMode(EncodingMode.URI_COMPONENT); + factory.setParsePath(false); + URI uri = factory.uriString("/baz/{id}").build("a/b", "c/d"); + assertEquals("/foo/a/b/baz/c/d", uri.toString()); + } + + @Test // SPR-15201 + public void pathWithTrailingSlash() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + URI uri = factory.expand("http://localhost:8080/spring/"); + assertEquals("http://localhost:8080/spring/", uri.toString()); + } + + @Test + public void pathWithDuplicateSlashes() { + DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(); + URI uri = factory.expand("/foo/////////bar"); + assertEquals("/foo/bar", uri.toString()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/DefaultUriTemplateHandlerTests.java b/spring-web/src/test/java/org/springframework/web/util/DefaultUriTemplateHandlerTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a3666d375cf02d7f37145260bc3e0299cc7a755b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/DefaultUriTemplateHandlerTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link DefaultUriTemplateHandler}. + * + * @author Rossen Stoyanchev + */ +@SuppressWarnings("deprecation") +public class DefaultUriTemplateHandlerTests { + + private final DefaultUriTemplateHandler handler = new DefaultUriTemplateHandler(); + + + @Test + public void baseUrlWithoutPath() throws Exception { + this.handler.setBaseUrl("http://localhost:8080"); + URI actual = this.handler.expand("/myapiresource"); + + assertEquals("http://localhost:8080/myapiresource", actual.toString()); + } + + @Test + public void baseUrlWithPath() throws Exception { + this.handler.setBaseUrl("http://localhost:8080/context"); + URI actual = this.handler.expand("/myapiresource"); + + assertEquals("http://localhost:8080/context/myapiresource", actual.toString()); + } + + @Test // SPR-14147 + public void defaultUriVariables() throws Exception { + Map defaultVars = new HashMap<>(2); + defaultVars.put("host", "api.example.com"); + defaultVars.put("port", "443"); + this.handler.setDefaultUriVariables(defaultVars); + + Map vars = new HashMap<>(1); + vars.put("id", 123L); + + String template = "https://{host}:{port}/v42/customers/{id}"; + URI actual = this.handler.expand(template, vars); + + assertEquals("https://api.example.com:443/v42/customers/123", actual.toString()); + } + + @Test + public void parsePathIsOff() throws Exception { + this.handler.setParsePath(false); + Map vars = new HashMap<>(2); + vars.put("hotel", "1"); + vars.put("publicpath", "pics/logo.png"); + String template = "https://example.com/hotels/{hotel}/pic/{publicpath}"; + URI actual = this.handler.expand(template, vars); + + assertEquals("https://example.com/hotels/1/pic/pics/logo.png", actual.toString()); + } + + @Test + public void parsePathIsOn() throws Exception { + this.handler.setParsePath(true); + Map vars = new HashMap<>(2); + vars.put("hotel", "1"); + vars.put("publicpath", "pics/logo.png"); + vars.put("scale", "150x150"); + String template = "https://example.com/hotels/{hotel}/pic/{publicpath}/size/{scale}"; + URI actual = this.handler.expand(template, vars); + + assertEquals("https://example.com/hotels/1/pic/pics%2Flogo.png/size/150x150", actual.toString()); + } + + @Test + public void strictEncodingIsOffWithMap() throws Exception { + this.handler.setStrictEncoding(false); + Map vars = new HashMap<>(2); + vars.put("userId", "john;doe"); + String template = "https://www.example.com/user/{userId}/dashboard"; + URI actual = this.handler.expand(template, vars); + + assertEquals("https://www.example.com/user/john;doe/dashboard", actual.toString()); + } + + @Test + public void strictEncodingOffWithArray() throws Exception { + this.handler.setStrictEncoding(false); + String template = "https://www.example.com/user/{userId}/dashboard"; + URI actual = this.handler.expand(template, "john;doe"); + + assertEquals("https://www.example.com/user/john;doe/dashboard", actual.toString()); + } + + @Test + public void strictEncodingOnWithMap() throws Exception { + this.handler.setStrictEncoding(true); + Map vars = new HashMap<>(2); + vars.put("userId", "john;doe"); + String template = "https://www.example.com/user/{userId}/dashboard"; + URI actual = this.handler.expand(template, vars); + + assertEquals("https://www.example.com/user/john%3Bdoe/dashboard", actual.toString()); + } + + @Test + public void strictEncodingOnWithArray() throws Exception { + this.handler.setStrictEncoding(true); + String template = "https://www.example.com/user/{userId}/dashboard"; + URI actual = this.handler.expand(template, "john;doe"); + + assertEquals("https://www.example.com/user/john%3Bdoe/dashboard", actual.toString()); + } + + @Test // SPR-14147 + public void strictEncodingAndDefaultUriVariables() throws Exception { + Map defaultVars = new HashMap<>(1); + defaultVars.put("host", "www.example.com"); + this.handler.setDefaultUriVariables(defaultVars); + this.handler.setStrictEncoding(true); + + Map vars = new HashMap<>(1); + vars.put("userId", "john;doe"); + + String template = "https://{host}/user/{userId}/dashboard"; + URI actual = this.handler.expand(template, vars); + + assertEquals("https://www.example.com/user/john%3Bdoe/dashboard", actual.toString()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/HtmlCharacterEntityReferencesTests.java b/spring-web/src/test/java/org/springframework/web/util/HtmlCharacterEntityReferencesTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5d0f668378fc532efee2936faa5397a5ff369f94 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/HtmlCharacterEntityReferencesTests.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.StreamTokenizer; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * @author Martin Kersten + * @author Juergen Hoeller + */ +public class HtmlCharacterEntityReferencesTests { + + private static final String DTD_FILE = "HtmlCharacterEntityReferences.dtd"; + + @Test + public void testSupportsAllCharacterEntityReferencesDefinedByHtml() { + HtmlCharacterEntityReferences entityReferences = new HtmlCharacterEntityReferences(); + Map referenceCharactersMap = getReferenceCharacterMap(); + + for (int character = 0; character < 10000; character++) { + String referenceName = referenceCharactersMap.get(character); + if (referenceName != null) { + String fullReference = + HtmlCharacterEntityReferences.REFERENCE_START + + referenceName + + HtmlCharacterEntityReferences.REFERENCE_END; + assertTrue("The unicode character " + character + " should be mapped to a reference", + entityReferences.isMappedToReference((char) character)); + assertEquals("The reference of unicode character " + character + " should be entity " + referenceName, + fullReference, entityReferences.convertToReference((char) character)); + assertEquals("The entity reference [" + referenceName + "] should be mapped to unicode character " + + character, (char) character, entityReferences.convertToCharacter(referenceName)); + } + else if (character == 39) { + assertTrue(entityReferences.isMappedToReference((char) character)); + assertEquals("'", entityReferences.convertToReference((char) character)); + } + else { + assertFalse("The unicode character " + character + " should not be mapped to a reference", + entityReferences.isMappedToReference((char) character)); + assertNull("No entity reference of unicode character " + character + " should exist", + entityReferences.convertToReference((char) character)); + } + } + + assertEquals("The registered entity count of entityReferences should match the number of entity references", + referenceCharactersMap.size() + 1, entityReferences.getSupportedReferenceCount()); + assertEquals("The HTML 4.0 Standard defines 252+1 entity references so do entityReferences", + 252 + 1, entityReferences.getSupportedReferenceCount()); + + assertEquals("Invalid entity reference names should not be convertible", + (char) -1, entityReferences.convertToCharacter("invalid")); + } + + // SPR-9293 + @Test + public void testConvertToReferenceUTF8() { + HtmlCharacterEntityReferences entityReferences = new HtmlCharacterEntityReferences(); + String utf8 = "UTF-8"; + assertEquals("<", entityReferences.convertToReference('<', utf8)); + assertEquals(">", entityReferences.convertToReference('>', utf8)); + assertEquals("&", entityReferences.convertToReference('&', utf8)); + assertEquals(""", entityReferences.convertToReference('"', utf8)); + assertEquals("'", entityReferences.convertToReference('\'', utf8)); + assertNull(entityReferences.convertToReference((char) 233, utf8)); + assertNull(entityReferences.convertToReference((char) 934, utf8)); + } + + private Map getReferenceCharacterMap() { + CharacterEntityResourceIterator entityIterator = new CharacterEntityResourceIterator(); + Map referencedCharactersMap = new HashMap<>(); + while (entityIterator.hasNext()) { + int character = entityIterator.getReferredCharacter(); + String entityName = entityIterator.nextEntry(); + referencedCharactersMap.put(new Integer(character), entityName); + } + return referencedCharactersMap; + } + + + private static class CharacterEntityResourceIterator { + + private final StreamTokenizer tokenizer; + + private String currentEntityName = null; + + private int referredCharacter = -1; + + public CharacterEntityResourceIterator() { + try { + InputStream inputStream = getClass().getResourceAsStream(DTD_FILE); + if (inputStream == null) { + throw new IOException("Cannot find definition resource [" + DTD_FILE + "]"); + } + tokenizer = new StreamTokenizer(new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to open definition resource [" + DTD_FILE + "]"); + } + } + + public boolean hasNext() { + return (currentEntityName != null || readNextEntity()); + } + + public String nextEntry() { + if (hasNext()) { + String entityName = currentEntityName; + currentEntityName = null; + return entityName; + } + return null; + } + + public int getReferredCharacter() { + return referredCharacter; + } + + private boolean readNextEntity() { + try { + while (navigateToNextEntity()) { + String entityName = nextWordToken(); + if ("CDATA".equals(nextWordToken())) { + int referredCharacter = nextReferredCharacterId(); + if (entityName != null && referredCharacter != -1) { + this.currentEntityName = entityName; + this.referredCharacter = referredCharacter; + return true; + } + } + } + return false; + } + catch (IOException ex) { + throw new IllegalStateException("Could not parse definition resource: " + ex.getMessage()); + } + } + + private boolean navigateToNextEntity() throws IOException { + while (tokenizer.nextToken() != StreamTokenizer.TT_WORD || !"ENTITY".equals(tokenizer.sval)) { + if (tokenizer.ttype == StreamTokenizer.TT_EOF) { + return false; + } + } + return true; + } + + private int nextReferredCharacterId() throws IOException { + String reference = nextWordToken(); + if (reference != null && reference.startsWith("&#") && reference.endsWith(";")) { + return Integer.parseInt(reference.substring(2, reference.length() - 1)); + } + return -1; + } + + private String nextWordToken() throws IOException { + tokenizer.nextToken(); + return tokenizer.sval; + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/HtmlUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/HtmlUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..5331e9ac3ce8cccaa7cb360fa406b58e48919f46 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/HtmlUtilsTests.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * @author Alef Arendsen + * @author Martin Kersten + * @author Rick Evans + */ +public class HtmlUtilsTests { + + @Test + public void testHtmlEscape() { + String unescaped = "\"This is a quote'"; + String escaped = HtmlUtils.htmlEscape(unescaped); + assertEquals(""This is a quote'", escaped); + escaped = HtmlUtils.htmlEscapeDecimal(unescaped); + assertEquals(""This is a quote'", escaped); + escaped = HtmlUtils.htmlEscapeHex(unescaped); + assertEquals(""This is a quote'", escaped); + } + + @Test + public void testHtmlUnescape() { + String escaped = ""This is a quote'"; + String unescaped = HtmlUtils.htmlUnescape(escaped); + assertEquals("\"This is a quote'", unescaped); + } + + @Test + public void testEncodeIntoHtmlCharacterSet() { + assertEquals("An empty string should be converted to an empty string", + "", HtmlUtils.htmlEscape("")); + assertEquals("A string containing no special characters should not be affected", + "A sentence containing no special characters.", + HtmlUtils.htmlEscape("A sentence containing no special characters.")); + + assertEquals("'< >' should be encoded to '< >'", + "< >", HtmlUtils.htmlEscape("< >")); + assertEquals("'< >' should be encoded to '< >'", + "< >", HtmlUtils.htmlEscapeDecimal("< >")); + + assertEquals("The special character 8709 should be encoded to '∅'", + "∅", HtmlUtils.htmlEscape("" + (char) 8709)); + assertEquals("The special character 8709 should be encoded to '∅'", + "∅", HtmlUtils.htmlEscapeDecimal("" + (char) 8709)); + + assertEquals("The special character 977 should be encoded to 'ϑ'", + "ϑ", HtmlUtils.htmlEscape("" + (char) 977)); + assertEquals("The special character 977 should be encoded to 'ϑ'", + "ϑ", HtmlUtils.htmlEscapeDecimal("" + (char) 977)); + } + + // SPR-9293 + @Test + public void testEncodeIntoHtmlCharacterSetFromUtf8() { + String utf8 = ("UTF-8"); + assertEquals("An empty string should be converted to an empty string", + "", HtmlUtils.htmlEscape("", utf8)); + assertEquals("A string containing no special characters should not be affected", + "A sentence containing no special characters.", + HtmlUtils.htmlEscape("A sentence containing no special characters.")); + + assertEquals("'< >' should be encoded to '< >'", + "< >", HtmlUtils.htmlEscape("< >", utf8)); + assertEquals("'< >' should be encoded to '< >'", + "< >", HtmlUtils.htmlEscapeDecimal("< >", utf8)); + + assertEquals("UTF-8 supported chars should not be escaped", + "Μερικοί Ελληνικοί "χαρακτήρες"", + HtmlUtils.htmlEscape("Μερικοί Ελληνικοί \"χαρακτήρες\"", utf8)); + } + + @Test + public void testDecodeFromHtmlCharacterSet() { + assertEquals("An empty string should be converted to an empty string", + "", HtmlUtils.htmlUnescape("")); + assertEquals("A string containing no special characters should not be affected", + "This is a sentence containing no special characters.", + HtmlUtils.htmlUnescape("This is a sentence containing no special characters.")); + + assertEquals("'A B' should be decoded to 'A B'", + "A" + (char) 160 + "B", HtmlUtils.htmlUnescape("A B")); + + assertEquals("'< >' should be decoded to '< >'", + "< >", HtmlUtils.htmlUnescape("< >")); + assertEquals("'< >' should be decoded to '< >'", + "< >", HtmlUtils.htmlUnescape("< >")); + + assertEquals("'ABC' should be decoded to 'ABC'", + "ABC", HtmlUtils.htmlUnescape("ABC")); + + assertEquals("'φ' should be decoded to uni-code character 966", + "" + (char) 966, HtmlUtils.htmlUnescape("φ")); + + assertEquals("'″' should be decoded to uni-code character 8243", + "" + (char) 8243, HtmlUtils.htmlUnescape("″")); + + assertEquals("A not supported named reference leads should be ignored", + "&prIme;", HtmlUtils.htmlUnescape("&prIme;")); + + assertEquals("An empty reference '&;' should be survive the decoding", + "&;", HtmlUtils.htmlUnescape("&;")); + + assertEquals("The longest character entity reference 'ϑ' should be processable", + "" + (char) 977, HtmlUtils.htmlUnescape("ϑ")); + + assertEquals("A malformed decimal reference should survive the decoding", + "&#notADecimalNumber;", HtmlUtils.htmlUnescape("&#notADecimalNumber;")); + assertEquals("A malformed hex reference should survive the decoding", + "&#XnotAHexNumber;", HtmlUtils.htmlUnescape("&#XnotAHexNumber;")); + + assertEquals("The numerical reference '' should be converted to char 1", + "" + (char) 1, HtmlUtils.htmlUnescape("")); + + assertEquals("The malformed hex reference '&#x;' should remain '&#x;'", + "&#x;", HtmlUtils.htmlUnescape("&#x;")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/JavaScriptUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/JavaScriptUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..8415db47f793d5a692dd4ddd7a833d8ce451729a --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/JavaScriptUtilsTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2004-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.UnsupportedEncodingException; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Test fixture for {@link JavaScriptUtils}. + * + * @author Rossen Stoyanchev + */ +public class JavaScriptUtilsTests { + + @Test + public void escape() { + StringBuilder sb = new StringBuilder(); + sb.append('"'); + sb.append("'"); + sb.append("\\"); + sb.append("/"); + sb.append("\t"); + sb.append("\n"); + sb.append("\r"); + sb.append("\f"); + sb.append("\b"); + sb.append("\013"); + assertEquals("\\\"\\'\\\\\\/\\t\\n\\n\\f\\b\\v", JavaScriptUtils.javaScriptEscape(sb.toString())); + } + + // SPR-9983 + + @Test + public void escapePsLsLineTerminators() { + StringBuilder sb = new StringBuilder(); + sb.append('\u2028'); + sb.append('\u2029'); + String result = JavaScriptUtils.javaScriptEscape(sb.toString()); + + assertEquals("\\u2028\\u2029", result); + } + + // SPR-9983 + + @Test + public void escapeLessThanGreaterThanSigns() throws UnsupportedEncodingException { + assertEquals("\\u003C\\u003E", JavaScriptUtils.javaScriptEscape("<>")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/ServletContextPropertyUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/ServletContextPropertyUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..62b248154907fd4f15af59b806078555ccb30630 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/ServletContextPropertyUtilsTests.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */package org.springframework.web.util; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockServletContext; + +import static org.junit.Assert.*; + +/** + * @author Marten Deinum + * @since 3.2.2 + */ +public class ServletContextPropertyUtilsTests { + + @Test + public void resolveAsServletContextInitParameter() { + MockServletContext servletContext = new MockServletContext(); + servletContext.setInitParameter("test.prop", "bar"); + String resolved = ServletContextPropertyUtils.resolvePlaceholders("${test.prop:foo}", servletContext); + assertEquals("bar", resolved); + } + + @Test + public void fallbackToSystemProperties() { + MockServletContext servletContext = new MockServletContext(); + System.setProperty("test.prop", "bar"); + try { + String resolved = ServletContextPropertyUtils.resolvePlaceholders("${test.prop:foo}", servletContext); + assertEquals("bar", resolved); + } + finally { + System.clearProperty("test.prop"); + } + } +} diff --git a/spring-web/src/test/java/org/springframework/web/util/TagUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/TagUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..dd7ef947450576ec20f9ca1acf8a070b1421fb58 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/TagUtilsTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import javax.servlet.jsp.PageContext; +import javax.servlet.jsp.tagext.Tag; +import javax.servlet.jsp.tagext.TagSupport; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Unit tests for the {@link TagUtils} class. + * + * @author Alef Arendsen + * @author Rick Evans + */ +public class TagUtilsTests { + + @Test + public void getScopeSunnyDay() { + assertEquals("page", TagUtils.SCOPE_PAGE); + assertEquals("application", TagUtils.SCOPE_APPLICATION); + assertEquals("session", TagUtils.SCOPE_SESSION); + assertEquals("request", TagUtils.SCOPE_REQUEST); + + assertEquals(PageContext.PAGE_SCOPE, TagUtils.getScope("page")); + assertEquals(PageContext.REQUEST_SCOPE, TagUtils.getScope("request")); + assertEquals(PageContext.SESSION_SCOPE, TagUtils.getScope("session")); + assertEquals(PageContext.APPLICATION_SCOPE, TagUtils.getScope("application")); + + // non-existent scope + assertEquals("TagUtils.getScope(..) with a non-existent scope argument must " + + "just return the default scope (PageContext.PAGE_SCOPE).", PageContext.PAGE_SCOPE, + TagUtils.getScope("bla")); + } + + @Test(expected = IllegalArgumentException.class) + public void getScopeWithNullScopeArgument() { + TagUtils.getScope(null); + } + + @Test(expected = IllegalArgumentException.class) + public void hasAncestorOfTypeWhereAncestorTagIsNotATagType() throws Exception { + assertFalse(TagUtils.hasAncestorOfType(new TagSupport(), String.class)); + } + + @Test(expected = IllegalArgumentException.class) + public void hasAncestorOfTypeWithNullTagArgument() throws Exception { + assertFalse(TagUtils.hasAncestorOfType(null, TagSupport.class)); + } + + @Test(expected = IllegalArgumentException.class) + public void hasAncestorOfTypeWithNullAncestorTagClassArgument() throws Exception { + assertFalse(TagUtils.hasAncestorOfType(new TagSupport(), null)); + } + + @Test + public void hasAncestorOfTypeTrueScenario() throws Exception { + Tag a = new TagA(); + Tag b = new TagB(); + Tag c = new TagC(); + + a.setParent(b); + b.setParent(c); + + assertTrue(TagUtils.hasAncestorOfType(a, TagC.class)); + } + + @Test + public void hasAncestorOfTypeFalseScenario() throws Exception { + Tag a = new TagA(); + Tag b = new TagB(); + Tag anotherB = new TagB(); + + a.setParent(b); + b.setParent(anotherB); + + assertFalse(TagUtils.hasAncestorOfType(a, TagC.class)); + } + + @Test + public void hasAncestorOfTypeWhenTagHasNoParent() throws Exception { + assertFalse(TagUtils.hasAncestorOfType(new TagA(), TagC.class)); + } + + @Test(expected = IllegalArgumentException.class) + public void assertHasAncestorOfTypeWithNullTagName() throws Exception { + TagUtils.assertHasAncestorOfType(new TagA(), TagC.class, null, "c"); + } + + @Test(expected = IllegalArgumentException.class) + public void assertHasAncestorOfTypeWithNullAncestorTagName() throws Exception { + TagUtils.assertHasAncestorOfType(new TagA(), TagC.class, "a", null); + } + + @Test(expected = IllegalStateException.class) + public void assertHasAncestorOfTypeThrowsExceptionOnFail() throws Exception { + Tag a = new TagA(); + Tag b = new TagB(); + Tag anotherB = new TagB(); + + a.setParent(b); + b.setParent(anotherB); + + TagUtils.assertHasAncestorOfType(a, TagC.class, "a", "c"); + } + + @Test + public void testAssertHasAncestorOfTypeDoesNotThrowExceptionOnPass() throws Exception { + Tag a = new TagA(); + Tag b = new TagB(); + Tag c = new TagC(); + + a.setParent(b); + b.setParent(c); + + TagUtils.assertHasAncestorOfType(a, TagC.class, "a", "c"); + } + + @SuppressWarnings("serial") + private static final class TagA extends TagSupport { + + } + + @SuppressWarnings("serial") + private static final class TagB extends TagSupport { + + } + + @SuppressWarnings("serial") + private static final class TagC extends TagSupport { + + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java new file mode 100644 index 0000000000000000000000000000000000000000..959aa563b715dd49ccb3566439659d00ec2148a2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java @@ -0,0 +1,975 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.http.HttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + +/** + * Unit tests for {@link UriComponentsBuilder}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Phillip Webb + * @author Oliver Gierke + * @author Juergen Hoeller + * @author Sam Brannen + * @author David Eckel + */ +public class UriComponentsBuilderTests { + + @Test + public void plain() throws URISyntaxException { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + UriComponents result = builder.scheme("https").host("example.com") + .path("foo").queryParam("bar").fragment("baz") + .build(); + assertEquals("https", result.getScheme()); + assertEquals("example.com", result.getHost()); + assertEquals("foo", result.getPath()); + assertEquals("bar", result.getQuery()); + assertEquals("baz", result.getFragment()); + + URI expected = new URI("https://example.com/foo?bar#baz"); + assertEquals("Invalid result URI", expected, result.toUri()); + } + + @Test + public void multipleFromSameBuilder() throws URISyntaxException { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance() + .scheme("https").host("example.com").pathSegment("foo"); + UriComponents result1 = builder.build(); + builder = builder.pathSegment("foo2").queryParam("bar").fragment("baz"); + UriComponents result2 = builder.build(); + + assertEquals("https", result1.getScheme()); + assertEquals("example.com", result1.getHost()); + assertEquals("/foo", result1.getPath()); + URI expected = new URI("https://example.com/foo"); + assertEquals("Invalid result URI", expected, result1.toUri()); + + assertEquals("https", result2.getScheme()); + assertEquals("example.com", result2.getHost()); + assertEquals("/foo/foo2", result2.getPath()); + assertEquals("bar", result2.getQuery()); + assertEquals("baz", result2.getFragment()); + expected = new URI("https://example.com/foo/foo2?bar#baz"); + assertEquals("Invalid result URI", expected, result2.toUri()); + } + + @Test + public void fromPath() throws URISyntaxException { + UriComponents result = UriComponentsBuilder.fromPath("foo").queryParam("bar").fragment("baz").build(); + assertEquals("foo", result.getPath()); + assertEquals("bar", result.getQuery()); + assertEquals("baz", result.getFragment()); + + assertEquals("Invalid result URI String", "foo?bar#baz", result.toUriString()); + + URI expected = new URI("foo?bar#baz"); + assertEquals("Invalid result URI", expected, result.toUri()); + + result = UriComponentsBuilder.fromPath("/foo").build(); + assertEquals("/foo", result.getPath()); + + expected = new URI("/foo"); + assertEquals("Invalid result URI", expected, result.toUri()); + } + + @Test + public void fromHierarchicalUri() throws URISyntaxException { + URI uri = new URI("https://example.com/foo?bar#baz"); + UriComponents result = UriComponentsBuilder.fromUri(uri).build(); + assertEquals("https", result.getScheme()); + assertEquals("example.com", result.getHost()); + assertEquals("/foo", result.getPath()); + assertEquals("bar", result.getQuery()); + assertEquals("baz", result.getFragment()); + + assertEquals("Invalid result URI", uri, result.toUri()); + } + + @Test + public void fromOpaqueUri() throws URISyntaxException { + URI uri = new URI("mailto:foo@bar.com#baz"); + UriComponents result = UriComponentsBuilder.fromUri(uri).build(); + assertEquals("mailto", result.getScheme()); + assertEquals("foo@bar.com", result.getSchemeSpecificPart()); + assertEquals("baz", result.getFragment()); + + assertEquals("Invalid result URI", uri, result.toUri()); + } + + @Test // SPR-9317 + public void fromUriEncodedQuery() throws URISyntaxException { + URI uri = new URI("https://www.example.org/?param=aGVsbG9Xb3JsZA%3D%3D"); + String fromUri = UriComponentsBuilder.fromUri(uri).build().getQueryParams().get("param").get(0); + String fromUriString = UriComponentsBuilder.fromUriString(uri.toString()) + .build().getQueryParams().get("param").get(0); + + assertEquals(fromUri, fromUriString); + } + + @Test + public void fromUriString() { + UriComponents result = UriComponentsBuilder.fromUriString("https://www.ietf.org/rfc/rfc3986.txt").build(); + assertEquals("https", result.getScheme()); + assertNull(result.getUserInfo()); + assertEquals("www.ietf.org", result.getHost()); + assertEquals(-1, result.getPort()); + assertEquals("/rfc/rfc3986.txt", result.getPath()); + assertEquals(Arrays.asList("rfc", "rfc3986.txt"), result.getPathSegments()); + assertNull(result.getQuery()); + assertNull(result.getFragment()); + + String url = "https://arjen:foobar@java.sun.com:80" + + "/javase/6/docs/api/java/util/BitSet.html?foo=bar#and(java.util.BitSet)"; + result = UriComponentsBuilder.fromUriString(url).build(); + assertEquals("https", result.getScheme()); + assertEquals("arjen:foobar", result.getUserInfo()); + assertEquals("java.sun.com", result.getHost()); + assertEquals(80, result.getPort()); + assertEquals("/javase/6/docs/api/java/util/BitSet.html", result.getPath()); + assertEquals("foo=bar", result.getQuery()); + MultiValueMap expectedQueryParams = new LinkedMultiValueMap<>(1); + expectedQueryParams.add("foo", "bar"); + assertEquals(expectedQueryParams, result.getQueryParams()); + assertEquals("and(java.util.BitSet)", result.getFragment()); + + result = UriComponentsBuilder.fromUriString("mailto:java-net@java.sun.com#baz").build(); + assertEquals("mailto", result.getScheme()); + assertNull(result.getUserInfo()); + assertNull(result.getHost()); + assertEquals(-1, result.getPort()); + assertEquals("java-net@java.sun.com", result.getSchemeSpecificPart()); + assertNull(result.getPath()); + assertNull(result.getQuery()); + assertEquals("baz", result.getFragment()); + + result = UriComponentsBuilder.fromUriString("docs/guide/collections/designfaq.html#28").build(); + assertNull(result.getScheme()); + assertNull(result.getUserInfo()); + assertNull(result.getHost()); + assertEquals(-1, result.getPort()); + assertEquals("docs/guide/collections/designfaq.html", result.getPath()); + assertNull(result.getQuery()); + assertEquals("28", result.getFragment()); + } + + @Test // SPR-9832 + public void fromUriStringQueryParamWithReservedCharInValue() { + String uri = "https://www.google.com/ig/calculator?q=1USD=?EUR"; + UriComponents result = UriComponentsBuilder.fromUriString(uri).build(); + + assertEquals("q=1USD=?EUR", result.getQuery()); + assertEquals("1USD=?EUR", result.getQueryParams().getFirst("q")); + } + + @Test // SPR-14828 + public void fromUriStringQueryParamEncodedAndContainingPlus() { + String httpUrl = "http://localhost:8080/test/print?value=%EA%B0%80+%EB%82%98"; + URI uri = UriComponentsBuilder.fromHttpUrl(httpUrl).build(true).toUri(); + + assertEquals(httpUrl, uri.toString()); + } + + @Test // SPR-10779 + public void fromHttpUrlStringCaseInsesitiveScheme() { + assertEquals("http", UriComponentsBuilder.fromHttpUrl("HTTP://www.google.com").build().getScheme()); + assertEquals("https", UriComponentsBuilder.fromHttpUrl("HTTPS://www.google.com").build().getScheme()); + } + + @Test(expected = IllegalArgumentException.class) // SPR-10539 + public void fromHttpUrlStringInvalidIPv6Host() { + UriComponentsBuilder.fromHttpUrl("http://[1abc:2abc:3abc::5ABC:6abc:8080/resource").build().encode(); + } + + @Test // SPR-10539 + public void fromUriStringIPv6Host() { + UriComponents result = UriComponentsBuilder + .fromUriString("http://[1abc:2abc:3abc::5ABC:6abc]:8080/resource").build().encode(); + assertEquals("[1abc:2abc:3abc::5ABC:6abc]", result.getHost()); + + UriComponents resultWithScopeId = UriComponentsBuilder + .fromUriString("http://[1abc:2abc:3abc::5ABC:6abc%eth0]:8080/resource").build().encode(); + assertEquals("[1abc:2abc:3abc::5ABC:6abc%25eth0]", resultWithScopeId.getHost()); + + UriComponents resultIPv4compatible = UriComponentsBuilder + .fromUriString("http://[::192.168.1.1]:8080/resource").build().encode(); + assertEquals("[::192.168.1.1]", resultIPv4compatible.getHost()); + } + + @Test // SPR-11970 + public void fromUriStringNoPathWithReservedCharInQuery() { + UriComponents result = UriComponentsBuilder.fromUriString("https://example.com?foo=bar@baz").build(); + assertTrue(StringUtils.isEmpty(result.getUserInfo())); + assertEquals("example.com", result.getHost()); + assertTrue(result.getQueryParams().containsKey("foo")); + assertEquals("bar@baz", result.getQueryParams().getFirst("foo")); + } + + @Test + public void fromHttpRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/path"); + request.setQueryString("a=1"); + + UriComponents result = UriComponentsBuilder.fromHttpRequest(new ServletServerHttpRequest(request)).build(); + assertEquals("http", result.getScheme()); + assertEquals("localhost", result.getHost()); + assertEquals(-1, result.getPort()); + assertEquals("/path", result.getPath()); + assertEquals("a=1", result.getQuery()); + } + + @Test // SPR-12771 + public void fromHttpRequestResetsPortBeforeSettingIt() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("X-Forwarded-Proto", "https"); + request.addHeader("X-Forwarded-Host", "84.198.58.199"); + request.addHeader("X-Forwarded-Port", 443); + request.setScheme("http"); + request.setServerName("example.com"); + request.setServerPort(80); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals(-1, result.getPort()); + assertEquals("/rest/mobile/users/1", result.getPath()); + } + + @Test // SPR-14761 + public void fromHttpRequestWithForwardedIPv4Host() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("https"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/mvc-showcase"); + request.addHeader("Forwarded", "host=192.168.0.1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https://192.168.0.1/mvc-showcase", result.toString()); + } + + @Test // SPR-14761 + public void fromHttpRequestWithForwardedIPv6() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/mvc-showcase"); + request.addHeader("Forwarded", "host=[1abc:2abc:3abc::5ABC:6abc]"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("http://[1abc:2abc:3abc::5ABC:6abc]/mvc-showcase", result.toString()); + } + + @Test // SPR-14761 + public void fromHttpRequestWithForwardedIPv6Host() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/mvc-showcase"); + request.addHeader("X-Forwarded-Host", "[1abc:2abc:3abc::5ABC:6abc]"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("http://[1abc:2abc:3abc::5ABC:6abc]/mvc-showcase", result.toString()); + } + + @Test // SPR-14761 + public void fromHttpRequestWithForwardedIPv6HostAndPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/mvc-showcase"); + request.addHeader("X-Forwarded-Host", "[1abc:2abc:3abc::5ABC:6abc]:8080"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("http://[1abc:2abc:3abc::5ABC:6abc]:8080/mvc-showcase", result.toString()); + } + + @Test + public void fromHttpRequestWithForwardedHost() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/mvc-showcase"); + request.addHeader("X-Forwarded-Host", "anotherHost"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("http://anotherHost/mvc-showcase", result.toString()); + } + + @Test // SPR-10701 + public void fromHttpRequestWithForwardedHostIncludingPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/mvc-showcase"); + request.addHeader("X-Forwarded-Host", "webtest.foo.bar.com:443"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("webtest.foo.bar.com", result.getHost()); + assertEquals(443, result.getPort()); + } + + @Test // SPR-11140 + public void fromHttpRequestWithForwardedHostMultiValuedHeader() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.addHeader("X-Forwarded-Host", "a.example.org, b.example.org, c.example.org"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("a.example.org", result.getHost()); + assertEquals(-1, result.getPort()); + } + + @Test // SPR-11855 + public void fromHttpRequestWithForwardedHostAndPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(8080); + request.addHeader("X-Forwarded-Host", "foobarhost"); + request.addHeader("X-Forwarded-Port", "9090"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("foobarhost", result.getHost()); + assertEquals(9090, result.getPort()); + } + + @Test // SPR-11872 + public void fromHttpRequestWithForwardedHostWithDefaultPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(10080); + request.addHeader("X-Forwarded-Host", "example.org"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("example.org", result.getHost()); + assertEquals(-1, result.getPort()); + } + + @Test // SPR-16262 + public void fromHttpRequestWithForwardedProtoWithDefaultPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("example.org"); + request.setServerPort(10080); + request.addHeader("X-Forwarded-Proto", "https"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("example.org", result.getHost()); + assertEquals(-1, result.getPort()); + } + + @Test // SPR-16863 + public void fromHttpRequestWithForwardedSsl() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("example.org"); + request.setServerPort(10080); + request.addHeader("X-Forwarded-Ssl", "on"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("example.org", result.getHost()); + assertEquals(-1, result.getPort()); + } + + @Test + public void fromHttpRequestWithForwardedHostWithForwardedScheme() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(10080); + request.addHeader("X-Forwarded-Host", "example.org"); + request.addHeader("X-Forwarded-Proto", "https"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("example.org", result.getHost()); + assertEquals("https", result.getScheme()); + assertEquals(-1, result.getPort()); + } + + @Test // SPR-12771 + public void fromHttpRequestWithForwardedProtoAndDefaultPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(80); + request.setRequestURI("/mvc-showcase"); + request.addHeader("X-Forwarded-Proto", "https"); + request.addHeader("X-Forwarded-Host", "84.198.58.199"); + request.addHeader("X-Forwarded-Port", "443"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https://84.198.58.199/mvc-showcase", result.toString()); + } + + @Test // SPR-12813 + public void fromHttpRequestWithForwardedPortMultiValueHeader() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(9090); + request.setRequestURI("/mvc-showcase"); + request.addHeader("X-Forwarded-Host", "a.example.org"); + request.addHeader("X-Forwarded-Port", "80,52022"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("http://a.example.org/mvc-showcase", result.toString()); + } + + @Test // SPR-12816 + public void fromHttpRequestWithForwardedProtoMultiValueHeader() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(8080); + request.setRequestURI("/mvc-showcase"); + request.addHeader("X-Forwarded-Host", "a.example.org"); + request.addHeader("X-Forwarded-Port", "443"); + request.addHeader("X-Forwarded-Proto", "https,https"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https://a.example.org/mvc-showcase", result.toString()); + } + + @Test // SPR-12742 + public void fromHttpRequestWithTrailingSlash() { + UriComponents before = UriComponentsBuilder.fromPath("/foo/").build(); + UriComponents after = UriComponentsBuilder.newInstance().uriComponents(before).build(); + assertEquals("/foo/", after.getPath()); + } + + @Test + public void path() { + UriComponentsBuilder builder = UriComponentsBuilder.fromPath("/foo/bar"); + UriComponents result = builder.build(); + + assertEquals("/foo/bar", result.getPath()); + assertEquals(Arrays.asList("foo", "bar"), result.getPathSegments()); + } + + @Test + public void pathSegments() { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + UriComponents result = builder.pathSegment("foo").pathSegment("bar").build(); + + assertEquals("/foo/bar", result.getPath()); + assertEquals(Arrays.asList("foo", "bar"), result.getPathSegments()); + } + + @Test + public void pathThenPath() { + UriComponentsBuilder builder = UriComponentsBuilder.fromPath("/foo/bar").path("ba/z"); + UriComponents result = builder.build().encode(); + + assertEquals("/foo/barba/z", result.getPath()); + assertEquals(Arrays.asList("foo", "barba", "z"), result.getPathSegments()); + } + + @Test + public void pathThenPathSegments() { + UriComponentsBuilder builder = UriComponentsBuilder.fromPath("/foo/bar").pathSegment("ba/z"); + UriComponents result = builder.build().encode(); + + assertEquals("/foo/bar/ba%2Fz", result.getPath()); + assertEquals(Arrays.asList("foo", "bar", "ba%2Fz"), result.getPathSegments()); + } + + @Test + public void pathSegmentsThenPathSegments() { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance().pathSegment("foo").pathSegment("bar"); + UriComponents result = builder.build(); + + assertEquals("/foo/bar", result.getPath()); + assertEquals(Arrays.asList("foo", "bar"), result.getPathSegments()); + } + + @Test + public void pathSegmentsThenPath() { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance().pathSegment("foo").path("/"); + UriComponents result = builder.build(); + + assertEquals("/foo/", result.getPath()); + assertEquals(Collections.singletonList("foo"), result.getPathSegments()); + } + + @Test + public void pathSegmentsSomeEmpty() { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance().pathSegment("", "foo", "", "bar"); + UriComponents result = builder.build(); + + assertEquals("/foo/bar", result.getPath()); + assertEquals(Arrays.asList("foo", "bar"), result.getPathSegments()); + } + + @Test // SPR-12398 + public void pathWithDuplicateSlashes() { + UriComponents uriComponents = UriComponentsBuilder.fromPath("/foo/////////bar").build(); + assertEquals("/foo/bar", uriComponents.getPath()); + } + + @Test + public void replacePath() { + UriComponentsBuilder builder = UriComponentsBuilder.fromUriString("https://www.ietf.org/rfc/rfc2396.txt"); + builder.replacePath("/rfc/rfc3986.txt"); + UriComponents result = builder.build(); + + assertEquals("https://www.ietf.org/rfc/rfc3986.txt", result.toUriString()); + + builder = UriComponentsBuilder.fromUriString("https://www.ietf.org/rfc/rfc2396.txt"); + builder.replacePath(null); + result = builder.build(); + + assertEquals("https://www.ietf.org", result.toUriString()); + } + + @Test + public void replaceQuery() { + UriComponentsBuilder builder = UriComponentsBuilder.fromUriString("https://example.com/foo?foo=bar&baz=qux"); + builder.replaceQuery("baz=42"); + UriComponents result = builder.build(); + + assertEquals("https://example.com/foo?baz=42", result.toUriString()); + + builder = UriComponentsBuilder.fromUriString("https://example.com/foo?foo=bar&baz=qux"); + builder.replaceQuery(null); + result = builder.build(); + + assertEquals("https://example.com/foo", result.toUriString()); + } + + @Test + public void queryParams() { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + UriComponents result = builder.queryParam("baz", "qux", 42).build(); + + assertEquals("baz=qux&baz=42", result.getQuery()); + MultiValueMap expectedQueryParams = new LinkedMultiValueMap<>(2); + expectedQueryParams.add("baz", "qux"); + expectedQueryParams.add("baz", "42"); + assertEquals(expectedQueryParams, result.getQueryParams()); + } + + @Test + public void emptyQueryParam() { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance(); + UriComponents result = builder.queryParam("baz").build(); + + assertEquals("baz", result.getQuery()); + MultiValueMap expectedQueryParams = new LinkedMultiValueMap<>(2); + expectedQueryParams.add("baz", null); + assertEquals(expectedQueryParams, result.getQueryParams()); + } + + @Test + public void replaceQueryParam() { + UriComponentsBuilder builder = UriComponentsBuilder.newInstance().queryParam("baz", "qux", 42); + builder.replaceQueryParam("baz", "xuq", 24); + UriComponents result = builder.build(); + + assertEquals("baz=xuq&baz=24", result.getQuery()); + + builder = UriComponentsBuilder.newInstance().queryParam("baz", "qux", 42); + builder.replaceQueryParam("baz"); + result = builder.build(); + + assertNull("Query param should have been deleted", result.getQuery()); + } + + @Test + public void buildAndExpandHierarchical() { + UriComponents result = UriComponentsBuilder.fromPath("/{foo}").buildAndExpand("fooValue"); + assertEquals("/fooValue", result.toUriString()); + + Map values = new HashMap<>(); + values.put("foo", "fooValue"); + values.put("bar", "barValue"); + result = UriComponentsBuilder.fromPath("/{foo}/{bar}").buildAndExpand(values); + assertEquals("/fooValue/barValue", result.toUriString()); + } + + @Test + public void buildAndExpandOpaque() { + UriComponents result = UriComponentsBuilder.fromUriString("mailto:{user}@{domain}") + .buildAndExpand("foo", "example.com"); + assertEquals("mailto:foo@example.com", result.toUriString()); + + Map values = new HashMap<>(); + values.put("user", "foo"); + values.put("domain", "example.com"); + UriComponentsBuilder.fromUriString("mailto:{user}@{domain}").buildAndExpand(values); + assertEquals("mailto:foo@example.com", result.toUriString()); + } + + @Test + public void queryParamWithValueWithEquals() { + UriComponents uriComponents = UriComponentsBuilder.fromUriString("https://example.com/foo?bar=baz").build(); + assertThat(uriComponents.toUriString(), equalTo("https://example.com/foo?bar=baz")); + assertThat(uriComponents.getQueryParams().get("bar").get(0), equalTo("baz")); + } + + @Test + public void queryParamWithoutValueWithEquals() { + UriComponents uriComponents = UriComponentsBuilder.fromUriString("https://example.com/foo?bar=").build(); + assertThat(uriComponents.toUriString(), equalTo("https://example.com/foo?bar=")); + assertThat(uriComponents.getQueryParams().get("bar").get(0), equalTo("")); + } + + @Test + public void queryParamWithoutValueWithoutEquals() { + UriComponents uriComponents = UriComponentsBuilder.fromUriString("https://example.com/foo?bar").build(); + assertThat(uriComponents.toUriString(), equalTo("https://example.com/foo?bar")); + + // TODO [SPR-13537] Change equalTo(null) to equalTo(""). + assertThat(uriComponents.getQueryParams().get("bar").get(0), equalTo(null)); + } + + @Test + public void relativeUrls() { + String baseUrl = "https://example.com"; + assertThat(UriComponentsBuilder.fromUriString(baseUrl + "/foo/../bar").build().toString(), + equalTo(baseUrl + "/foo/../bar")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl + "/foo/../bar").build().toUriString(), + equalTo(baseUrl + "/foo/../bar")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl + "/foo/../bar").build().toUri().getPath(), + equalTo("/foo/../bar")); + assertThat(UriComponentsBuilder.fromUriString("../../").build().toString(), + equalTo("../../")); + assertThat(UriComponentsBuilder.fromUriString("../../").build().toUriString(), + equalTo("../../")); + assertThat(UriComponentsBuilder.fromUriString("../../").build().toUri().getPath(), + equalTo("../../")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl).path("foo/../bar").build().toString(), + equalTo(baseUrl + "/foo/../bar")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl).path("foo/../bar").build().toUriString(), + equalTo(baseUrl + "/foo/../bar")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl).path("foo/../bar").build().toUri().getPath(), + equalTo("/foo/../bar")); + } + + @Test + public void emptySegments() { + String baseUrl = "https://example.com/abc/"; + assertThat(UriComponentsBuilder.fromUriString(baseUrl).path("/x/y/z").build().toString(), + equalTo("https://example.com/abc/x/y/z")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl).pathSegment("x", "y", "z").build().toString(), + equalTo("https://example.com/abc/x/y/z")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl).path("/x/").path("/y/z").build().toString(), + equalTo("https://example.com/abc/x/y/z")); + assertThat(UriComponentsBuilder.fromUriString(baseUrl).pathSegment("x").path("y").build().toString(), + equalTo("https://example.com/abc/x/y")); + } + + @Test + public void parsesEmptyFragment() { + UriComponents components = UriComponentsBuilder.fromUriString("/example#").build(); + assertThat(components.getFragment(), is(nullValue())); + assertThat(components.toString(), equalTo("/example")); + } + + @Test // SPR-13257 + public void parsesEmptyUri() { + UriComponents components = UriComponentsBuilder.fromUriString("").build(); + assertThat(components.toString(), equalTo("")); + } + + @Test // gh-25243 + public void testCloneAndMerge() { + UriComponentsBuilder builder1 = UriComponentsBuilder.newInstance(); + builder1.scheme("http").host("e1.com").path("/p1").pathSegment("ps1").queryParam("q1", "x").fragment("f1").encode(); + + UriComponentsBuilder builder2 = builder1.cloneBuilder(); + builder2.scheme("https").host("e2.com").path("p2").pathSegment("{ps2}").queryParam("q2").fragment("f2"); + + builder1.queryParam("q1", "y"); // one more entry for an existing parameter + + UriComponents result1 = builder1.build(); + assertEquals("http", result1.getScheme()); + assertEquals("e1.com", result1.getHost()); + assertEquals("/p1/ps1", result1.getPath()); + assertEquals("q1=x&q1=y", result1.getQuery()); + assertEquals("f1", result1.getFragment()); + + UriComponents result2 = builder2.buildAndExpand("ps2;a"); + assertEquals("https", result2.getScheme()); + assertEquals("e2.com", result2.getHost()); + assertEquals("/p1/ps1/p2/ps2%3Ba", result2.getPath()); + assertEquals("q1=x&q2", result2.getQuery()); + assertEquals("f2", result2.getFragment()); + } + + @Test // gh-24772 + public void testDeepClone() { + HashMap vars = new HashMap<>(); + vars.put("ps1", "foo"); + vars.put("ps2", "bar"); + + UriComponentsBuilder builder1 = UriComponentsBuilder.newInstance(); + builder1.scheme("http").host("e1.com").userInfo("user:pwd").path("/p1").pathSegment("{ps1}") + .pathSegment("{ps2}").queryParam("q1").fragment("f1").uriVariables(vars).encode(); + + UriComponentsBuilder builder2 = builder1.cloneBuilder(); + + UriComponents result1 = builder1.build(); + assertEquals("http", result1.getScheme()); + assertEquals("user:pwd", result1.getUserInfo()); + assertEquals("e1.com", result1.getHost()); + assertEquals("/p1/foo/bar", result1.getPath()); + assertEquals("q1", result1.getQuery()); + assertEquals("f1", result1.getFragment()); + assertNull(result1.getSchemeSpecificPart()); + + UriComponents result2 = builder2.build(); + assertEquals("http", result2.getScheme()); + assertEquals("user:pwd", result2.getUserInfo()); + assertEquals("e1.com", result2.getHost()); + assertEquals("/p1/foo/bar", result2.getPath()); + assertEquals("q1", result2.getQuery()); + assertEquals("f1", result2.getFragment()); + assertNull(result1.getSchemeSpecificPart()); + } + + @Test // SPR-11856 + public void fromHttpRequestForwardedHeader() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "proto=https; host=84.198.58.199"); + request.setScheme("http"); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + } + + @Test + public void fromHttpRequestForwardedHeaderQuoted() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "proto=\"https\"; host=\"84.198.58.199\""); + request.setScheme("http"); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + } + + @Test + public void fromHttpRequestMultipleForwardedHeader() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "host=84.198.58.199;proto=https"); + request.addHeader("Forwarded", "proto=ftp; host=1.2.3.4"); + request.setScheme("http"); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + } + + @Test + public void fromHttpRequestMultipleForwardedHeaderComma() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "host=84.198.58.199 ;proto=https, proto=ftp; host=1.2.3.4"); + request.setScheme("http"); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + } + + @Test + public void fromHttpRequestForwardedHeaderWithHostPortAndWithoutServerPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "proto=https; host=84.198.58.199:9090"); + request.setScheme("http"); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + assertEquals(9090, result.getPort()); + assertEquals("https://84.198.58.199:9090/rest/mobile/users/1", result.toUriString()); + } + + @Test + public void fromHttpRequestForwardedHeaderWithHostPortAndServerPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "proto=https; host=84.198.58.199:9090"); + request.setScheme("http"); + request.setServerPort(8080); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + assertEquals(9090, result.getPort()); + assertEquals("https://84.198.58.199:9090/rest/mobile/users/1", result.toUriString()); + } + + @Test + public void fromHttpRequestForwardedHeaderWithoutHostPortAndWithServerPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "proto=https; host=84.198.58.199"); + request.setScheme("http"); + request.setServerPort(8080); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("84.198.58.199", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + assertEquals(-1, result.getPort()); + assertEquals("https://84.198.58.199/rest/mobile/users/1", result.toUriString()); + } + + @Test // SPR-16262 + public void fromHttpRequestForwardedHeaderWithProtoAndServerPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "proto=https"); + request.setScheme("http"); + request.setServerPort(8080); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("example.com", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + assertEquals(-1, result.getPort()); + assertEquals("https://example.com/rest/mobile/users/1", result.toUriString()); + } + + @Test // SPR-16364 + public void uriComponentsNotEqualAfterNormalization() { + UriComponents uri1 = UriComponentsBuilder.fromUriString("http://test.com").build().normalize(); + UriComponents uri2 = UriComponentsBuilder.fromUriString("http://test.com/").build(); + + assertTrue(uri1.getPathSegments().isEmpty()); + assertTrue(uri2.getPathSegments().isEmpty()); + assertNotEquals(uri1, uri2); + } + + @Test // SPR-17256 + public void uriComponentsWithMergedQueryParams() { + String uri = UriComponentsBuilder.fromUriString("http://localhost:8081") + .uriComponents(UriComponentsBuilder.fromUriString("/{path}?sort={sort}").build()) + .queryParam("sort", "another_value").build().toString(); + + assertEquals("http://localhost:8081/{path}?sort={sort}&sort=another_value", uri); + } + + @Test // SPR-17630 + public void toUriStringWithCurlyBraces() { + assertEquals("/path?q=%7Basa%7Dasa", + UriComponentsBuilder.fromUriString("/path?q={asa}asa").toUriString()); + } +} diff --git a/spring-web/src/test/java/org/springframework/web/util/UriComponentsTests.java b/spring-web/src/test/java/org/springframework/web/util/UriComponentsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a5b589d18def544b5e3e28a69dfe0869fc168335 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/UriComponentsTests.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Collections; + +import org.junit.Test; + +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.springframework.web.util.UriComponentsBuilder.*; + +/** + * Unit tests for {@link UriComponents}. + * + * @author Arjen Poutsma + * @author Phillip Webb + * @author Rossen Stoyanchev + */ +public class UriComponentsTests { + + @Test + public void expandAndEncode() { + + UriComponents uri = UriComponentsBuilder + .fromPath("/hotel list/{city} specials").queryParam("q", "{value}").build() + .expand("Z\u00fcrich", "a+b").encode(); + + assertEquals("/hotel%20list/Z%C3%BCrich%20specials?q=a+b", uri.toString()); + } + + @Test + public void encodeAndExpand() { + + UriComponents uri = UriComponentsBuilder + .fromPath("/hotel list/{city} specials").queryParam("q", "{value}").encode().build() + .expand("Z\u00fcrich", "a+b"); + + assertEquals("/hotel%20list/Z%C3%BCrich%20specials?q=a%2Bb", uri.toString()); + } + + @Test + public void encodeAndExpandPartially() { + + UriComponents uri = UriComponentsBuilder + .fromPath("/hotel list/{city} specials").queryParam("q", "{value}").encode() + .uriVariables(Collections.singletonMap("city", "Z\u00fcrich")) + .build(); + + assertEquals("/hotel%20list/Z%C3%BCrich%20specials?q=a%2Bb", uri.expand("a+b").toString()); + } + + @Test // SPR-17168 + public void encodeAndExpandWithDollarSign() { + UriComponents uri = UriComponentsBuilder.fromPath("/path").queryParam("q", "{value}").encode().build(); + assertEquals("/path?q=JavaClass%241.class", uri.expand("JavaClass$1.class").toString()); + } + + @Test + public void toUriEncoded() throws URISyntaxException { + UriComponents uriComponents = UriComponentsBuilder.fromUriString( + "https://example.com/hotel list/Z\u00fcrich").build(); + assertEquals(new URI("https://example.com/hotel%20list/Z%C3%BCrich"), uriComponents.encode().toUri()); + } + + @Test + public void toUriNotEncoded() throws URISyntaxException { + UriComponents uriComponents = UriComponentsBuilder.fromUriString( + "https://example.com/hotel list/Z\u00fcrich").build(); + assertEquals(new URI("https://example.com/hotel%20list/Z\u00fcrich"), uriComponents.toUri()); + } + + @Test + public void toUriAlreadyEncoded() throws URISyntaxException { + UriComponents uriComponents = UriComponentsBuilder.fromUriString( + "https://example.com/hotel%20list/Z%C3%BCrich").build(true); + UriComponents encoded = uriComponents.encode(); + assertEquals(new URI("https://example.com/hotel%20list/Z%C3%BCrich"), encoded.toUri()); + } + + @Test + public void toUriWithIpv6HostAlreadyEncoded() throws URISyntaxException { + UriComponents uriComponents = UriComponentsBuilder.fromUriString( + "http://[1abc:2abc:3abc::5ABC:6abc]:8080/hotel%20list/Z%C3%BCrich").build(true); + UriComponents encoded = uriComponents.encode(); + assertEquals(new URI("http://[1abc:2abc:3abc::5ABC:6abc]:8080/hotel%20list/Z%C3%BCrich"), encoded.toUri()); + } + + @Test + public void expand() { + UriComponents uriComponents = UriComponentsBuilder.fromUriString( + "https://example.com").path("/{foo} {bar}").build(); + uriComponents = uriComponents.expand("1 2", "3 4"); + assertEquals("/1 2 3 4", uriComponents.getPath()); + assertEquals("https://example.com/1 2 3 4", uriComponents.toUriString()); + } + + @Test // SPR-13311 + public void expandWithRegexVar() { + String template = "/myurl/{name:[a-z]{1,5}}/show"; + UriComponents uriComponents = UriComponentsBuilder.fromUriString(template).build(); + uriComponents = uriComponents.expand(Collections.singletonMap("name", "test")); + assertEquals("/myurl/test/show", uriComponents.getPath()); + } + + @Test // SPR-17630 + public void uirTemplateExpandWithMismatchedCurlyBraces() { + assertEquals("/myurl/?q=%7B%7B%7B%7B", + UriComponentsBuilder.fromUriString("/myurl/?q={{{{").encode().build().toUriString()); + } + + @Test // SPR-12123 + public void port() { + UriComponents uri1 = fromUriString("https://example.com:8080/bar").build(); + UriComponents uri2 = fromUriString("https://example.com/bar").port(8080).build(); + UriComponents uri3 = fromUriString("https://example.com/bar").port("{port}").build().expand(8080); + UriComponents uri4 = fromUriString("https://example.com/bar").port("808{digit}").build().expand(0); + assertEquals(8080, uri1.getPort()); + assertEquals("https://example.com:8080/bar", uri1.toUriString()); + assertEquals(8080, uri2.getPort()); + assertEquals("https://example.com:8080/bar", uri2.toUriString()); + assertEquals(8080, uri3.getPort()); + assertEquals("https://example.com:8080/bar", uri3.toUriString()); + assertEquals(8080, uri4.getPort()); + assertEquals("https://example.com:8080/bar", uri4.toUriString()); + } + + @Test(expected = IllegalStateException.class) + public void expandEncoded() { + UriComponentsBuilder.fromPath("/{foo}").build().encode().expand("bar"); + } + + @Test(expected = IllegalArgumentException.class) + public void invalidCharacters() { + UriComponentsBuilder.fromPath("/{foo}").build(true); + } + + @Test(expected = IllegalArgumentException.class) + public void invalidEncodedSequence() { + UriComponentsBuilder.fromPath("/fo%2o").build(true); + } + + @Test + public void normalize() { + UriComponents uriComponents = UriComponentsBuilder.fromUriString("https://example.com/foo/../bar").build(); + assertEquals("https://example.com/bar", uriComponents.normalize().toString()); + } + + @Test + public void serializable() throws Exception { + UriComponents uriComponents = UriComponentsBuilder.fromUriString( + "https://example.com").path("/{foo}").query("bar={baz}").build(); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(uriComponents); + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bos.toByteArray())); + UriComponents readObject = (UriComponents) ois.readObject(); + assertThat(uriComponents.toString(), equalTo(readObject.toString())); + } + + @Test + public void copyToUriComponentsBuilder() { + UriComponents source = UriComponentsBuilder.fromPath("/foo/bar").pathSegment("ba/z").build(); + UriComponentsBuilder targetBuilder = UriComponentsBuilder.newInstance(); + source.copyToUriComponentsBuilder(targetBuilder); + UriComponents result = targetBuilder.build().encode(); + assertEquals("/foo/bar/ba%2Fz", result.getPath()); + assertEquals(Arrays.asList("foo", "bar", "ba%2Fz"), result.getPathSegments()); + } + + @Test + public void equalsHierarchicalUriComponents() { + String url = "https://example.com"; + UriComponents uric1 = UriComponentsBuilder.fromUriString(url).path("/{foo}").query("bar={baz}").build(); + UriComponents uric2 = UriComponentsBuilder.fromUriString(url).path("/{foo}").query("bar={baz}").build(); + UriComponents uric3 = UriComponentsBuilder.fromUriString(url).path("/{foo}").query("bin={baz}").build(); + assertThat(uric1, instanceOf(HierarchicalUriComponents.class)); + assertThat(uric1, equalTo(uric1)); + assertThat(uric1, equalTo(uric2)); + assertThat(uric1, not(equalTo(uric3))); + } + + @Test + public void equalsOpaqueUriComponents() { + String baseUrl = "http:example.com"; + UriComponents uric1 = UriComponentsBuilder.fromUriString(baseUrl + "/foo/bar").build(); + UriComponents uric2 = UriComponentsBuilder.fromUriString(baseUrl + "/foo/bar").build(); + UriComponents uric3 = UriComponentsBuilder.fromUriString(baseUrl + "/foo/bin").build(); + assertThat(uric1, instanceOf(OpaqueUriComponents.class)); + assertThat(uric1, equalTo(uric1)); + assertThat(uric1, equalTo(uric2)); + assertThat(uric1, not(equalTo(uric3))); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/UriTemplateTests.java b/spring-web/src/test/java/org/springframework/web/util/UriTemplateTests.java new file mode 100644 index 0000000000000000000000000000000000000000..cfbca2f9f31b769f3edc281268ed6e98b67f5095 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/UriTemplateTests.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Rossen Stoyanchev + */ +public class UriTemplateTests { + + @Test + public void getVariableNames() throws Exception { + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + List variableNames = template.getVariableNames(); + assertEquals("Invalid variable names", Arrays.asList("hotel", "booking"), variableNames); + } + + @Test + public void expandVarArgs() throws Exception { + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + URI result = template.expand("1", "42"); + assertEquals("Invalid expanded template", new URI("/hotels/1/bookings/42"), result); + } + + // SPR-9712 + + @Test + public void expandVarArgsWithArrayValue() throws Exception { + UriTemplate template = new UriTemplate("/sum?numbers={numbers}"); + URI result = template.expand(new int[] {1, 2, 3}); + assertEquals(new URI("/sum?numbers=1,2,3"), result); + } + + @Test(expected = IllegalArgumentException.class) + public void expandVarArgsNotEnoughVariables() throws Exception { + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + template.expand("1"); + } + + @Test + public void expandMap() throws Exception { + Map uriVariables = new HashMap<>(2); + uriVariables.put("booking", "42"); + uriVariables.put("hotel", "1"); + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + URI result = template.expand(uriVariables); + assertEquals("Invalid expanded template", new URI("/hotels/1/bookings/42"), result); + } + + @Test + public void expandMapDuplicateVariables() throws Exception { + UriTemplate template = new UriTemplate("/order/{c}/{c}/{c}"); + assertEquals(Arrays.asList("c", "c", "c"), template.getVariableNames()); + URI result = template.expand(Collections.singletonMap("c", "cheeseburger")); + assertEquals(new URI("/order/cheeseburger/cheeseburger/cheeseburger"), result); + } + + @Test + public void expandMapNonString() throws Exception { + Map uriVariables = new HashMap<>(2); + uriVariables.put("booking", 42); + uriVariables.put("hotel", 1); + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + URI result = template.expand(uriVariables); + assertEquals("Invalid expanded template", new URI("/hotels/1/bookings/42"), result); + } + + @Test + public void expandMapEncoded() throws Exception { + Map uriVariables = Collections.singletonMap("hotel", "Z\u00fcrich"); + UriTemplate template = new UriTemplate("/hotel list/{hotel}"); + URI result = template.expand(uriVariables); + assertEquals("Invalid expanded template", new URI("/hotel%20list/Z%C3%BCrich"), result); + } + + @Test(expected = IllegalArgumentException.class) + public void expandMapUnboundVariables() throws Exception { + Map uriVariables = new HashMap<>(2); + uriVariables.put("booking", "42"); + uriVariables.put("bar", "1"); + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + template.expand(uriVariables); + } + + @Test + public void expandEncoded() throws Exception { + UriTemplate template = new UriTemplate("/hotel list/{hotel}"); + URI result = template.expand("Z\u00fcrich"); + assertEquals("Invalid expanded template", new URI("/hotel%20list/Z%C3%BCrich"), result); + } + + @Test + public void matches() throws Exception { + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + assertTrue("UriTemplate does not match", template.matches("/hotels/1/bookings/42")); + assertFalse("UriTemplate matches", template.matches("/hotels/bookings")); + assertFalse("UriTemplate matches", template.matches("")); + assertFalse("UriTemplate matches", template.matches(null)); + } + + @Test + public void matchesCustomRegex() throws Exception { + UriTemplate template = new UriTemplate("/hotels/{hotel:\\d+}"); + assertTrue("UriTemplate does not match", template.matches("/hotels/42")); + assertFalse("UriTemplate matches", template.matches("/hotels/foo")); + } + + @Test + public void match() throws Exception { + Map expected = new HashMap<>(2); + expected.put("booking", "42"); + expected.put("hotel", "1"); + + UriTemplate template = new UriTemplate("/hotels/{hotel}/bookings/{booking}"); + Map result = template.match("/hotels/1/bookings/42"); + assertEquals("Invalid match", expected, result); + } + + @Test + public void matchCustomRegex() throws Exception { + Map expected = new HashMap<>(2); + expected.put("booking", "42"); + expected.put("hotel", "1"); + + UriTemplate template = new UriTemplate("/hotels/{hotel:\\d}/bookings/{booking:\\d+}"); + Map result = template.match("/hotels/1/bookings/42"); + assertEquals("Invalid match", expected, result); + } + + @Test // SPR-13627 + public void matchCustomRegexWithNestedCurlyBraces() throws Exception { + UriTemplate template = new UriTemplate("/site.{domain:co.[a-z]{2}}"); + Map result = template.match("/site.co.eu"); + assertEquals("Invalid match", Collections.singletonMap("domain", "co.eu"), result); + } + + @Test + public void matchDuplicate() throws Exception { + UriTemplate template = new UriTemplate("/order/{c}/{c}/{c}"); + Map result = template.match("/order/cheeseburger/cheeseburger/cheeseburger"); + Map expected = Collections.singletonMap("c", "cheeseburger"); + assertEquals("Invalid match", expected, result); + } + + @Test + public void matchMultipleInOneSegment() throws Exception { + UriTemplate template = new UriTemplate("/{foo}-{bar}"); + Map result = template.match("/12-34"); + Map expected = new HashMap<>(2); + expected.put("foo", "12"); + expected.put("bar", "34"); + assertEquals("Invalid match", expected, result); + } + + @Test // SPR-16169 + public void matchWithMultipleSegmentsAtTheEnd() { + UriTemplate template = new UriTemplate("/account/{accountId}"); + assertFalse(template.matches("/account/15/alias/5")); + } + + @Test + public void queryVariables() throws Exception { + UriTemplate template = new UriTemplate("/search?q={query}"); + assertTrue(template.matches("/search?q=foo")); + } + + @Test + public void fragments() throws Exception { + UriTemplate template = new UriTemplate("/search#{fragment}"); + assertTrue(template.matches("/search#foo")); + + template = new UriTemplate("/search?query={query}#{fragment}"); + assertTrue(template.matches("/search?query=foo#bar")); + } + + @Test // SPR-13705 + public void matchesWithSlashAtTheEnd() { + UriTemplate uriTemplate = new UriTemplate("/test/"); + assertTrue(uriTemplate.matches("/test/")); + } + + @Test + public void expandWithDollar() { + UriTemplate template = new UriTemplate("/{a}"); + URI uri = template.expand("$replacement"); + assertEquals("/$replacement", uri.toString()); + } + + @Test + public void expandWithAtSign() { + UriTemplate template = new UriTemplate("http://localhost/query={query}"); + URI uri = template.expand("foo@bar"); + assertEquals("http://localhost/query=foo@bar", uri.toString()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/UriUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/UriUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..86a66bff8d2dd9be0a4f7566abb4d2fbb01a2477 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/UriUtilsTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Med Belamachi + */ +public class UriUtilsTests { + + private static final Charset CHARSET = StandardCharsets.UTF_8; + + + @Test + public void encodeScheme() { + assertEquals("Invalid encoded result", "foobar+-.", UriUtils.encodeScheme("foobar+-.", CHARSET)); + assertEquals("Invalid encoded result", "foo%20bar", UriUtils.encodeScheme("foo bar", CHARSET)); + } + + @Test + public void encodeUserInfo() { + assertEquals("Invalid encoded result", "foobar:", UriUtils.encodeUserInfo("foobar:", CHARSET)); + assertEquals("Invalid encoded result", "foo%20bar", UriUtils.encodeUserInfo("foo bar", CHARSET)); + } + + @Test + public void encodeHost() { + assertEquals("Invalid encoded result", "foobar", UriUtils.encodeHost("foobar", CHARSET)); + assertEquals("Invalid encoded result", "foo%20bar", UriUtils.encodeHost("foo bar", CHARSET)); + } + + @Test + public void encodePort() { + assertEquals("Invalid encoded result", "80", UriUtils.encodePort("80", CHARSET)); + } + + @Test + public void encodePath() { + assertEquals("Invalid encoded result", "/foo/bar", UriUtils.encodePath("/foo/bar", CHARSET)); + assertEquals("Invalid encoded result", "/foo%20bar", UriUtils.encodePath("/foo bar", CHARSET)); + assertEquals("Invalid encoded result", "/Z%C3%BCrich", UriUtils.encodePath("/Z\u00fcrich", CHARSET)); + } + + @Test + public void encodePathSegment() { + assertEquals("Invalid encoded result", "foobar", UriUtils.encodePathSegment("foobar", CHARSET)); + assertEquals("Invalid encoded result", "%2Ffoo%2Fbar", UriUtils.encodePathSegment("/foo/bar", CHARSET)); + } + + @Test + public void encodeQuery() { + assertEquals("Invalid encoded result", "foobar", UriUtils.encodeQuery("foobar", CHARSET)); + assertEquals("Invalid encoded result", "foo%20bar", UriUtils.encodeQuery("foo bar", CHARSET)); + assertEquals("Invalid encoded result", "foobar/+", UriUtils.encodeQuery("foobar/+", CHARSET)); + assertEquals("Invalid encoded result", "T%C5%8Dky%C5%8D", UriUtils.encodeQuery("T\u014dky\u014d", CHARSET)); + } + + @Test + public void encodeQueryParam() { + assertEquals("Invalid encoded result", "foobar", UriUtils.encodeQueryParam("foobar", CHARSET)); + assertEquals("Invalid encoded result", "foo%20bar", UriUtils.encodeQueryParam("foo bar", CHARSET)); + assertEquals("Invalid encoded result", "foo%26bar", UriUtils.encodeQueryParam("foo&bar", CHARSET)); + } + + @Test + public void encodeFragment() { + assertEquals("Invalid encoded result", "foobar", UriUtils.encodeFragment("foobar", CHARSET)); + assertEquals("Invalid encoded result", "foo%20bar", UriUtils.encodeFragment("foo bar", CHARSET)); + assertEquals("Invalid encoded result", "foobar/", UriUtils.encodeFragment("foobar/", CHARSET)); + } + + @Test + public void encode() { + assertEquals("Invalid encoded result", "foo", UriUtils.encode("foo", CHARSET)); + assertEquals("Invalid encoded result", "https%3A%2F%2Fexample.com%2Ffoo%20bar", + UriUtils.encode("https://example.com/foo bar", CHARSET)); + } + + @Test + public void decode() { + assertEquals("Invalid encoded URI", "", UriUtils.decode("", CHARSET)); + assertEquals("Invalid encoded URI", "foobar", UriUtils.decode("foobar", CHARSET)); + assertEquals("Invalid encoded URI", "foo bar", UriUtils.decode("foo%20bar", CHARSET)); + assertEquals("Invalid encoded URI", "foo+bar", UriUtils.decode("foo%2bbar", CHARSET)); + assertEquals("Invalid encoded result", "T\u014dky\u014d", UriUtils.decode("T%C5%8Dky%C5%8D", CHARSET)); + assertEquals("Invalid encoded result", "/Z\u00fcrich", UriUtils.decode("/Z%C3%BCrich", CHARSET)); + assertEquals("Invalid encoded result", "T\u014dky\u014d", UriUtils.decode("T\u014dky\u014d", CHARSET)); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeInvalidSequence() { + UriUtils.decode("foo%2", CHARSET); + } + + @Test + public void extractFileExtension() { + assertEquals("html", UriUtils.extractFileExtension("index.html")); + assertEquals("html", UriUtils.extractFileExtension("/index.html")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html#/a")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html#/path/a")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html#/path/a.do")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html#aaa?bbb")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html#aaa.xml?bbb")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html?param=a")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html?param=/path/a")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html?param=/path/a.do")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html?param=/path/a#/path/a")); + assertEquals("html", UriUtils.extractFileExtension("/products/view.html?param=/path/a.do#/path/a.do")); + assertEquals("html", UriUtils.extractFileExtension("/products;q=11/view.html?param=/path/a.do")); + assertEquals("html", UriUtils.extractFileExtension("/products;q=11/view.html;r=22?param=/path/a.do")); + assertEquals("html", UriUtils.extractFileExtension("/products;q=11/view.html;r=22;s=33?param=/path/a.do")); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/UrlPathHelperTests.java b/spring-web/src/test/java/org/springframework/web/util/UrlPathHelperTests.java new file mode 100644 index 0000000000000000000000000000000000000000..c4b096501e2824050d0f5ab40d4982aa407eff13 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/UrlPathHelperTests.java @@ -0,0 +1,446 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.io.UnsupportedEncodingException; + +import org.junit.Ignore; +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link UrlPathHelper}. + * + * @author Rob Harrop + * @author Juergen Hoeller + * @author Costin Leau + */ +public class UrlPathHelperTests { + + private static final String WEBSPHERE_URI_ATTRIBUTE = "com.ibm.websphere.servlet.uri_non_decoded"; + + private final UrlPathHelper helper = new UrlPathHelper(); + + private final MockHttpServletRequest request = new MockHttpServletRequest(); + + + @Test + public void getPathWithinApplication() { + request.setContextPath("/petclinic"); + request.setRequestURI("/petclinic/welcome.html"); + + assertEquals("Incorrect path returned", "/welcome.html", helper.getPathWithinApplication(request)); + } + + @Test + public void getPathWithinApplicationForRootWithNoLeadingSlash() { + request.setContextPath("/petclinic"); + request.setRequestURI("/petclinic"); + + assertEquals("Incorrect root path returned", "/", helper.getPathWithinApplication(request)); + } + + @Test + public void getPathWithinApplicationForSlashContextPath() { + request.setContextPath("/"); + request.setRequestURI("/welcome.html"); + + assertEquals("Incorrect path returned", "/welcome.html", helper.getPathWithinApplication(request)); + } + + @Test + public void getPathWithinServlet() { + request.setContextPath("/petclinic"); + request.setServletPath("/main"); + request.setRequestURI("/petclinic/main/welcome.html"); + + assertEquals("Incorrect path returned", "/welcome.html", helper.getPathWithinServletMapping(request)); + } + + @Test + public void alwaysUseFullPath() { + helper.setAlwaysUseFullPath(true); + request.setContextPath("/petclinic"); + request.setServletPath("/main"); + request.setRequestURI("/petclinic/main/welcome.html"); + + assertEquals("Incorrect path returned", "/main/welcome.html", helper.getLookupPathForRequest(request)); + } + + // SPR-11101 + + @Test + public void getPathWithinServletWithoutUrlDecoding() { + request.setContextPath("/SPR-11101"); + request.setServletPath("/test_url_decoding/a/b"); + request.setRequestURI("/test_url_decoding/a%2Fb"); + + helper.setUrlDecode(false); + String actual = helper.getPathWithinServletMapping(request); + + assertEquals("/test_url_decoding/a%2Fb", actual); + } + + @Test + public void getRequestUri() { + request.setRequestURI("/welcome.html"); + assertEquals("Incorrect path returned", "/welcome.html", helper.getRequestUri(request)); + + request.setRequestURI("/foo%20bar"); + assertEquals("Incorrect path returned", "/foo bar", helper.getRequestUri(request)); + + request.setRequestURI("/foo+bar"); + assertEquals("Incorrect path returned", "/foo+bar", helper.getRequestUri(request)); + } + + @Test + public void getRequestRemoveSemicolonContent() throws UnsupportedEncodingException { + helper.setRemoveSemicolonContent(true); + + request.setRequestURI("/foo;f=F;o=O;o=O/bar;b=B;a=A;r=R"); + assertEquals("/foo/bar", helper.getRequestUri(request)); + + // SPR-13455 + + request.setServletPath("/foo/1"); + request.setRequestURI("/foo/;test/1"); + + assertEquals("/foo/1", helper.getRequestUri(request)); + } + + @Test + public void getRequestKeepSemicolonContent() { + helper.setRemoveSemicolonContent(false); + + testKeepSemicolonContent("/foo;a=b;c=d", "/foo;a=b;c=d"); + testKeepSemicolonContent("/test;jsessionid=1234", "/test"); + testKeepSemicolonContent("/test;JSESSIONID=1234", "/test"); + testKeepSemicolonContent("/test;jsessionid=1234;a=b", "/test;a=b"); + testKeepSemicolonContent("/test;a=b;jsessionid=1234;c=d", "/test;a=b;c=d"); + testKeepSemicolonContent("/test;jsessionid=1234/anotherTest", "/test/anotherTest"); + testKeepSemicolonContent("/test;jsessionid=;a=b", "/test;a=b"); + testKeepSemicolonContent("/somethingLongerThan12;jsessionid=1234", "/somethingLongerThan12"); + } + + private void testKeepSemicolonContent(String requestUri, String expectedPath) { + request.setRequestURI(requestUri); + assertEquals(expectedPath, helper.getRequestUri(request)); + } + + @Test + public void getLookupPathWithSemicolonContent() { + helper.setRemoveSemicolonContent(false); + + request.setContextPath("/petclinic"); + request.setServletPath("/main"); + request.setRequestURI("/petclinic;a=b/main;b=c/welcome.html;c=d"); + + assertEquals("/welcome.html;c=d", helper.getLookupPathForRequest(request)); + } + + @Test + public void getLookupPathWithSemicolonContentAndNullPathInfo() { + helper.setRemoveSemicolonContent(false); + + request.setContextPath("/petclinic"); + request.setServletPath("/welcome.html"); + request.setRequestURI("/petclinic;a=b/welcome.html;c=d"); + + assertEquals("/welcome.html;c=d", helper.getLookupPathForRequest(request)); + } + + + // + // suite of tests root requests for default servlets (SRV 11.2) on Websphere vs Tomcat and other containers + // see: https://jira.springframework.org/browse/SPR-7064 + // + + + // + // / mapping (default servlet) + // + + @Test + public void tomcatDefaultServletRoot() throws Exception { + request.setContextPath("/test"); + request.setPathInfo(null); + request.setServletPath("/"); + request.setRequestURI("/test/"); + assertEquals("/", helper.getLookupPathForRequest(request)); + } + + @Test + public void tomcatDefaultServletFile() throws Exception { + request.setContextPath("/test"); + request.setPathInfo(null); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo"); + + assertEquals("/foo", helper.getLookupPathForRequest(request)); + } + + @Test + public void tomcatDefaultServletFolder() throws Exception { + request.setContextPath("/test"); + request.setPathInfo(null); + request.setServletPath("/foo/"); + request.setRequestURI("/test/foo/"); + + assertEquals("/foo/", helper.getLookupPathForRequest(request)); + } + + //SPR-12372 & SPR-13455 + @Test + public void removeDuplicateSlashesInPath() throws Exception { + request.setContextPath("/SPR-12372"); + request.setPathInfo(null); + request.setServletPath("/foo/bar/"); + request.setRequestURI("/SPR-12372/foo//bar/"); + + assertEquals("/foo/bar/", helper.getLookupPathForRequest(request)); + + request.setServletPath("/foo/bar/"); + request.setRequestURI("/SPR-12372/foo/bar//"); + + assertEquals("/foo/bar/", helper.getLookupPathForRequest(request)); + + // "normal" case + request.setServletPath("/foo/bar//"); + request.setRequestURI("/SPR-12372/foo/bar//"); + + assertEquals("/foo/bar//", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasDefaultServletRoot() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/"); + request.setServletPath(""); + request.setRequestURI("/test/"); + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/"); + + assertEquals("/", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasDefaultServletRootWithCompliantSetting() throws Exception { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/"); + tomcatDefaultServletRoot(); + } + + @Test + public void wasDefaultServletFile() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/foo"); + request.setServletPath(""); + request.setRequestURI("/test/foo"); + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo"); + + assertEquals("/foo", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasDefaultServletFileWithCompliantSetting() throws Exception { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo"); + tomcatDefaultServletFile(); + } + + @Test + public void wasDefaultServletFolder() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/foo/"); + request.setServletPath(""); + request.setRequestURI("/test/foo/"); + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/"); + + assertEquals("/foo/", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasDefaultServletFolderWithCompliantSetting() throws Exception { + UrlPathHelper.websphereComplianceFlag = true; + try { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/"); + tomcatDefaultServletFolder(); + } + finally { + UrlPathHelper.websphereComplianceFlag = false; + } + } + + + // + // /foo/* mapping + // + + @Test + public void tomcatCasualServletRoot() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/"); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo/"); + + assertEquals("/", helper.getLookupPathForRequest(request)); + } + + // test the root mapping for /foo/* w/o a trailing slash - //foo + @Test @Ignore + public void tomcatCasualServletRootWithMissingSlash() throws Exception { + request.setContextPath("/test"); + request.setPathInfo(null); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo"); + + assertEquals("/", helper.getLookupPathForRequest(request)); + } + + @Test + public void tomcatCasualServletFile() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/foo"); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo/foo"); + + assertEquals("/foo", helper.getLookupPathForRequest(request)); + } + + @Test + public void tomcatCasualServletFolder() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/foo/"); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo/foo/"); + + assertEquals("/foo/", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasCasualServletRoot() throws Exception { + request.setContextPath("/test"); + request.setPathInfo(null); + request.setServletPath("/foo/"); + request.setRequestURI("/test/foo/"); + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/"); + + assertEquals("/", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasCasualServletRootWithCompliantSetting() throws Exception { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/"); + tomcatCasualServletRoot(); + } + + // test the root mapping for /foo/* w/o a trailing slash - //foo + @Ignore + @Test + public void wasCasualServletRootWithMissingSlash() throws Exception { + request.setContextPath("/test"); + request.setPathInfo(null); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo"); + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo"); + + assertEquals("/", helper.getLookupPathForRequest(request)); + } + + @Ignore + @Test + public void wasCasualServletRootWithMissingSlashWithCompliantSetting() throws Exception { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo"); + tomcatCasualServletRootWithMissingSlash(); + } + + @Test + public void wasCasualServletFile() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/foo"); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo/foo"); + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/foo"); + + assertEquals("/foo", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasCasualServletFileWithCompliantSetting() throws Exception { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/foo"); + tomcatCasualServletFile(); + } + + @Test + public void wasCasualServletFolder() throws Exception { + request.setContextPath("/test"); + request.setPathInfo("/foo/"); + request.setServletPath("/foo"); + request.setRequestURI("/test/foo/foo/"); + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/foo/"); + + assertEquals("/foo/", helper.getLookupPathForRequest(request)); + } + + @Test + public void wasCasualServletFolderWithCompliantSetting() throws Exception { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/test/foo/foo/"); + tomcatCasualServletFolder(); + } + + @Test + public void getOriginatingRequestUri() { + request.setAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE, "/path"); + request.setRequestURI("/forwarded"); + assertEquals("/path", helper.getOriginatingRequestUri(request)); + } + + @Test + public void getOriginatingRequestUriWebsphere() { + request.setAttribute(WEBSPHERE_URI_ATTRIBUTE, "/path"); + request.setRequestURI("/forwarded"); + assertEquals("/path", helper.getOriginatingRequestUri(request)); + } + + @Test + public void getOriginatingRequestUriDefault() { + request.setRequestURI("/forwarded"); + assertEquals("/forwarded", helper.getOriginatingRequestUri(request)); + } + + @Test + public void getOriginatingQueryString() { + request.setQueryString("forward=on"); + request.setAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE, "/path"); + request.setAttribute(WebUtils.FORWARD_QUERY_STRING_ATTRIBUTE, "original=on"); + assertEquals("original=on", this.helper.getOriginatingQueryString(request)); + } + + @Test + public void getOriginatingQueryStringNotPresent() { + request.setQueryString("forward=true"); + assertEquals("forward=true", this.helper.getOriginatingQueryString(request)); + } + + @Test + public void getOriginatingQueryStringIsNull() { + request.setQueryString("forward=true"); + request.setAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE, "/path"); + assertNull(this.helper.getOriginatingQueryString(request)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java new file mode 100644 index 0000000000000000000000000000000000000000..3a419ba6681b890d5f89519e9cb5d895713097cb --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -0,0 +1,254 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.mock.web.test.MockFilterChain; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.util.MultiValueMap; +import org.springframework.web.filter.ForwardedHeaderFilter; + +import static org.junit.Assert.*; + +/** + * @author Juergen Hoeller + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + */ +public class WebUtilsTests { + + @Test + public void findParameterValue() { + Map params = new HashMap<>(); + params.put("myKey1", "myValue1"); + params.put("myKey2_myValue2", "xxx"); + params.put("myKey3_myValue3.x", "xxx"); + params.put("myKey4_myValue4.y", new String[] {"yyy"}); + + assertNull(WebUtils.findParameterValue(params, "myKey0")); + assertEquals("myValue1", WebUtils.findParameterValue(params, "myKey1")); + assertEquals("myValue2", WebUtils.findParameterValue(params, "myKey2")); + assertEquals("myValue3", WebUtils.findParameterValue(params, "myKey3")); + assertEquals("myValue4", WebUtils.findParameterValue(params, "myKey4")); + } + + @Test + public void parseMatrixVariablesString() { + MultiValueMap variables; + + variables = WebUtils.parseMatrixVariables(null); + assertEquals(0, variables.size()); + + variables = WebUtils.parseMatrixVariables("year"); + assertEquals(1, variables.size()); + assertEquals("", variables.getFirst("year")); + + variables = WebUtils.parseMatrixVariables("year=2012"); + assertEquals(1, variables.size()); + assertEquals("2012", variables.getFirst("year")); + + variables = WebUtils.parseMatrixVariables("year=2012;colors=red,blue,green"); + assertEquals(2, variables.size()); + assertEquals(Arrays.asList("red", "blue", "green"), variables.get("colors")); + assertEquals("2012", variables.getFirst("year")); + + variables = WebUtils.parseMatrixVariables(";year=2012;colors=red,blue,green;"); + assertEquals(2, variables.size()); + assertEquals(Arrays.asList("red", "blue", "green"), variables.get("colors")); + assertEquals("2012", variables.getFirst("year")); + + variables = WebUtils.parseMatrixVariables("colors=red;colors=blue;colors=green"); + assertEquals(1, variables.size()); + assertEquals(Arrays.asList("red", "blue", "green"), variables.get("colors")); + + variables = WebUtils.parseMatrixVariables("jsessionid=c0o7fszeb1"); + assertTrue(variables.isEmpty()); + + variables = WebUtils.parseMatrixVariables("a=b;jsessionid=c0o7fszeb1;c=d"); + assertEquals(2, variables.size()); + assertEquals(Collections.singletonList("b"), variables.get("a")); + assertEquals(Collections.singletonList("d"), variables.get("c")); + + variables = WebUtils.parseMatrixVariables("a=b;jsessionid=c0o7fszeb1;c=d"); + assertEquals(2, variables.size()); + assertEquals(Collections.singletonList("b"), variables.get("a")); + assertEquals(Collections.singletonList("d"), variables.get("c")); + } + + @Test + public void isValidOrigin() { + List allowed = Collections.emptyList(); + assertTrue(checkValidOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed)); + assertFalse(checkValidOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); + + allowed = Collections.singletonList("*"); + assertTrue(checkValidOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed)); + + allowed = Collections.singletonList("http://mydomain1.com"); + assertTrue(checkValidOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed)); + assertFalse(checkValidOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed)); + } + + @Test + public void isSameOrigin() { + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com:80")); + assertTrue(checkSameOrigin("https", "mydomain1.com", 443, "https://mydomain1.com")); + assertTrue(checkSameOrigin("https", "mydomain1.com", 443, "https://mydomain1.com:443")); + assertTrue(checkSameOrigin("http", "mydomain1.com", 123, "http://mydomain1.com:123")); + assertTrue(checkSameOrigin("ws", "mydomain1.com", -1, "ws://mydomain1.com")); + assertTrue(checkSameOrigin("wss", "mydomain1.com", 443, "wss://mydomain1.com")); + + assertFalse(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain2.com")); + assertFalse(checkSameOrigin("http", "mydomain1.com", -1, "https://mydomain1.com")); + assertFalse(checkSameOrigin("http", "mydomain1.com", -1, "invalid-origin")); + assertFalse(checkSameOrigin("https", "mydomain1.com", -1, "https://mydomain1.com")); + + // Handling of invalid origins as described in SPR-13478 + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com/")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com:80")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com/path")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com:80/path")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com/")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com:80/")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com/path")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com:80/path")); + + // Handling of IPv6 hosts as described in SPR-13525 + assertTrue(checkSameOrigin("http", "[::1]", -1, "http://[::1]")); + assertTrue(checkSameOrigin("http", "[::1]", 8080, "http://[::1]:8080")); + assertTrue(checkSameOrigin("http", + "[2001:0db8:0000:85a3:0000:0000:ac1f:8001]", -1, + "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]")); + assertTrue(checkSameOrigin("http", + "[2001:0db8:0000:85a3:0000:0000:ac1f:8001]", 8080, + "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]:8080")); + assertFalse(checkSameOrigin("http", "[::1]", -1, "http://[::1]:8080")); + assertFalse(checkSameOrigin("http", "[::1]", 8080, + "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]:8080")); + } + + @Test // SPR-16262 + public void isSameOriginWithXForwardedHeaders() throws Exception { + String server = "mydomain1.com"; + testWithXForwardedHeaders(server, -1, "https", null, -1, "https://mydomain1.com"); + testWithXForwardedHeaders(server, 123, "https", null, -1, "https://mydomain1.com"); + testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", -1, "https://mydomain2.com"); + testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", -1, "https://mydomain2.com"); + testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", 456, "https://mydomain2.com:456"); + testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", 456, "https://mydomain2.com:456"); + } + + @Test // SPR-16262 + public void isSameOriginWithForwardedHeader() throws Exception { + String server = "mydomain1.com"; + testWithForwardedHeader(server, -1, "proto=https", "https://mydomain1.com"); + testWithForwardedHeader(server, 123, "proto=https", "https://mydomain1.com"); + testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com", "https://mydomain2.com"); + testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com", "https://mydomain2.com"); + testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456"); + testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456"); + } + + + private boolean checkValidOrigin(String serverName, int port, String originHeader, List allowed) { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + servletRequest.setServerName(serverName); + if (port != -1) { + servletRequest.setServerPort(port); + } + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); + return WebUtils.isValidOrigin(request, allowed); + } + + private boolean checkSameOrigin(String scheme, String serverName, int port, String originHeader) { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + servletRequest.setScheme(scheme); + servletRequest.setServerName(serverName); + if (port != -1) { + servletRequest.setServerPort(port); + } + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); + return WebUtils.isSameOrigin(request); + } + + private void testWithXForwardedHeaders(String serverName, int port, String forwardedProto, + String forwardedHost, int forwardedPort, String originHeader) throws Exception { + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setServerName(serverName); + if (port != -1) { + request.setServerPort(port); + } + if (forwardedProto != null) { + request.addHeader("X-Forwarded-Proto", forwardedProto); + } + if (forwardedHost != null) { + request.addHeader("X-Forwarded-Host", forwardedHost); + } + if (forwardedPort != -1) { + request.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort)); + } + request.addHeader(HttpHeaders.ORIGIN, originHeader); + + HttpServletRequest requestToUse = adaptFromForwardedHeaders(request); + ServerHttpRequest httpRequest = new ServletServerHttpRequest(requestToUse); + + assertTrue(WebUtils.isSameOrigin(httpRequest)); + } + + private void testWithForwardedHeader(String serverName, int port, String forwardedHeader, + String originHeader) throws Exception { + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setServerName(serverName); + if (port != -1) { + request.setServerPort(port); + } + request.addHeader("Forwarded", forwardedHeader); + request.addHeader(HttpHeaders.ORIGIN, originHeader); + + HttpServletRequest requestToUse = adaptFromForwardedHeaders(request); + ServerHttpRequest httpRequest = new ServletServerHttpRequest(requestToUse); + + assertTrue(WebUtils.isSameOrigin(httpRequest)); + } + + // SPR-16668 + private HttpServletRequest adaptFromForwardedHeaders(HttpServletRequest request) throws Exception { + MockFilterChain chain = new MockFilterChain(); + new ForwardedHeaderFilter().doFilter(request, new MockHttpServletResponse(), chain); + return (HttpServletRequest) chain.getRequest(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/pattern/PathPatternParserTests.java b/spring-web/src/test/java/org/springframework/web/util/pattern/PathPatternParserTests.java new file mode 100644 index 0000000000000000000000000000000000000000..a7a044cbdbad8da8b2c3cdb3eaa85ad02fd30e0b --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/pattern/PathPatternParserTests.java @@ -0,0 +1,486 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.Test; + +import org.springframework.http.server.PathContainer; +import org.springframework.web.util.pattern.PatternParseException.PatternMessage; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Exercise the {@link PathPatternParser}. + * @author Andy Clement + */ +public class PathPatternParserTests { + + private PathPattern pathPattern; + + @Test + public void basicPatterns() { + checkStructure("/"); + checkStructure("/foo"); + checkStructure("foo"); + checkStructure("foo/"); + checkStructure("/foo/"); + checkStructure(""); + } + + @Test + public void singleCharWildcardPatterns() { + pathPattern = checkStructure("?"); + assertPathElements(pathPattern, SingleCharWildcardedPathElement.class); + checkStructure("/?/"); + checkStructure("/?abc?/"); + } + + @Test + public void multiwildcardPattern() { + pathPattern = checkStructure("/**"); + assertPathElements(pathPattern, WildcardTheRestPathElement.class); + // this is not double wildcard, it's / then **acb (an odd, unnecessary use of double *) + pathPattern = checkStructure("/**acb"); + assertPathElements(pathPattern, SeparatorPathElement.class, RegexPathElement.class); + } + + @Test + public void toStringTests() { + assertEquals("CaptureTheRest(/{*foobar})", checkStructure("/{*foobar}").toChainString()); + assertEquals("CaptureVariable({foobar})", checkStructure("{foobar}").toChainString()); + assertEquals("Literal(abc)", checkStructure("abc").toChainString()); + assertEquals("Regex({a}_*_{b})", checkStructure("{a}_*_{b}").toChainString()); + assertEquals("Separator(/)", checkStructure("/").toChainString()); + assertEquals("SingleCharWildcarded(?a?b?c)", checkStructure("?a?b?c").toChainString()); + assertEquals("Wildcard(*)", checkStructure("*").toChainString()); + assertEquals("WildcardTheRest(/**)", checkStructure("/**").toChainString()); + } + + @Test + public void captureTheRestPatterns() { + pathPattern = parse("{*foobar}"); + assertEquals("/{*foobar}", pathPattern.computePatternString()); + assertPathElements(pathPattern, CaptureTheRestPathElement.class); + pathPattern = checkStructure("/{*foobar}"); + assertPathElements(pathPattern, CaptureTheRestPathElement.class); + checkError("/{*foobar}/", 10, PatternMessage.NO_MORE_DATA_EXPECTED_AFTER_CAPTURE_THE_REST); + checkError("/{*foobar}abc", 10, PatternMessage.NO_MORE_DATA_EXPECTED_AFTER_CAPTURE_THE_REST); + checkError("/{*f%obar}", 4, PatternMessage.ILLEGAL_CHARACTER_IN_CAPTURE_DESCRIPTOR); + checkError("/{*foobar}abc", 10, PatternMessage.NO_MORE_DATA_EXPECTED_AFTER_CAPTURE_THE_REST); + checkError("/{f*oobar}", 3, PatternMessage.ILLEGAL_CHARACTER_IN_CAPTURE_DESCRIPTOR); + checkError("/{*foobar}/abc", 10, PatternMessage.NO_MORE_DATA_EXPECTED_AFTER_CAPTURE_THE_REST); + checkError("/{*foobar:.*}/abc", 9, PatternMessage.ILLEGAL_CHARACTER_IN_CAPTURE_DESCRIPTOR); + checkError("/{abc}{*foobar}", 1, PatternMessage.CAPTURE_ALL_IS_STANDALONE_CONSTRUCT); + checkError("/{abc}{*foobar}{foo}", 15, PatternMessage.NO_MORE_DATA_EXPECTED_AFTER_CAPTURE_THE_REST); + } + + @Test + public void equalsAndHashcode() { + PathPatternParser caseInsensitiveParser = new PathPatternParser(); + caseInsensitiveParser.setCaseSensitive(false); + PathPatternParser caseSensitiveParser = new PathPatternParser(); + PathPattern pp1 = caseInsensitiveParser.parse("/abc"); + PathPattern pp2 = caseInsensitiveParser.parse("/abc"); + PathPattern pp3 = caseInsensitiveParser.parse("/def"); + assertEquals(pp1, pp2); + assertEquals(pp1.hashCode(), pp2.hashCode()); + assertNotEquals(pp1, pp3); + assertFalse(pp1.equals("abc")); + + pp1 = caseInsensitiveParser.parse("/abc"); + pp2 = caseSensitiveParser.parse("/abc"); + assertFalse(pp1.equals(pp2)); + assertNotEquals(pp1.hashCode(), pp2.hashCode()); + } + + @Test + public void regexPathElementPatterns() { + checkError("/{var:[^/]*}", 8, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("/{var:abc", 8, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("/{var:a{{1,2}}}", 6, PatternMessage.REGEX_PATTERN_SYNTAX_EXCEPTION); + + pathPattern = checkStructure("/{var:\\\\}"); + PathElement next = pathPattern.getHeadSection().next; + assertEquals(CaptureVariablePathElement.class.getName(), next.getClass().getName()); + assertMatches(pathPattern,"/\\"); + + pathPattern = checkStructure("/{var:\\/}"); + next = pathPattern.getHeadSection().next; + assertEquals(CaptureVariablePathElement.class.getName(), next.getClass().getName()); + assertNoMatch(pathPattern,"/aaa"); + + pathPattern = checkStructure("/{var:a{1,2}}"); + next = pathPattern.getHeadSection().next; + assertEquals(CaptureVariablePathElement.class.getName(), next.getClass().getName()); + + pathPattern = checkStructure("/{var:[^\\/]*}"); + next = pathPattern.getHeadSection().next; + assertEquals(CaptureVariablePathElement.class.getName(), next.getClass().getName()); + PathPattern.PathMatchInfo result = matchAndExtract(pathPattern,"/foo"); + assertEquals("foo", result.getUriVariables().get("var")); + + pathPattern = checkStructure("/{var:\\[*}"); + next = pathPattern.getHeadSection().next; + assertEquals(CaptureVariablePathElement.class.getName(), next.getClass().getName()); + result = matchAndExtract(pathPattern,"/[[["); + assertEquals("[[[", result.getUriVariables().get("var")); + + pathPattern = checkStructure("/{var:[\\{]*}"); + next = pathPattern.getHeadSection().next; + assertEquals(CaptureVariablePathElement.class.getName(), next.getClass().getName()); + result = matchAndExtract(pathPattern,"/{{{"); + assertEquals("{{{", result.getUriVariables().get("var")); + + pathPattern = checkStructure("/{var:[\\}]*}"); + next = pathPattern.getHeadSection().next; + assertEquals(CaptureVariablePathElement.class.getName(), next.getClass().getName()); + result = matchAndExtract(pathPattern,"/}}}"); + assertEquals("}}}", result.getUriVariables().get("var")); + + pathPattern = checkStructure("*"); + assertEquals(WildcardPathElement.class.getName(), pathPattern.getHeadSection().getClass().getName()); + checkStructure("/*"); + checkStructure("/*/"); + checkStructure("*/"); + checkStructure("/*/"); + pathPattern = checkStructure("/*a*/"); + next = pathPattern.getHeadSection().next; + assertEquals(RegexPathElement.class.getName(), next.getClass().getName()); + pathPattern = checkStructure("*/"); + assertEquals(WildcardPathElement.class.getName(), pathPattern.getHeadSection().getClass().getName()); + checkError("{foo}_{foo}", 0, PatternMessage.ILLEGAL_DOUBLE_CAPTURE, "foo"); + checkError("/{bar}/{bar}", 7, PatternMessage.ILLEGAL_DOUBLE_CAPTURE, "bar"); + checkError("/{bar}/{bar}_{foo}", 7, PatternMessage.ILLEGAL_DOUBLE_CAPTURE, "bar"); + + pathPattern = checkStructure("{symbolicName:[\\p{L}\\.]+}-sources-{version:[\\p{N}\\.]+}.jar"); + assertEquals(RegexPathElement.class.getName(), pathPattern.getHeadSection().getClass().getName()); + } + + @Test + public void completeCapturingPatterns() { + pathPattern = checkStructure("{foo}"); + assertEquals(CaptureVariablePathElement.class.getName(), pathPattern.getHeadSection().getClass().getName()); + checkStructure("/{foo}"); + checkStructure("/{f}/"); + checkStructure("/{foo}/{bar}/{wibble}"); + checkStructure("/{mobile-number}"); // gh-23101 + } + + @Test + public void noEncoding() { + // Check no encoding of expressions or constraints + PathPattern pp = parse("/{var:f o}"); + assertEquals("Separator(/) CaptureVariable({var:f o})",pp.toChainString()); + + pp = parse("/{var:f o}_"); + assertEquals("Separator(/) Regex({var:f o}_)",pp.toChainString()); + + pp = parse("{foo:f o}_ _{bar:b\\|o}"); + assertEquals("Regex({foo:f o}_ _{bar:b\\|o})",pp.toChainString()); + } + + @Test + public void completeCaptureWithConstraints() { + pathPattern = checkStructure("{foo:...}"); + assertPathElements(pathPattern, CaptureVariablePathElement.class); + pathPattern = checkStructure("{foo:[0-9]*}"); + assertPathElements(pathPattern, CaptureVariablePathElement.class); + checkError("{foo:}", 5, PatternMessage.MISSING_REGEX_CONSTRAINT); + } + + @Test + public void partialCapturingPatterns() { + pathPattern = checkStructure("{foo}abc"); + assertEquals(RegexPathElement.class.getName(), pathPattern.getHeadSection().getClass().getName()); + checkStructure("abc{foo}"); + checkStructure("/abc{foo}"); + checkStructure("{foo}def/"); + checkStructure("/abc{foo}def/"); + checkStructure("{foo}abc{bar}"); + checkStructure("{foo}abc{bar}/"); + checkStructure("/{foo}abc{bar}/"); + } + + @Test + public void illegalCapturePatterns() { + checkError("{abc/", 4, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("{abc:}/", 5, PatternMessage.MISSING_REGEX_CONSTRAINT); + checkError("{", 1, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("{abc", 4, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("{/}", 1, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("/{", 2, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("}", 0, PatternMessage.MISSING_OPEN_CAPTURE); + checkError("/}", 1, PatternMessage.MISSING_OPEN_CAPTURE); + checkError("def}", 3, PatternMessage.MISSING_OPEN_CAPTURE); + checkError("/{/}", 2, PatternMessage.MISSING_CLOSE_CAPTURE); + checkError("/{{/}", 2, PatternMessage.ILLEGAL_NESTED_CAPTURE); + checkError("/{abc{/}", 5, PatternMessage.ILLEGAL_NESTED_CAPTURE); + checkError("/{0abc}/abc", 2, PatternMessage.ILLEGAL_CHARACTER_AT_START_OF_CAPTURE_DESCRIPTOR); + checkError("/{a?bc}/abc", 3, PatternMessage.ILLEGAL_CHARACTER_IN_CAPTURE_DESCRIPTOR); + checkError("/{abc}_{abc}", 1, PatternMessage.ILLEGAL_DOUBLE_CAPTURE); + checkError("/foobar/{abc}_{abc}", 8, PatternMessage.ILLEGAL_DOUBLE_CAPTURE); + checkError("/foobar/{abc:..}_{abc:..}", 8, PatternMessage.ILLEGAL_DOUBLE_CAPTURE); + PathPattern pp = parse("/{abc:foo(bar)}"); + try { + pp.matchAndExtract(toPSC("/foo")); + fail("Should have raised exception"); + } + catch (IllegalArgumentException iae) { + assertEquals("No capture groups allowed in the constraint regex: foo(bar)", iae.getMessage()); + } + try { + pp.matchAndExtract(toPSC("/foobar")); + fail("Should have raised exception"); + } + catch (IllegalArgumentException iae) { + assertEquals("No capture groups allowed in the constraint regex: foo(bar)", iae.getMessage()); + } + } + + @Test + public void badPatterns() { +// checkError("/{foo}{bar}/",6,PatternMessage.CANNOT_HAVE_ADJACENT_CAPTURES); + checkError("/{?}/", 2, PatternMessage.ILLEGAL_CHARACTER_AT_START_OF_CAPTURE_DESCRIPTOR, "?"); + checkError("/{a?b}/", 3, PatternMessage.ILLEGAL_CHARACTER_IN_CAPTURE_DESCRIPTOR, "?"); + checkError("/{%%$}", 2, PatternMessage.ILLEGAL_CHARACTER_AT_START_OF_CAPTURE_DESCRIPTOR, "%"); + checkError("/{ }", 2, PatternMessage.ILLEGAL_CHARACTER_AT_START_OF_CAPTURE_DESCRIPTOR, " "); + checkError("/{%:[0-9]*}", 2, PatternMessage.ILLEGAL_CHARACTER_AT_START_OF_CAPTURE_DESCRIPTOR, "%"); + } + + @Test + public void patternPropertyGetCaptureCountTests() { + // Test all basic section types + assertEquals(1, parse("{foo}").getCapturedVariableCount()); + assertEquals(0, parse("foo").getCapturedVariableCount()); + assertEquals(1, parse("{*foobar}").getCapturedVariableCount()); + assertEquals(1, parse("/{*foobar}").getCapturedVariableCount()); + assertEquals(0, parse("/**").getCapturedVariableCount()); + assertEquals(1, parse("{abc}asdf").getCapturedVariableCount()); + assertEquals(1, parse("{abc}_*").getCapturedVariableCount()); + assertEquals(2, parse("{abc}_{def}").getCapturedVariableCount()); + assertEquals(0, parse("/").getCapturedVariableCount()); + assertEquals(0, parse("a?b").getCapturedVariableCount()); + assertEquals(0, parse("*").getCapturedVariableCount()); + + // Test on full templates + assertEquals(0, parse("/foo/bar").getCapturedVariableCount()); + assertEquals(1, parse("/{foo}").getCapturedVariableCount()); + assertEquals(2, parse("/{foo}/{bar}").getCapturedVariableCount()); + assertEquals(4, parse("/{foo}/{bar}_{goo}_{wibble}/abc/bar").getCapturedVariableCount()); + } + + @Test + public void patternPropertyGetWildcardCountTests() { + // Test all basic section types + assertEquals(computeScore(1, 0), parse("{foo}").getScore()); + assertEquals(computeScore(0, 0), parse("foo").getScore()); + assertEquals(computeScore(0, 0), parse("{*foobar}").getScore()); +// assertEquals(1,parse("/**").getScore()); + assertEquals(computeScore(1, 0), parse("{abc}asdf").getScore()); + assertEquals(computeScore(1, 1), parse("{abc}_*").getScore()); + assertEquals(computeScore(2, 0), parse("{abc}_{def}").getScore()); + assertEquals(computeScore(0, 0), parse("/").getScore()); + assertEquals(computeScore(0, 0), parse("a?b").getScore()); // currently deliberate + assertEquals(computeScore(0, 1), parse("*").getScore()); + + // Test on full templates + assertEquals(computeScore(0, 0), parse("/foo/bar").getScore()); + assertEquals(computeScore(1, 0), parse("/{foo}").getScore()); + assertEquals(computeScore(2, 0), parse("/{foo}/{bar}").getScore()); + assertEquals(computeScore(4, 0), parse("/{foo}/{bar}_{goo}_{wibble}/abc/bar").getScore()); + assertEquals(computeScore(4, 3), parse("/{foo}/*/*_*/{bar}_{goo}_{wibble}/abc/bar").getScore()); + } + + @Test + public void multipleSeparatorPatterns() { + pathPattern = checkStructure("///aaa"); + assertEquals(6, pathPattern.getNormalizedLength()); + assertPathElements(pathPattern, SeparatorPathElement.class, SeparatorPathElement.class, + SeparatorPathElement.class, LiteralPathElement.class); + pathPattern = checkStructure("///aaa////aaa/b"); + assertEquals(15, pathPattern.getNormalizedLength()); + assertPathElements(pathPattern, SeparatorPathElement.class, SeparatorPathElement.class, + SeparatorPathElement.class, LiteralPathElement.class, SeparatorPathElement.class, + SeparatorPathElement.class, SeparatorPathElement.class, SeparatorPathElement.class, + LiteralPathElement.class, SeparatorPathElement.class, LiteralPathElement.class); + pathPattern = checkStructure("/////**"); + assertEquals(5, pathPattern.getNormalizedLength()); + assertPathElements(pathPattern, SeparatorPathElement.class, SeparatorPathElement.class, + SeparatorPathElement.class, SeparatorPathElement.class, WildcardTheRestPathElement.class); + } + + @Test + public void patternPropertyGetLengthTests() { + // Test all basic section types + assertEquals(1, parse("{foo}").getNormalizedLength()); + assertEquals(3, parse("foo").getNormalizedLength()); + assertEquals(1, parse("{*foobar}").getNormalizedLength()); + assertEquals(1, parse("/{*foobar}").getNormalizedLength()); + assertEquals(1, parse("/**").getNormalizedLength()); + assertEquals(5, parse("{abc}asdf").getNormalizedLength()); + assertEquals(3, parse("{abc}_*").getNormalizedLength()); + assertEquals(3, parse("{abc}_{def}").getNormalizedLength()); + assertEquals(1, parse("/").getNormalizedLength()); + assertEquals(3, parse("a?b").getNormalizedLength()); + assertEquals(1, parse("*").getNormalizedLength()); + + // Test on full templates + assertEquals(8, parse("/foo/bar").getNormalizedLength()); + assertEquals(2, parse("/{foo}").getNormalizedLength()); + assertEquals(4, parse("/{foo}/{bar}").getNormalizedLength()); + assertEquals(16, parse("/{foo}/{bar}_{goo}_{wibble}/abc/bar").getNormalizedLength()); + } + + @Test + public void compareTests() { + PathPattern p1, p2, p3; + + // Based purely on number of captures + p1 = parse("{a}"); + p2 = parse("{a}/{b}"); + p3 = parse("{a}/{b}/{c}"); + assertEquals(-1, p1.compareTo(p2)); // Based on number of captures + List patterns = new ArrayList<>(); + patterns.add(p2); + patterns.add(p3); + patterns.add(p1); + Collections.sort(patterns); + assertEquals(p1, patterns.get(0)); + + // Based purely on length + p1 = parse("/a/b/c"); + p2 = parse("/a/boo/c/doo"); + p3 = parse("/asdjflaksjdfjasdf"); + assertEquals(1, p1.compareTo(p2)); + patterns = new ArrayList<>(); + patterns.add(p2); + patterns.add(p3); + patterns.add(p1); + Collections.sort(patterns); + assertEquals(p3, patterns.get(0)); + + // Based purely on 'wildness' + p1 = parse("/*"); + p2 = parse("/*/*"); + p3 = parse("/*/*/*_*"); + assertEquals(-1, p1.compareTo(p2)); + patterns = new ArrayList<>(); + patterns.add(p2); + patterns.add(p3); + patterns.add(p1); + Collections.sort(patterns); + assertEquals(p1, patterns.get(0)); + + // Based purely on catchAll + p1 = parse("{*foobar}"); + p2 = parse("{*goo}"); + assertTrue(p1.compareTo(p2) != 0); + + p1 = parse("/{*foobar}"); + p2 = parse("/abc/{*ww}"); + assertEquals(+1, p1.compareTo(p2)); + assertEquals(-1, p2.compareTo(p1)); + + p3 = parse("/this/that/theother"); + assertTrue(p1.isCatchAll()); + assertTrue(p2.isCatchAll()); + assertFalse(p3.isCatchAll()); + patterns = new ArrayList<>(); + patterns.add(p2); + patterns.add(p3); + patterns.add(p1); + Collections.sort(patterns); + assertEquals(p3, patterns.get(0)); + assertEquals(p2, patterns.get(1)); + } + + private PathPattern parse(String pattern) { + PathPatternParser patternParser = new PathPatternParser(); + return patternParser.parse(pattern); + } + + /** + * Verify the pattern string computed for a parsed pattern matches the original pattern text + */ + private PathPattern checkStructure(String pattern) { + PathPattern pp = parse(pattern); + assertEquals(pattern, pp.computePatternString()); + return pp; + } + + private void checkError(String pattern, int expectedPos, PatternMessage expectedMessage, + String... expectedInserts) { + + try { + pathPattern = parse(pattern); + fail("Expected to fail"); + } + catch (PatternParseException ppe) { + assertEquals(ppe.toDetailedString(), expectedPos, ppe.getPosition()); + assertEquals(ppe.toDetailedString(), expectedMessage, ppe.getMessageType()); + if (expectedInserts.length != 0) { + assertEquals(ppe.getInserts().length, expectedInserts.length); + for (int i = 0; i < expectedInserts.length; i++) { + assertEquals("Insert at position " + i + " is wrong", expectedInserts[i], ppe.getInserts()[i]); + } + } + } + } + + @SafeVarargs + private final void assertPathElements(PathPattern p, Class... sectionClasses) { + PathElement head = p.getHeadSection(); + for (Class sectionClass : sectionClasses) { + if (head == null) { + fail("Ran out of data in parsed pattern. Pattern is: " + p.toChainString()); + } + assertEquals("Not expected section type. Pattern is: " + p.toChainString(), + sectionClass.getSimpleName(), head.getClass().getSimpleName()); + head = head.next; + } + } + + // Mirrors the score computation logic in PathPattern + private int computeScore(int capturedVariableCount, int wildcardCount) { + return capturedVariableCount + wildcardCount * 100; + } + + private void assertMatches(PathPattern pp, String path) { + assertTrue(pp.matches(PathPatternTests.toPathContainer(path))); + } + + private void assertNoMatch(PathPattern pp, String path) { + assertFalse(pp.matches(PathPatternTests.toPathContainer(path))); + } + + private PathPattern.PathMatchInfo matchAndExtract(PathPattern pp, String path) { + return pp.matchAndExtract(PathPatternTests.toPathContainer(path)); + } + + private PathContainer toPSC(String path) { + return PathPatternTests.toPathContainer(path); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/util/pattern/PathPatternTests.java b/spring-web/src/test/java/org/springframework/web/util/pattern/PathPatternTests.java new file mode 100644 index 0000000000000000000000000000000000000000..13bf9248483e9c8b661cf15ad251579f277ca697 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/util/pattern/PathPatternTests.java @@ -0,0 +1,1230 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.util.pattern; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.PathContainer.Element; +import org.springframework.util.AntPathMatcher; +import org.springframework.web.util.pattern.PathPattern.PathRemainingMatchInfo; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Exercise matching of {@link PathPattern} objects. + * + * @author Andy Clement + */ +public class PathPatternTests { + + @Test + public void pathContainer() { + assertEquals("[/][abc][/][def]",elementsToString(toPathContainer("/abc/def").elements())); + assertEquals("[abc][/][def]",elementsToString(toPathContainer("abc/def").elements())); + assertEquals("[abc][/][def][/]",elementsToString(toPathContainer("abc/def/").elements())); + assertEquals("[abc][/][/][def][/][/]",elementsToString(toPathContainer("abc//def//").elements())); + assertEquals("[/]",elementsToString(toPathContainer("/").elements())); + assertEquals("[/][/][/]",elementsToString(toPathContainer("///").elements())); + } + + @Test + public void matching_LiteralPathElement() { + checkMatches("foo", "foo"); + checkNoMatch("foo", "bar"); + checkNoMatch("foo", "/foo"); + checkNoMatch("/foo", "foo"); + checkMatches("/f", "/f"); + checkMatches("/foo", "/foo"); + checkNoMatch("/foo", "/food"); + checkNoMatch("/food", "/foo"); + checkMatches("/foo/", "/foo/"); + checkMatches("/foo/bar/woo", "/foo/bar/woo"); + checkMatches("foo/bar/woo", "foo/bar/woo"); + } + + @Test + public void basicMatching() { + checkMatches("", ""); + checkMatches("", "/"); + checkMatches("", null); + checkNoMatch("/abc", "/"); + checkMatches("/", "/"); + checkNoMatch("/", "/a"); + checkMatches("foo/bar/", "foo/bar/"); + checkNoMatch("foo", "foobar"); + checkMatches("/foo/bar", "/foo/bar"); + checkNoMatch("/foo/bar", "/foo/baz"); + } + + private void assertMatches(PathPattern pp, String path) { + assertTrue(pp.matches(toPathContainer(path))); + } + + private void assertNoMatch(PathPattern pp, String path) { + assertFalse(pp.matches(toPathContainer(path))); + } + + @Test + public void optionalTrailingSeparators() { + PathPattern pp; + // LiteralPathElement + pp = parse("/resource"); + assertMatches(pp,"/resource"); + assertMatches(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + pp = parse("/resource/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + + pp = parse("res?urce"); + assertNoMatch(pp,"resource//"); + // SingleCharWildcardPathElement + pp = parse("/res?urce"); + assertMatches(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + pp = parse("/res?urce/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + + // CaptureVariablePathElement + pp = parse("/{var}"); + assertMatches(pp,"/resource"); + assertEquals("resource",pp.matchAndExtract(toPathContainer("/resource")).getUriVariables().get("var")); + assertMatches(pp,"/resource/"); + assertEquals("resource",pp.matchAndExtract(toPathContainer("/resource/")).getUriVariables().get("var")); + assertNoMatch(pp,"/resource//"); + pp = parse("/{var}/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertEquals("resource",pp.matchAndExtract(toPathContainer("/resource/")).getUriVariables().get("var")); + assertNoMatch(pp,"/resource//"); + + // CaptureTheRestPathElement + pp = parse("/{*var}"); + assertMatches(pp,"/resource"); + assertEquals("/resource",pp.matchAndExtract(toPathContainer("/resource")).getUriVariables().get("var")); + assertMatches(pp,"/resource/"); + assertEquals("/resource/",pp.matchAndExtract(toPathContainer("/resource/")).getUriVariables().get("var")); + assertMatches(pp,"/resource//"); + assertEquals("/resource//",pp.matchAndExtract(toPathContainer("/resource//")).getUriVariables().get("var")); + assertMatches(pp,"//resource//"); + assertEquals("//resource//",pp.matchAndExtract(toPathContainer("//resource//")).getUriVariables().get("var")); + + // WildcardTheRestPathElement + pp = parse("/**"); + assertMatches(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertMatches(pp,"/resource//"); + assertMatches(pp,"//resource//"); + + // WildcardPathElement + pp = parse("/*"); + assertMatches(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + pp = parse("/*/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + + // RegexPathElement + pp = parse("/{var1}_{var2}"); + assertMatches(pp,"/res1_res2"); + assertEquals("res1",pp.matchAndExtract(toPathContainer("/res1_res2")).getUriVariables().get("var1")); + assertEquals("res2",pp.matchAndExtract(toPathContainer("/res1_res2")).getUriVariables().get("var2")); + assertMatches(pp,"/res1_res2/"); + assertEquals("res1",pp.matchAndExtract(toPathContainer("/res1_res2/")).getUriVariables().get("var1")); + assertEquals("res2",pp.matchAndExtract(toPathContainer("/res1_res2/")).getUriVariables().get("var2")); + assertNoMatch(pp,"/res1_res2//"); + pp = parse("/{var1}_{var2}/"); + assertNoMatch(pp,"/res1_res2"); + assertMatches(pp,"/res1_res2/"); + assertEquals("res1",pp.matchAndExtract(toPathContainer("/res1_res2/")).getUriVariables().get("var1")); + assertEquals("res2",pp.matchAndExtract(toPathContainer("/res1_res2/")).getUriVariables().get("var2")); + assertNoMatch(pp,"/res1_res2//"); + pp = parse("/{var1}*"); + assertMatches(pp,"/a"); + assertMatches(pp,"/a/"); + assertNoMatch(pp,"/"); // no characters for var1 + assertNoMatch(pp,"//"); // no characters for var1 + + // Now with trailing matching turned OFF + PathPatternParser parser = new PathPatternParser(); + parser.setMatchOptionalTrailingSeparator(false); + // LiteralPathElement + pp = parser.parse("/resource"); + assertMatches(pp,"/resource"); + assertNoMatch(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + pp = parser.parse("/resource/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + + // SingleCharWildcardPathElement + pp = parser.parse("/res?urce"); + assertMatches(pp,"/resource"); + assertNoMatch(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + pp = parser.parse("/res?urce/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + + // CaptureVariablePathElement + pp = parser.parse("/{var}"); + assertMatches(pp,"/resource"); + assertEquals("resource",pp.matchAndExtract(toPathContainer("/resource")).getUriVariables().get("var")); + assertNoMatch(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + pp = parser.parse("/{var}/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertEquals("resource",pp.matchAndExtract(toPathContainer("/resource/")).getUriVariables().get("var")); + assertNoMatch(pp,"/resource//"); + + // CaptureTheRestPathElement + pp = parser.parse("/{*var}"); + assertMatches(pp,"/resource"); + assertEquals("/resource",pp.matchAndExtract(toPathContainer("/resource")).getUriVariables().get("var")); + assertMatches(pp,"/resource/"); + assertEquals("/resource/",pp.matchAndExtract(toPathContainer("/resource/")).getUriVariables().get("var")); + assertMatches(pp,"/resource//"); + assertEquals("/resource//",pp.matchAndExtract(toPathContainer("/resource//")).getUriVariables().get("var")); + assertMatches(pp,"//resource//"); + assertEquals("//resource//",pp.matchAndExtract(toPathContainer("//resource//")).getUriVariables().get("var")); + + // WildcardTheRestPathElement + pp = parser.parse("/**"); + assertMatches(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertMatches(pp,"/resource//"); + assertMatches(pp,"//resource//"); + + // WildcardPathElement + pp = parser.parse("/*"); + assertMatches(pp,"/resource"); + assertNoMatch(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + pp = parser.parse("/*/"); + assertNoMatch(pp,"/resource"); + assertMatches(pp,"/resource/"); + assertNoMatch(pp,"/resource//"); + + // RegexPathElement + pp = parser.parse("/{var1}_{var2}"); + assertMatches(pp,"/res1_res2"); + assertEquals("res1",pp.matchAndExtract(toPathContainer("/res1_res2")).getUriVariables().get("var1")); + assertEquals("res2",pp.matchAndExtract(toPathContainer("/res1_res2")).getUriVariables().get("var2")); + assertNoMatch(pp,"/res1_res2/"); + assertNoMatch(pp,"/res1_res2//"); + pp = parser.parse("/{var1}_{var2}/"); + assertNoMatch(pp,"/res1_res2"); + assertMatches(pp,"/res1_res2/"); + assertEquals("res1",pp.matchAndExtract(toPathContainer("/res1_res2/")).getUriVariables().get("var1")); + assertEquals("res2",pp.matchAndExtract(toPathContainer("/res1_res2/")).getUriVariables().get("var2")); + assertNoMatch(pp,"/res1_res2//"); + pp = parser.parse("/{var1}*"); + assertMatches(pp,"/a"); + assertNoMatch(pp,"/a/"); + assertNoMatch(pp,"/"); // no characters for var1 + assertNoMatch(pp,"//"); // no characters for var1 + } + + @Test + public void pathRemainderBasicCases_spr15336() { + // Cover all PathElement kinds + assertEquals("/bar", getPathRemaining("/foo","/foo/bar").getPathRemaining().value()); + assertEquals("/", getPathRemaining("/foo","/foo/").getPathRemaining().value()); + assertEquals("/bar",getPathRemaining("/foo*","/foo/bar").getPathRemaining().value()); + assertEquals("/bar", getPathRemaining("/*","/foo/bar").getPathRemaining().value()); + assertEquals("/bar", getPathRemaining("/{foo}","/foo/bar").getPathRemaining().value()); + assertNull(getPathRemaining("/foo","/bar/baz")); + assertEquals("",getPathRemaining("/**","/foo/bar").getPathRemaining().value()); + assertEquals("",getPathRemaining("/{*bar}","/foo/bar").getPathRemaining().value()); + assertEquals("/bar",getPathRemaining("/a?b/d?e","/aab/dde/bar").getPathRemaining().value()); + assertEquals("/bar",getPathRemaining("/{abc}abc","/xyzabc/bar").getPathRemaining().value()); + assertEquals("/bar",getPathRemaining("/*y*","/xyzxyz/bar").getPathRemaining().value()); + assertEquals("",getPathRemaining("/","/").getPathRemaining().value()); + assertEquals("a",getPathRemaining("/","/a").getPathRemaining().value()); + assertEquals("a/",getPathRemaining("/","/a/").getPathRemaining().value()); + assertEquals("/bar",getPathRemaining("/a{abc}","/a/bar").getPathRemaining().value()); + assertEquals("/bar", getPathRemaining("/foo//","/foo///bar").getPathRemaining().value()); + } + + @Test + public void encodingAndBoundVariablesCapturePathElement() { + checkCapture("{var}","f%20o","var","f o"); + checkCapture("{var1}/{var2}","f%20o/f%7Co","var1","f o","var2","f|o"); + checkCapture("{var1}/{var2}","f%20o/f%7co","var1","f o","var2","f|o"); // lower case encoding + checkCapture("{var:foo}","foo","var","foo"); + checkCapture("{var:f o}","f%20o","var","f o"); // constraint is expressed in non encoded form + checkCapture("{var:f.o}","f%20o","var","f o"); + checkCapture("{var:f\\|o}","f%7co","var","f|o"); + } + + @Test + public void encodingAndBoundVariablesCaptureTheRestPathElement() { + checkCapture("/{*var}","/f%20o","var","/f o"); + checkCapture("{var1}/{*var2}","f%20o/f%7Co","var1","f o","var2","/f|o"); + checkCapture("/{*var}","/foo","var","/foo"); + checkCapture("/{*var}","/f%20o","var","/f o"); + checkCapture("/{*var}","/f%20o","var","/f o"); + checkCapture("/{*var}","/f%7co","var","/f|o"); + } + + @Test + public void encodingAndBoundVariablesRegexPathElement() { + checkCapture("/{var1:f o}_ _{var2}","/f%20o_%20_f%7co","var1","f o","var2","f|o"); + checkCapture("/{var1}_{var2}","/f%20o_foo","var1","f o","var2","foo"); + checkCapture("/{var1}_ _{var2}","/f%20o_%20_f%7co","var1","f o","var2","f|o"); + checkCapture("/{var1}_ _{var2:f\\|o}","/f%20o_%20_f%7co","var1","f o","var2","f|o"); + checkCapture("/{var1:f o}_ _{var2:f\\|o}","/f%20o_%20_f%7co","var1","f o","var2","f|o"); + } + + @Test + public void pathRemainingCornerCases_spr15336() { + // No match when the literal path element is a longer form of the segment in the pattern + assertNull(parse("/foo").matchStartOfPath(toPathContainer("/footastic/bar"))); + assertNull(parse("/f?o").matchStartOfPath(toPathContainer("/footastic/bar"))); + assertNull(parse("/f*o*p").matchStartOfPath(toPathContainer("/flooptastic/bar"))); + assertNull(parse("/{abc}abc").matchStartOfPath(toPathContainer("/xyzabcbar/bar"))); + + // With a /** on the end have to check if there is any more data post + // 'the match' it starts with a separator + assertNull(parse("/resource/**").matchStartOfPath(toPathContainer("/resourceX"))); + assertEquals("",parse("/resource/**") + .matchStartOfPath(toPathContainer("/resource")).getPathRemaining().value()); + + // Similar to above for the capture-the-rest variant + assertNull(parse("/resource/{*foo}").matchStartOfPath(toPathContainer("/resourceX"))); + assertEquals("", parse("/resource/{*foo}") + .matchStartOfPath(toPathContainer("/resource")).getPathRemaining().value()); + + PathPattern.PathRemainingMatchInfo pri = parse("/aaa/{bbb}/c?d/e*f/*/g") + .matchStartOfPath(toPathContainer("/aaa/b/ccd/ef/x/g/i")); + assertNotNull(pri); + assertEquals("/i",pri.getPathRemaining().value()); + assertEquals("b",pri.getUriVariables().get("bbb")); + + pri = parse("/aaa/{bbb}/c?d/e*f/*/g/").matchStartOfPath(toPathContainer("/aaa/b/ccd/ef/x/g/i")); + assertNotNull(pri); + assertEquals("i",pri.getPathRemaining().value()); + assertEquals("b",pri.getUriVariables().get("bbb")); + + pri = parse("/{aaa}_{bbb}/e*f/{x}/g").matchStartOfPath(toPathContainer("/aa_bb/ef/x/g/i")); + assertNotNull(pri); + assertEquals("/i",pri.getPathRemaining().value()); + assertEquals("aa",pri.getUriVariables().get("aaa")); + assertEquals("bb",pri.getUriVariables().get("bbb")); + assertEquals("x",pri.getUriVariables().get("x")); + + assertNull(parse("/a/b").matchStartOfPath(toPathContainer(""))); + assertEquals("/a/b",parse("").matchStartOfPath(toPathContainer("/a/b")).getPathRemaining().value()); + assertEquals("",parse("").matchStartOfPath(toPathContainer("")).getPathRemaining().value()); + } + + @Test + public void questionMarks() { + checkNoMatch("a", "ab"); + checkMatches("/f?o/bar", "/foo/bar"); + checkNoMatch("/foo/b2r", "/foo/bar"); + checkNoMatch("?", "te"); + checkMatches("?", "a"); + checkMatches("???", "abc"); + checkNoMatch("tes?", "te"); + checkNoMatch("tes?", "tes"); + checkNoMatch("tes?", "testt"); + checkNoMatch("tes?", "tsst"); + checkMatches(".?.a", ".a.a"); + checkNoMatch(".?.a", ".aba"); + checkMatches("/f?o/bar","/f%20o/bar"); + } + + @Test + public void captureTheRest() { + checkMatches("/resource/{*foobar}", "/resource"); + checkNoMatch("/resource/{*foobar}", "/resourceX"); + checkNoMatch("/resource/{*foobar}", "/resourceX/foobar"); + checkMatches("/resource/{*foobar}", "/resource/foobar"); + checkCapture("/resource/{*foobar}", "/resource/foobar", "foobar", "/foobar"); + checkCapture("/customer/{*something}", "/customer/99", "something", "/99"); + checkCapture("/customer/{*something}", "/customer/aa/bb/cc", "something", + "/aa/bb/cc"); + checkCapture("/customer/{*something}", "/customer/", "something", "/"); + checkCapture("/customer/////{*something}", "/customer/////", "something", "/"); + checkCapture("/customer/////{*something}", "/customer//////", "something", "//"); + checkCapture("/customer//////{*something}", "/customer//////99", "something", "/99"); + checkCapture("/customer//////{*something}", "/customer//////99", "something", "/99"); + checkCapture("/customer/{*something}", "/customer", "something", ""); + checkCapture("/{*something}", "", "something", ""); + checkCapture("/customer/{*something}", "/customer//////99", "something", "//////99"); + } + + @Test + public void multipleSeparatorsInPattern() { + PathPattern pp = parse("a//b//c"); + assertEquals("Literal(a) Separator(/) Separator(/) Literal(b) Separator(/) Separator(/) Literal(c)", + pp.toChainString()); + assertMatches(pp,"a//b//c"); + assertEquals("Literal(a) Separator(/) WildcardTheRest(/**)",parse("a//**").toChainString()); + checkMatches("///abc", "///abc"); + checkNoMatch("///abc", "/abc"); + checkNoMatch("//", "/"); + checkMatches("//", "//"); + checkNoMatch("///abc//d/e", "/abc/d/e"); + checkMatches("///abc//d/e", "///abc//d/e"); + checkNoMatch("///abc//{def}//////xyz", "/abc/foo/xyz"); + checkMatches("///abc//{def}//////xyz", "///abc//p//////xyz"); + } + + @Test + public void multipleSelectorsInPath() { + checkNoMatch("/abc", "////abc"); + checkNoMatch("/", "//"); + checkNoMatch("/abc/def/ghi", "/abc//def///ghi"); + checkNoMatch("/abc", "////abc"); + checkMatches("////abc", "////abc"); + checkNoMatch("/", "//"); + checkNoMatch("/abc//def", "/abc/def"); + checkNoMatch("/abc//def///ghi", "/abc/def/ghi"); + checkMatches("/abc//def///ghi", "/abc//def///ghi"); + } + + @Test + public void multipleSeparatorsInPatternAndPath() { + checkNoMatch("///one///two///three", "//one/////two///////three"); + checkMatches("//one/////two///////three", "//one/////two///////three"); + checkNoMatch("//one//two//three", "/one/////two/three"); + checkMatches("/one/////two/three", "/one/////two/three"); + checkCapture("///{foo}///bar", "///one///bar", "foo", "one"); + } + + @Test + public void wildcards() { + checkMatches("/*/bar", "/foo/bar"); + checkNoMatch("/*/bar", "/foo/baz"); + checkNoMatch("/*/bar", "//bar"); + checkMatches("/f*/bar", "/foo/bar"); + checkMatches("/*/bar", "/foo/bar"); + checkMatches("a/*","a/"); + checkMatches("/*","/"); + checkMatches("/*/bar", "/foo/bar"); + checkNoMatch("/*/bar", "/foo/baz"); + checkMatches("/f*/bar", "/foo/bar"); + checkMatches("/*/bar", "/foo/bar"); + checkMatches("/a*b*c*d/bar", "/abcd/bar"); + checkMatches("*a*", "testa"); + checkMatches("a/*", "a/"); + checkNoMatch("a/*", "a//"); // no data for * + checkMatches("a/*", "a/a/"); // trailing slash, so is allowed + PathPatternParser ppp = new PathPatternParser(); + ppp.setMatchOptionalTrailingSeparator(false); + assertFalse(ppp.parse("a/*").matches(toPathContainer("a//"))); + checkMatches("a/*", "a/a"); + checkMatches("a/*", "a/a/"); // trailing slash is optional + checkMatches("/resource/**", "/resource"); + checkNoMatch("/resource/**", "/resourceX"); + checkNoMatch("/resource/**", "/resourceX/foobar"); + checkMatches("/resource/**", "/resource/foobar"); + } + + @Test + public void constrainedMatches() { + checkCapture("{foo:[0-9]*}", "123", "foo", "123"); + checkNoMatch("{foo:[0-9]*}", "abc"); + checkNoMatch("/{foo:[0-9]*}", "abc"); + checkCapture("/*/{foo:....}/**", "/foo/barg/foo", "foo", "barg"); + checkCapture("/*/{foo:....}/**", "/foo/barg/abc/def/ghi", "foo", "barg"); + checkNoMatch("{foo:....}", "99"); + checkMatches("{foo:..}", "99"); + checkCapture("/{abc:\\{\\}}", "/{}", "abc", "{}"); + checkCapture("/{abc:\\[\\]}", "/[]", "abc", "[]"); + checkCapture("/{abc:\\\\\\\\}", "/\\\\"); // this is fun... + } + + @Test + public void antPathMatcherTests() { + // test exact matching + checkMatches("test", "test"); + checkMatches("/test", "/test"); + checkMatches("https://example.org", "https://example.org"); + checkNoMatch("/test.jpg", "test.jpg"); + checkNoMatch("test", "/test"); + checkNoMatch("/test", "test"); + + // test matching with ?'s + checkMatches("t?st", "test"); + checkMatches("??st", "test"); + checkMatches("tes?", "test"); + checkMatches("te??", "test"); + checkMatches("?es?", "test"); + checkNoMatch("tes?", "tes"); + checkNoMatch("tes?", "testt"); + checkNoMatch("tes?", "tsst"); + + // test matching with *'s + checkMatches("*", "test"); + checkMatches("test*", "test"); + checkMatches("test*", "testTest"); + checkMatches("test/*", "test/Test"); + checkMatches("test/*", "test/t"); + checkMatches("test/*", "test/"); + checkMatches("*test*", "AnothertestTest"); + checkMatches("*test", "Anothertest"); + checkMatches("*.*", "test."); + checkMatches("*.*", "test.test"); + checkMatches("*.*", "test.test.test"); + checkMatches("test*aaa", "testblaaaa"); + checkNoMatch("test*", "tst"); + checkNoMatch("test*", "tsttest"); + checkMatches("test*", "test/"); // trailing slash is optional + checkMatches("test*", "test"); // trailing slash is optional + checkNoMatch("test*", "test/t"); + checkNoMatch("test/*", "test"); + checkNoMatch("*test*", "tsttst"); + checkNoMatch("*test", "tsttst"); + checkNoMatch("*.*", "tsttst"); + checkNoMatch("test*aaa", "test"); + checkNoMatch("test*aaa", "testblaaab"); + + // test matching with ?'s and /'s + checkMatches("/?", "/a"); + checkMatches("/?/a", "/a/a"); + checkMatches("/a/?", "/a/b"); + checkMatches("/??/a", "/aa/a"); + checkMatches("/a/??", "/a/bb"); + checkMatches("/?", "/a"); + + checkMatches("/**", ""); + checkMatches("/books/**", "/books"); + checkMatches("/**", "/testing/testing"); + checkMatches("/*/**", "/testing/testing"); + checkMatches("/bla*bla/test", "/blaXXXbla/test"); + checkMatches("/*bla/test", "/XXXbla/test"); + checkNoMatch("/bla*bla/test", "/blaXXXbl/test"); + checkNoMatch("/*bla/test", "XXXblab/test"); + checkNoMatch("/*bla/test", "XXXbl/test"); + checkNoMatch("/????", "/bala/bla"); + checkMatches("/foo/bar/**", "/foo/bar/"); + checkMatches("/{bla}.html", "/testing.html"); + checkCapture("/{bla}.*", "/testing.html", "bla", "testing"); + } + + @Test + public void pathRemainingEnhancements_spr15419() { + PathPattern pp; + PathPattern.PathRemainingMatchInfo pri; + // It would be nice to partially match a path and get any bound variables in one step + pp = parse("/{this}/{one}/{here}"); + pri = getPathRemaining(pp, "/foo/bar/goo/boo"); + assertEquals("/boo",pri.getPathRemaining().value()); + assertEquals("foo",pri.getUriVariables().get("this")); + assertEquals("bar",pri.getUriVariables().get("one")); + assertEquals("goo",pri.getUriVariables().get("here")); + + pp = parse("/aaa/{foo}"); + pri = getPathRemaining(pp, "/aaa/bbb"); + assertEquals("",pri.getPathRemaining().value()); + assertEquals("bbb",pri.getUriVariables().get("foo")); + + pp = parse("/aaa/bbb"); + pri = getPathRemaining(pp, "/aaa/bbb"); + assertEquals("",pri.getPathRemaining().value()); + assertEquals(0,pri.getUriVariables().size()); + + pp = parse("/*/{foo}/b*"); + pri = getPathRemaining(pp, "/foo"); + assertNull(pri); + pri = getPathRemaining(pp, "/abc/def/bhi"); + assertEquals("",pri.getPathRemaining().value()); + assertEquals("def",pri.getUriVariables().get("foo")); + + pri = getPathRemaining(pp, "/abc/def/bhi/jkl"); + assertEquals("/jkl",pri.getPathRemaining().value()); + assertEquals("def",pri.getUriVariables().get("foo")); + } + + @Test + public void caseSensitivity() { + PathPatternParser pp = new PathPatternParser(); + pp.setCaseSensitive(false); + PathPattern p = pp.parse("abc"); + assertMatches(p,"AbC"); + assertNoMatch(p,"def"); + p = pp.parse("fOo"); + assertMatches(p,"FoO"); + p = pp.parse("/fOo/bAr"); + assertMatches(p,"/FoO/BaR"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(true); + p = pp.parse("abc"); + assertNoMatch(p,"AbC"); + p = pp.parse("fOo"); + assertNoMatch(p,"FoO"); + p = pp.parse("/fOo/bAr"); + assertNoMatch(p,"/FoO/BaR"); + p = pp.parse("/fOO/bAr"); + assertMatches(p,"/fOO/bAr"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(false); + p = pp.parse("{foo:[A-Z]*}"); + assertMatches(p,"abc"); + assertMatches(p,"ABC"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(true); + p = pp.parse("{foo:[A-Z]*}"); + assertNoMatch(p,"abc"); + assertMatches(p,"ABC"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(false); + p = pp.parse("ab?"); + assertMatches(p,"AbC"); + p = pp.parse("fO?"); + assertMatches(p,"FoO"); + p = pp.parse("/fO?/bA?"); + assertMatches(p,"/FoO/BaR"); + assertNoMatch(p,"/bAr/fOo"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(true); + p = pp.parse("ab?"); + assertNoMatch(p,"AbC"); + p = pp.parse("fO?"); + assertNoMatch(p,"FoO"); + p = pp.parse("/fO?/bA?"); + assertNoMatch(p,"/FoO/BaR"); + p = pp.parse("/fO?/bA?"); + assertMatches(p,"/fOO/bAr"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(false); + p = pp.parse("{abc:[A-Z]*}_{def:[A-Z]*}"); + assertMatches(p,"abc_abc"); + assertMatches(p,"ABC_aBc"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(true); + p = pp.parse("{abc:[A-Z]*}_{def:[A-Z]*}"); + assertNoMatch(p,"abc_abc"); + assertMatches(p,"ABC_ABC"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(false); + p = pp.parse("*?a?*"); + assertMatches(p,"bab"); + assertMatches(p,"bAb"); + + pp = new PathPatternParser(); + pp.setCaseSensitive(true); + p = pp.parse("*?A?*"); + assertNoMatch(p,"bab"); + assertMatches(p,"bAb"); + } + + @Test + public void extractPathWithinPattern_spr15259() { + checkExtractPathWithinPattern("/**","//",""); + checkExtractPathWithinPattern("/**","/",""); + checkExtractPathWithinPattern("/**","",""); + checkExtractPathWithinPattern("/**","/foobar","foobar"); + } + + @Test + public void extractPathWithinPattern() throws Exception { + checkExtractPathWithinPattern("/welcome*/", "/welcome/", "welcome"); + checkExtractPathWithinPattern("/docs/commit.html", "/docs/commit.html", ""); + checkExtractPathWithinPattern("/docs/*", "/docs/cvs/commit", "cvs/commit"); + checkExtractPathWithinPattern("/docs/cvs/*.html", "/docs/cvs/commit.html", "commit.html"); + checkExtractPathWithinPattern("/docs/**", "/docs/cvs/commit", "cvs/commit"); + checkExtractPathWithinPattern("/doo/{*foobar}", "/doo/customer.html", "customer.html"); + checkExtractPathWithinPattern("/doo/{*foobar}", "/doo/daa/customer.html", "daa/customer.html"); + checkExtractPathWithinPattern("/*.html", "/commit.html", "commit.html"); + checkExtractPathWithinPattern("/docs/*/*/*/*", "/docs/cvs/other/commit.html", "cvs/other/commit.html"); + checkExtractPathWithinPattern("/d?cs/**", "/docs/cvs/commit", "docs/cvs/commit"); + checkExtractPathWithinPattern("/*/**", "/docs/cvs/commit///", "docs/cvs/commit"); + checkExtractPathWithinPattern("/*/**", "/docs/cvs/commit/", "docs/cvs/commit"); + checkExtractPathWithinPattern("/aaa/bbb/**", "/aaa///",""); + checkExtractPathWithinPattern("/aaa/bbb/**", "/aaa//",""); + checkExtractPathWithinPattern("/aaa/bbb/**", "/aaa/",""); + checkExtractPathWithinPattern("/docs/**", "/docs/cvs/commit///", "cvs/commit"); + checkExtractPathWithinPattern("/docs/**", "/docs/cvs/commit/", "cvs/commit"); + checkExtractPathWithinPattern("/docs/c?s/*.html", "/docs/cvs/commit.html", "cvs/commit.html"); + checkExtractPathWithinPattern("/d?cs/*/*.html", "/docs/cvs/commit.html", "docs/cvs/commit.html"); + checkExtractPathWithinPattern("/a/b/c*d*/*.html", "/a/b/cod/foo.html", "cod/foo.html"); + checkExtractPathWithinPattern("a/{foo}/b/{bar}", "a/c/b/d", "c/b/d"); + checkExtractPathWithinPattern("a/{foo}_{bar}/d/e", "a/b_c/d/e", "b_c/d/e"); + checkExtractPathWithinPattern("aaa//*///ccc///ddd", "aaa//bbb///ccc///ddd", "bbb/ccc/ddd"); + checkExtractPathWithinPattern("aaa//*///ccc///ddd", "aaa//bbb//ccc/ddd", "bbb/ccc/ddd"); + checkExtractPathWithinPattern("aaa/c*/ddd/", "aaa/ccc///ddd///", "ccc/ddd"); + checkExtractPathWithinPattern("", "", ""); + checkExtractPathWithinPattern("/", "", ""); + checkExtractPathWithinPattern("", "/", ""); + checkExtractPathWithinPattern("//", "", ""); + checkExtractPathWithinPattern("", "//", ""); + checkExtractPathWithinPattern("//", "//", ""); + checkExtractPathWithinPattern("//", "/", ""); + checkExtractPathWithinPattern("/", "//", ""); + } + + @Test + public void extractUriTemplateVariables_spr15264() { + PathPattern pp; + pp = new PathPatternParser().parse("/{foo}"); + assertMatches(pp,"/abc"); + assertNoMatch(pp,"/"); + assertNoMatch(pp,"//"); + checkCapture("/{foo}", "/abc", "foo", "abc"); + + pp = new PathPatternParser().parse("/{foo}/{bar}"); + assertMatches(pp,"/abc/def"); + assertNoMatch(pp,"/def"); + assertNoMatch(pp,"/"); + assertNoMatch(pp,"//def"); + assertNoMatch(pp,"//"); + + pp = parse("/{foo}/boo"); + assertMatches(pp,"/abc/boo"); + assertMatches(pp,"/a/boo"); + assertNoMatch(pp,"/boo"); + assertNoMatch(pp,"//boo"); + + pp = parse("/{foo}*"); + assertMatches(pp,"/abc"); + assertNoMatch(pp,"/"); + + checkCapture("/{word:[a-z]*}", "/abc", "word", "abc"); + pp = parse("/{word:[a-z]*}"); + assertNoMatch(pp,"/1"); + assertMatches(pp,"/a"); + assertNoMatch(pp,"/"); + + // Two captures mean we use a RegexPathElement + pp = new PathPatternParser().parse("/{foo}{bar}"); + assertMatches(pp,"/abcdef"); + assertNoMatch(pp,"/"); + assertNoMatch(pp,"//"); + checkCapture("/{foo:[a-z][a-z]}{bar:[a-z]}", "/abc", "foo", "ab", "bar", "c"); + + // Only patterns not capturing variables cannot match against just / + PathPatternParser ppp = new PathPatternParser(); + ppp.setMatchOptionalTrailingSeparator(true); + pp = ppp.parse("/****"); + assertMatches(pp,"/abcdef"); + assertMatches(pp,"/"); + assertMatches(pp,"/"); + assertMatches(pp,"//"); + + // Confirming AntPathMatcher behaviour: + assertFalse(new AntPathMatcher().match("/{foo}", "/")); + assertTrue(new AntPathMatcher().match("/{foo}", "/a")); + assertTrue(new AntPathMatcher().match("/{foo}{bar}", "/a")); + assertFalse(new AntPathMatcher().match("/{foo}*", "/")); + assertTrue(new AntPathMatcher().match("/*", "/")); + assertFalse(new AntPathMatcher().match("/*{foo}", "/")); + Map vars = new AntPathMatcher().extractUriTemplateVariables("/{foo}{bar}", "/a"); + assertEquals("a",vars.get("foo")); + assertEquals("",vars.get("bar")); + } + + @Test + public void extractUriTemplateVariables() throws Exception { + assertMatches(parse("{hotel}"),"1"); + assertMatches(parse("/hotels/{hotel}"),"/hotels/1"); + checkCapture("/hotels/{hotel}", "/hotels/1", "hotel", "1"); + checkCapture("/h?tels/{hotel}", "/hotels/1", "hotel", "1"); + checkCapture("/hotels/{hotel}/bookings/{booking}", "/hotels/1/bookings/2", "hotel", "1", "booking", "2"); + checkCapture("/*/hotels/*/{hotel}", "/foo/hotels/bar/1", "hotel", "1"); + checkCapture("/{page}.html", "/42.html", "page", "42"); + checkNoMatch("/{var}","/"); + checkCapture("/{page}.*", "/42.html", "page", "42"); + checkCapture("/A-{B}-C", "/A-b-C", "B", "b"); + checkCapture("/{name}.{extension}", "/test.html", "name", "test", "extension", "html"); + + assertNull(checkCapture("/{one}/", "//")); + assertNull(checkCapture("", "/abc")); + + assertEquals(0, checkCapture("", "").getUriVariables().size()); + checkCapture("{id}", "99", "id", "99"); + checkCapture("/customer/{customerId}", "/customer/78", "customerId", "78"); + checkCapture("/customer/{customerId}/banana", "/customer/42/banana", "customerId", + "42"); + checkCapture("{id}/{id2}", "99/98", "id", "99", "id2", "98"); + checkCapture("/foo/{bar}/boo/{baz}", "/foo/plum/boo/apple", "bar", "plum", "baz", + "apple"); + checkCapture("/{bla}.*", "/testing.html", "bla", "testing"); + PathPattern.PathMatchInfo extracted = checkCapture("/abc", "/abc"); + assertEquals(0, extracted.getUriVariables().size()); + checkCapture("/{bla}/foo","/a/foo"); + } + + @Test + public void extractUriTemplateVariablesRegex() { + PathPatternParser pp = new PathPatternParser(); + PathPattern p = null; + + p = pp.parse("{symbolicName:[\\w\\.]+}-{version:[\\w\\.]+}.jar"); + PathPattern.PathMatchInfo result = matchAndExtract(p, "com.example-1.0.0.jar"); + assertEquals("com.example", result.getUriVariables().get("symbolicName")); + assertEquals("1.0.0", result.getUriVariables().get("version")); + + p = pp.parse("{symbolicName:[\\w\\.]+}-sources-{version:[\\w\\.]+}.jar"); + result = matchAndExtract(p, "com.example-sources-1.0.0.jar"); + assertEquals("com.example", result.getUriVariables().get("symbolicName")); + assertEquals("1.0.0", result.getUriVariables().get("version")); + } + + @Test + public void extractUriTemplateVarsRegexQualifiers() { + PathPatternParser pp = new PathPatternParser(); + + PathPattern p = pp.parse("{symbolicName:[\\p{L}\\.]+}-sources-{version:[\\p{N}\\.]+}.jar"); + PathPattern.PathMatchInfo result = p.matchAndExtract(toPathContainer("com.example-sources-1.0.0.jar")); + assertEquals("com.example", result.getUriVariables().get("symbolicName")); + assertEquals("1.0.0", result.getUriVariables().get("version")); + + p = pp.parse("{symbolicName:[\\w\\.]+}-sources-" + + "{version:[\\d\\.]+}-{year:\\d{4}}{month:\\d{2}}{day:\\d{2}}.jar"); + result = matchAndExtract(p,"com.example-sources-1.0.0-20100220.jar"); + assertEquals("com.example", result.getUriVariables().get("symbolicName")); + assertEquals("1.0.0", result.getUriVariables().get("version")); + assertEquals("2010", result.getUriVariables().get("year")); + assertEquals("02", result.getUriVariables().get("month")); + assertEquals("20", result.getUriVariables().get("day")); + + p = pp.parse("{symbolicName:[\\p{L}\\.]+}-sources-{version:[\\p{N}\\.\\{\\}]+}.jar"); + result = matchAndExtract(p, "com.example-sources-1.0.0.{12}.jar"); + assertEquals("com.example", result.getUriVariables().get("symbolicName")); + assertEquals("1.0.0.{12}", result.getUriVariables().get("version")); + } + + @Test + public void extractUriTemplateVarsRegexCapturingGroups() { + PathPatternParser ppp = new PathPatternParser(); + PathPattern pathPattern = ppp.parse("/web/{id:foo(bar)?}_{goo}"); + exception.expect(IllegalArgumentException.class); + exception.expectMessage(containsString("The number of capturing groups in the pattern")); + matchAndExtract(pathPattern,"/web/foobar_goo"); + } + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + @Test + public void combine() { + TestPathCombiner pathMatcher = new TestPathCombiner(); + assertEquals("", pathMatcher.combine("", "")); + assertEquals("/hotels", pathMatcher.combine("/hotels", "")); + assertEquals("/hotels", pathMatcher.combine("", "/hotels")); + assertEquals("/hotels/booking", pathMatcher.combine("/hotels/*", "booking")); + assertEquals("/hotels/booking", pathMatcher.combine("/hotels/*", "/booking")); + assertEquals("/hotels/**/booking", pathMatcher.combine("/hotels/**", "booking")); + assertEquals("/hotels/**/booking", pathMatcher.combine("/hotels/**", "/booking")); + assertEquals("/hotels/booking", pathMatcher.combine("/hotels", "/booking")); + assertEquals("/hotels/booking", pathMatcher.combine("/hotels", "booking")); + assertEquals("/hotels/booking", pathMatcher.combine("/hotels/", "booking")); + assertEquals("/hotels/{hotel}", pathMatcher.combine("/hotels/*", "{hotel}")); + assertEquals("/hotels/**/{hotel}", pathMatcher.combine("/hotels/**", "{hotel}")); + assertEquals("/hotels/{hotel}", pathMatcher.combine("/hotels", "{hotel}")); + assertEquals("/hotels/{hotel}.*", pathMatcher.combine("/hotels", "{hotel}.*")); + assertEquals("/hotels/*/booking/{booking}", + pathMatcher.combine("/hotels/*/booking", "{booking}")); + assertEquals("/hotel.html", pathMatcher.combine("/*.html", "/hotel.html")); + assertEquals("/hotel.html", pathMatcher.combine("/*.html", "/hotel")); + assertEquals("/hotel.html", pathMatcher.combine("/*.html", "/hotel.*")); + // TODO this seems rather bogus, should we eagerly show an error? + assertEquals("/d/e/f/hotel.html", pathMatcher.combine("/a/b/c/*.html", "/d/e/f/hotel.*")); + assertEquals("/*.html", pathMatcher.combine("/**", "/*.html")); + assertEquals("/*.html", pathMatcher.combine("/*", "/*.html")); + assertEquals("/*.html", pathMatcher.combine("/*.*", "/*.html")); + assertEquals("/{foo}/bar", pathMatcher.combine("/{foo}", "/bar")); // SPR-8858 + assertEquals("/user/user", pathMatcher.combine("/user", "/user")); // SPR-7970 + assertEquals("/{foo:.*[^0-9].*}/edit/", + pathMatcher.combine("/{foo:.*[^0-9].*}", "/edit/")); // SPR-10062 + assertEquals("/1.0/foo/test", pathMatcher.combine("/1.0", "/foo/test")); + // SPR-10554 + assertEquals("/hotel", pathMatcher.combine("/", "/hotel")); // SPR-12975 + assertEquals("/hotel/booking", pathMatcher.combine("/hotel/", "/booking")); // SPR-12975 + assertEquals("/hotel", pathMatcher.combine("", "/hotel")); + assertEquals("/hotel", pathMatcher.combine("/hotel", "")); + // TODO Do we need special handling when patterns contain multiple dots? + } + + @Test + public void combineWithTwoFileExtensionPatterns() { + TestPathCombiner pathMatcher = new TestPathCombiner(); + exception.expect(IllegalArgumentException.class); + pathMatcher.combine("/*.html", "/*.txt"); + } + + @Test + public void patternComparator() { + Comparator comparator = PathPattern.SPECIFICITY_COMPARATOR; + + assertEquals(0, comparator.compare(parse("/hotels/new"), parse("/hotels/new"))); + + assertEquals(-1, comparator.compare(parse("/hotels/new"), parse("/hotels/*"))); + assertEquals(1, comparator.compare(parse("/hotels/*"), parse("/hotels/new"))); + assertEquals(0, comparator.compare(parse("/hotels/*"), parse("/hotels/*"))); + + assertEquals(-1, + comparator.compare(parse("/hotels/new"), parse("/hotels/{hotel}"))); + assertEquals(1, + comparator.compare(parse("/hotels/{hotel}"), parse("/hotels/new"))); + assertEquals(0, + comparator.compare(parse("/hotels/{hotel}"), parse("/hotels/{hotel}"))); + assertEquals(-1, comparator.compare(parse("/hotels/{hotel}/booking"), + parse("/hotels/{hotel}/bookings/{booking}"))); + assertEquals(1, comparator.compare(parse("/hotels/{hotel}/bookings/{booking}"), + parse("/hotels/{hotel}/booking"))); + + assertEquals(-1, + comparator.compare( + parse("/hotels/{hotel}/bookings/{booking}/cutomers/{customer}"), + parse("/**"))); + assertEquals(1, comparator.compare(parse("/**"), + parse("/hotels/{hotel}/bookings/{booking}/cutomers/{customer}"))); + assertEquals(0, comparator.compare(parse("/**"), parse("/**"))); + + assertEquals(-1, + comparator.compare(parse("/hotels/{hotel}"), parse("/hotels/*"))); + assertEquals(1, comparator.compare(parse("/hotels/*"), parse("/hotels/{hotel}"))); + + assertEquals(-1, comparator.compare(parse("/hotels/*"), parse("/hotels/*/**"))); + assertEquals(1, comparator.compare(parse("/hotels/*/**"), parse("/hotels/*"))); + +// TODO: shouldn't the wildcard lower the score? +// assertEquals(-1, +// comparator.compare(parse("/hotels/new"), parse("/hotels/new.*"))); + + // SPR-6741 + assertEquals(-1, + comparator.compare( + parse("/hotels/{hotel}/bookings/{booking}/cutomers/{customer}"), + parse("/hotels/**"))); + assertEquals(1, comparator.compare(parse("/hotels/**"), + parse("/hotels/{hotel}/bookings/{booking}/cutomers/{customer}"))); + assertEquals(1, comparator.compare(parse("/hotels/foo/bar/**"), + parse("/hotels/{hotel}"))); + assertEquals(-1, comparator.compare(parse("/hotels/{hotel}"), + parse("/hotels/foo/bar/**"))); + + // SPR-8683 + assertEquals(1, comparator.compare(parse("/**"), parse("/hotels/{hotel}"))); + + // longer is better + assertEquals(1, comparator.compare(parse("/hotels"), parse("/hotels2"))); + + // SPR-13139 + assertEquals(-1, comparator.compare(parse("*"), parse("*/**"))); + assertEquals(1, comparator.compare(parse("*/**"), parse("*"))); + } + + @Test + public void compare_spr15597() { + PathPatternParser parser = new PathPatternParser(); + PathPattern p1 = parser.parse("/{foo}"); + PathPattern p2 = parser.parse("/{foo}.*"); + PathPattern.PathMatchInfo r1 = matchAndExtract(p1, "/file.txt"); + PathPattern.PathMatchInfo r2 = matchAndExtract(p2, "/file.txt"); + + // works fine + assertEquals("file.txt", r1.getUriVariables().get("foo")); + assertEquals("file", r2.getUriVariables().get("foo")); + + // This produces 2 (see comments in https://jira.spring.io/browse/SPR-14544 ) + // Comparator patternComparator = new AntPathMatcher().getPatternComparator(""); + // System.out.println(patternComparator.compare("/{foo}","/{foo}.*")); + + assertThat(p1.compareTo(p2), Matchers.greaterThan(0)); + } + + @Test + public void patternCompareWithNull() { + assertTrue(PathPattern.SPECIFICITY_COMPARATOR.compare(null, null) == 0); + assertTrue(PathPattern.SPECIFICITY_COMPARATOR.compare(parse("/abc"), null) < 0); + assertTrue(PathPattern.SPECIFICITY_COMPARATOR.compare(null, parse("/abc")) > 0); + } + + @Test + public void patternComparatorSort() { + Comparator comparator = PathPattern.SPECIFICITY_COMPARATOR; + + List paths = new ArrayList<>(3); + PathPatternParser pp = new PathPatternParser(); + paths.add(null); + paths.add(null); + paths.sort(comparator); + assertNull(paths.get(0)); + assertNull(paths.get(1)); + paths.clear(); + + paths.add(null); + paths.add(pp.parse("/hotels/new")); + paths.sort(comparator); + assertEquals("/hotels/new", paths.get(0).getPatternString()); + assertNull(paths.get(1)); + paths.clear(); + + paths.add(pp.parse("/hotels/*")); + paths.add(pp.parse("/hotels/new")); + paths.sort(comparator); + assertEquals("/hotels/new", paths.get(0).getPatternString()); + assertEquals("/hotels/*", paths.get(1).getPatternString()); + paths.clear(); + + paths.add(pp.parse("/hotels/new")); + paths.add(pp.parse("/hotels/*")); + paths.sort(comparator); + assertEquals("/hotels/new", paths.get(0).getPatternString()); + assertEquals("/hotels/*", paths.get(1).getPatternString()); + paths.clear(); + + paths.add(pp.parse("/hotels/**")); + paths.add(pp.parse("/hotels/*")); + paths.sort(comparator); + assertEquals("/hotels/*", paths.get(0).getPatternString()); + assertEquals("/hotels/**", paths.get(1).getPatternString()); + paths.clear(); + + paths.add(pp.parse("/hotels/*")); + paths.add(pp.parse("/hotels/**")); + paths.sort(comparator); + assertEquals("/hotels/*", paths.get(0).getPatternString()); + assertEquals("/hotels/**", paths.get(1).getPatternString()); + paths.clear(); + + paths.add(pp.parse("/hotels/{hotel}")); + paths.add(pp.parse("/hotels/new")); + paths.sort(comparator); + assertEquals("/hotels/new", paths.get(0).getPatternString()); + assertEquals("/hotels/{hotel}", paths.get(1).getPatternString()); + paths.clear(); + + paths.add(pp.parse("/hotels/new")); + paths.add(pp.parse("/hotels/{hotel}")); + paths.sort(comparator); + assertEquals("/hotels/new", paths.get(0).getPatternString()); + assertEquals("/hotels/{hotel}", paths.get(1).getPatternString()); + paths.clear(); + + paths.add(pp.parse("/hotels/*")); + paths.add(pp.parse("/hotels/{hotel}")); + paths.add(pp.parse("/hotels/new")); + paths.sort(comparator); + assertEquals("/hotels/new", paths.get(0).getPatternString()); + assertEquals("/hotels/{hotel}", paths.get(1).getPatternString()); + assertEquals("/hotels/*", paths.get(2).getPatternString()); + paths.clear(); + + paths.add(pp.parse("/hotels/ne*")); + paths.add(pp.parse("/hotels/n*")); + Collections.shuffle(paths); + paths.sort(comparator); + assertEquals("/hotels/ne*", paths.get(0).getPatternString()); + assertEquals("/hotels/n*", paths.get(1).getPatternString()); + paths.clear(); + + // comparator = new PatternComparatorConsideringPath("/hotels/new.html"); + // paths.add(pp.parse("/hotels/new.*")); + // paths.add(pp.parse("/hotels/{hotel}")); + // Collections.shuffle(paths); + // Collections.sort(paths, comparator); + // assertEquals("/hotels/new.*", paths.get(0).toPatternString()); + // assertEquals("/hotels/{hotel}", paths.get(1).toPatternString()); + // paths.clear(); + + comparator = (p1, p2) -> { + int index = p1.compareTo(p2); + return (index != 0 ? index : p1.getPatternString().compareTo(p2.getPatternString())); + }; + paths.add(pp.parse("/*/login.*")); + paths.add(pp.parse("/*/endUser/action/login.*")); + paths.sort(comparator); + assertEquals("/*/endUser/action/login.*", paths.get(0).getPatternString()); + assertEquals("/*/login.*", paths.get(1).getPatternString()); + paths.clear(); + } + + @Test // SPR-13286 + public void caseInsensitive() { + PathPatternParser pp = new PathPatternParser(); + pp.setCaseSensitive(false); + PathPattern p = pp.parse("/group/{groupName}/members"); + assertMatches(p,"/group/sales/members"); + assertMatches(p,"/Group/Sales/Members"); + assertMatches(p,"/group/Sales/members"); + } + + @Test + public void parameters() { + // CaptureVariablePathElement + PathPattern.PathMatchInfo result = matchAndExtract("/abc/{var}","/abc/one;two=three;four=five"); + assertEquals("one",result.getUriVariables().get("var")); + assertEquals("three",result.getMatrixVariables().get("var").getFirst("two")); + assertEquals("five",result.getMatrixVariables().get("var").getFirst("four")); + // RegexPathElement + result = matchAndExtract("/abc/{var1}_{var2}","/abc/123_456;a=b;c=d"); + assertEquals("123",result.getUriVariables().get("var1")); + assertEquals("456",result.getUriVariables().get("var2")); + // vars associated with second variable + assertNull(result.getMatrixVariables().get("var1")); + assertNull(result.getMatrixVariables().get("var1")); + assertEquals("b",result.getMatrixVariables().get("var2").getFirst("a")); + assertEquals("d",result.getMatrixVariables().get("var2").getFirst("c")); + // CaptureTheRestPathElement + result = matchAndExtract("/{*var}","/abc/123_456;a=b;c=d"); + assertEquals("/abc/123_456",result.getUriVariables().get("var")); + assertEquals("b",result.getMatrixVariables().get("var").getFirst("a")); + assertEquals("d",result.getMatrixVariables().get("var").getFirst("c")); + result = matchAndExtract("/{*var}","/abc/123_456;a=b;c=d/789;a=e;f=g"); + assertEquals("/abc/123_456/789",result.getUriVariables().get("var")); + assertEquals("[b, e]",result.getMatrixVariables().get("var").get("a").toString()); + assertEquals("d",result.getMatrixVariables().get("var").getFirst("c")); + assertEquals("g",result.getMatrixVariables().get("var").getFirst("f")); + + result = matchAndExtract("/abc/{var}","/abc/one"); + assertEquals("one",result.getUriVariables().get("var")); + assertNull(result.getMatrixVariables().get("var")); + + result = matchAndExtract("",""); + assertNotNull(result); + result = matchAndExtract("","/"); + assertNotNull(result); + } + + private PathPattern.PathMatchInfo matchAndExtract(String pattern, String path) { + return parse(pattern).matchAndExtract(PathPatternTests.toPathContainer(path)); + } + + private PathPattern parse(String path) { + PathPatternParser pp = new PathPatternParser(); + pp.setMatchOptionalTrailingSeparator(true); + return pp.parse(path); + } + + public static PathContainer toPathContainer(String path) { + if (path == null) { + return null; + } + return PathContainer.parsePath(path); + } + + private void checkMatches(String uriTemplate, String path) { + PathPatternParser parser = new PathPatternParser(); + parser.setMatchOptionalTrailingSeparator(true); + PathPattern p = parser.parse(uriTemplate); + PathContainer pc = toPathContainer(path); + assertTrue(p.matches(pc)); + } + + private void checkNoMatch(String uriTemplate, String path) { + PathPatternParser p = new PathPatternParser(); + PathPattern pattern = p.parse(uriTemplate); + PathContainer PathContainer = toPathContainer(path); + assertFalse(pattern.matches(PathContainer)); + } + + private PathPattern.PathMatchInfo checkCapture(String uriTemplate, String path, String... keyValues) { + PathPatternParser parser = new PathPatternParser(); + PathPattern pattern = parser.parse(uriTemplate); + PathPattern.PathMatchInfo matchResult = pattern.matchAndExtract(toPathContainer(path)); + Map expectedKeyValues = new HashMap<>(); + for (int i = 0; i < keyValues.length; i += 2) { + expectedKeyValues.put(keyValues[i], keyValues[i + 1]); + } + for (Map.Entry me : expectedKeyValues.entrySet()) { + String value = matchResult.getUriVariables().get(me.getKey()); + if (value == null) { + fail("Did not find key '" + me.getKey() + "' in captured variables: " + + matchResult.getUriVariables()); + } + if (!value.equals(me.getValue())) { + fail("Expected value '" + me.getValue() + "' for key '" + me.getKey() + + "' but was '" + value + "'"); + } + } + return matchResult; + } + + private void checkExtractPathWithinPattern(String pattern, String path, String expected) { + PathPatternParser ppp = new PathPatternParser(); + PathPattern pp = ppp.parse(pattern); + String s = pp.extractPathWithinPattern(toPathContainer(path)).value(); + assertEquals(expected, s); + } + + private PathRemainingMatchInfo getPathRemaining(String pattern, String path) { + return parse(pattern).matchStartOfPath(toPathContainer(path)); + } + + private PathRemainingMatchInfo getPathRemaining(PathPattern pattern, String path) { + return pattern.matchStartOfPath(toPathContainer(path)); + } + + private PathPattern.PathMatchInfo matchAndExtract(PathPattern p, String path) { + return p.matchAndExtract(toPathContainer(path)); + } + + private String elementsToString(List elements) { + StringBuilder s = new StringBuilder(); + for (Element element: elements) { + s.append("[").append(element.value()).append("]"); + } + return s.toString(); + } + + + static class TestPathCombiner { + + PathPatternParser pp = new PathPatternParser(); + + public String combine(String string1, String string2) { + PathPattern pattern1 = pp.parse(string1); + PathPattern pattern2 = pp.parse(string2); + return pattern1.combine(pattern2).getPatternString(); + } + + } + +} diff --git a/spring-web/src/test/kotlin/org/springframework/web/client/RestOperationsExtensionsTests.kt b/spring-web/src/test/kotlin/org/springframework/web/client/RestOperationsExtensionsTests.kt new file mode 100644 index 0000000000000000000000000000000000000000..fe1aca80b7553fd62f1303692e4b323a0ba204dc --- /dev/null +++ b/spring-web/src/test/kotlin/org/springframework/web/client/RestOperationsExtensionsTests.kt @@ -0,0 +1,264 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client + +import com.nhaarman.mockito_kotlin.mock +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Answers +import org.mockito.Mock +import org.mockito.Mockito.* +import org.mockito.junit.MockitoJUnitRunner +import org.springframework.core.ParameterizedTypeReference +import org.springframework.http.HttpEntity +import org.springframework.http.HttpMethod +import org.springframework.http.RequestEntity +import org.springframework.util.ReflectionUtils +import java.net.URI +import kotlin.reflect.full.createType +import kotlin.reflect.jvm.kotlinFunction + +/** + * Mock object based tests for [RestOperations] Kotlin extensions. + * + * @author Sebastien Deleuze + */ +@RunWith(MockitoJUnitRunner::class) +class RestOperationsExtensionsTests { + + @Mock(answer = Answers.RETURNS_MOCKS) + lateinit var template: RestOperations + + @Test + fun `getForObject with reified type parameters, String and varargs`() { + val url = "https://spring.io" + val var1 = "var1" + val var2 = "var2" + template.getForObject(url, var1, var2) + verify(template, times(1)).getForObject(url, Foo::class.java, var1, var2) + } + + @Test + fun `getForObject with reified type parameters, String and Map`() { + val url = "https://spring.io" + val vars = mapOf(Pair("key1", "value1"), Pair("key2", "value2")) + template.getForObject(url, vars) + verify(template, times(1)).getForObject(url, Foo::class.java, vars) + } + + @Test + fun `getForObject with reified type parameters and URI`() { + val url = URI("https://spring.io") + template.getForObject(url) + verify(template, times(1)).getForObject(url, Foo::class.java) + } + + @Test + fun `getForEntity with reified type parameters, String and URI`() { + val url = URI("https://spring.io") + template.getForEntity(url) + verify(template, times(1)).getForEntity(url, Foo::class.java) + } + + @Test + fun `getForEntity with reified type parameters, String and varargs`() { + val url = "https://spring.io" + val var1 = "var1" + val var2 = "var2" + template.getForEntity(url, var1, var2) + verify(template, times(1)).getForEntity(url, Foo::class.java, var1, var2) + } + + @Test + fun `getForEntity with reified type parameters and Map`() { + val url = "https://spring.io" + val vars = mapOf(Pair("key1", "value1"), Pair("key2", "value2")) + template.getForEntity(url, vars) + verify(template, times(1)).getForEntity(url, Foo::class.java, vars) + } + + @Test + fun `patchForObject with reified type parameters, String and varargs`() { + val url = "https://spring.io" + val body: Any = "body" + val var1 = "var1" + val var2 = "var2" + template.patchForObject(url, body, var1, var2) + verify(template, times(1)).patchForObject(url, body, Foo::class.java, var1, var2) + } + + @Test + fun `patchForObject with reified type parameters, String and Map`() { + val url = "https://spring.io" + val body: Any = "body" + val vars = mapOf(Pair("key1", "value1"), Pair("key2", "value2")) + template.patchForObject(url, body, vars) + verify(template, times(1)).patchForObject(url, body, Foo::class.java, vars) + } + + @Test + fun `patchForObject with reified type parameters and String`() { + val url = "https://spring.io" + val body: Any = "body" + template.patchForObject(url, body) + verify(template, times(1)).patchForObject(url, body, Foo::class.java) + } + + @Test + fun `patchForObject with reified type parameters`() { + val url = "https://spring.io" + template.patchForObject(url) + verify(template, times(1)).patchForObject(url, null, Foo::class.java) + } + + @Test + fun `postForObject with reified type parameters, String and varargs`() { + val url = "https://spring.io" + val body: Any = "body" + val var1 = "var1" + val var2 = "var2" + template.postForObject(url, body, var1, var2) + verify(template, times(1)).postForObject(url, body, Foo::class.java, var1, var2) + } + + @Test + fun `postForObject with reified type parameters, String and Map`() { + val url = "https://spring.io" + val body: Any = "body" + val vars = mapOf(Pair("key1", "value1"), Pair("key2", "value2")) + template.postForObject(url, body, vars) + verify(template, times(1)).postForObject(url, body, Foo::class.java, vars) + } + + @Test + fun `postForObject with reified type parameters and String`() { + val url = "https://spring.io" + val body: Any = "body" + template.postForObject(url, body) + verify(template, times(1)).postForObject(url, body, Foo::class.java) + } + + @Test + fun `postForObject with reified type parameters`() { + val url = "https://spring.io" + template.postForObject(url) + verify(template, times(1)).postForObject(url, null, Foo::class.java) + } + + @Test + fun `postForEntity with reified type parameters, String and varargs`() { + val url = "https://spring.io" + val body: Any = "body" + val var1 = "var1" + val var2 = "var2" + template.postForEntity(url, body, var1, var2) + verify(template, times(1)).postForEntity(url, body, Foo::class.java, var1, var2) + } + + @Test + fun `postForEntity with reified type parameters, String and Map`() { + val url = "https://spring.io" + val body: Any = "body" + val vars = mapOf(Pair("key1", "value1"), Pair("key2", "value2")) + template.postForEntity(url, body, vars) + verify(template, times(1)).postForEntity(url, body, Foo::class.java, vars) + } + + @Test + fun `postForEntity with reified type parameters and String`() { + val url = "https://spring.io" + val body: Any = "body" + template.postForEntity(url, body) + verify(template, times(1)).postForEntity(url, body, Foo::class.java) + } + + @Test + fun `postForEntity with reified type parameters`() { + val url = "https://spring.io" + template.postForEntity(url) + verify(template, times(1)).postForEntity(url, null, Foo::class.java) + } + + @Test + fun `exchange with reified type parameters, String, HttpMethod, HttpEntity and varargs`() { + val url = "https://spring.io" + val method = HttpMethod.GET + val entity = mock>() + val var1 = "var1" + val var2 = "var2" + template.exchange>(url, method, entity, var1, var2) + verify(template, times(1)).exchange(url, method, entity, + object : ParameterizedTypeReference>() {}, var1, var2) + } + + @Test + fun `exchange with reified type parameters, String, HttpMethod, HttpEntity and Map`() { + val url = "https://spring.io" + val method = HttpMethod.GET + val entity = mock>() + val vars = mapOf(Pair("key1", "value1"), Pair("key2", "value2")) + template.exchange>(url, method, entity, vars) + verify(template, times(1)).exchange(url, method, entity, + object : ParameterizedTypeReference>() {}, vars) + } + + @Test + fun `exchange with reified type parameters, String, HttpMethod and HttpEntity`() { + val url = "https://spring.io" + val method = HttpMethod.GET + val entity = mock>() + template.exchange>(url, method, entity) + verify(template, times(1)).exchange(url, method, entity, + object : ParameterizedTypeReference>() {}) + } + + @Test + fun `exchange with reified type parameters, String and HttpMethod`() { + val url = "https://spring.io" + val method = HttpMethod.GET + template.exchange>(url, method) + verify(template, times(1)).exchange(url, method, null, + object : ParameterizedTypeReference>() {}) + } + + @Test + fun `exchange with reified type parameters, String and HttpEntity`() { + val entity = mock>() + template.exchange>(entity) + verify(template, times(1)).exchange(entity, + object : ParameterizedTypeReference>() {}) + } + + @Test + fun `RestOperations are available`() { + val extensions = Class.forName("org.springframework.web.client.RestOperationsExtensionsKt") + ReflectionUtils.doWithMethods(RestOperations::class.java) { method -> + arrayOf(ParameterizedTypeReference::class, Class::class).forEach { kClass -> + if (method.parameterTypes.contains(kClass.java)) { + val parameters = mutableListOf>(RestOperations::class.java).apply { addAll(method.parameterTypes.filter { it != kClass.java }) } + val f = extensions.getDeclaredMethod(method.name, *parameters.toTypedArray()).kotlinFunction!! + Assert.assertEquals(1, f.typeParameters.size) + Assert.assertEquals(listOf(Any::class.createType()), f.typeParameters[0].upperBounds) + } + } + } + } + + class Foo + +} diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt new file mode 100644 index 0000000000000000000000000000000000000000..328d43d1e65d0b0fde10886ba4d6fdcdafbe0896 --- /dev/null +++ b/spring-web/src/test/kotlin/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt @@ -0,0 +1,233 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.method.annotation + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNull +import org.junit.Before +import org.junit.Test +import org.springframework.core.MethodParameter +import org.springframework.core.annotation.SynthesizingMethodParameter +import org.springframework.core.convert.support.DefaultConversionService +import org.springframework.http.HttpMethod +import org.springframework.http.MediaType +import org.springframework.mock.web.test.MockHttpServletRequest +import org.springframework.mock.web.test.MockHttpServletResponse +import org.springframework.mock.web.test.MockMultipartFile +import org.springframework.mock.web.test.MockMultipartHttpServletRequest +import org.springframework.util.ReflectionUtils +import org.springframework.web.bind.MissingServletRequestParameterException +import org.springframework.web.bind.annotation.RequestParam +import org.springframework.web.bind.support.ConfigurableWebBindingInitializer +import org.springframework.web.bind.support.DefaultDataBinderFactory +import org.springframework.web.bind.support.WebDataBinderFactory +import org.springframework.web.context.request.NativeWebRequest +import org.springframework.web.context.request.ServletWebRequest +import org.springframework.web.multipart.MultipartFile +import org.springframework.web.multipart.support.MissingServletRequestPartException + +/** + * Kotlin test fixture for [RequestParamMethodArgumentResolver]. + * + * @author Sebastien Deleuze + */ +class RequestParamMethodArgumentResolverKotlinTests { + + lateinit var resolver: RequestParamMethodArgumentResolver + lateinit var webRequest: NativeWebRequest + lateinit var binderFactory: WebDataBinderFactory + lateinit var request: MockHttpServletRequest + + lateinit var nullableParamRequired: MethodParameter + lateinit var nullableParamNotRequired: MethodParameter + lateinit var nonNullableParamRequired: MethodParameter + lateinit var nonNullableParamNotRequired: MethodParameter + + lateinit var nullableMultipartParamRequired: MethodParameter + lateinit var nullableMultipartParamNotRequired: MethodParameter + lateinit var nonNullableMultipartParamRequired: MethodParameter + lateinit var nonNullableMultipartParamNotRequired: MethodParameter + + + @Before + fun setup() { + resolver = RequestParamMethodArgumentResolver(null, true) + request = MockHttpServletRequest() + val initializer = ConfigurableWebBindingInitializer() + initializer.conversionService = DefaultConversionService() + binderFactory = DefaultDataBinderFactory(initializer) + webRequest = ServletWebRequest(request, MockHttpServletResponse()) + + val method = ReflectionUtils.findMethod(javaClass, "handle", String::class.java, + String::class.java, String::class.java, String::class.java, + MultipartFile::class.java, MultipartFile::class.java, + MultipartFile::class.java, MultipartFile::class.java)!! + + nullableParamRequired = SynthesizingMethodParameter(method, 0) + nullableParamNotRequired = SynthesizingMethodParameter(method, 1) + nonNullableParamRequired = SynthesizingMethodParameter(method, 2) + nonNullableParamNotRequired = SynthesizingMethodParameter(method, 3) + + nullableMultipartParamRequired = SynthesizingMethodParameter(method, 4) + nullableMultipartParamNotRequired = SynthesizingMethodParameter(method, 5) + nonNullableMultipartParamRequired = SynthesizingMethodParameter(method, 6) + nonNullableMultipartParamNotRequired = SynthesizingMethodParameter(method, 7) + } + + @Test + fun resolveNullableRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nullableParamRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test + fun resolveNullableRequiredWithoutParameter() { + var result = resolver.resolveArgument(nullableParamRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNullableNotRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nullableParamNotRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test + fun resolveNullableNotRequiredWithoutParameter() { + var result = resolver.resolveArgument(nullableParamNotRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNonNullableRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nonNullableParamRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test(expected = MissingServletRequestParameterException::class) + fun resolveNonNullableRequiredWithoutParameter() { + resolver.resolveArgument(nonNullableParamRequired, null, webRequest, binderFactory) + } + + @Test + fun resolveNonNullableNotRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nonNullableParamNotRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test(expected = TypeCastException::class) + fun resolveNonNullableNotRequiredWithoutParameter() { + resolver.resolveArgument(nonNullableParamNotRequired, null, webRequest, binderFactory) as String + } + + + @Test + fun resolveNullableRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nullableMultipartParamRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test + fun resolveNullableRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + + var result = resolver.resolveArgument(nullableMultipartParamRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNullableNotRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nullableMultipartParamNotRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test + fun resolveNullableNotRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + + var result = resolver.resolveArgument(nullableMultipartParamNotRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNonNullableRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nonNullableMultipartParamRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test(expected = MissingServletRequestPartException::class) + fun resolveNonNullableRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + resolver.resolveArgument(nonNullableMultipartParamRequired, null, webRequest, binderFactory) + } + + @Test + fun resolveNonNullableNotRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nonNullableMultipartParamNotRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test(expected = TypeCastException::class) + fun resolveNonNullableNotRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + resolver.resolveArgument(nonNullableMultipartParamNotRequired, null, webRequest, binderFactory) as MultipartFile + } + + + @Suppress("unused_parameter") + fun handle( + @RequestParam("name") nullableParamRequired: String?, + @RequestParam("name", required = false) nullableParamNotRequired: String?, + @RequestParam("name") nonNullableParamRequired: String, + @RequestParam("name", required = false) nonNullableParamNotRequired: String, + + @RequestParam("mfile") nullableMultipartParamRequired: MultipartFile?, + @RequestParam("mfile", required = false) nullableMultipartParamNotRequired: MultipartFile?, + @RequestParam("mfile") nonNullableMultipartParamRequired: MultipartFile, + @RequestParam("mfile", required = false) nonNullableMultipartParamNotRequired: MultipartFile) { + } + +} + diff --git a/spring-web/src/test/proto/sample.proto b/spring-web/src/test/proto/sample.proto new file mode 100644 index 0000000000000000000000000000000000000000..812b4c2e4df299f33e0028a2946e57e09157332d --- /dev/null +++ b/spring-web/src/test/proto/sample.proto @@ -0,0 +1,12 @@ +option java_package = "org.springframework.protobuf"; +option java_outer_classname = "OuterSample"; +option java_multiple_files = true; + +message Msg { + optional string foo = 1; + optional SecondMsg blah = 2; +} + +message SecondMsg { + optional int32 blah = 1; +} diff --git a/spring-web/src/test/resources/log4j2-test.xml b/spring-web/src/test/resources/log4j2-test.xml new file mode 100644 index 0000000000000000000000000000000000000000..f37050e01ad6655cba09de6c59ce0702836065c1 --- /dev/null +++ b/spring-web/src/test/resources/log4j2-test.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/foo.txt b/spring-web/src/test/resources/org/springframework/http/codec/multipart/foo.txt new file mode 100644 index 0000000000000000000000000000000000000000..28256c4a9fecedce674b8197be9329931af05bf0 --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/http/codec/multipart/foo.txt @@ -0,0 +1 @@ +Lorem Ipsum. \ No newline at end of file diff --git a/spring-web/src/test/resources/org/springframework/http/codec/multipart/invalid.multipart b/spring-web/src/test/resources/org/springframework/http/codec/multipart/invalid.multipart new file mode 100644 index 0000000000000000000000000000000000000000..7eaa1f997efc9f0b22badc379c8c1bcf81c04bcf --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/http/codec/multipart/invalid.multipart @@ -0,0 +1,5 @@ +--NbjrKgjbsaMLdnMxMfDpD6myWomYc0qNX0w +Content-Disposition: form-data; name="part-00-name" + +post-payload-text-23456789ABCDEF:post-payload-0001-3456789ABCDEF:post-payload-0002-3456789ABCDEF:post-payload-0003-3456789ABCDEF +--NbjrKgjbsaMLdnMxMfDpD6myWomYc diff --git a/spring-web/src/test/resources/org/springframework/http/converter/byterangeresource.txt b/spring-web/src/test/resources/org/springframework/http/converter/byterangeresource.txt new file mode 100644 index 0000000000000000000000000000000000000000..84bbb9ddaf74cecd89c7874138994ec1170f4d2b --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/http/converter/byterangeresource.txt @@ -0,0 +1 @@ +Spring Framework test resource content. \ No newline at end of file diff --git a/spring-web/src/test/resources/org/springframework/http/converter/feed/atom.xml b/spring-web/src/test/resources/org/springframework/http/converter/feed/atom.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c315ca1d2a980c5c00e703ca35454836a9d6c8b --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/http/converter/feed/atom.xml @@ -0,0 +1,13 @@ + + + title + subtitle + + id1 + title1 + + + id2 + title2 + + diff --git a/spring-web/src/test/resources/org/springframework/http/converter/feed/rss.xml b/spring-web/src/test/resources/org/springframework/http/converter/feed/rss.xml new file mode 100644 index 0000000000000000000000000000000000000000..49c3053c4311816f364f45e69074e3317ac817fe --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/http/converter/feed/rss.xml @@ -0,0 +1,14 @@ + + + + title + https://example.com + description + + title1 + + + title2 + + + \ No newline at end of file diff --git a/spring-web/src/test/resources/org/springframework/http/converter/logo.jpg b/spring-web/src/test/resources/org/springframework/http/converter/logo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8a70e6af172588135f7f6d62c01519a638392b4a Binary files /dev/null and b/spring-web/src/test/resources/org/springframework/http/converter/logo.jpg differ diff --git a/spring-web/src/test/resources/org/springframework/http/converter/xml/external.txt b/spring-web/src/test/resources/org/springframework/http/converter/xml/external.txt new file mode 100644 index 0000000000000000000000000000000000000000..67d8e4dbe09dc52489bd81b45909b63bb146894d --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/http/converter/xml/external.txt @@ -0,0 +1 @@ +Foo Bar \ No newline at end of file diff --git a/spring-web/src/test/resources/org/springframework/http/server/reactive/spring.png b/spring-web/src/test/resources/org/springframework/http/server/reactive/spring.png new file mode 100644 index 0000000000000000000000000000000000000000..2fec781a5e31ff09dd56ed4e01c1ee1927c67b2b Binary files /dev/null and b/spring-web/src/test/resources/org/springframework/http/server/reactive/spring.png differ diff --git a/spring-web/src/test/resources/org/springframework/web/context/request/requestScopeTests.xml b/spring-web/src/test/resources/org/springframework/web/context/request/requestScopeTests.xml new file mode 100644 index 0000000000000000000000000000000000000000..1b9769e7d1ca83d6e7821bbedf25f2bbb717879e --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/web/context/request/requestScopeTests.xml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-web/src/test/resources/org/springframework/web/context/request/requestScopedProxyTests.xml b/spring-web/src/test/resources/org/springframework/web/context/request/requestScopedProxyTests.xml new file mode 100644 index 0000000000000000000000000000000000000000..d357806bddf7671d01b3449dbe87f494c917f453 --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/web/context/request/requestScopedProxyTests.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-web/src/test/resources/org/springframework/web/context/request/sessionScopeTests.xml b/spring-web/src/test/resources/org/springframework/web/context/request/sessionScopeTests.xml new file mode 100644 index 0000000000000000000000000000000000000000..a371316fd7aa95d4c198dbce57e48f15d770d9b7 --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/web/context/request/sessionScopeTests.xml @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/spring-web/src/test/resources/org/springframework/web/util/HtmlCharacterEntityReferences.dtd b/spring-web/src/test/resources/org/springframework/web/util/HtmlCharacterEntityReferences.dtd new file mode 100644 index 0000000000000000000000000000000000000000..20ea159c655d68a70fa44f31515f1e9196f423b5 --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/web/util/HtmlCharacterEntityReferences.dtd @@ -0,0 +1,521 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-web/src/test/resources/org/springframework/web/util/testlog4j.properties b/spring-web/src/test/resources/org/springframework/web/util/testlog4j.properties new file mode 100644 index 0000000000000000000000000000000000000000..1121ae21b3cd601b05bd170514366ab3b99e38f0 --- /dev/null +++ b/spring-web/src/test/resources/org/springframework/web/util/testlog4j.properties @@ -0,0 +1,5 @@ +log4j.rootCategory=DEBUG, mock + +log4j.appender.mock=org.springframework.web.util.MockLog4jAppender + +log4j.logger.org.springframework.mock.web=WARN diff --git a/spring-webflux/spring-webflux.gradle b/spring-webflux/spring-webflux.gradle new file mode 100644 index 0000000000000000000000000000000000000000..18f5bf2f33f4ba5ce46eb3257527a0a76c0e8b48 --- /dev/null +++ b/spring-webflux/spring-webflux.gradle @@ -0,0 +1,66 @@ +description = "Spring WebFlux" + +dependencyManagement { + imports { + mavenBom "io.projectreactor:reactor-bom:${reactorVersion}" + mavenBom "io.netty:netty-bom:${nettyVersion}" + mavenBom "org.eclipse.jetty:jetty-bom:${jettyVersion}" + } +} + +dependencies { + compile(project(":spring-beans")) + compile(project(":spring-core")) + compile(project(":spring-web")) + compile("io.projectreactor:reactor-core") + optional(project(":spring-context")) + optional(project(":spring-context-support")) // for FreeMarker support + optional("javax.servlet:javax.servlet-api:4.0.1") + optional("javax.websocket:javax.websocket-api:1.1") + optional("org.webjars:webjars-locator-core:0.37") + optional("org.freemarker:freemarker:${freemarkerVersion}") + optional("com.fasterxml.jackson.core:jackson-databind:${jackson2Version}") + optional("com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${jackson2Version}") + optional("io.reactivex:rxjava:${rxjavaVersion}") + optional("io.reactivex:rxjava-reactive-streams:${rxjavaAdapterVersion}") + optional("io.projectreactor.netty:reactor-netty") + optional("org.apache.tomcat:tomcat-websocket:${tomcatVersion}") { + exclude group: "org.apache.tomcat", module: "tomcat-websocket-api" + exclude group: "org.apache.tomcat", module: "tomcat-servlet-api" + } + optional("org.eclipse.jetty.websocket:websocket-server") { + exclude group: "javax.servlet", module: "javax.servlet" + } + optional("io.undertow:undertow-websockets-jsr:${undertowVersion}") { + exclude group: "org.jboss.spec.javax.websocket", module: "jboss-websocket-api_1.1_spec" + } + optional("org.apache.httpcomponents:httpclient:4.5.10") { + exclude group: "commons-logging", module: "commons-logging" + } + optional("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}") + optional("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}") + optional("com.google.protobuf:protobuf-java-util:3.6.1") + testCompile("javax.xml.bind:jaxb-api:2.3.1") + testCompile("com.fasterxml:aalto-xml:1.1.1") + testCompile("org.hibernate:hibernate-validator:6.0.21.Final") + testCompile "io.reactivex.rxjava2:rxjava:${rxjava2Version}" + testCompile("io.projectreactor:reactor-test") + testCompile("io.undertow:undertow-core:${undertowVersion}") + testCompile("org.apache.tomcat.embed:tomcat-embed-core:${tomcatVersion}") + testCompile("org.apache.tomcat:tomcat-util:${tomcatVersion}") + testCompile("org.eclipse.jetty:jetty-server") + testCompile("org.eclipse.jetty:jetty-servlet") + testCompile("org.eclipse.jetty:jetty-reactive-httpclient:1.0.3") + testCompile("com.squareup.okhttp3:mockwebserver:3.14.7") + testCompile("org.jetbrains.kotlin:kotlin-script-runtime:${kotlinVersion}") + testRuntime("org.jetbrains.kotlin:kotlin-script-util:${kotlinVersion}") + testRuntime("org.jetbrains.kotlin:kotlin-compiler:${kotlinVersion}") + testRuntime("org.jruby:jruby:9.2.7.0") + testRuntime("org.python:jython-standalone:2.7.1") + testRuntime("org.synchronoss.cloud:nio-multipart-parser:1.1.0") + testRuntime("org.webjars:underscorejs:1.8.3") + testRuntime("org.glassfish:javax.el:3.0.1-b08") + testRuntime("com.sun.xml.bind:jaxb-core:2.3.0.1") + testRuntime("com.sun.xml.bind:jaxb-impl:2.3.0.1") + testRuntime("com.sun.activation:javax.activation:1.2.0") +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/BindingContext.java b/spring-webflux/src/main/java/org/springframework/web/reactive/BindingContext.java new file mode 100644 index 0000000000000000000000000000000000000000..cb0849ae81b7d9f6d46abe6d122f0191126acb55 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/BindingContext.java @@ -0,0 +1,109 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive; + +import org.springframework.lang.Nullable; +import org.springframework.ui.Model; +import org.springframework.validation.support.BindingAwareConcurrentModel; +import org.springframework.web.bind.support.WebBindingInitializer; +import org.springframework.web.bind.support.WebExchangeDataBinder; +import org.springframework.web.server.ServerErrorException; +import org.springframework.web.server.ServerWebExchange; + +/** + * Context to assist with binding request data onto Objects and provide access + * to a shared {@link Model} with controller-specific attributes. + * + *

Provides methods to create a {@link WebExchangeDataBinder} for a specific + * target, command Object to apply data binding and validation to, or without a + * target Object for simple type conversion from request values. + * + *

Container for the default model for the request. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class BindingContext { + + @Nullable + private final WebBindingInitializer initializer; + + private final Model model = new BindingAwareConcurrentModel(); + + + /** + * Create a new {@code BindingContext}. + */ + public BindingContext() { + this(null); + } + + /** + * Create a new {@code BindingContext} with the given initializer. + * @param initializer the binding initializer to apply (may be {@code null}) + */ + public BindingContext(@Nullable WebBindingInitializer initializer) { + this.initializer = initializer; + } + + + /** + * Return the default model. + */ + public Model getModel() { + return this.model; + } + + + /** + * Create a {@link WebExchangeDataBinder} to apply data binding and + * validation with on the target, command object. + * @param exchange the current exchange + * @param target the object to create a data binder for + * @param name the name of the target object + * @return the created data binder + * @throws ServerErrorException if {@code @InitBinder} method invocation fails + */ + public WebExchangeDataBinder createDataBinder(ServerWebExchange exchange, @Nullable Object target, String name) { + WebExchangeDataBinder dataBinder = new WebExchangeDataBinder(target, name); + if (this.initializer != null) { + this.initializer.initBinder(dataBinder); + } + return initDataBinder(dataBinder, exchange); + } + + /** + * Initialize the data binder instance for the given exchange. + * @throws ServerErrorException if {@code @InitBinder} method invocation fails + */ + protected WebExchangeDataBinder initDataBinder(WebExchangeDataBinder binder, ServerWebExchange exchange) { + return binder; + } + + /** + * Create a {@link WebExchangeDataBinder} without a target object for type + * conversion of request values to simple types. + * @param exchange the current exchange + * @param name the name of the target object + * @return the created data binder + * @throws ServerErrorException if {@code @InitBinder} method invocation fails + */ + public WebExchangeDataBinder createDataBinder(ServerWebExchange exchange, String name) { + return createDataBinder(exchange, null, name); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/DispatcherHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/DispatcherHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..27f699b3cab245b7d677595ba5dc7cbb5c381ead --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/DispatcherHandler.java @@ -0,0 +1,190 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ResponseStatusException; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +/** + * Central dispatcher for HTTP request handlers/controllers. Dispatches to + * registered handlers for processing a request, providing convenient mapping + * facilities. + * + *

{@code DispatcherHandler} discovers the delegate components it needs from + * Spring configuration. It detects the following in the application context: + *

    + *
  • {@link HandlerMapping} -- map requests to handler objects + *
  • {@link HandlerAdapter} -- for using any handler interface + *
  • {@link HandlerResultHandler} -- process handler return values + *
+ * + *

{@code DispatcherHandler} is also designed to be a Spring bean itself and + * implements {@link ApplicationContextAware} for access to the context it runs + * in. If {@code DispatcherHandler} is declared with the bean name "webHandler" + * it is discovered by {@link WebHttpHandlerBuilder#applicationContext} which + * creates a processing chain together with {@code WebFilter}, + * {@code WebExceptionHandler} and others. + * + *

A {@code DispatcherHandler} bean declaration is included in + * {@link org.springframework.web.reactive.config.EnableWebFlux @EnableWebFlux} + * configuration. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @author Juergen Hoeller + * @since 5.0 + * @see WebHttpHandlerBuilder#applicationContext(ApplicationContext) + */ +public class DispatcherHandler implements WebHandler, ApplicationContextAware { + + @SuppressWarnings("ThrowableInstanceNeverThrown") + private static final Exception HANDLER_NOT_FOUND_EXCEPTION = + new ResponseStatusException(HttpStatus.NOT_FOUND, "No matching handler"); + + + @Nullable + private List handlerMappings; + + @Nullable + private List handlerAdapters; + + @Nullable + private List resultHandlers; + + + /** + * Create a new {@code DispatcherHandler} which needs to be configured with + * an {@link ApplicationContext} through {@link #setApplicationContext}. + */ + public DispatcherHandler() { + } + + /** + * Create a new {@code DispatcherHandler} for the given {@link ApplicationContext}. + * @param applicationContext the application context to find the handler beans in + */ + public DispatcherHandler(ApplicationContext applicationContext) { + initStrategies(applicationContext); + } + + + /** + * Return all {@link HandlerMapping} beans detected by type in the + * {@link #setApplicationContext injected context} and also + * {@link AnnotationAwareOrderComparator#sort(List) sorted}. + *

Note: This method may return {@code null} if invoked + * prior to {@link #setApplicationContext(ApplicationContext)}. + * @return immutable list with the configured mappings or {@code null} + */ + @Nullable + public final List getHandlerMappings() { + return this.handlerMappings; + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) { + initStrategies(applicationContext); + } + + + protected void initStrategies(ApplicationContext context) { + Map mappingBeans = BeanFactoryUtils.beansOfTypeIncludingAncestors( + context, HandlerMapping.class, true, false); + + ArrayList mappings = new ArrayList<>(mappingBeans.values()); + AnnotationAwareOrderComparator.sort(mappings); + this.handlerMappings = Collections.unmodifiableList(mappings); + + Map adapterBeans = BeanFactoryUtils.beansOfTypeIncludingAncestors( + context, HandlerAdapter.class, true, false); + + this.handlerAdapters = new ArrayList<>(adapterBeans.values()); + AnnotationAwareOrderComparator.sort(this.handlerAdapters); + + Map beans = BeanFactoryUtils.beansOfTypeIncludingAncestors( + context, HandlerResultHandler.class, true, false); + + this.resultHandlers = new ArrayList<>(beans.values()); + AnnotationAwareOrderComparator.sort(this.resultHandlers); + } + + + @Override + public Mono handle(ServerWebExchange exchange) { + if (this.handlerMappings == null) { + return createNotFoundError(); + } + return Flux.fromIterable(this.handlerMappings) + .concatMap(mapping -> mapping.getHandler(exchange)) + .next() + .switchIfEmpty(createNotFoundError()) + .flatMap(handler -> invokeHandler(exchange, handler)) + .flatMap(result -> handleResult(exchange, result)); + } + + private Mono createNotFoundError() { + return Mono.defer(() -> { + Exception ex = new ResponseStatusException(HttpStatus.NOT_FOUND, "No matching handler"); + return Mono.error(ex); + }); + } + + private Mono invokeHandler(ServerWebExchange exchange, Object handler) { + if (this.handlerAdapters != null) { + for (HandlerAdapter handlerAdapter : this.handlerAdapters) { + if (handlerAdapter.supports(handler)) { + return handlerAdapter.handle(exchange, handler); + } + } + } + return Mono.error(new IllegalStateException("No HandlerAdapter: " + handler)); + } + + private Mono handleResult(ServerWebExchange exchange, HandlerResult result) { + return getResultHandler(result).handleResult(exchange, result) + .onErrorResume(ex -> result.applyExceptionHandler(ex).flatMap(exceptionResult -> + getResultHandler(exceptionResult).handleResult(exchange, exceptionResult))); + } + + private HandlerResultHandler getResultHandler(HandlerResult handlerResult) { + if (this.resultHandlers != null) { + for (HandlerResultHandler resultHandler : this.resultHandlers) { + if (resultHandler.supports(handlerResult)) { + return resultHandler; + } + } + } + throw new IllegalStateException("No HandlerResultHandler for " + handlerResult.getReturnValue()); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerAdapter.java b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..71ec10478bcf3694b5ff1d2d42b41a0be459c695 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerAdapter.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive; + +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.ServerWebExchange; + +/** + * Contract that decouples the {@link DispatcherHandler} from the details of + * invoking a handler and makes it possible to support any handler type. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface HandlerAdapter { + + /** + * Whether this {@code HandlerAdapter} supports the given {@code handler}. + * @param handler handler object to check + * @return whether or not the handler is supported + */ + boolean supports(Object handler); + + /** + * Handle the request with the given handler. + *

Implementations are encouraged to handle exceptions resulting from the + * invocation of a handler in order and if necessary to return an alternate + * result that represents an error response. + *

Furthermore since an async {@code HandlerResult} may produce an error + * later during result handling implementations are also encouraged to + * {@link HandlerResult#setExceptionHandler(Function) set an exception + * handler} on the {@code HandlerResult} so that may also be applied later + * after result handling. + * @param exchange current server exchange + * @param handler the selected handler which must have been previously + * checked via {@link #supports(Object)} + * @return {@link Mono} that emits a single {@code HandlerResult} or none if + * the request has been fully handled and doesn't require further handling. + */ + Mono handle(ServerWebExchange exchange, Object handler); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..b56f7872f41872edf758fad46a91adb7b6b3befc --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerMapping.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.ServerWebExchange; + +/** + * Interface to be implemented by objects that define a mapping between + * requests and handler objects. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface HandlerMapping { + + /** + * Name of the {@link ServerWebExchange#getAttributes() attribute} that + * contains the mapped handler for the best matching pattern. + */ + String BEST_MATCHING_HANDLER_ATTRIBUTE = HandlerMapping.class.getName() + ".bestMatchingHandler"; + + /** + * Name of the {@link ServerWebExchange#getAttributes() attribute} that + * contains the best matching pattern within the handler mapping. + */ + String BEST_MATCHING_PATTERN_ATTRIBUTE = HandlerMapping.class.getName() + ".bestMatchingPattern"; + + /** + * Name of the {@link ServerWebExchange#getAttributes() attribute} that + * contains the path within the handler mapping, in case of a pattern match + * such as {@code "/static/**"} or the full relevant URI otherwise. + *

Note: This attribute is not required to be supported by all + * HandlerMapping implementations. URL-based HandlerMappings will + * typically support it but handlers should not necessarily expect + * this request attribute to be present in all scenarios. + */ + String PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE = HandlerMapping.class.getName() + ".pathWithinHandlerMapping"; + + /** + * Name of the {@link ServerWebExchange#getAttributes() attribute} that + * contains the URI templates map mapping variable names to values. + *

Note: This attribute is not required to be supported by all + * HandlerMapping implementations. URL-based HandlerMappings will + * typically support it, but handlers should not necessarily expect + * this request attribute to be present in all scenarios. + */ + String URI_TEMPLATE_VARIABLES_ATTRIBUTE = HandlerMapping.class.getName() + ".uriTemplateVariables"; + + /** + * Name of the {@link ServerWebExchange#getAttributes() attribute} that + * contains a map with URI variable names and a corresponding MultiValueMap + * of URI matrix variables for each. + *

Note: This attribute is not required to be supported by all + * HandlerMapping implementations and may also not be present depending on + * whether the HandlerMapping is configured to keep matrix variable content + * in the request URI. + */ + String MATRIX_VARIABLES_ATTRIBUTE = HandlerMapping.class.getName() + ".matrixVariables"; + + /** + * Name of the {@link ServerWebExchange#getAttributes() attribute} containing + * the set of producible MediaType's applicable to the mapped handler. + *

Note: This attribute is not required to be supported by all + * HandlerMapping implementations. Handlers should not necessarily expect + * this request attribute to be present in all scenarios. + */ + String PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE = HandlerMapping.class.getName() + ".producibleMediaTypes"; + + + /** + * Return a handler for this request. + * @param exchange current server exchange + * @return a {@link Mono} that emits one value or none in case the request + * cannot be resolved to a handler + */ + Mono getHandler(ServerWebExchange exchange); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerResult.java b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerResult.java new file mode 100644 index 0000000000000000000000000000000000000000..3426dd6851518816ad0267e701a5dcb6b0252dba --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerResult.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive; + +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.ui.Model; +import org.springframework.util.Assert; + +/** + * Represent the result of the invocation of a handler or a handler method. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class HandlerResult { + + private final Object handler; + + @Nullable + private final Object returnValue; + + private final ResolvableType returnType; + + private final BindingContext bindingContext; + + @Nullable + private Function> exceptionHandler; + + + /** + * Create a new {@code HandlerResult}. + * @param handler the handler that handled the request + * @param returnValue the return value from the handler possibly {@code null} + * @param returnType the return value type + */ + public HandlerResult(Object handler, @Nullable Object returnValue, MethodParameter returnType) { + this(handler, returnValue, returnType, null); + } + + /** + * Create a new {@code HandlerResult}. + * @param handler the handler that handled the request + * @param returnValue the return value from the handler possibly {@code null} + * @param returnType the return value type + * @param context the binding context used for request handling + */ + public HandlerResult(Object handler, @Nullable Object returnValue, MethodParameter returnType, + @Nullable BindingContext context) { + + Assert.notNull(handler, "'handler' is required"); + Assert.notNull(returnType, "'returnType' is required"); + this.handler = handler; + this.returnValue = returnValue; + this.returnType = ResolvableType.forMethodParameter(returnType); + this.bindingContext = (context != null ? context : new BindingContext()); + } + + + /** + * Return the handler that handled the request. + */ + public Object getHandler() { + return this.handler; + } + + /** + * Return the value returned from the handler, if any. + */ + @Nullable + public Object getReturnValue() { + return this.returnValue; + } + + /** + * Return the type of the value returned from the handler -- e.g. the return + * type declared on a controller method's signature. Also see + * {@link #getReturnTypeSource()} to obtain the underlying + * {@link MethodParameter} for the return type. + */ + public ResolvableType getReturnType() { + return this.returnType; + } + + /** + * Return the {@link MethodParameter} from which {@link #getReturnType() + * returnType} was created. + */ + public MethodParameter getReturnTypeSource() { + return (MethodParameter) this.returnType.getSource(); + } + + /** + * Return the BindingContext used for request handling. + */ + public BindingContext getBindingContext() { + return this.bindingContext; + } + + /** + * Return the model used for request handling. This is a shortcut for + * {@code getBindingContext().getModel()}. + */ + public Model getModel() { + return this.bindingContext.getModel(); + } + + /** + * Configure an exception handler that may be used to produce an alternative + * result when result handling fails. Especially for an async return value + * errors may occur after the invocation of the handler. + * @param function the error handler + * @return the current instance + */ + public HandlerResult setExceptionHandler(Function> function) { + this.exceptionHandler = function; + return this; + } + + /** + * Whether there is an exception handler. + */ + public boolean hasExceptionHandler() { + return (this.exceptionHandler != null); + } + + /** + * Apply the exception handler and return the alternative result. + * @param failure the exception + * @return the new result or the same error if there is no exception handler + */ + public Mono applyExceptionHandler(Throwable failure) { + return (this.exceptionHandler != null ? this.exceptionHandler.apply(failure) : Mono.error(failure)); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerResultHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..71e640c6c344c609722cddc42096ca39b794bbb0 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/HandlerResultHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive; + +import reactor.core.publisher.Mono; + +import org.springframework.web.server.ServerWebExchange; + +/** + * Process the {@link HandlerResult}, usually returned by an {@link HandlerAdapter}. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface HandlerResultHandler { + + /** + * Whether this handler supports the given {@link HandlerResult}. + * @param result result object to check + * @return whether or not this object can use the given result + */ + boolean supports(HandlerResult result); + + /** + * Process the given result modifying response headers and/or writing data + * to the response. + * @param exchange current server exchange + * @param result the result from the handling + * @return {@code Mono} to indicate when request handling is complete. + */ + Mono handleResult(ServerWebExchange exchange, HandlerResult result); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/accept/FixedContentTypeResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/FixedContentTypeResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..3648d76ef2415e1d6d80de0723f1fc41d5ae2b73 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/FixedContentTypeResolver.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.accept; + +import java.util.Collections; +import java.util.List; + +import org.springframework.http.MediaType; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * Resolver that always resolves to a fixed list of media types. This can be + * used as the "last in line" strategy providing a fallback for when the client + * has not requested any media types. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class FixedContentTypeResolver implements RequestedContentTypeResolver { + + private final List contentTypes; + + + /** + * Constructor with a single default {@code MediaType}. + */ + public FixedContentTypeResolver(MediaType mediaType) { + this(Collections.singletonList(mediaType)); + } + + /** + * Constructor with an ordered List of default {@code MediaType}'s to return + * for use in applications that support a variety of content types. + *

Consider appending {@link MediaType#ALL} at the end if destinations + * are present which do not support any of the other default media types. + */ + public FixedContentTypeResolver(List contentTypes) { + Assert.notNull(contentTypes, "'contentTypes' must not be null"); + this.contentTypes = Collections.unmodifiableList(contentTypes); + } + + + /** + * Return the configured list of media types. + */ + public List getContentTypes() { + return this.contentTypes; + } + + + @Override + public List resolveMediaTypes(ServerWebExchange exchange) { + return this.contentTypes; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/accept/HeaderContentTypeResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/HeaderContentTypeResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..88bc3ab250f0aa3350f218b11c58957807ba4dda --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/HeaderContentTypeResolver.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.accept; + +import java.util.List; + +import org.springframework.http.InvalidMediaTypeException; +import org.springframework.http.MediaType; +import org.springframework.util.CollectionUtils; +import org.springframework.web.server.NotAcceptableStatusException; +import org.springframework.web.server.ServerWebExchange; + +/** + * Resolver that looks at the 'Accept' header of the request. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class HeaderContentTypeResolver implements RequestedContentTypeResolver { + + @Override + public List resolveMediaTypes(ServerWebExchange exchange) throws NotAcceptableStatusException { + try { + List mediaTypes = exchange.getRequest().getHeaders().getAccept(); + MediaType.sortBySpecificityAndQuality(mediaTypes); + return (!CollectionUtils.isEmpty(mediaTypes) ? mediaTypes : MEDIA_TYPE_ALL_LIST); + } + catch (InvalidMediaTypeException ex) { + String value = exchange.getRequest().getHeaders().getFirst("Accept"); + throw new NotAcceptableStatusException( + "Could not parse 'Accept' header [" + value + "]: " + ex.getMessage()); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/accept/ParameterContentTypeResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/ParameterContentTypeResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..12bd6b4082380f091c636fbaf2e776eaf36e8529 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/ParameterContentTypeResolver.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.accept; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.NotAcceptableStatusException; +import org.springframework.web.server.ServerWebExchange; + +/** + * Resolver that checks a query parameter and uses it to lookup a matching + * MediaType. Lookup keys can be registered or as a fallback + * {@link MediaTypeFactory} can be used to perform a lookup. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ParameterContentTypeResolver implements RequestedContentTypeResolver { + + /** Primary lookup for media types by key (e.g. "json" -> "application/json") */ + private final Map mediaTypes = new ConcurrentHashMap<>(64); + + private String parameterName = "format"; + + + public ParameterContentTypeResolver(Map mediaTypes) { + mediaTypes.forEach((key, value) -> this.mediaTypes.put(formatKey(key), value)); + } + + private static String formatKey(String key) { + return key.toLowerCase(Locale.ENGLISH); + } + + + /** + * Set the name of the parameter to use to determine requested media types. + *

By default this is set to {@literal "format"}. + */ + public void setParameterName(String parameterName) { + Assert.notNull(parameterName, "'parameterName' is required"); + this.parameterName = parameterName; + } + + public String getParameterName() { + return this.parameterName; + } + + + @Override + public List resolveMediaTypes(ServerWebExchange exchange) throws NotAcceptableStatusException { + String key = exchange.getRequest().getQueryParams().getFirst(getParameterName()); + if (!StringUtils.hasText(key)) { + return MEDIA_TYPE_ALL_LIST; + } + key = formatKey(key); + MediaType match = this.mediaTypes.get(key); + if (match == null) { + match = MediaTypeFactory.getMediaType("filename." + key) + .orElseThrow(() -> { + List supported = new ArrayList<>(this.mediaTypes.values()); + return new NotAcceptableStatusException(supported); + }); + } + this.mediaTypes.putIfAbsent(key, match); + return Collections.singletonList(match); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/accept/RequestedContentTypeResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/RequestedContentTypeResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..717237900fc87050c3cc4076979d77e711e153ae --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/RequestedContentTypeResolver.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.accept; + +import java.util.Collections; +import java.util.List; + +import org.springframework.http.MediaType; +import org.springframework.web.server.NotAcceptableStatusException; +import org.springframework.web.server.ServerWebExchange; + +/** + * Strategy to resolve the requested media types for a {@code ServerWebExchange}. + * + *

See {@link RequestedContentTypeResolverBuilder} to create a sequence of + * strategies. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface RequestedContentTypeResolver { + + /** + * A singleton list with {@link MediaType#ALL} that is returned from + * {@link #resolveMediaTypes} when no specific media types are requested. + * @since 5.0.5 + */ + List MEDIA_TYPE_ALL_LIST = Collections.singletonList(MediaType.ALL); + + + /** + * Resolve the given request to a list of requested media types. The returned + * list is ordered by specificity first and by quality parameter second. + * @param exchange the current exchange + * @return the requested media types, or {@link #MEDIA_TYPE_ALL_LIST} if none + * were requested. + * @throws NotAcceptableStatusException if the requested media type is invalid + */ + List resolveMediaTypes(ServerWebExchange exchange); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/accept/RequestedContentTypeResolverBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/RequestedContentTypeResolverBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..e82ce8a6bcf257f69cfc01a6555823dad978ff44 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/RequestedContentTypeResolverBuilder.java @@ -0,0 +1,158 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.accept; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * Builder for a composite {@link RequestedContentTypeResolver} that delegates + * to other resolvers each implementing a different strategy to determine the + * requested content type -- e.g. Accept header, query parameter, or other. + * + *

Use builder methods to add resolvers in the desired order. For a given + * request he first resolver to return a list that is not empty and does not + * consist of just {@link MediaType#ALL}, will be used. + * + *

By default, if no resolvers are explicitly configured, the builder will + * add {@link HeaderContentTypeResolver}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class RequestedContentTypeResolverBuilder { + + private final List> candidates = new ArrayList<>(); + + + /** + * Add a resolver to get the requested content type from a query parameter. + * By default the query parameter name is {@code "format"}. + */ + public ParameterResolverConfigurer parameterResolver() { + ParameterResolverConfigurer parameterBuilder = new ParameterResolverConfigurer(); + this.candidates.add(parameterBuilder::createResolver); + return parameterBuilder; + } + + /** + * Add resolver to get the requested content type from the + * {@literal "Accept"} header. + */ + public void headerResolver() { + this.candidates.add(HeaderContentTypeResolver::new); + } + + /** + * Add resolver that returns a fixed set of media types. + * @param mediaTypes the media types to use + */ + public void fixedResolver(MediaType... mediaTypes) { + this.candidates.add(() -> new FixedContentTypeResolver(Arrays.asList(mediaTypes))); + } + + /** + * Add a custom resolver. + * @param resolver the resolver to add + */ + public void resolver(RequestedContentTypeResolver resolver) { + this.candidates.add(() -> resolver); + } + + /** + * Build a {@link RequestedContentTypeResolver} that delegates to the list + * of resolvers configured through this builder. + */ + public RequestedContentTypeResolver build() { + List resolvers = (!this.candidates.isEmpty() ? + this.candidates.stream().map(Supplier::get).collect(Collectors.toList()) : + Collections.singletonList(new HeaderContentTypeResolver())); + + return exchange -> { + for (RequestedContentTypeResolver resolver : resolvers) { + List mediaTypes = resolver.resolveMediaTypes(exchange); + if (mediaTypes.equals(RequestedContentTypeResolver.MEDIA_TYPE_ALL_LIST)) { + continue; + } + return mediaTypes; + } + return RequestedContentTypeResolver.MEDIA_TYPE_ALL_LIST; + }; + } + + + /** + * Helper to create and configure {@link ParameterContentTypeResolver}. + */ + public static class ParameterResolverConfigurer { + + private final Map mediaTypes = new HashMap<>(); + + @Nullable + private String parameterName; + + /** + * Configure a mapping between a lookup key (extracted from a query + * parameter value) and a corresponding {@code MediaType}. + * @param key the lookup key + * @param mediaType the MediaType for that key + */ + public ParameterResolverConfigurer mediaType(String key, MediaType mediaType) { + this.mediaTypes.put(key, mediaType); + return this; + } + + /** + * Map-based variant of {@link #mediaType(String, MediaType)}. + * @param mediaTypes the mappings to copy + */ + public ParameterResolverConfigurer mediaType(Map mediaTypes) { + this.mediaTypes.putAll(mediaTypes); + return this; + } + + /** + * Set the name of the parameter to use to determine requested media types. + *

By default this is set to {@literal "format"}. + */ + public ParameterResolverConfigurer parameterName(String parameterName) { + this.parameterName = parameterName; + return this; + } + + /** + * Private factory method to create the resolver. + */ + private RequestedContentTypeResolver createResolver() { + ParameterContentTypeResolver resolver = new ParameterContentTypeResolver(this.mediaTypes); + if (this.parameterName != null) { + resolver.setParameterName(this.parameterName); + } + return resolver; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/accept/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..a9693904f158e12f487a37c6b46f37851b0e6b48 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/accept/package-info.java @@ -0,0 +1,11 @@ +/** + * {@link org.springframework.web.reactive.accept.RequestedContentTypeResolver} + * strategy and implementations to resolve the requested content type for a + * given request. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.accept; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java new file mode 100644 index 0000000000000000000000000000000000000000..06301f82f255685589700c360cebcac38ef099f5 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.Arrays; + +import org.springframework.web.cors.CorsConfiguration; + +/** + * Assists with the creation of a {@link CorsConfiguration} instance for a given + * URL path pattern. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see CorsConfiguration + * @see CorsRegistry + */ +public class CorsRegistration { + + private final String pathPattern; + + private final CorsConfiguration config; + + + public CorsRegistration(String pathPattern) { + this.pathPattern = pathPattern; + // Same implicit default values as the @CrossOrigin annotation + allows simple methods + this.config = new CorsConfiguration().applyPermitDefaultValues(); + } + + + /** + * The list of allowed origins that be specific origins, e.g. + * {@code "https://domain1.com"}, or {@code "*"} for all origins. + *

A matched origin is listed in the {@code Access-Control-Allow-Origin} + * response header of preflight actual CORS requests. + *

By default all origins are allowed. + *

Note: CORS checks use values from "Forwarded" + * (RFC 7239), + * "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, + * if present, in order to reflect the client-originated address. + * Consider using the {@code ForwardedHeaderFilter} in order to choose from a + * central place whether to extract and use, or to discard such headers. + * See the Spring Framework reference for more on this filter. + */ + public CorsRegistration allowedOrigins(String... origins) { + this.config.setAllowedOrigins(Arrays.asList(origins)); + return this; + } + + /** + * Set the HTTP methods to allow, e.g. {@code "GET"}, {@code "POST"}, etc. + *

The special value {@code "*"} allows all methods. + *

By default "simple" methods {@code GET}, {@code HEAD}, and {@code POST} + * are allowed. + */ + public CorsRegistration allowedMethods(String... methods) { + this.config.setAllowedMethods(Arrays.asList(methods)); + return this; + } + + /** + * Set the list of headers that a pre-flight request can list as allowed + * for use during an actual request. + *

The special value {@code "*"} may be used to allow all headers. + *

A header name is not required to be listed if it is one of: + * {@code Cache-Control}, {@code Content-Language}, {@code Expires}, + * {@code Last-Modified}, or {@code Pragma} as per the CORS spec. + *

By default all headers are allowed. + */ + public CorsRegistration allowedHeaders(String... headers) { + this.config.setAllowedHeaders(Arrays.asList(headers)); + return this; + } + + /** + * Set the list of response headers other than "simple" headers, i.e. + * {@code Cache-Control}, {@code Content-Language}, {@code Content-Type}, + * {@code Expires}, {@code Last-Modified}, or {@code Pragma}, that an + * actual response might have and can be exposed. + *

The special value {@code "*"} allows all headers to be exposed for + * non-credentialed requests. + *

By default this is not set. + */ + public CorsRegistration exposedHeaders(String... headers) { + this.config.setExposedHeaders(Arrays.asList(headers)); + return this; + } + + /** + * Whether the browser should send credentials, such as cookies along with + * cross domain requests, to the annotated endpoint. The configured value is + * set on the {@code Access-Control-Allow-Credentials} response header of + * preflight requests. + *

NOTE: Be aware that this option establishes a high + * level of trust with the configured domains and also increases the surface + * attack of the web application by exposing sensitive user-specific + * information such as cookies and CSRF tokens. + *

By default this is not set in which case the + * {@code Access-Control-Allow-Credentials} header is also not set and + * credentials are therefore not allowed. + */ + public CorsRegistration allowCredentials(boolean allowCredentials) { + this.config.setAllowCredentials(allowCredentials); + return this; + } + + /** + * Configure how long in seconds the response from a pre-flight request + * can be cached by clients. + *

By default this is set to 1800 seconds (30 minutes). + */ + public CorsRegistration maxAge(long maxAge) { + this.config.setMaxAge(maxAge); + return this; + } + + protected String getPathPattern() { + return this.pathPattern; + } + + protected CorsConfiguration getCorsConfiguration() { + return this.config; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java new file mode 100644 index 0000000000000000000000000000000000000000..263e66c15ee902b105ecdad985e6458ddddd3769 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.web.cors.CorsConfiguration; + +/** + * Assists with the registration of global, URL pattern based + * {@link CorsConfiguration} mappings. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class CorsRegistry { + + private final List registrations = new ArrayList<>(); + + + /** + * Enable cross-origin request handling for the specified path pattern. + *

Exact path mapping URIs (such as {@code "/admin"}) are supported as + * well as Ant-style path patterns (such as {@code "/admin/**"}). + *

By default, the {@code CorsConfiguration} for this mapping is + * initialized with default values as described in + * {@link CorsConfiguration#applyPermitDefaultValues()}. + */ + public CorsRegistration addMapping(String pathPattern) { + CorsRegistration registration = new CorsRegistration(pathPattern); + this.registrations.add(registration); + return registration; + } + + /** + * Return the registered {@link CorsConfiguration} objects, + * keyed by path pattern. + */ + protected Map getCorsConfigurations() { + Map configs = new LinkedHashMap<>(this.registrations.size()); + for (CorsRegistration registration : this.registrations) { + configs.put(registration.getPathPattern(), registration.getCorsConfiguration()); + } + return configs; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/DelegatingWebFluxConfiguration.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/DelegatingWebFluxConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..f754040fa71cfe16c21ea6881dc75b9141131845 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/DelegatingWebFluxConfiguration.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.List; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.format.FormatterRegistry; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.util.CollectionUtils; +import org.springframework.validation.MessageCodesResolver; +import org.springframework.validation.Validator; +import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; +import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; + +/** + * A subclass of {@code WebFluxConfigurationSupport} that detects and delegates + * to all beans of type {@link WebFluxConfigurer} allowing them to customize the + * configuration provided by {@code WebFluxConfigurationSupport}. This is the + * class actually imported by {@link EnableWebFlux @EnableWebFlux}. + * + * @author Brian Clozel + * @since 5.0 + */ +@Configuration +public class DelegatingWebFluxConfiguration extends WebFluxConfigurationSupport { + + private final WebFluxConfigurerComposite configurers = new WebFluxConfigurerComposite(); + + + @Autowired(required = false) + public void setConfigurers(List configurers) { + if (!CollectionUtils.isEmpty(configurers)) { + this.configurers.addWebFluxConfigurers(configurers); + } + } + + + @Override + protected void configureContentTypeResolver(RequestedContentTypeResolverBuilder builder) { + this.configurers.configureContentTypeResolver(builder); + } + + @Override + protected void addCorsMappings(CorsRegistry registry) { + this.configurers.addCorsMappings(registry); + } + + @Override + public void configurePathMatching(PathMatchConfigurer configurer) { + this.configurers.configurePathMatching(configurer); + } + + @Override + protected void addResourceHandlers(ResourceHandlerRegistry registry) { + this.configurers.addResourceHandlers(registry); + } + + @Override + protected void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { + this.configurers.configureArgumentResolvers(configurer); + } + + @Override + protected void configureHttpMessageCodecs(ServerCodecConfigurer configurer) { + this.configurers.configureHttpMessageCodecs(configurer); + } + + @Override + protected void addFormatters(FormatterRegistry registry) { + this.configurers.addFormatters(registry); + } + + @Override + protected Validator getValidator() { + Validator validator = this.configurers.getValidator(); + return (validator != null ? validator : super.getValidator()); + } + + @Override + protected MessageCodesResolver getMessageCodesResolver() { + MessageCodesResolver messageCodesResolver = this.configurers.getMessageCodesResolver(); + return (messageCodesResolver != null ? messageCodesResolver : super.getMessageCodesResolver()); + } + + @Override + protected void configureViewResolvers(ViewResolverRegistry registry) { + this.configurers.configureViewResolvers(registry); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/EnableWebFlux.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/EnableWebFlux.java new file mode 100644 index 0000000000000000000000000000000000000000..2dfd10b7edfd99f7d60b0626e9936afef9f5d307 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/EnableWebFlux.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.context.annotation.Import; + +/** + * Adding this annotation to an {@code @Configuration} class imports the Spring + * WebFlux configuration from {@link WebFluxConfigurationSupport} that enables + * use of annotated controllers and functional endpoints. + * + *

For example: + * + *

+ * @Configuration
+ * @EnableWebFlux
+ * @ComponentScan(basePackageClasses = MyConfiguration.class)
+ * public class MyConfiguration {
+ * }
+ * 
+ * + *

To customize the imported configuration, implement + * {@link WebFluxConfigurer} and one or more of its methods: + * + *

+ * @Configuration
+ * @EnableWebFlux
+ * @ComponentScan(basePackageClasses = MyConfiguration.class)
+ * public class MyConfiguration implements WebFluxConfigurer {
+ *
+ * 	   @Override
+ * 	   public void configureMessageWriters(List<HttpMessageWriter<?>> messageWriters) {
+ *         messageWriters.add(new MyHttpMessageWriter());
+ * 	   }
+ *
+ * 	   // ...
+ * }
+ * 
+ * + *

Only one {@code @Configuration} class should have the {@code @EnableWebFlux} + * annotation in order to import the Spring WebFlux configuration. There can + * however be multiple {@code @Configuration} classes that implement + * {@code WebFluxConfigurer} that customize the provided configuration. + * + *

If {@code WebFluxConfigurer} does not expose some setting that needs to be + * configured, consider switching to an advanced mode by removing the + * {@code @EnableWebFlux} annotation and extending directly from + * {@link WebFluxConfigurationSupport} or {@link DelegatingWebFluxConfiguration} -- + * the latter allows detecting and delegating to one or more + * {@code WebFluxConfigurer} configuration classes. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 5.0 + * @see WebFluxConfigurer + * @see WebFluxConfigurationSupport + * @see DelegatingWebFluxConfiguration + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +@Documented +@Import(DelegatingWebFluxConfiguration.class) +public @interface EnableWebFlux { +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/PathMatchConfigurer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/PathMatchConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..7f9988916e4d3ea28dc6da43a957d95f1516912f --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/PathMatchConfigurer.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Predicate; + +import org.springframework.lang.Nullable; + +/** + * Assist with configuring {@code HandlerMapping}'s with path matching options. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public class PathMatchConfigurer { + + @Nullable + private Boolean trailingSlashMatch; + + + @Nullable + private Boolean caseSensitiveMatch; + + @Nullable + private Map>> pathPrefixes; + + + /** + * Whether to match to URLs irrespective of their case. + * If enabled a method mapped to "/users" won't match to "/Users/". + *

The default value is {@code false}. + */ + public PathMatchConfigurer setUseCaseSensitiveMatch(Boolean caseSensitiveMatch) { + this.caseSensitiveMatch = caseSensitiveMatch; + return this; + } + + /** + * Whether to match to URLs irrespective of the presence of a trailing slash. + * If enabled a method mapped to "/users" also matches to "/users/". + *

The default value is {@code true}. + */ + public PathMatchConfigurer setUseTrailingSlashMatch(Boolean trailingSlashMatch) { + this.trailingSlashMatch = trailingSlashMatch; + return this; + } + + /** + * Configure a path prefix to apply to matching controller methods. + *

Prefixes are used to enrich the mappings of every {@code @RequestMapping} + * method whose controller type is matched by the corresponding + * {@code Predicate}. The prefix for the first matching predicate is used. + *

Consider using {@link org.springframework.web.method.HandlerTypePredicate + * HandlerTypePredicate} to group controllers. + * @param prefix the path prefix to apply + * @param predicate a predicate for matching controller types + * @since 5.1 + */ + public PathMatchConfigurer addPathPrefix(String prefix, Predicate> predicate) { + if (this.pathPrefixes == null) { + this.pathPrefixes = new LinkedHashMap<>(); + } + this.pathPrefixes.put(prefix, predicate); + return this; + } + + + @Nullable + protected Boolean isUseTrailingSlashMatch() { + return this.trailingSlashMatch; + } + + @Nullable + protected Boolean isUseCaseSensitiveMatch() { + return this.caseSensitiveMatch; + } + + @Nullable + protected Map>> getPathPrefixes() { + return this.pathPrefixes; + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceChainRegistration.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceChainRegistration.java new file mode 100644 index 0000000000000000000000000000000000000000..fe31f1963618eb467b8fa7c0628322443dcd8f7e --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceChainRegistration.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.cache.Cache; +import org.springframework.cache.concurrent.ConcurrentMapCache; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.web.reactive.resource.CachingResourceResolver; +import org.springframework.web.reactive.resource.CachingResourceTransformer; +import org.springframework.web.reactive.resource.CssLinkResourceTransformer; +import org.springframework.web.reactive.resource.PathResourceResolver; +import org.springframework.web.reactive.resource.ResourceResolver; +import org.springframework.web.reactive.resource.ResourceTransformer; +import org.springframework.web.reactive.resource.VersionResourceResolver; +import org.springframework.web.reactive.resource.WebJarsResourceResolver; + +/** + * Assists with the registration of resource resolvers and transformers. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ResourceChainRegistration { + + private static final String DEFAULT_CACHE_NAME = "spring-resource-chain-cache"; + + private static final boolean isWebJarsAssetLocatorPresent = ClassUtils.isPresent( + "org.webjars.WebJarAssetLocator", ResourceChainRegistration.class.getClassLoader()); + + + private final List resolvers = new ArrayList<>(4); + + private final List transformers = new ArrayList<>(4); + + private boolean hasVersionResolver; + + private boolean hasPathResolver; + + private boolean hasCssLinkTransformer; + + private boolean hasWebjarsResolver; + + + public ResourceChainRegistration(boolean cacheResources) { + this(cacheResources, cacheResources ? new ConcurrentMapCache(DEFAULT_CACHE_NAME) : null); + } + + public ResourceChainRegistration(boolean cacheResources, @Nullable Cache cache) { + Assert.isTrue(!cacheResources || cache != null, "'cache' is required when cacheResources=true"); + if (cacheResources) { + this.resolvers.add(new CachingResourceResolver(cache)); + this.transformers.add(new CachingResourceTransformer(cache)); + } + } + + + /** + * Add a resource resolver to the chain. + * @param resolver the resolver to add + * @return the current instance for chained method invocation + */ + public ResourceChainRegistration addResolver(ResourceResolver resolver) { + Assert.notNull(resolver, "The provided ResourceResolver should not be null"); + this.resolvers.add(resolver); + if (resolver instanceof VersionResourceResolver) { + this.hasVersionResolver = true; + } + else if (resolver instanceof PathResourceResolver) { + this.hasPathResolver = true; + } + else if (resolver instanceof WebJarsResourceResolver) { + this.hasWebjarsResolver = true; + } + return this; + } + + /** + * Add a resource transformer to the chain. + * @param transformer the transformer to add + * @return the current instance for chained method invocation + */ + public ResourceChainRegistration addTransformer(ResourceTransformer transformer) { + Assert.notNull(transformer, "The provided ResourceTransformer should not be null"); + this.transformers.add(transformer); + if (transformer instanceof CssLinkResourceTransformer) { + this.hasCssLinkTransformer = true; + } + return this; + } + + protected List getResourceResolvers() { + if (!this.hasPathResolver) { + List result = new ArrayList<>(this.resolvers); + if (isWebJarsAssetLocatorPresent && !this.hasWebjarsResolver) { + result.add(new WebJarsResourceResolver()); + } + result.add(new PathResourceResolver()); + return result; + } + return this.resolvers; + } + + protected List getResourceTransformers() { + if (this.hasVersionResolver && !this.hasCssLinkTransformer) { + List result = new ArrayList<>(this.transformers); + boolean hasTransformers = !this.transformers.isEmpty(); + boolean hasCaching = hasTransformers && this.transformers.get(0) instanceof CachingResourceTransformer; + result.add(hasCaching ? 1 : 0, new CssLinkResourceTransformer()); + return result; + } + return this.transformers; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceHandlerRegistration.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceHandlerRegistration.java new file mode 100644 index 0000000000000000000000000000000000000000..a4982ac857991efd1aeac7b1c2f827786c96bfb5 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceHandlerRegistration.java @@ -0,0 +1,159 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.springframework.cache.Cache; +import org.springframework.core.io.Resource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.http.CacheControl; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.reactive.resource.ResourceWebHandler; + +/** + * Assist with creating and configuring a static resources handler. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ResourceHandlerRegistration { + + private final ResourceLoader resourceLoader; + + private final String[] pathPatterns; + + private final List locationValues = new ArrayList<>(); + + @Nullable + private CacheControl cacheControl; + + @Nullable + private ResourceChainRegistration resourceChainRegistration; + + + /** + * Create a {@link ResourceHandlerRegistration} instance. + * @param resourceLoader a resource loader for turning a String location + * into a {@link Resource} + * @param pathPatterns one or more resource URL path patterns + */ + public ResourceHandlerRegistration(ResourceLoader resourceLoader, String... pathPatterns) { + Assert.notNull(resourceLoader, "ResourceLoader is required"); + Assert.notEmpty(pathPatterns, "At least one path pattern is required for resource handling"); + this.resourceLoader = resourceLoader; + this.pathPatterns = pathPatterns; + } + + + /** + * Add one or more resource locations from which to serve static content. + * Each location must point to a valid directory. Multiple locations may + * be specified as a comma-separated list, and the locations will be checked + * for a given resource in the order specified. + * + *

For example, {{@code "/"}, + * {@code "classpath:/META-INF/public-web-resources/"}} allows resources to + * be served both from the web application root and from any JAR on the + * classpath that contains a {@code /META-INF/public-web-resources/} directory, + * with resources in the web application root taking precedence. + * @return the same {@link ResourceHandlerRegistration} instance, for + * chained method invocation + */ + public ResourceHandlerRegistration addResourceLocations(String... resourceLocations) { + this.locationValues.addAll(Arrays.asList(resourceLocations)); + return this; + } + + /** + * Specify the {@link CacheControl} which should be used + * by the resource handler. + * @param cacheControl the CacheControl configuration to use + * @return the same {@link ResourceHandlerRegistration} instance, for + * chained method invocation + */ + public ResourceHandlerRegistration setCacheControl(CacheControl cacheControl) { + this.cacheControl = cacheControl; + return this; + } + + /** + * Configure a chain of resource resolvers and transformers to use. This + * can be useful, for example, to apply a version strategy to resource URLs. + *

If this method is not invoked, by default only a simple + * {@code PathResourceResolver} is used in order to match URL paths to + * resources under the configured locations. + * @param cacheResources whether to cache the result of resource resolution; + * setting this to "true" is recommended for production (and "false" for + * development, especially when applying a version strategy) + * @return the same {@link ResourceHandlerRegistration} instance, for + * chained method invocation + */ + public ResourceChainRegistration resourceChain(boolean cacheResources) { + this.resourceChainRegistration = new ResourceChainRegistration(cacheResources); + return this.resourceChainRegistration; + } + + /** + * Configure a chain of resource resolvers and transformers to use. This + * can be useful, for example, to apply a version strategy to resource URLs. + *

If this method is not invoked, by default only a simple + * {@code PathResourceResolver} is used in order to match URL paths to + * resources under the configured locations. + * @param cacheResources whether to cache the result of resource resolution; + * setting this to "true" is recommended for production (and "false" for + * development, especially when applying a version strategy + * @param cache the cache to use for storing resolved and transformed resources; + * by default a {@link org.springframework.cache.concurrent.ConcurrentMapCache} + * is used. Since Resources aren't serializable and can be dependent on the + * application host, one should not use a distributed cache but rather an + * in-memory cache. + * @return the same {@link ResourceHandlerRegistration} instance, for chained method invocation + */ + public ResourceChainRegistration resourceChain(boolean cacheResources, Cache cache) { + this.resourceChainRegistration = new ResourceChainRegistration(cacheResources, cache); + return this.resourceChainRegistration; + } + + /** + * Returns the URL path patterns for the resource handler. + */ + protected String[] getPathPatterns() { + return this.pathPatterns; + } + + /** + * Returns a {@link ResourceWebHandler} instance. + */ + protected ResourceWebHandler getRequestHandler() { + ResourceWebHandler handler = new ResourceWebHandler(); + handler.setLocationValues(this.locationValues); + handler.setResourceLoader(this.resourceLoader); + if (this.resourceChainRegistration != null) { + handler.setResourceResolvers(this.resourceChainRegistration.getResourceResolvers()); + handler.setResourceTransformers(this.resourceChainRegistration.getResourceTransformers()); + } + if (this.cacheControl != null) { + handler.setCacheControl(this.cacheControl); + } + return handler; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceHandlerRegistry.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceHandlerRegistry.java new file mode 100644 index 0000000000000000000000000000000000000000..cd9dc05a2393f50a9a8dbee9701545d18800e082 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ResourceHandlerRegistry.java @@ -0,0 +1,161 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.beans.factory.BeanInitializationException; +import org.springframework.core.Ordered; +import org.springframework.core.io.ResourceLoader; +import org.springframework.lang.Nullable; +import org.springframework.web.reactive.handler.AbstractUrlHandlerMapping; +import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; +import org.springframework.web.reactive.resource.ResourceTransformerSupport; +import org.springframework.web.reactive.resource.ResourceUrlProvider; +import org.springframework.web.reactive.resource.ResourceWebHandler; +import org.springframework.web.server.WebHandler; + +/** + * Stores registrations of resource handlers for serving static resources such + * as images, css files and others through Spring WebFlux including setting cache + * headers optimized for efficient loading in a web browser. Resources can be + * served out of locations under web application root, from the classpath, and + * others. + * + *

To create a resource handler, use {@link #addResourceHandler(String...)} + * providing the URL path patterns for which the handler should be invoked to + * serve static resources (e.g. {@code "/resources/**"}). + * + *

Then use additional methods on the returned + * {@link ResourceHandlerRegistration} to add one or more locations from which + * to serve static content from (e.g. {{@code "/"}, + * {@code "classpath:/META-INF/public-web-resources/"}}) or to specify a cache + * period for served resources. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public class ResourceHandlerRegistry { + + private final ResourceLoader resourceLoader; + + private final List registrations = new ArrayList<>(); + + private int order = Ordered.LOWEST_PRECEDENCE - 1; + + @Nullable + private ResourceUrlProvider resourceUrlProvider; + + + /** + * Create a new resource handler registry for the given resource loader + * (typically an application context). + * @param resourceLoader the resource loader to use + */ + public ResourceHandlerRegistry(ResourceLoader resourceLoader) { + this.resourceLoader = resourceLoader; + } + + /** + * Configure the {@link ResourceUrlProvider} that can be used by + * {@link org.springframework.web.reactive.resource.ResourceTransformer} instances. + * @param resourceUrlProvider the resource URL provider to use + * @since 5.0.11 + */ + public void setResourceUrlProvider(@Nullable ResourceUrlProvider resourceUrlProvider) { + this.resourceUrlProvider = resourceUrlProvider; + } + + + + /** + * Add a resource handler for serving static resources based on the specified + * URL path patterns. The handler will be invoked for every incoming request + * that matches to one of the specified path patterns. + *

Patterns like {@code "/static/**"} or {@code "/css/{filename:\\w+\\.css}"} + * are allowed. See {@link org.springframework.web.util.pattern.PathPattern} + * for more details on the syntax. + * @return a {@link ResourceHandlerRegistration} to use to further configure + * the registered resource handler + */ + public ResourceHandlerRegistration addResourceHandler(String... patterns) { + ResourceHandlerRegistration registration = new ResourceHandlerRegistration(this.resourceLoader, patterns); + this.registrations.add(registration); + return registration; + } + + /** + * Whether a resource handler has already been registered for the given path pattern. + */ + public boolean hasMappingForPattern(String pathPattern) { + for (ResourceHandlerRegistration registration : this.registrations) { + if (Arrays.asList(registration.getPathPatterns()).contains(pathPattern)) { + return true; + } + } + return false; + } + + /** + * Specify the order to use for resource handling relative to other + * {@code HandlerMapping}s configured in the Spring configuration. + *

The default value used is {@code Integer.MAX_VALUE-1}. + */ + public ResourceHandlerRegistry setOrder(int order) { + this.order = order; + return this; + } + + /** + * Return a handler mapping with the mapped resource handlers; or {@code null} in case + * of no registrations. + */ + @Nullable + protected AbstractUrlHandlerMapping getHandlerMapping() { + if (this.registrations.isEmpty()) { + return null; + } + Map urlMap = new LinkedHashMap<>(); + for (ResourceHandlerRegistration registration : this.registrations) { + for (String pathPattern : registration.getPathPatterns()) { + ResourceWebHandler handler = registration.getRequestHandler(); + handler.getResourceTransformers().forEach(transformer -> { + if (transformer instanceof ResourceTransformerSupport) { + ((ResourceTransformerSupport) transformer).setResourceUrlProvider(this.resourceUrlProvider); + } + }); + try { + handler.afterPropertiesSet(); + } + catch (Throwable ex) { + throw new BeanInitializationException("Failed to init ResourceHttpRequestHandler", ex); + } + urlMap.put(pathPattern, handler); + } + } + SimpleUrlHandlerMapping handlerMapping = new SimpleUrlHandlerMapping(); + handlerMapping.setOrder(this.order); + handlerMapping.setUrlMap(urlMap); + return handlerMapping; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/UrlBasedViewResolverRegistration.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/UrlBasedViewResolverRegistration.java new file mode 100644 index 0000000000000000000000000000000000000000..9b1ec5fb17a41791b28f4226e52bf4e017326822 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/UrlBasedViewResolverRegistration.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import org.springframework.util.Assert; +import org.springframework.web.reactive.result.view.UrlBasedViewResolver; + +/** + * Assist with configuring properties of a {@link UrlBasedViewResolver}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class UrlBasedViewResolverRegistration { + + private final UrlBasedViewResolver viewResolver; + + + public UrlBasedViewResolverRegistration(UrlBasedViewResolver viewResolver) { + Assert.notNull(viewResolver, "ViewResolver must not be null"); + this.viewResolver = viewResolver; + } + + + /** + * Set the prefix that gets prepended to view names when building a URL. + * @see UrlBasedViewResolver#setPrefix + */ + public UrlBasedViewResolverRegistration prefix(String prefix) { + this.viewResolver.setPrefix(prefix); + return this; + } + + /** + * Set the suffix that gets appended to view names when building a URL. + * @see UrlBasedViewResolver#setSuffix + */ + public UrlBasedViewResolverRegistration suffix(String suffix) { + this.viewResolver.setSuffix(suffix); + return this; + } + + /** + * Set the view class that should be used to create views. + * @see UrlBasedViewResolver#setViewClass + */ + public UrlBasedViewResolverRegistration viewClass(Class viewClass) { + this.viewResolver.setViewClass(viewClass); + return this; + } + + /** + * Set the view names (or name patterns) that can be handled by this view + * resolver. View names can contain simple wildcards such that 'my*', '*Report' + * and '*Repo*' will all match the view name 'myReport'. + * @see UrlBasedViewResolver#setViewNames + */ + public UrlBasedViewResolverRegistration viewNames(String... viewNames) { + this.viewResolver.setViewNames(viewNames); + return this; + } + + protected UrlBasedViewResolver getViewResolver() { + return this.viewResolver; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/ViewResolverRegistry.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ViewResolverRegistry.java new file mode 100644 index 0000000000000000000000000000000000000000..af695751d1b881d58aa17c6eb1e3ffbe91057005 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/ViewResolverRegistry.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.BeanInitializationException; +import org.springframework.context.ApplicationContext; +import org.springframework.core.Ordered; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; +import org.springframework.web.reactive.result.view.HttpMessageWriterView; +import org.springframework.web.reactive.result.view.UrlBasedViewResolver; +import org.springframework.web.reactive.result.view.View; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.reactive.result.view.freemarker.FreeMarkerConfigurer; +import org.springframework.web.reactive.result.view.freemarker.FreeMarkerViewResolver; +import org.springframework.web.reactive.result.view.script.ScriptTemplateConfigurer; +import org.springframework.web.reactive.result.view.script.ScriptTemplateViewResolver; + +/** + * Assist with the configuration of a chain of {@link ViewResolver}'s supporting + * different template mechanisms. + * + *

In addition, you can also configure {@link #defaultViews(View...) + * defaultViews} for rendering according to the requested content type, e.g. + * JSON, XML, etc. + * + * @author Rossen Stoyanchev + * @author Sebastien Deleuze + * @since 5.0 + */ +public class ViewResolverRegistry { + + @Nullable + private final ApplicationContext applicationContext; + + private final List viewResolvers = new ArrayList<>(4); + + private final List defaultViews = new ArrayList<>(4); + + @Nullable + private Integer order; + + + public ViewResolverRegistry(@Nullable ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + + /** + * Register a {@code FreeMarkerViewResolver} with a ".ftl" suffix. + *

Note that you must also configure FreeMarker by + * adding a {@link FreeMarkerConfigurer} bean. + */ + public UrlBasedViewResolverRegistration freeMarker() { + if (!checkBeanOfType(FreeMarkerConfigurer.class)) { + throw new BeanInitializationException("In addition to a FreeMarker view resolver " + + "there must also be a single FreeMarkerConfig bean in this web application context " + + "(or its parent): FreeMarkerConfigurer is the usual implementation. " + + "This bean may be given any name."); + } + FreeMarkerRegistration registration = new FreeMarkerRegistration(); + UrlBasedViewResolver resolver = registration.getViewResolver(); + if (this.applicationContext != null) { + resolver.setApplicationContext(this.applicationContext); + } + this.viewResolvers.add(resolver); + return registration; + } + + /** + * Register a script template view resolver with an empty default view name prefix and suffix. + *

Note that you must also configure script templating by + * adding a {@link ScriptTemplateConfigurer} bean. + * @since 5.0.4 + */ + public UrlBasedViewResolverRegistration scriptTemplate() { + if (!checkBeanOfType(ScriptTemplateConfigurer.class)) { + throw new BeanInitializationException("In addition to a script template view resolver " + + "there must also be a single ScriptTemplateConfig bean in this web application context " + + "(or its parent): ScriptTemplateConfigurer is the usual implementation. " + + "This bean may be given any name."); + } + ScriptRegistration registration = new ScriptRegistration(); + UrlBasedViewResolver resolver = registration.getViewResolver(); + if (this.applicationContext != null) { + resolver.setApplicationContext(this.applicationContext); + } + this.viewResolvers.add(resolver); + return registration; + } + + /** + * Register a {@link ViewResolver} bean instance. This may be useful to + * configure a 3rd party resolver implementation or as an alternative to + * other registration methods in this class when they don't expose some + * more advanced property that needs to be set. + */ + public void viewResolver(ViewResolver viewResolver) { + this.viewResolvers.add(viewResolver); + } + + /** + * Set default views associated with any view name and selected based on the + * best match for the requested content type. + *

Use {@link HttpMessageWriterView + * HttpMessageWriterView} to adapt and use any existing + * {@code HttpMessageWriter} (e.g. JSON, XML) as a {@code View}. + */ + public void defaultViews(View... defaultViews) { + this.defaultViews.addAll(Arrays.asList(defaultViews)); + } + + /** + * Whether any view resolvers have been registered. + */ + public boolean hasRegistrations() { + return (!this.viewResolvers.isEmpty()); + } + + /** + * Set the order for the + * {@link org.springframework.web.reactive.result.view.ViewResolutionResultHandler + * ViewResolutionResultHandler}. + *

By default this property is not set, which means the result handler is + * ordered at {@link Ordered#LOWEST_PRECEDENCE}. + */ + public void order(int order) { + this.order = order; + } + + + private boolean checkBeanOfType(Class beanType) { + return (this.applicationContext == null || + !ObjectUtils.isEmpty(BeanFactoryUtils.beanNamesForTypeIncludingAncestors( + this.applicationContext, beanType, false, false))); + } + + protected int getOrder() { + return (this.order != null ? this.order : Ordered.LOWEST_PRECEDENCE); + } + + protected List getViewResolvers() { + return this.viewResolvers; + } + + protected List getDefaultViews() { + return this.defaultViews; + } + + + private static class FreeMarkerRegistration extends UrlBasedViewResolverRegistration { + + public FreeMarkerRegistration() { + super(new FreeMarkerViewResolver()); + getViewResolver().setSuffix(".ftl"); + } + } + + private static class ScriptRegistration extends UrlBasedViewResolverRegistration { + + public ScriptRegistration() { + super(new ScriptTemplateViewResolver()); + getViewResolver(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurationSupport.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurationSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..07e2e62f0ef437b4de145b6d1b14cf5866510f8c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurationSupport.java @@ -0,0 +1,495 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import reactor.core.publisher.Mono; + +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.BeanInitializationException; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.annotation.Bean; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.annotation.Order; +import org.springframework.core.convert.converter.Converter; +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.core.io.ResourceLoader; +import org.springframework.format.Formatter; +import org.springframework.format.FormatterRegistry; +import org.springframework.format.support.DefaultFormattingConversionService; +import org.springframework.format.support.FormattingConversionService; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.validation.Errors; +import org.springframework.validation.MessageCodesResolver; +import org.springframework.validation.Validator; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.reactive.DispatcherHandler; +import org.springframework.web.reactive.HandlerMapping; +import org.springframework.web.reactive.accept.RequestedContentTypeResolver; +import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; +import org.springframework.web.reactive.function.server.support.HandlerFunctionAdapter; +import org.springframework.web.reactive.function.server.support.RouterFunctionMapping; +import org.springframework.web.reactive.function.server.support.ServerResponseResultHandler; +import org.springframework.web.reactive.handler.AbstractHandlerMapping; +import org.springframework.web.reactive.handler.WebFluxResponseStatusExceptionHandler; +import org.springframework.web.reactive.resource.ResourceUrlProvider; +import org.springframework.web.reactive.result.SimpleHandlerAdapter; +import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; +import org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter; +import org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerMapping; +import org.springframework.web.reactive.result.method.annotation.ResponseBodyResultHandler; +import org.springframework.web.reactive.result.method.annotation.ResponseEntityResultHandler; +import org.springframework.web.reactive.result.view.ViewResolutionResultHandler; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.i18n.LocaleContextResolver; + +/** + * The main class for Spring WebFlux configuration. + * + *

Import directly or extend and override protected methods to customize. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public class WebFluxConfigurationSupport implements ApplicationContextAware { + + @Nullable + private Map corsConfigurations; + + @Nullable + private PathMatchConfigurer pathMatchConfigurer; + + @Nullable + private ViewResolverRegistry viewResolverRegistry; + + @Nullable + private ApplicationContext applicationContext; + + + @Override + public void setApplicationContext(@Nullable ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + if (applicationContext != null) { + Assert.state(!applicationContext.containsBean("mvcContentNegotiationManager"), + "The Java/XML config for Spring MVC and Spring WebFlux cannot both be enabled, " + + "e.g. via @EnableWebMvc and @EnableWebFlux, in the same application."); + } + } + + @Nullable + public final ApplicationContext getApplicationContext() { + return this.applicationContext; + } + + + @Bean + public DispatcherHandler webHandler() { + return new DispatcherHandler(); + } + + @Bean + @Order(0) + public WebExceptionHandler responseStatusExceptionHandler() { + return new WebFluxResponseStatusExceptionHandler(); + } + + @Bean + public RequestMappingHandlerMapping requestMappingHandlerMapping() { + RequestMappingHandlerMapping mapping = createRequestMappingHandlerMapping(); + mapping.setOrder(0); + mapping.setContentTypeResolver(webFluxContentTypeResolver()); + mapping.setCorsConfigurations(getCorsConfigurations()); + + PathMatchConfigurer configurer = getPathMatchConfigurer(); + Boolean useTrailingSlashMatch = configurer.isUseTrailingSlashMatch(); + if (useTrailingSlashMatch != null) { + mapping.setUseTrailingSlashMatch(useTrailingSlashMatch); + } + Boolean useCaseSensitiveMatch = configurer.isUseCaseSensitiveMatch(); + if (useCaseSensitiveMatch != null) { + mapping.setUseCaseSensitiveMatch(useCaseSensitiveMatch); + } + Map>> pathPrefixes = configurer.getPathPrefixes(); + if (pathPrefixes != null) { + mapping.setPathPrefixes(pathPrefixes); + } + + return mapping; + } + + /** + * Override to plug a sub-class of {@link RequestMappingHandlerMapping}. + */ + protected RequestMappingHandlerMapping createRequestMappingHandlerMapping() { + return new RequestMappingHandlerMapping(); + } + + @Bean + public RequestedContentTypeResolver webFluxContentTypeResolver() { + RequestedContentTypeResolverBuilder builder = new RequestedContentTypeResolverBuilder(); + configureContentTypeResolver(builder); + return builder.build(); + } + + /** + * Override to configure how the requested content type is resolved. + */ + protected void configureContentTypeResolver(RequestedContentTypeResolverBuilder builder) { + } + + /** + * Callback for building the global CORS configuration. This method is final. + * Use {@link #addCorsMappings(CorsRegistry)} to customize the CORS conifg. + */ + protected final Map getCorsConfigurations() { + if (this.corsConfigurations == null) { + CorsRegistry registry = new CorsRegistry(); + addCorsMappings(registry); + this.corsConfigurations = registry.getCorsConfigurations(); + } + return this.corsConfigurations; + } + + /** + * Override this method to configure cross origin requests processing. + * @see CorsRegistry + */ + protected void addCorsMappings(CorsRegistry registry) { + } + + /** + * Callback for building the {@link PathMatchConfigurer}. This method is + * final, use {@link #configurePathMatching} to customize path matching. + */ + protected final PathMatchConfigurer getPathMatchConfigurer() { + if (this.pathMatchConfigurer == null) { + this.pathMatchConfigurer = new PathMatchConfigurer(); + configurePathMatching(this.pathMatchConfigurer); + } + return this.pathMatchConfigurer; + } + + /** + * Override to configure path matching options. + */ + public void configurePathMatching(PathMatchConfigurer configurer) { + } + + @Bean + public RouterFunctionMapping routerFunctionMapping() { + RouterFunctionMapping mapping = createRouterFunctionMapping(); + mapping.setOrder(-1); // go before RequestMappingHandlerMapping + mapping.setMessageReaders(serverCodecConfigurer().getReaders()); + mapping.setCorsConfigurations(getCorsConfigurations()); + + return mapping; + } + + /** + * Override to plug a sub-class of {@link RouterFunctionMapping}. + */ + protected RouterFunctionMapping createRouterFunctionMapping() { + return new RouterFunctionMapping(); + } + + /** + * Return a handler mapping ordered at Integer.MAX_VALUE-1 with mapped + * resource handlers. To configure resource handling, override + * {@link #addResourceHandlers}. + */ + @Bean + public HandlerMapping resourceHandlerMapping() { + ResourceLoader resourceLoader = this.applicationContext; + if (resourceLoader == null) { + resourceLoader = new DefaultResourceLoader(); + } + ResourceHandlerRegistry registry = new ResourceHandlerRegistry(resourceLoader); + registry.setResourceUrlProvider(resourceUrlProvider()); + addResourceHandlers(registry); + + AbstractHandlerMapping handlerMapping = registry.getHandlerMapping(); + if (handlerMapping != null) { + PathMatchConfigurer configurer = getPathMatchConfigurer(); + Boolean useTrailingSlashMatch = configurer.isUseTrailingSlashMatch(); + Boolean useCaseSensitiveMatch = configurer.isUseCaseSensitiveMatch(); + if (useTrailingSlashMatch != null) { + handlerMapping.setUseTrailingSlashMatch(useTrailingSlashMatch); + } + if (useCaseSensitiveMatch != null) { + handlerMapping.setUseCaseSensitiveMatch(useCaseSensitiveMatch); + } + } + else { + handlerMapping = new EmptyHandlerMapping(); + } + return handlerMapping; + } + + @Bean + public ResourceUrlProvider resourceUrlProvider() { + return new ResourceUrlProvider(); + } + + /** + * Override this method to add resource handlers for serving static resources. + * @see ResourceHandlerRegistry + */ + protected void addResourceHandlers(ResourceHandlerRegistry registry) { + } + + @Bean + public RequestMappingHandlerAdapter requestMappingHandlerAdapter() { + RequestMappingHandlerAdapter adapter = createRequestMappingHandlerAdapter(); + adapter.setMessageReaders(serverCodecConfigurer().getReaders()); + adapter.setWebBindingInitializer(getConfigurableWebBindingInitializer()); + adapter.setReactiveAdapterRegistry(webFluxAdapterRegistry()); + + ArgumentResolverConfigurer configurer = new ArgumentResolverConfigurer(); + configureArgumentResolvers(configurer); + adapter.setArgumentResolverConfigurer(configurer); + + return adapter; + } + + /** + * Override to plug a sub-class of {@link RequestMappingHandlerAdapter}. + */ + protected RequestMappingHandlerAdapter createRequestMappingHandlerAdapter() { + return new RequestMappingHandlerAdapter(); + } + + /** + * Configure resolvers for custom controller method arguments. + */ + protected void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { + } + + /** + * Return the configurer for HTTP message readers and writers. + *

Use {@link #configureHttpMessageCodecs(ServerCodecConfigurer)} to + * configure the readers and writers. + */ + @Bean + public ServerCodecConfigurer serverCodecConfigurer() { + ServerCodecConfigurer serverCodecConfigurer = ServerCodecConfigurer.create(); + configureHttpMessageCodecs(serverCodecConfigurer); + return serverCodecConfigurer; + } + + /** + * Override to plug a sub-class of {@link LocaleContextResolver}. + */ + protected LocaleContextResolver createLocaleContextResolver() { + return new AcceptHeaderLocaleContextResolver(); + } + + @Bean + public LocaleContextResolver localeContextResolver() { + return createLocaleContextResolver(); + } + + /** + * Override to configure the HTTP message readers and writers to use. + */ + protected void configureHttpMessageCodecs(ServerCodecConfigurer configurer) { + } + + /** + * Return the {@link ConfigurableWebBindingInitializer} to use for + * initializing all {@link WebDataBinder} instances. + */ + protected ConfigurableWebBindingInitializer getConfigurableWebBindingInitializer() { + ConfigurableWebBindingInitializer initializer = new ConfigurableWebBindingInitializer(); + initializer.setConversionService(webFluxConversionService()); + initializer.setValidator(webFluxValidator()); + MessageCodesResolver messageCodesResolver = getMessageCodesResolver(); + if (messageCodesResolver != null) { + initializer.setMessageCodesResolver(messageCodesResolver); + } + return initializer; + } + + /** + * Return a {@link FormattingConversionService} for use with annotated controllers. + *

See {@link #addFormatters} as an alternative to overriding this method. + */ + @Bean + public FormattingConversionService webFluxConversionService() { + FormattingConversionService service = new DefaultFormattingConversionService(); + addFormatters(service); + return service; + } + + /** + * Override this method to add custom {@link Converter} and/or {@link Formatter} + * delegates to the common {@link FormattingConversionService}. + * @see #webFluxConversionService() + */ + protected void addFormatters(FormatterRegistry registry) { + } + + /** + * Return a {@link ReactiveAdapterRegistry} to adapting reactive types. + */ + @Bean + public ReactiveAdapterRegistry webFluxAdapterRegistry() { + return new ReactiveAdapterRegistry(); + } + + /** + * Return a global {@link Validator} instance for example for validating + * {@code @RequestBody} method arguments. + *

Delegates to {@link #getValidator()} first. If that returns {@code null} + * checks the classpath for the presence of a JSR-303 implementations + * before creating a {@code OptionalValidatorFactoryBean}. If a JSR-303 + * implementation is not available, a "no-op" {@link Validator} is returned. + */ + @Bean + public Validator webFluxValidator() { + Validator validator = getValidator(); + if (validator == null) { + if (ClassUtils.isPresent("javax.validation.Validator", getClass().getClassLoader())) { + Class clazz; + try { + String name = "org.springframework.validation.beanvalidation.OptionalValidatorFactoryBean"; + clazz = ClassUtils.forName(name, getClass().getClassLoader()); + } + catch (ClassNotFoundException | LinkageError ex) { + throw new BeanInitializationException("Failed to resolve default validator class", ex); + } + validator = (Validator) BeanUtils.instantiateClass(clazz); + } + else { + validator = new NoOpValidator(); + } + } + return validator; + } + + /** + * Override this method to provide a custom {@link Validator}. + */ + @Nullable + protected Validator getValidator() { + return null; + } + + /** + * Override this method to provide a custom {@link MessageCodesResolver}. + */ + @Nullable + protected MessageCodesResolver getMessageCodesResolver() { + return null; + } + + @Bean + public HandlerFunctionAdapter handlerFunctionAdapter() { + return new HandlerFunctionAdapter(); + } + + @Bean + public SimpleHandlerAdapter simpleHandlerAdapter() { + return new SimpleHandlerAdapter(); + } + + @Bean + public ResponseEntityResultHandler responseEntityResultHandler() { + return new ResponseEntityResultHandler(serverCodecConfigurer().getWriters(), + webFluxContentTypeResolver(), webFluxAdapterRegistry()); + } + + @Bean + public ResponseBodyResultHandler responseBodyResultHandler() { + return new ResponseBodyResultHandler(serverCodecConfigurer().getWriters(), + webFluxContentTypeResolver(), webFluxAdapterRegistry()); + } + + @Bean + public ViewResolutionResultHandler viewResolutionResultHandler() { + ViewResolverRegistry registry = getViewResolverRegistry(); + List resolvers = registry.getViewResolvers(); + ViewResolutionResultHandler handler = new ViewResolutionResultHandler( + resolvers, webFluxContentTypeResolver(), webFluxAdapterRegistry()); + handler.setDefaultViews(registry.getDefaultViews()); + handler.setOrder(registry.getOrder()); + return handler; + } + + @Bean + public ServerResponseResultHandler serverResponseResultHandler() { + List resolvers = getViewResolverRegistry().getViewResolvers(); + ServerResponseResultHandler handler = new ServerResponseResultHandler(); + handler.setMessageWriters(serverCodecConfigurer().getWriters()); + handler.setViewResolvers(resolvers); + return handler; + } + + /** + * Callback for building the {@link ViewResolverRegistry}. This method is final, + * use {@link #configureViewResolvers} to customize view resolvers. + */ + protected final ViewResolverRegistry getViewResolverRegistry() { + if (this.viewResolverRegistry == null) { + this.viewResolverRegistry = new ViewResolverRegistry(this.applicationContext); + configureViewResolvers(this.viewResolverRegistry); + } + return this.viewResolverRegistry; + } + + /** + * Configure view resolution for supporting template engines. + * @see ViewResolverRegistry + */ + protected void configureViewResolvers(ViewResolverRegistry registry) { + } + + + private static final class EmptyHandlerMapping extends AbstractHandlerMapping { + + @Override + public Mono getHandlerInternal(ServerWebExchange exchange) { + return Mono.empty(); + } + } + + + private static final class NoOpValidator implements Validator { + + @Override + public boolean supports(Class clazz) { + return false; + } + + @Override + public void validate(@Nullable Object target, Errors errors) { + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurer.java new file mode 100644 index 0000000000000000000000000000000000000000..e33b07b801dacfef3ce192899f494239f7ca940c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurer.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.format.Formatter; +import org.springframework.format.FormatterRegistry; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.lang.Nullable; +import org.springframework.validation.MessageCodesResolver; +import org.springframework.validation.Validator; +import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; +import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; + +/** + * Defines callback methods to customize the configuration for WebFlux + * applications enabled via {@link EnableWebFlux @EnableWebFlux}. + * + *

{@code @EnableWebFlux}-annotated configuration classes may implement + * this interface to be called back and given a chance to customize the + * default configuration. Consider implementing this interface and + * overriding the relevant methods for your needs. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 5.0 + * @see WebFluxConfigurationSupport + * @see DelegatingWebFluxConfiguration + */ +public interface WebFluxConfigurer { + + /** + * Configure how the content type requested for the response is resolved + * when handling requests with annotated controllers. + * @param builder for configuring the resolvers to use + */ + default void configureContentTypeResolver(RequestedContentTypeResolverBuilder builder) { + } + + /** + * Configure "global" cross origin request processing. + *

The configured readers and writers will apply to all requests including + * annotated controllers and functional endpoints. Annotated controllers can + * further declare more fine-grained configuration via + * {@link org.springframework.web.bind.annotation.CrossOrigin @CrossOrigin}. + * @see CorsRegistry + */ + default void addCorsMappings(CorsRegistry registry) { + } + + /** + * Configure path matching options. + *

The configured path matching options will be used for mapping to + * annotated controllers and also + * {@link #addResourceHandlers(ResourceHandlerRegistry) static resources}. + * @param configurer the {@link PathMatchConfigurer} instance + */ + default void configurePathMatching(PathMatchConfigurer configurer) { + } + + /** + * Add resource handlers for serving static resources. + * @see ResourceHandlerRegistry + */ + default void addResourceHandlers(ResourceHandlerRegistry registry) { + } + + /** + * Configure resolvers for custom {@code @RequestMapping} method arguments. + * @param configurer to configurer to use + */ + default void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { + } + + /** + * Configure custom HTTP message readers and writers or override built-in ones. + *

The configured readers and writers will be used for both annotated + * controllers and functional endpoints. + * @param configurer the configurer to use + */ + default void configureHttpMessageCodecs(ServerCodecConfigurer configurer) { + } + + /** + * Add custom {@link Converter Converters} and {@link Formatter Formatters} for + * performing type conversion and formatting of annotated controller method arguments. + */ + default void addFormatters(FormatterRegistry registry) { + } + + /** + * Provide a custom {@link Validator}. + *

By default a validator for standard bean validation is created if + * bean validation API is present on the classpath. + *

The configured validator is used for validating annotated controller + * method arguments. + */ + @Nullable + default Validator getValidator() { + return null; + } + + /** + * Provide a custom {@link MessageCodesResolver} to use for data binding in + * annotated controller method arguments instead of the one created by + * default in {@link org.springframework.validation.DataBinder}. + */ + @Nullable + default MessageCodesResolver getMessageCodesResolver() { + return null; + } + + /** + * Configure view resolution for rendering responses with a view and a model, + * where the view is typically an HTML template but could also be based on + * an HTTP message writer (e.g. JSON, XML). + *

The configured view resolvers will be used for both annotated + * controllers and functional endpoints. + */ + default void configureViewResolvers(ViewResolverRegistry registry) { + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurerComposite.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurerComposite.java new file mode 100644 index 0000000000000000000000000000000000000000..2e1eed91cd480144aa202d870a6ec80a80c08117 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/WebFluxConfigurerComposite.java @@ -0,0 +1,118 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.springframework.format.FormatterRegistry; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.validation.MessageCodesResolver; +import org.springframework.validation.Validator; +import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; +import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; + +/** + * A {@link WebFluxConfigurer} that delegates to one or more others. + * + * @author Brian Clozel + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class WebFluxConfigurerComposite implements WebFluxConfigurer { + + private final List delegates = new ArrayList<>(); + + + public void addWebFluxConfigurers(List configurers) { + if (!CollectionUtils.isEmpty(configurers)) { + this.delegates.addAll(configurers); + } + } + + + @Override + public void configureContentTypeResolver(RequestedContentTypeResolverBuilder builder) { + this.delegates.forEach(delegate -> delegate.configureContentTypeResolver(builder)); + } + + @Override + public void addCorsMappings(CorsRegistry registry) { + this.delegates.forEach(delegate -> delegate.addCorsMappings(registry)); + } + + @Override + public void configurePathMatching(PathMatchConfigurer configurer) { + this.delegates.forEach(delegate -> delegate.configurePathMatching(configurer)); + } + + @Override + public void addResourceHandlers(ResourceHandlerRegistry registry) { + this.delegates.forEach(delegate -> delegate.addResourceHandlers(registry)); + } + + @Override + public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { + this.delegates.forEach(delegate -> delegate.configureArgumentResolvers(configurer)); + } + + @Override + public void configureHttpMessageCodecs(ServerCodecConfigurer configurer) { + this.delegates.forEach(delegate -> delegate.configureHttpMessageCodecs(configurer)); + } + + @Override + public void addFormatters(FormatterRegistry registry) { + this.delegates.forEach(delegate -> delegate.addFormatters(registry)); + } + + @Override + public Validator getValidator() { + return createSingleBean(WebFluxConfigurer::getValidator, Validator.class); + } + + @Override + public MessageCodesResolver getMessageCodesResolver() { + return createSingleBean(WebFluxConfigurer::getMessageCodesResolver, MessageCodesResolver.class); + } + + @Override + public void configureViewResolvers(ViewResolverRegistry registry) { + this.delegates.forEach(delegate -> delegate.configureViewResolvers(registry)); + } + + @Nullable + private T createSingleBean(Function factory, Class beanType) { + List result = this.delegates.stream().map(factory).filter(Objects::nonNull).collect(Collectors.toList()); + if (result.isEmpty()) { + return null; + } + else if (result.size() == 1) { + return result.get(0); + } + else { + throw new IllegalStateException("More than one WebFluxConfigurer implements " + + beanType.getSimpleName() + " factory method."); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/config/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/config/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..9722fd8a1448d8c0f87c44bc46cc0a6ba2f8fc81 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/config/package-info.java @@ -0,0 +1,9 @@ +/** + * Spring WebFlux configuration infrastructure. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.config; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractor.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractor.java new file mode 100644 index 0000000000000000000000000000000000000000..91b6a9fb5684c94b0e69d6b3e49e48b02a7d7433 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractor.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.server.reactive.ServerHttpResponse; + +/** + * A function that can extract data from a {@link ReactiveHttpInputMessage} body. + * + * @author Arjen Poutsma + * @since 5.0 + * @param the type of data to extract + * @param the type of {@link ReactiveHttpInputMessage} this extractor can be applied to + * @see BodyExtractors + */ +@FunctionalInterface +public interface BodyExtractor { + + /** + * Extract from the given input message. + * @param inputMessage the request to extract from + * @param context the configuration to use + * @return the extracted data + */ + T extract(M inputMessage, Context context); + + + /** + * Defines the context used during the extraction. + */ + interface Context { + + /** + * Return the {@link HttpMessageReader HttpMessageReaders} to be used for body extraction. + * @return the stream of message readers + */ + List> messageReaders(); + + /** + * Optionally return the {@link ServerHttpResponse}, if present. + */ + Optional serverResponse(); + + /** + * Return the map of hints to use to customize body extraction. + */ + Map hints(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java new file mode 100644 index 0000000000000000000000000000000000000000..2f1159cd33a547643866264c727d09921dbbb451 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java @@ -0,0 +1,281 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import java.util.List; +import java.util.Optional; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.MultiValueMap; + +/** + * Static factory methods for {@link BodyExtractor} implementations. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public abstract class BodyExtractors { + + private static final ResolvableType FORM_DATA_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + + private static final ResolvableType MULTIPART_DATA_TYPE = ResolvableType.forClassWithGenerics( + MultiValueMap.class, String.class, Part.class); + + private static final ResolvableType PART_TYPE = ResolvableType.forClass(Part.class); + + private static final ResolvableType VOID_TYPE = ResolvableType.forClass(Void.class); + + + /** + * Extractor to decode the input content into {@code Mono}. + * @param elementClass the class of the element type to decode to + * @param the element type to decode to + * @return {@code BodyExtractor} for {@code Mono} + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono(Class elementClass) { + return toMono(ResolvableType.forClass(elementClass)); + } + + /** + * Variant of {@link #toMono(Class)} for type information with generics. + * @param typeRef the type reference for the type to decode to + * @param the element type to decode to + * @return {@code BodyExtractor} for {@code Mono} + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono(ParameterizedTypeReference typeRef) { + return toMono(ResolvableType.forType(typeRef.getType())); + } + + private static BodyExtractor, ReactiveHttpInputMessage> toMono(ResolvableType elementType) { + return (inputMessage, context) -> + readWithMessageReaders(inputMessage, context, elementType, + (HttpMessageReader reader) -> readToMono(inputMessage, context, elementType, reader), + ex -> Mono.from(unsupportedErrorHandler(inputMessage, ex)), + skipBodyAsMono(inputMessage)); + } + + /** + * Extractor to decode the input content into {@code Flux}. + * @param elementClass the class of the element type to decode to + * @param the element type to decode to + * @return {@code BodyExtractor} for {@code Flux} + */ + public static BodyExtractor, ReactiveHttpInputMessage> toFlux(Class elementClass) { + return toFlux(ResolvableType.forClass(elementClass)); + } + + /** + * Variant of {@link #toFlux(Class)} for type information with generics. + * @param typeRef the type reference for the type to decode to + * @param the element type to decode to + * @return {@code BodyExtractor} for {@code Flux} + */ + public static BodyExtractor, ReactiveHttpInputMessage> toFlux(ParameterizedTypeReference typeRef) { + return toFlux(ResolvableType.forType(typeRef.getType())); + } + + @SuppressWarnings("unchecked") + private static BodyExtractor, ReactiveHttpInputMessage> toFlux(ResolvableType elementType) { + return (inputMessage, context) -> + readWithMessageReaders(inputMessage, context, elementType, + (HttpMessageReader reader) -> readToFlux(inputMessage, context, elementType, reader), + ex -> unsupportedErrorHandler(inputMessage, ex), + skipBodyAsFlux(inputMessage)); + } + + + // Extractors for specific content .. + + /** + * Extractor to read form data into {@code MultiValueMap}. + *

As of 5.1 this method can also be used on the client side to read form + * data from a server response (e.g. OAuth). + * @return {@code BodyExtractor} for form data + */ + public static BodyExtractor>, ReactiveHttpInputMessage> toFormData() { + return (message, context) -> { + ResolvableType elementType = FORM_DATA_TYPE; + MediaType mediaType = MediaType.APPLICATION_FORM_URLENCODED; + HttpMessageReader> reader = findReader(elementType, mediaType, context); + return readToMono(message, context, elementType, reader); + }; + } + + /** + * Extractor to read multipart data into a {@code MultiValueMap}. + * @return {@code BodyExtractor} for multipart data + */ + // Parameterized for server-side use + public static BodyExtractor>, ServerHttpRequest> toMultipartData() { + return (serverRequest, context) -> { + ResolvableType elementType = MULTIPART_DATA_TYPE; + MediaType mediaType = MediaType.MULTIPART_FORM_DATA; + HttpMessageReader> reader = findReader(elementType, mediaType, context); + return readToMono(serverRequest, context, elementType, reader); + }; + } + + /** + * Extractor to read multipart data into {@code Flux}. + * @return {@code BodyExtractor} for multipart request parts + */ + // Parameterized for server-side use + public static BodyExtractor, ServerHttpRequest> toParts() { + return (serverRequest, context) -> { + ResolvableType elementType = PART_TYPE; + MediaType mediaType = MediaType.MULTIPART_FORM_DATA; + HttpMessageReader reader = findReader(elementType, mediaType, context); + return readToFlux(serverRequest, context, elementType, reader); + }; + } + + /** + * Extractor that returns the raw {@link DataBuffer DataBuffers}. + *

Note: the data buffers should be + * {@link org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) + * released} after being used. + * @return {@code BodyExtractor} for data buffers + */ + public static BodyExtractor, ReactiveHttpInputMessage> toDataBuffers() { + return (inputMessage, context) -> inputMessage.getBody(); + } + + + // Private support methods + + private static > S readWithMessageReaders( + ReactiveHttpInputMessage message, BodyExtractor.Context context, ResolvableType elementType, + Function, S> readerFunction, + Function errorFunction, + Supplier emptySupplier) { + + if (VOID_TYPE.equals(elementType)) { + return emptySupplier.get(); + } + MediaType contentType = Optional.ofNullable(message.getHeaders().getContentType()) + .orElse(MediaType.APPLICATION_OCTET_STREAM); + + return context.messageReaders().stream() + .filter(reader -> reader.canRead(elementType, contentType)) + .findFirst() + .map(BodyExtractors::cast) + .map(readerFunction) + .orElseGet(() -> { + List mediaTypes = context.messageReaders().stream() + .flatMap(reader -> reader.getReadableMediaTypes().stream()) + .collect(Collectors.toList()); + return errorFunction.apply( + new UnsupportedMediaTypeException(contentType, mediaTypes, elementType)); + }); + } + + private static Mono readToMono(ReactiveHttpInputMessage message, BodyExtractor.Context context, + ResolvableType type, HttpMessageReader reader) { + + return context.serverResponse() + .map(response -> reader.readMono(type, type, (ServerHttpRequest) message, response, context.hints())) + .orElseGet(() -> reader.readMono(type, message, context.hints())); + } + + private static Flux readToFlux(ReactiveHttpInputMessage message, BodyExtractor.Context context, + ResolvableType type, HttpMessageReader reader) { + + return context.serverResponse() + .map(response -> reader.read(type, type, (ServerHttpRequest) message, response, context.hints())) + .orElseGet(() -> reader.read(type, message, context.hints())); + } + + private static Flux unsupportedErrorHandler( + ReactiveHttpInputMessage message, UnsupportedMediaTypeException ex) { + + Flux result; + if (message.getHeaders().getContentType() == null) { + // Maybe it's okay there is no content type, if there is no content.. + result = message.getBody().map(buffer -> { + DataBufferUtils.release(buffer); + throw ex; + }); + } + else { + result = message instanceof ClientHttpResponse ? + consumeAndCancel(message).thenMany(Flux.error(ex)) : Flux.error(ex); + } + return result; + } + + private static HttpMessageReader findReader( + ResolvableType elementType, MediaType mediaType, BodyExtractor.Context context) { + + return context.messageReaders().stream() + .filter(messageReader -> messageReader.canRead(elementType, mediaType)) + .findFirst() + .map(BodyExtractors::cast) + .orElseThrow(() -> new IllegalStateException( + "No HttpMessageReader for \"" + mediaType + "\" and \"" + elementType + "\"")); + } + + @SuppressWarnings("unchecked") + private static HttpMessageReader cast(HttpMessageReader reader) { + return (HttpMessageReader) reader; + } + + private static Supplier> skipBodyAsFlux(ReactiveHttpInputMessage message) { + return message instanceof ClientHttpResponse ? + () -> consumeAndCancel(message).thenMany(Mono.empty()) : Flux::empty; + } + + @SuppressWarnings("unchecked") + private static Supplier> skipBodyAsMono(ReactiveHttpInputMessage message) { + return message instanceof ClientHttpResponse ? + () -> consumeAndCancel(message).then(Mono.empty()) : Mono::empty; + } + + private static Mono consumeAndCancel(ReactiveHttpInputMessage message) { + return message.getBody() + .map(buffer -> { + DataBufferUtils.release(buffer); + throw new ReadCancellationException(); + }) + .onErrorResume(ReadCancellationException.class, ex -> Mono.empty()) + .then(); + } + + @SuppressWarnings("serial") + private static class ReadCancellationException extends RuntimeException { + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserter.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserter.java new file mode 100644 index 0000000000000000000000000000000000000000..082ab61d8506fc45c73f0ccf50ff13540739c213 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserter.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import reactor.core.publisher.Mono; + +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.server.reactive.ServerHttpRequest; + +/** + * A combination of functions that can populate a {@link ReactiveHttpOutputMessage} body. + * + * @author Arjen Poutsma + * @since 5.0 + * @param the type of data to insert + * @param the type of {@link ReactiveHttpOutputMessage} this inserter can be applied to + * @see BodyInserters + */ +@FunctionalInterface +public interface BodyInserter { + + /** + * Insert into the given output message. + * @param outputMessage the response to insert into + * @param context the context to use + * @return a {@code Mono} that indicates completion or error + */ + Mono insert(M outputMessage, Context context); + + + /** + * Defines the context used during the insertion. + */ + interface Context { + + /** + * Return the {@link HttpMessageWriter HttpMessageWriters} to be used for response body conversion. + * @return the stream of message writers + */ + List> messageWriters(); + + /** + * Optionally return the {@link ServerHttpRequest}, if present. + */ + Optional serverRequest(); + + /** + * Return the map of hints to use for response body conversion. + */ + Map hints(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java new file mode 100644 index 0000000000000000000000000000000000000000..72442eaf78ef77112e898193e57559666722fcf5 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java @@ -0,0 +1,468 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import java.util.List; +import java.util.stream.Collectors; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpEntity; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.http.client.reactive.ClientHttpRequest; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Static factory methods for {@link BodyInserter} implementations. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 5.0 + */ +public abstract class BodyInserters { + + private static final ResolvableType RESOURCE_TYPE = ResolvableType.forClass(Resource.class); + + private static final ResolvableType SSE_TYPE = ResolvableType.forClass(ServerSentEvent.class); + + private static final ResolvableType FORM_DATA_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + + private static final ResolvableType MULTIPART_DATA_TYPE = ResolvableType.forClassWithGenerics( + MultiValueMap.class, String.class, Object.class); + + private static final BodyInserter EMPTY_INSERTER = + (response, context) -> response.setComplete(); + + + /** + * Inserter that does not write. + * @return the inserter + */ + @SuppressWarnings("unchecked") + public static BodyInserter empty() { + return (BodyInserter) EMPTY_INSERTER; + } + + /** + * Inserter to write the given object. + *

Alternatively, consider using the {@code syncBody(Object)} shortcuts on + * {@link org.springframework.web.reactive.function.client.WebClient WebClient} and + * {@link org.springframework.web.reactive.function.server.ServerResponse ServerResponse}. + * @param body the body to write to the response + * @param the type of the body + * @return the inserter to write a single object + */ + public static BodyInserter fromObject(T body) { + return (message, context) -> + writeWithMessageWriters(message, context, Mono.just(body), ResolvableType.forInstance(body)); + } + + /** + * Inserter to write the given {@link Publisher}. + *

Alternatively, consider using the {@code body} shortcuts on + * {@link org.springframework.web.reactive.function.client.WebClient WebClient} and + * {@link org.springframework.web.reactive.function.server.ServerResponse ServerResponse}. + * @param publisher the publisher to write with + * @param elementClass the type of elements in the publisher + * @param the type of the elements contained in the publisher + * @param

the {@code Publisher} type + * @return the inserter to write a {@code Publisher} + */ + public static > BodyInserter fromPublisher( + P publisher, Class elementClass) { + + Assert.notNull(publisher, "Publisher must not be null"); + Assert.notNull(elementClass, "Element Class must not be null"); + return (message, context) -> + writeWithMessageWriters(message, context, publisher, ResolvableType.forClass(elementClass)); + } + + /** + * Inserter to write the given {@link Publisher}. + *

Alternatively, consider using the {@code body} shortcuts on + * {@link org.springframework.web.reactive.function.client.WebClient WebClient} and + * {@link org.springframework.web.reactive.function.server.ServerResponse ServerResponse}. + * @param publisher the publisher to write with + * @param typeReference the type of elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the {@code Publisher} type + * @return the inserter to write a {@code Publisher} + */ + public static > BodyInserter fromPublisher( + P publisher, ParameterizedTypeReference typeReference) { + + Assert.notNull(publisher, "Publisher must not be null"); + Assert.notNull(typeReference, "ParameterizedTypeReference must not be null"); + return (message, context) -> + writeWithMessageWriters(message, context, publisher, ResolvableType.forType(typeReference.getType())); + } + + /** + * Inserter to write the given {@code Resource}. + *

If the resource can be resolved to a {@linkplain Resource#getFile() file}, it will + * be copied using zero-copy. + * @param resource the resource to write to the output message + * @param the type of the {@code Resource} + * @return the inserter to write a {@code Publisher} + */ + public static BodyInserter fromResource(T resource) { + Assert.notNull(resource, "Resource must not be null"); + return (outputMessage, context) -> { + ResolvableType elementType = RESOURCE_TYPE; + HttpMessageWriter writer = findWriter(context, elementType, null); + return write(Mono.just(resource), elementType, null, outputMessage, context, writer); + }; + } + + /** + * Inserter to write the given {@code ServerSentEvent} publisher. + *

Alternatively, you can provide event data objects via + * {@link #fromPublisher(Publisher, Class)}, and set the "Content-Type" to + * {@link MediaType#TEXT_EVENT_STREAM text/event-stream}. + * @param eventsPublisher the {@code ServerSentEvent} publisher to write to the response body + * @param the type of the data elements in the {@link ServerSentEvent} + * @return the inserter to write a {@code ServerSentEvent} publisher + * @see Server-Sent Events W3C recommendation + */ + // Parameterized for server-side use + public static >> BodyInserter fromServerSentEvents( + S eventsPublisher) { + + Assert.notNull(eventsPublisher, "Publisher must not be null"); + return (serverResponse, context) -> { + ResolvableType elementType = SSE_TYPE; + MediaType mediaType = MediaType.TEXT_EVENT_STREAM; + HttpMessageWriter> writer = findWriter(context, elementType, mediaType); + return write(eventsPublisher, elementType, mediaType, serverResponse, context, writer); + }; + } + + /** + * Return a {@link FormInserter} to write the given {@code MultiValueMap} + * as URL-encoded form data. The returned inserter allows for additional + * entries to be added via {@link FormInserter#with(String, Object)}. + *

Note that you can also use the {@code syncBody(Object)} method in the + * request builders of both the {@code WebClient} and {@code WebTestClient}. + * In that case the setting of the request content type is also not required, + * just be sure the map contains String values only or otherwise it would be + * interpreted as a multipart request. + * @param formData the form data to write to the output message + * @return the inserter that allows adding more form data + */ + public static FormInserter fromFormData(MultiValueMap formData) { + return new DefaultFormInserter().with(formData); + } + + /** + * Return a {@link FormInserter} to write the given key-value pair as + * URL-encoded form data. The returned inserter allows for additional + * entries to be added via {@link FormInserter#with(String, Object)}. + * @param name the key to add to the form + * @param value the value to add to the form + * @return the inserter that allows adding more form data + */ + public static FormInserter fromFormData(String name, String value) { + Assert.notNull(name, "'name' must not be null"); + Assert.notNull(value, "'value' must not be null"); + return new DefaultFormInserter().with(name, value); + } + + /** + * Return a {@link MultipartInserter} to write the given + * {@code MultiValueMap} as multipart data. Values in the map can be an + * Object or an {@link HttpEntity}. + *

Note that you can also build the multipart data externally with + * {@link MultipartBodyBuilder}, and pass the resulting map directly to the + * {@code syncBody(Object)} shortcut method in {@code WebClient}. + * @param multipartData the form data to write to the output message + * @return the inserter that allows adding more parts + * @see MultipartBodyBuilder + */ + public static MultipartInserter fromMultipartData(MultiValueMap multipartData) { + Assert.notNull(multipartData, "'multipartData' must not be null"); + return new DefaultMultipartInserter().withInternal(multipartData); + } + + /** + * Return a {@link MultipartInserter} to write the given parts, + * as multipart data. Values in the map can be an Object or an + * {@link HttpEntity}. + *

Note that you can also build the multipart data externally with + * {@link MultipartBodyBuilder}, and pass the resulting map directly to the + * {@code syncBody(Object)} shortcut method in {@code WebClient}. + * @param name the part name + * @param value the part value, an Object or {@code HttpEntity} + * @return the inserter that allows adding more parts + */ + public static MultipartInserter fromMultipartData(String name, Object value) { + Assert.notNull(name, "'name' must not be null"); + Assert.notNull(value, "'value' must not be null"); + return new DefaultMultipartInserter().with(name, value); + } + + /** + * Return a {@link MultipartInserter} to write the given asynchronous parts, + * as multipart data. + *

Note that you can also build the multipart data externally with + * {@link MultipartBodyBuilder}, and pass the resulting map directly to the + * {@code syncBody(Object)} shortcut method in {@code WebClient}. + * @param name the part name + * @param publisher the publisher that forms the part value + * @param elementClass the class contained in the {@code publisher} + * @return the inserter that allows adding more parts + */ + public static > MultipartInserter fromMultipartAsyncData( + String name, P publisher, Class elementClass) { + + return new DefaultMultipartInserter().withPublisher(name, publisher, elementClass); + } + + /** + * Variant of {@link #fromMultipartAsyncData(String, Publisher, Class)} that + * accepts a {@link ParameterizedTypeReference} for the element type, which + * allows specifying generic type information. + *

Note that you can also build the multipart data externally with + * {@link MultipartBodyBuilder}, and pass the resulting map directly to the + * {@code syncBody(Object)} shortcut method in {@code WebClient}. + * @param name the part name + * @param publisher the publisher that forms the part value + * @param typeReference the type contained in the {@code publisher} + * @return the inserter that allows adding more parts + */ + public static > MultipartInserter fromMultipartAsyncData( + String name, P publisher, ParameterizedTypeReference typeReference) { + + return new DefaultMultipartInserter().withPublisher(name, publisher, typeReference); + } + + /** + * Inserter to write the given {@code Publisher} to the body. + * @param publisher the data buffer publisher to write + * @param the type of the publisher + * @return the inserter to write directly to the body + * @see ReactiveHttpOutputMessage#writeWith(Publisher) + */ + public static > BodyInserter fromDataBuffers( + T publisher) { + + Assert.notNull(publisher, "Publisher must not be null"); + return (outputMessage, context) -> outputMessage.writeWith(publisher); + } + + + private static

, M extends ReactiveHttpOutputMessage> Mono writeWithMessageWriters( + M outputMessage, BodyInserter.Context context, P body, ResolvableType bodyType) { + + MediaType mediaType = outputMessage.getHeaders().getContentType(); + return context.messageWriters().stream() + .filter(messageWriter -> messageWriter.canWrite(bodyType, mediaType)) + .findFirst() + .map(BodyInserters::cast) + .map(writer -> write(body, bodyType, mediaType, outputMessage, context, writer)) + .orElseGet(() -> Mono.error(unsupportedError(bodyType, context, mediaType))); + } + + private static UnsupportedMediaTypeException unsupportedError(ResolvableType bodyType, + BodyInserter.Context context, @Nullable MediaType mediaType) { + + List supportedMediaTypes = context.messageWriters().stream() + .flatMap(reader -> reader.getWritableMediaTypes().stream()) + .collect(Collectors.toList()); + + return new UnsupportedMediaTypeException(mediaType, supportedMediaTypes, bodyType); + } + + private static Mono write(Publisher input, ResolvableType type, + @Nullable MediaType mediaType, ReactiveHttpOutputMessage message, + BodyInserter.Context context, HttpMessageWriter writer) { + + return context.serverRequest() + .map(request -> { + ServerHttpResponse response = (ServerHttpResponse) message; + return writer.write(input, type, type, mediaType, request, response, context.hints()); + }) + .orElseGet(() -> writer.write(input, type, mediaType, message, context.hints())); + } + + private static HttpMessageWriter findWriter( + BodyInserter.Context context, ResolvableType elementType, @Nullable MediaType mediaType) { + + return context.messageWriters().stream() + .filter(messageWriter -> messageWriter.canWrite(elementType, mediaType)) + .findFirst() + .map(BodyInserters::cast) + .orElseThrow(() -> new IllegalStateException( + "No HttpMessageWriter for \"" + mediaType + "\" and \"" + elementType + "\"")); + } + + @SuppressWarnings("unchecked") + private static HttpMessageWriter cast(HttpMessageWriter messageWriter) { + return (HttpMessageWriter) messageWriter; + } + + + /** + * Extension of {@link BodyInserter} that allows for adding form data or + * multipart form data. + * + * @param the value type + */ + public interface FormInserter extends BodyInserter, ClientHttpRequest> { + + // FormInserter is parameterized to ClientHttpRequest (for client-side use only) + + /** + * Adds the specified key-value pair to the form. + * @param key the key to be added + * @param value the value to be added + * @return this inserter for adding more parts + */ + FormInserter with(String key, T value); + + /** + * Adds the specified values to the form. + * @param values the values to be added + * @return this inserter for adding more parts + */ + FormInserter with(MultiValueMap values); + + } + + + /** + * Extension of {@link FormInserter} that allows for adding asynchronous parts. + */ + public interface MultipartInserter extends FormInserter { + + /** + * Add an asynchronous part with {@link Publisher}-based content. + * @param name the name of the part to add + * @param publisher the part contents + * @param elementClass the type of elements contained in the publisher + * @return this inserter for adding more parts + */ + > MultipartInserter withPublisher(String name, P publisher, + Class elementClass); + + /** + * Variant of {@link #withPublisher(String, Publisher, Class)} that accepts a + * {@link ParameterizedTypeReference} for the element type, which allows + * specifying generic type information. + * @param name the key to be added + * @param publisher the publisher to be added as value + * @param typeReference the type of elements contained in {@code publisher} + * @return this inserter for adding more parts + */ + > MultipartInserter withPublisher(String name, P publisher, + ParameterizedTypeReference typeReference); + + } + + + private static class DefaultFormInserter implements FormInserter { + + private final MultiValueMap data = new LinkedMultiValueMap<>(); + + @Override + public FormInserter with(String key, @Nullable String value) { + this.data.add(key, value); + return this; + } + + @Override + public FormInserter with(MultiValueMap values) { + this.data.addAll(values); + return this; + } + + @Override + public Mono insert(ClientHttpRequest outputMessage, Context context) { + HttpMessageWriter> messageWriter = + findWriter(context, FORM_DATA_TYPE, MediaType.APPLICATION_FORM_URLENCODED); + return messageWriter.write(Mono.just(this.data), FORM_DATA_TYPE, + MediaType.APPLICATION_FORM_URLENCODED, + outputMessage, context.hints()); + } + } + + + private static class DefaultMultipartInserter implements MultipartInserter { + + private final MultipartBodyBuilder builder = new MultipartBodyBuilder(); + + @Override + public MultipartInserter with(String key, Object value) { + this.builder.part(key, value); + return this; + } + + @Override + public MultipartInserter with(MultiValueMap values) { + return withInternal(values); + } + + @SuppressWarnings("unchecked") + private MultipartInserter withInternal(MultiValueMap values) { + values.forEach((key, valueList) -> { + for (Object value : valueList) { + this.builder.part(key, value); + } + }); + return this; + } + + @Override + public > MultipartInserter withPublisher( + String name, P publisher, Class elementClass) { + + this.builder.asyncPart(name, publisher, elementClass); + return this; + } + + @Override + public > MultipartInserter withPublisher( + String name, P publisher, ParameterizedTypeReference typeReference) { + + this.builder.asyncPart(name, publisher, typeReference); + return this; + } + + @Override + public Mono insert(ClientHttpRequest outputMessage, Context context) { + HttpMessageWriter>> messageWriter = + findWriter(context, MULTIPART_DATA_TYPE, MediaType.MULTIPART_FORM_DATA); + MultiValueMap> body = this.builder.build(); + return messageWriter.write(Mono.just(body), MULTIPART_DATA_TYPE, + MediaType.MULTIPART_FORM_DATA, outputMessage, context.hints()); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/UnsupportedMediaTypeException.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/UnsupportedMediaTypeException.java new file mode 100644 index 0000000000000000000000000000000000000000..8c863c567ec0415f8d6979f40fd07cb868a7ad87 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/UnsupportedMediaTypeException.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import java.util.Collections; +import java.util.List; + +import org.springframework.core.NestedRuntimeException; +import org.springframework.core.ResolvableType; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * Exception thrown to indicate that a {@code Content-Type} is not supported. + * + * @author Arjen Poutsma + * @since 5.0 + */ +@SuppressWarnings("serial") +public class UnsupportedMediaTypeException extends NestedRuntimeException { + + @Nullable + private final MediaType contentType; + + private final List supportedMediaTypes; + + @Nullable + private final ResolvableType bodyType; + + + /** + * Constructor for when the specified Content-Type is invalid. + */ + public UnsupportedMediaTypeException(String reason) { + super(reason); + this.contentType = null; + this.supportedMediaTypes = Collections.emptyList(); + this.bodyType = null; + } + + /** + * Constructor for when the Content-Type can be parsed but is not supported. + */ + public UnsupportedMediaTypeException(@Nullable MediaType contentType, List supportedTypes) { + this(contentType, supportedTypes, null); + } + + /** + * Constructor for when trying to encode from or decode to a specific Java type. + * @since 5.1 + */ + public UnsupportedMediaTypeException(@Nullable MediaType contentType, List supportedTypes, + @Nullable ResolvableType bodyType) { + + super(initReason(contentType, bodyType)); + this.contentType = contentType; + this.supportedMediaTypes = Collections.unmodifiableList(supportedTypes); + this.bodyType = bodyType; + } + + private static String initReason(@Nullable MediaType contentType, @Nullable ResolvableType bodyType) { + return "Content type '" + (contentType != null ? contentType : "") + "' not supported" + + (bodyType != null ? " for bodyType=" + bodyType.toString() : ""); + } + + + /** + * Return the request Content-Type header if it was parsed successfully, + * or {@code null} otherwise. + */ + @Nullable + public MediaType getContentType() { + return this.contentType; + } + + /** + * Return the list of supported content types in cases when the Content-Type + * header is parsed but not supported, or an empty list otherwise. + */ + public List getSupportedMediaTypes() { + return this.supportedMediaTypes; + } + + /** + * Return the body type in the context of which this exception was generated. + * This is applicable when the exception was raised as a result trying to + * encode from or decode to a specific Java type. + * @return the body type, or {@code null} if not available + * @since 5.1 + */ + @Nullable + public ResolvableType getBodyType() { + return this.bodyType; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..a0c8d007a8bb28cc40616b738e1317e1601e0dc0 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientRequest.java @@ -0,0 +1,261 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.net.URI; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.reactive.ClientHttpRequest; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserter; + +/** + * Represents a typed, immutable, client-side HTTP request, as executed by the + * {@link ExchangeFunction}. Instances of this interface can be created via static + * builder methods. + * + *

Note that applications are more likely to perform requests through + * {@link WebClient} rather than using this directly. + * + * @author Brian Clozel + * @author Arjen Poutsma + * @since 5.0 + */ +public interface ClientRequest { + + /** + * Name of {@link #attributes() attribute} whose value can be used to + * correlate log messages for this request. Use {@link #logPrefix()} to + * obtain a consistently formatted prefix based on this attribute. + * @since 5.1 + * @see #logPrefix() + */ + String LOG_ID_ATTRIBUTE = ClientRequest.class.getName() + ".LOG_ID"; + + + /** + * Return the HTTP method. + */ + HttpMethod method(); + + /** + * Return the request URI. + */ + URI url(); + + /** + * Return the headers of this request. + */ + HttpHeaders headers(); + + /** + * Return the cookies of this request. + */ + MultiValueMap cookies(); + + /** + * Return the body inserter of this request. + */ + BodyInserter body(); + + /** + * Return the request attribute value if present. + * @param name the attribute name + * @return the attribute value + */ + default Optional attribute(String name) { + return Optional.ofNullable(attributes().get(name)); + } + + /** + * Return the attributes of this request. + */ + Map attributes(); + + /** + * Return a log message prefix to use to correlate messages for this request. + * The prefix is based on the value of the attribute {@link #LOG_ID_ATTRIBUTE} + * along with some extra formatting so that the prefix can be conveniently + * prepended with no further formatting no separators required. + * @return the log message prefix or an empty String if the + * {@link #LOG_ID_ATTRIBUTE} is not set. + * @since 5.1 + */ + String logPrefix(); + + /** + * Write this request to the given {@link ClientHttpRequest}. + * @param request the client http request to write to + * @param strategies the strategies to use when writing + * @return {@code Mono} to indicate when writing is complete + */ + Mono writeTo(ClientHttpRequest request, ExchangeStrategies strategies); + + + // Static builder methods + + /** + * Create a builder with the method, URI, headers, and cookies of the given request. + * @param other the request to copy the method, URI, headers, and cookies from + * @return the created builder + */ + static Builder from(ClientRequest other) { + return new DefaultClientRequestBuilder(other); + } + + /** + * Create a builder with the given method and url. + * @param method the HTTP method (GET, POST, etc) + * @param url the url (as a URI instance) + * @return the created builder + * @deprecated in favor of {@link #create(HttpMethod, URI)} + */ + @Deprecated + static Builder method(HttpMethod method, URI url) { + return new DefaultClientRequestBuilder(method, url); + } + + /** + * Create a request builder with the given method and url. + * @param method the HTTP method (GET, POST, etc) + * @param url the url (as a URI instance) + * @return the created builder + */ + static Builder create(HttpMethod method, URI url) { + return new DefaultClientRequestBuilder(method, url); + } + + + /** + * Defines a builder for a request. + */ + interface Builder { + + /** + * Set the method of the request. + * @param method the new method + * @return this builder + * @since 5.0.1 + */ + Builder method(HttpMethod method); + + /** + * Set the url of the request. + * @param url the new url + * @return this builder + * @since 5.0.1 + */ + Builder url(URI url); + + /** + * Add the given header value(s) under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder header(String headerName, String... headerValues); + + /** + * Manipulate this request's headers with the given consumer. The + * headers provided to the consumer are "live", so that the consumer can be used to + * {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, + * {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other + * {@link HttpHeaders} methods. + * @param headersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + Builder headers(Consumer headersConsumer); + + /** + * Add a cookie with the given name and value(s). + * @param name the cookie name + * @param values the cookie value(s) + * @return this builder + */ + Builder cookie(String name, String... values); + + /** + * Manipulate this request's cookies with the given consumer. The + * map provided to the consumer is "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing header values, + * {@linkplain MultiValueMap#remove(Object) remove} values, or use any of the other + * {@link MultiValueMap} methods. + * @param cookiesConsumer a function that consumes the cookies map + * @return this builder + */ + Builder cookies(Consumer> cookiesConsumer); + + /** + * Set the body of the request to the given {@code BodyInserter}. + * @param inserter the {@code BodyInserter} that writes to the request + * @return this builder + */ + Builder body(BodyInserter inserter); + + /** + * Set the body of the request to the given {@code Publisher} and return it. + * @param publisher the {@code Publisher} to write to the request + * @param elementClass the class of elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the built request + */ + > Builder body(P publisher, Class elementClass); + + /** + * Set the body of the request to the given {@code Publisher} and return it. + * @param publisher the {@code Publisher} to write to the request + * @param typeReference a type reference describing the elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the built request + */ + > Builder body(P publisher, ParameterizedTypeReference typeReference); + + /** + * Set the attribute with the given name to the given value. + * @param name the name of the attribute to add + * @param value the value of the attribute to add + * @return this builder + */ + Builder attribute(String name, Object value); + + /** + * Manipulate the request attributes with the given consumer. The attributes provided to + * the consumer are "live", so that the consumer can be used to inspect attributes, + * remove attributes, or use any of the other map-provided methods. + * @param attributesConsumer a function that consumes the attributes + * @return this builder + */ + Builder attributes(Consumer> attributesConsumer); + + /** + * Build the request. + */ + ClientRequest build(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..57ca3749279729530880912601bc91acdc1523ff --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientResponse.java @@ -0,0 +1,347 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.function.Consumer; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractor; + +/** + * Represents an HTTP response, as returned by {@link WebClient} and also + * {@link ExchangeFunction}. Provides access to the response status and headers, + * and also methods to consume the response body. + * + *

NOTE: When given access to a {@link ClientResponse}, + * through the {@code WebClient} + * {@link WebClient.RequestHeadersSpec#exchange() exchange()} method, + * you must always use one of the body or toEntity methods to ensure resources + * are released and avoid potential issues with HTTP connection pooling. + * You can use {@code bodyToMono(Void.class)} if no response content is + * expected. However keep in mind that if the response does have content, the + * connection will be closed and will not be placed back in the pool. + * + * @author Brian Clozel + * @author Arjen Poutsma + * @since 5.0 + */ +public interface ClientResponse { + + /** + * Return the HTTP status code as an {@link HttpStatus} enum value. + * @return the HTTP status as an HttpStatus enum value (never {@code null}) + * @throws IllegalArgumentException in case of an unknown HTTP status code + * @since #getRawStatusCode() + * @see HttpStatus#valueOf(int) + */ + HttpStatus statusCode(); + + /** + * Return the (potentially non-standard) status code of this response. + * @return the HTTP status as an integer value + * @since 5.1 + * @see #statusCode() + * @see HttpStatus#resolve(int) + */ + int rawStatusCode(); + + /** + * Return the headers of this response. + */ + Headers headers(); + + /** + * Return the cookies of this response. + */ + MultiValueMap cookies(); + + /** + * Return the strategies used to convert the body of this response. + */ + ExchangeStrategies strategies(); + + /** + * Extract the body with the given {@code BodyExtractor}. + * @param extractor the {@code BodyExtractor} that reads from the response + * @param the type of the body returned + * @return the extracted body + */ + T body(BodyExtractor extractor); + + /** + * Extract the body to a {@code Mono}. + * @param elementClass the class of element in the {@code Mono} + * @param the element type + * @return a mono containing the body of the given type {@code T} + */ + Mono bodyToMono(Class elementClass); + + /** + * Extract the body to a {@code Mono}. + * @param typeReference a type reference describing the expected response body type + * @param the element type + * @return a mono containing the body of the given type {@code T} + */ + Mono bodyToMono(ParameterizedTypeReference typeReference); + + /** + * Extract the body to a {@code Flux}. + * @param elementClass the class of element in the {@code Flux} + * @param the element type + * @return a flux containing the body of the given type {@code T} + */ + Flux bodyToFlux(Class elementClass); + + /** + * Extract the body to a {@code Flux}. + * @param typeReference a type reference describing the expected response body type + * @param the element type + * @return a flux containing the body of the given type {@code T} + */ + Flux bodyToFlux(ParameterizedTypeReference typeReference); + + /** + * Return this response as a delayed {@code ResponseEntity}. + * @param bodyType the expected response body type + * @param response body type + * @return {@code Mono} with the {@code ResponseEntity} + */ + Mono> toEntity(Class bodyType); + + /** + * Return this response as a delayed {@code ResponseEntity}. + * @param typeReference a type reference describing the expected response body type + * @param response body type + * @return {@code Mono} with the {@code ResponseEntity} + */ + Mono> toEntity(ParameterizedTypeReference typeReference); + + /** + * Return this response as a delayed list of {@code ResponseEntity}s. + * @param elementType the expected response body list element type + * @param the type of elements in the list + * @return {@code Mono} with the list of {@code ResponseEntity}s + */ + Mono>> toEntityList(Class elementType); + + /** + * Return this response as a delayed list of {@code ResponseEntity}s. + * @param typeReference a type reference describing the expected response body type + * @param the type of elements in the list + * @return {@code Mono} with the list of {@code ResponseEntity}s + */ + Mono>> toEntityList(ParameterizedTypeReference typeReference); + + + // Static builder methods + + /** + * Create a builder with the status, headers, and cookies of the given response. + * @param other the response to copy the status, headers, and cookies from + * @return the created builder + */ + static Builder from(ClientResponse other) { + return new DefaultClientResponseBuilder(other); + } + + /** + * Create a response builder with the given status code and using default strategies for + * reading the body. + * @param statusCode the status code + * @return the created builder + */ + static Builder create(HttpStatus statusCode) { + return create(statusCode, ExchangeStrategies.withDefaults()); + } + + /** + * Create a response builder with the given status code and strategies for reading the body. + * @param statusCode the status code + * @param strategies the strategies + * @return the created builder + */ + static Builder create(HttpStatus statusCode, ExchangeStrategies strategies) { + return new DefaultClientResponseBuilder(strategies).statusCode(statusCode); + } + + /** + * Create a response builder with the given raw status code and strategies for reading the body. + * @param statusCode the status code + * @param strategies the strategies + * @return the created builder + * @since 5.1.9 + */ + static Builder create(int statusCode, ExchangeStrategies strategies) { + return new DefaultClientResponseBuilder(strategies).rawStatusCode(statusCode); + } + + /** + * Create a response builder with the given status code and message body readers. + * @param statusCode the status code + * @param messageReaders the message readers + * @return the created builder + */ + static Builder create(HttpStatus statusCode, List> messageReaders) { + return create(statusCode, new ExchangeStrategies() { + @Override + public List> messageReaders() { + return messageReaders; + } + @Override + public List> messageWriters() { + // not used in the response + return Collections.emptyList(); + } + }); + } + + + /** + * Represents the headers of the HTTP response. + * @see ClientResponse#headers() + */ + interface Headers { + + /** + * Return the length of the body in bytes, as specified by the + * {@code Content-Length} header. + */ + OptionalLong contentLength(); + + /** + * Return the {@linkplain MediaType media type} of the body, as specified + * by the {@code Content-Type} header. + */ + Optional contentType(); + + /** + * Return the header value(s), if any, for the header of the given name. + *

Return an empty list if no header values are found. + * @param headerName the header name + */ + List header(String headerName); + + /** + * Return the headers as an {@link HttpHeaders} instance. + */ + HttpHeaders asHttpHeaders(); + } + + + /** + * Defines a builder for a response. + */ + interface Builder { + + /** + * Set the status code of the response. + * @param statusCode the new status code + * @return this builder + */ + Builder statusCode(HttpStatus statusCode); + + /** + * Set the raw status code of the response. + * @param statusCode the new status code + * @return this builder + * @since 5.1.9 + */ + Builder rawStatusCode(int statusCode); + + /** + * Add the given header value(s) under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder header(String headerName, String... headerValues); + + /** + * Manipulate this response's headers with the given consumer. + *

The headers provided to the consumer are "live", so that the consumer + * can be used to {@linkplain HttpHeaders#set(String, String) overwrite} + * existing header values, {@linkplain HttpHeaders#remove(Object) remove} + * values, or use any of the other {@link HttpHeaders} methods. + * @param headersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + Builder headers(Consumer headersConsumer); + + /** + * Add a cookie with the given name and value(s). + * @param name the cookie name + * @param values the cookie value(s) + * @return this builder + */ + Builder cookie(String name, String... values); + + /** + * Manipulate this response's cookies with the given consumer. + *

The map provided to the consumer is "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing cookie values, + * {@linkplain MultiValueMap#remove(Object) remove} values, or use any of the other + * {@link MultiValueMap} methods. + * @param cookiesConsumer a function that consumes the cookies map + * @return this builder + */ + Builder cookies(Consumer> cookiesConsumer); + + /** + * Set the body of the response. + *

Calling this methods will + * {@linkplain org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) release} + * the existing body of the builder. + * @param body the new body + * @return this builder + */ + Builder body(Flux body); + + /** + * Set the body of the response to the UTF-8 encoded bytes of the given string. + *

Calling this methods will + * {@linkplain org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) release} + * the existing body of the builder. + * @param body the new body + * @return this builder + */ + Builder body(String body); + + /** + * Build the response. + */ + ClientResponse build(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientRequestBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientRequestBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..c3882b12a642667b9ab969df1062d6d4fedd6373 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientRequestBuilder.java @@ -0,0 +1,266 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.net.URI; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.codec.Hints; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.reactive.ClientHttpRequest; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.BodyInserters; + +/** + * Default implementation of {@link ClientRequest.Builder}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +final class DefaultClientRequestBuilder implements ClientRequest.Builder { + + private HttpMethod method; + + private URI url; + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private final Map attributes = new LinkedHashMap<>(); + + private BodyInserter body = BodyInserters.empty(); + + + public DefaultClientRequestBuilder(ClientRequest other) { + Assert.notNull(other, "ClientRequest must not be null"); + this.method = other.method(); + this.url = other.url(); + headers(headers -> headers.addAll(other.headers())); + cookies(cookies -> cookies.addAll(other.cookies())); + attributes(attributes -> attributes.putAll(other.attributes())); + body(other.body()); + } + + public DefaultClientRequestBuilder(HttpMethod method, URI url) { + Assert.notNull(method, "HttpMethod must not be null"); + Assert.notNull(url, "URI must not be null"); + this.method = method; + this.url = url; + } + + + @Override + public ClientRequest.Builder method(HttpMethod method) { + Assert.notNull(method, "HttpMethod must not be null"); + this.method = method; + return this; + } + + @Override + public ClientRequest.Builder url(URI url) { + Assert.notNull(url, "URI must not be null"); + this.url = url; + return this; + } + + @Override + public ClientRequest.Builder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public ClientRequest.Builder headers(Consumer headersConsumer) { + headersConsumer.accept(this.headers); + return this; + } + + @Override + public ClientRequest.Builder cookie(String name, String... values) { + for (String value : values) { + this.cookies.add(name, value); + } + return this; + } + + @Override + public ClientRequest.Builder cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(this.cookies); + return this; + } + + @Override + public > ClientRequest.Builder body(P publisher, Class elementClass) { + this.body = BodyInserters.fromPublisher(publisher, elementClass); + return this; + } + + @Override + public > ClientRequest.Builder body( + P publisher, ParameterizedTypeReference typeReference) { + + this.body = BodyInserters.fromPublisher(publisher, typeReference); + return this; + } + + @Override + public ClientRequest.Builder attribute(String name, Object value) { + this.attributes.put(name, value); + return this; + } + + @Override + public ClientRequest.Builder attributes(Consumer> attributesConsumer) { + attributesConsumer.accept(this.attributes); + return this; + } + + @Override + public ClientRequest.Builder body(BodyInserter inserter) { + this.body = inserter; + return this; + } + + @Override + public ClientRequest build() { + return new BodyInserterRequest(this.method, this.url, this.headers, this.cookies, this.body, this.attributes); + } + + + private static class BodyInserterRequest implements ClientRequest { + + private final HttpMethod method; + + private final URI url; + + private final HttpHeaders headers; + + private final MultiValueMap cookies; + + private final BodyInserter body; + + private final Map attributes; + + private final String logPrefix; + + public BodyInserterRequest(HttpMethod method, URI url, HttpHeaders headers, + MultiValueMap cookies, BodyInserter body, + Map attributes) { + + this.method = method; + this.url = url; + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + this.cookies = CollectionUtils.unmodifiableMultiValueMap(cookies); + this.body = body; + this.attributes = Collections.unmodifiableMap(attributes); + + Object id = attributes.computeIfAbsent(LOG_ID_ATTRIBUTE, name -> ObjectUtils.getIdentityHexString(this)); + this.logPrefix = "[" + id + "] "; + } + + @Override + public HttpMethod method() { + return this.method; + } + + @Override + public URI url() { + return this.url; + } + + @Override + public HttpHeaders headers() { + return this.headers; + } + + @Override + public MultiValueMap cookies() { + return this.cookies; + } + + @Override + public BodyInserter body() { + return this.body; + } + + @Override + public Map attributes() { + return this.attributes; + } + + @Override + public String logPrefix() { + return this.logPrefix; + } + + @Override + public Mono writeTo(ClientHttpRequest request, ExchangeStrategies strategies) { + HttpHeaders requestHeaders = request.getHeaders(); + if (!this.headers.isEmpty()) { + this.headers.entrySet().stream() + .filter(entry -> !requestHeaders.containsKey(entry.getKey())) + .forEach(entry -> requestHeaders + .put(entry.getKey(), entry.getValue())); + } + + MultiValueMap requestCookies = request.getCookies(); + if (!this.cookies.isEmpty()) { + this.cookies.forEach((name, values) -> values.forEach(value -> { + HttpCookie cookie = new HttpCookie(name, value); + requestCookies.add(name, cookie); + })); + } + + return this.body.insert(request, new BodyInserter.Context() { + @Override + public List> messageWriters() { + return strategies.messageWriters(); + } + @Override + public Optional serverRequest() { + return Optional.empty(); + } + @Override + public Map hints() { + return Hints.from(Hints.LOG_PREFIX_HINT, logPrefix()); + } + }); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..ba14fe11e7b9f188bae94d4a5739aa2311dfb156 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java @@ -0,0 +1,215 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.codec.Hints; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.BodyExtractors; + +/** + * Default implementation of {@link ClientResponse}. + * + * @author Arjen Poutsma + * @author Brian Clozel + * @since 5.0 + */ +class DefaultClientResponse implements ClientResponse { + + private final ClientHttpResponse response; + + private final Headers headers; + + private final ExchangeStrategies strategies; + + private final String logPrefix; + + + public DefaultClientResponse(ClientHttpResponse response, ExchangeStrategies strategies, String logPrefix) { + this.response = response; + this.strategies = strategies; + this.headers = new DefaultHeaders(); + this.logPrefix = logPrefix; + } + + + @Override + public ExchangeStrategies strategies() { + return this.strategies; + } + + @Override + public HttpStatus statusCode() { + return this.response.getStatusCode(); + } + + @Override + public int rawStatusCode() { + return this.response.getRawStatusCode(); + } + + @Override + public Headers headers() { + return this.headers; + } + + @Override + public MultiValueMap cookies() { + return this.response.getCookies(); + } + + @Override + public T body(BodyExtractor extractor) { + return extractor.extract(this.response, new BodyExtractor.Context() { + @Override + public List> messageReaders() { + return strategies.messageReaders(); + } + @Override + public Optional serverResponse() { + return Optional.empty(); + } + @Override + public Map hints() { + return Hints.from(Hints.LOG_PREFIX_HINT, logPrefix); + } + }); + } + + @Override + public Mono bodyToMono(Class elementClass) { + return body(BodyExtractors.toMono(elementClass)); + } + + @Override + public Mono bodyToMono(ParameterizedTypeReference typeReference) { + return body(BodyExtractors.toMono(typeReference)); + } + + @Override + public Flux bodyToFlux(Class elementClass) { + return body(BodyExtractors.toFlux(elementClass)); + } + + @Override + public Flux bodyToFlux(ParameterizedTypeReference typeReference) { + return body(BodyExtractors.toFlux(typeReference)); + } + + @Override + public Mono> toEntity(Class bodyType) { + return toEntityInternal(bodyToMono(bodyType)); + } + + @Override + public Mono> toEntity(ParameterizedTypeReference typeReference) { + return toEntityInternal(bodyToMono(typeReference)); + } + + private Mono> toEntityInternal(Mono bodyMono) { + HttpHeaders headers = headers().asHttpHeaders(); + int status = rawStatusCode(); + return bodyMono + .map(body -> createEntity(body, headers, status)) + .switchIfEmpty(Mono.defer( + () -> Mono.just(createEntity(headers, status)))); + } + + @Override + public Mono>> toEntityList(Class responseType) { + return toEntityListInternal(bodyToFlux(responseType)); + } + + @Override + public Mono>> toEntityList(ParameterizedTypeReference typeReference) { + return toEntityListInternal(bodyToFlux(typeReference)); + } + + private Mono>> toEntityListInternal(Flux bodyFlux) { + HttpHeaders headers = headers().asHttpHeaders(); + int status = rawStatusCode(); + return bodyFlux + .collectList() + .map(body -> createEntity(body, headers, status)); + } + + private ResponseEntity createEntity(HttpHeaders headers, int status) { + HttpStatus resolvedStatus = HttpStatus.resolve(status); + return resolvedStatus != null + ? new ResponseEntity<>(headers, resolvedStatus) + : ResponseEntity.status(status).headers(headers).build(); + } + + private ResponseEntity createEntity(T body, HttpHeaders headers, int status) { + HttpStatus resolvedStatus = HttpStatus.resolve(status); + return resolvedStatus != null + ? new ResponseEntity<>(body, headers, resolvedStatus) + : ResponseEntity.status(status).headers(headers).body(body); + } + + + private class DefaultHeaders implements Headers { + + private HttpHeaders delegate() { + return response.getHeaders(); + } + + @Override + public OptionalLong contentLength() { + return toOptionalLong(delegate().getContentLength()); + } + + @Override + public Optional contentType() { + return Optional.ofNullable(delegate().getContentType()); + } + + @Override + public List header(String headerName) { + List headerValues = delegate().get(headerName); + return (headerValues != null ? headerValues : Collections.emptyList()); + } + + @Override + public HttpHeaders asHttpHeaders() { + return HttpHeaders.readOnlyHttpHeaders(delegate()); + } + + private OptionalLong toOptionalLong(long value) { + return (value != -1 ? OptionalLong.of(value) : OptionalLong.empty()); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponseBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponseBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..13d56dbf6878b8d74be4f3b936399bc9de6802df --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponseBuilder.java @@ -0,0 +1,191 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.nio.charset.StandardCharsets; +import java.util.function.Consumer; + +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Default implementation of {@link ClientResponse.Builder}. + * + * @author Arjen Poutsma + * @since 5.0.5 + */ +final class DefaultClientResponseBuilder implements ClientResponse.Builder { + + private ExchangeStrategies strategies; + + private int statusCode = 200; + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private Flux body = Flux.empty(); + + + DefaultClientResponseBuilder(ExchangeStrategies strategies) { + Assert.notNull(strategies, "ExchangeStrategies must not be null"); + this.strategies = strategies; + } + + DefaultClientResponseBuilder(ClientResponse other) { + Assert.notNull(other, "ClientResponse must not be null"); + this.strategies = other.strategies(); + this.statusCode = other.rawStatusCode(); + this.headers.addAll(other.headers().asHttpHeaders()); + this.cookies.addAll(other.cookies()); + } + + + @Override + public DefaultClientResponseBuilder statusCode(HttpStatus statusCode) { + return rawStatusCode(statusCode.value()); + } + + @Override + public DefaultClientResponseBuilder rawStatusCode(int statusCode) { + Assert.isTrue(statusCode >= 100 && statusCode < 600, "StatusCode must be between 1xx and 5xx"); + this.statusCode = statusCode; + return this; + } + + @Override + public ClientResponse.Builder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public ClientResponse.Builder headers(Consumer headersConsumer) { + headersConsumer.accept(this.headers); + return this; + } + + @Override + public DefaultClientResponseBuilder cookie(String name, String... values) { + for (String value : values) { + this.cookies.add(name, ResponseCookie.from(name, value).build()); + } + return this; + } + + @Override + public ClientResponse.Builder cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(this.cookies); + return this; + } + + @Override + public ClientResponse.Builder body(Flux body) { + Assert.notNull(body, "Body must not be null"); + releaseBody(); + this.body = body; + return this; + } + + @Override + public ClientResponse.Builder body(String body) { + Assert.notNull(body, "Body must not be null"); + releaseBody(); + DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(); + this.body = Flux.just(body). + map(s -> { + byte[] bytes = body.getBytes(StandardCharsets.UTF_8); + return dataBufferFactory.wrap(bytes); + }); + return this; + } + + private void releaseBody() { + this.body.subscribe(DataBufferUtils.releaseConsumer()); + } + + @Override + public ClientResponse build() { + ClientHttpResponse httpResponse = + new BuiltClientHttpResponse(this.statusCode, this.headers, this.cookies, this.body); + + // When building ClientResponse manually, the ClientRequest.logPrefix() has to be passed, + // e.g. via ClientResponse.Builder, but this (builder) is not used currently. + return new DefaultClientResponse(httpResponse, this.strategies, ""); + } + + + private static class BuiltClientHttpResponse implements ClientHttpResponse { + + private final int statusCode; + + private final HttpHeaders headers; + + private final MultiValueMap cookies; + + private final Flux body; + + BuiltClientHttpResponse(int statusCode, HttpHeaders headers, + MultiValueMap cookies, Flux body) { + + this.statusCode = statusCode; + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + this.cookies = CollectionUtils.unmodifiableMultiValueMap(cookies); + this.body = body; + } + + @Override + public HttpStatus getStatusCode() { + return HttpStatus.valueOf(this.statusCode); + } + + @Override + public int getRawStatusCode() { + return this.statusCode; + } + + @Override + public HttpHeaders getHeaders() { + return this.headers; + } + + @Override + public MultiValueMap getCookies() { + return this.cookies; + } + + @Override + public Flux getBody() { + return this.body; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultExchangeStrategiesBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultExchangeStrategiesBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..e5703203fc5d7b286fe26c0603a4f6583c6448d2 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultExchangeStrategiesBuilder.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; + +/** + * Default implementation of {@link ExchangeStrategies.Builder}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +final class DefaultExchangeStrategiesBuilder implements ExchangeStrategies.Builder { + + final static ExchangeStrategies DEFAULT_EXCHANGE_STRATEGIES; + + static { + DefaultExchangeStrategiesBuilder builder = new DefaultExchangeStrategiesBuilder(); + builder.defaultConfiguration(); + DEFAULT_EXCHANGE_STRATEGIES = builder.build(); + } + + + private final ClientCodecConfigurer codecConfigurer; + + + public DefaultExchangeStrategiesBuilder() { + this.codecConfigurer = ClientCodecConfigurer.create(); + this.codecConfigurer.registerDefaults(false); + } + + private DefaultExchangeStrategiesBuilder(DefaultExchangeStrategies other) { + this.codecConfigurer = other.codecConfigurer.clone(); + } + + + public void defaultConfiguration() { + this.codecConfigurer.registerDefaults(true); + } + + @Override + public ExchangeStrategies.Builder codecs(Consumer consumer) { + consumer.accept(this.codecConfigurer); + return this; + } + + @Override + public ExchangeStrategies build() { + return new DefaultExchangeStrategies(this.codecConfigurer); + } + + + private static class DefaultExchangeStrategies implements ExchangeStrategies { + + private final ClientCodecConfigurer codecConfigurer; + + private final List> readers; + + private final List> writers; + + + public DefaultExchangeStrategies(ClientCodecConfigurer codecConfigurer) { + this.codecConfigurer = codecConfigurer; + this.readers = unmodifiableCopy(this.codecConfigurer.getReaders()); + this.writers = unmodifiableCopy(this.codecConfigurer.getWriters()); + } + + private static List unmodifiableCopy(List list) { + return Collections.unmodifiableList(new ArrayList<>(list)); + } + + + @Override + public List> messageReaders() { + return this.readers; + } + + @Override + public List> messageWriters() { + return this.writers; + } + + @Override + public Builder mutate() { + return new DefaultExchangeStrategiesBuilder(this); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java new file mode 100644 index 0000000000000000000000000000000000000000..b2d6508bd9945eb1ce8ab332f133e24b89549c62 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -0,0 +1,557 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.IntPredicate; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.reactive.ClientHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MimeType; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractors; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.UriBuilder; +import org.springframework.web.util.UriBuilderFactory; + +/** + * Default implementation of {@link WebClient}. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +class DefaultWebClient implements WebClient { + + private static final String URI_TEMPLATE_ATTRIBUTE = WebClient.class.getName() + ".uriTemplate"; + + private static final Mono NO_HTTP_CLIENT_RESPONSE_ERROR = Mono.error( + new IllegalStateException("The underlying HTTP client completed without emitting a response.")); + + + private final ExchangeFunction exchangeFunction; + + private final UriBuilderFactory uriBuilderFactory; + + @Nullable + private final HttpHeaders defaultHeaders; + + @Nullable + private final MultiValueMap defaultCookies; + + @Nullable + private final Consumer> defaultRequest; + + private final DefaultWebClientBuilder builder; + + + DefaultWebClient(ExchangeFunction exchangeFunction, @Nullable UriBuilderFactory factory, + @Nullable HttpHeaders defaultHeaders, @Nullable MultiValueMap defaultCookies, + @Nullable Consumer> defaultRequest, DefaultWebClientBuilder builder) { + + this.exchangeFunction = exchangeFunction; + this.uriBuilderFactory = (factory != null ? factory : new DefaultUriBuilderFactory()); + this.defaultHeaders = defaultHeaders; + this.defaultCookies = defaultCookies; + this.defaultRequest = defaultRequest; + this.builder = builder; + } + + + @Override + public RequestHeadersUriSpec get() { + return methodInternal(HttpMethod.GET); + } + + @Override + public RequestHeadersUriSpec head() { + return methodInternal(HttpMethod.HEAD); + } + + @Override + public RequestBodyUriSpec post() { + return methodInternal(HttpMethod.POST); + } + + @Override + public RequestBodyUriSpec put() { + return methodInternal(HttpMethod.PUT); + } + + @Override + public RequestBodyUriSpec patch() { + return methodInternal(HttpMethod.PATCH); + } + + @Override + public RequestHeadersUriSpec delete() { + return methodInternal(HttpMethod.DELETE); + } + + @Override + public RequestHeadersUriSpec options() { + return methodInternal(HttpMethod.OPTIONS); + } + + @Override + public RequestBodyUriSpec method(HttpMethod httpMethod) { + return methodInternal(httpMethod); + } + + private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) { + return new DefaultRequestBodyUriSpec(httpMethod); + } + + @Override + public Builder mutate() { + return new DefaultWebClientBuilder(this.builder); + } + + + private class DefaultRequestBodyUriSpec implements RequestBodyUriSpec { + + private final HttpMethod httpMethod; + + @Nullable + private URI uri; + + @Nullable + private HttpHeaders headers; + + @Nullable + private MultiValueMap cookies; + + @Nullable + private BodyInserter inserter; + + private final Map attributes = new LinkedHashMap<>(4); + + DefaultRequestBodyUriSpec(HttpMethod httpMethod) { + this.httpMethod = httpMethod; + } + + @Override + public RequestBodySpec uri(String uriTemplate, Object... uriVariables) { + attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate); + return uri(uriBuilderFactory.expand(uriTemplate, uriVariables)); + } + + @Override + public RequestBodySpec uri(String uriTemplate, Map uriVariables) { + attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate); + return uri(uriBuilderFactory.expand(uriTemplate, uriVariables)); + } + + @Override + public RequestBodySpec uri(Function uriFunction) { + return uri(uriFunction.apply(uriBuilderFactory.builder())); + } + + @Override + public RequestBodySpec uri(URI uri) { + this.uri = uri; + return this; + } + + private HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + } + return this.headers; + } + + private MultiValueMap getCookies() { + if (this.cookies == null) { + this.cookies = new LinkedMultiValueMap<>(4); + } + return this.cookies; + } + + @Override + public DefaultRequestBodyUriSpec header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + getHeaders().add(headerName, headerValue); + } + return this; + } + + @Override + public DefaultRequestBodyUriSpec headers(Consumer headersConsumer) { + headersConsumer.accept(getHeaders()); + return this; + } + + @Override + public RequestBodySpec attribute(String name, Object value) { + this.attributes.put(name, value); + return this; + } + + @Override + public RequestBodySpec attributes(Consumer> attributesConsumer) { + attributesConsumer.accept(this.attributes); + return this; + } + + @Override + public DefaultRequestBodyUriSpec accept(MediaType... acceptableMediaTypes) { + getHeaders().setAccept(Arrays.asList(acceptableMediaTypes)); + return this; + } + + @Override + public DefaultRequestBodyUriSpec acceptCharset(Charset... acceptableCharsets) { + getHeaders().setAcceptCharset(Arrays.asList(acceptableCharsets)); + return this; + } + + @Override + public DefaultRequestBodyUriSpec contentType(MediaType contentType) { + getHeaders().setContentType(contentType); + return this; + } + + @Override + public DefaultRequestBodyUriSpec contentLength(long contentLength) { + getHeaders().setContentLength(contentLength); + return this; + } + + @Override + public DefaultRequestBodyUriSpec cookie(String name, String value) { + getCookies().add(name, value); + return this; + } + + @Override + public DefaultRequestBodyUriSpec cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(getCookies()); + return this; + } + + @Override + public DefaultRequestBodyUriSpec ifModifiedSince(ZonedDateTime ifModifiedSince) { + getHeaders().setIfModifiedSince(ifModifiedSince); + return this; + } + + @Override + public DefaultRequestBodyUriSpec ifNoneMatch(String... ifNoneMatches) { + getHeaders().setIfNoneMatch(Arrays.asList(ifNoneMatches)); + return this; + } + + @Override + public RequestHeadersSpec body(BodyInserter inserter) { + this.inserter = inserter; + return this; + } + + @Override + public > RequestHeadersSpec body( + P publisher, ParameterizedTypeReference typeReference) { + + this.inserter = BodyInserters.fromPublisher(publisher, typeReference); + return this; + } + + @Override + public > RequestHeadersSpec body(P publisher, Class elementClass) { + this.inserter = BodyInserters.fromPublisher(publisher, elementClass); + return this; + } + + @Override + public RequestHeadersSpec syncBody(Object body) { + Assert.isTrue(!(body instanceof Publisher), + "Please specify the element class by using body(Publisher, Class)"); + this.inserter = BodyInserters.fromObject(body); + return this; + } + + @Override + public Mono exchange() { + ClientRequest request = (this.inserter != null ? + initRequestBuilder().body(this.inserter).build() : + initRequestBuilder().build()); + return Mono.defer(() -> exchangeFunction.exchange(request)) + .switchIfEmpty(NO_HTTP_CLIENT_RESPONSE_ERROR); + } + + private ClientRequest.Builder initRequestBuilder() { + if (defaultRequest != null) { + defaultRequest.accept(this); + } + return ClientRequest.create(this.httpMethod, initUri()) + .headers(headers -> headers.addAll(initHeaders())) + .cookies(cookies -> cookies.addAll(initCookies())) + .attributes(attributes -> attributes.putAll(this.attributes)); + } + + private URI initUri() { + return (this.uri != null ? this.uri : uriBuilderFactory.expand("")); + } + + private HttpHeaders initHeaders() { + if (CollectionUtils.isEmpty(this.headers)) { + return (defaultHeaders != null ? defaultHeaders : new HttpHeaders()); + } + else if (CollectionUtils.isEmpty(defaultHeaders)) { + return this.headers; + } + else { + HttpHeaders result = new HttpHeaders(); + result.putAll(defaultHeaders); + result.putAll(this.headers); + return result; + } + } + + private MultiValueMap initCookies() { + if (CollectionUtils.isEmpty(this.cookies)) { + return (defaultCookies != null ? defaultCookies : new LinkedMultiValueMap<>()); + } + else if (CollectionUtils.isEmpty(defaultCookies)) { + return this.cookies; + } + else { + MultiValueMap result = new LinkedMultiValueMap<>(); + result.putAll(defaultCookies); + result.putAll(this.cookies); + return result; + } + } + + @Override + public ResponseSpec retrieve() { + return new DefaultResponseSpec(exchange(), this::createRequest); + } + + private HttpRequest createRequest() { + return new HttpRequest() { + private final URI uri = initUri(); + private final HttpHeaders headers = initHeaders(); + + @Override + public HttpMethod getMethod() { + return httpMethod; + } + @Override + public String getMethodValue() { + return httpMethod.name(); + } + @Override + public URI getURI() { + return this.uri; + } + @Override + public HttpHeaders getHeaders() { + return this.headers; + } + }; + } + } + + + private static class DefaultResponseSpec implements ResponseSpec { + + private static final IntPredicate STATUS_CODE_ERROR = (value -> value >= 400); + + private static final StatusHandler DEFAULT_STATUS_HANDLER = + new StatusHandler(STATUS_CODE_ERROR, DefaultResponseSpec::createResponseException); + + private final Mono responseMono; + + private final Supplier requestSupplier; + + private final List statusHandlers = new ArrayList<>(1); + + DefaultResponseSpec(Mono responseMono, Supplier requestSupplier) { + this.responseMono = responseMono; + this.requestSupplier = requestSupplier; + this.statusHandlers.add(DEFAULT_STATUS_HANDLER); + } + + @Override + public ResponseSpec onStatus(Predicate statusPredicate, + Function> exceptionFunction) { + + return onRawStatus(toIntPredicate(statusPredicate), exceptionFunction); + } + + private static IntPredicate toIntPredicate(Predicate predicate) { + return value -> { + HttpStatus status = HttpStatus.resolve(value); + return (status != null && predicate.test(status)); + }; + } + + @Override + public ResponseSpec onRawStatus(IntPredicate statusCodePredicate, + Function> exceptionFunction) { + + if (this.statusHandlers.size() == 1 && this.statusHandlers.get(0) == DEFAULT_STATUS_HANDLER) { + this.statusHandlers.clear(); + } + this.statusHandlers.add(new StatusHandler(statusCodePredicate, + (clientResponse, request) -> exceptionFunction.apply(clientResponse))); + return this; + } + + @Override + public Mono bodyToMono(Class bodyType) { + return this.responseMono.flatMap(response -> handleBody(response, + response.bodyToMono(bodyType), mono -> mono.flatMap(Mono::error))); + } + + @Override + public Mono bodyToMono(ParameterizedTypeReference bodyType) { + return this.responseMono.flatMap(response -> + handleBody(response, response.bodyToMono(bodyType), mono -> mono.flatMap(Mono::error))); + } + + @Override + public Flux bodyToFlux(Class elementType) { + return this.responseMono.flatMapMany(response -> + handleBody(response, response.bodyToFlux(elementType), mono -> mono.flatMapMany(Flux::error))); + } + + @Override + public Flux bodyToFlux(ParameterizedTypeReference elementType) { + return this.responseMono.flatMapMany(response -> handleBody(response, + response.bodyToFlux(elementType), mono -> mono.flatMapMany(Flux::error))); + } + + private > T handleBody(ClientResponse response, + T bodyPublisher, Function, T> errorFunction) { + + int statusCode = response.rawStatusCode(); + for (StatusHandler handler : this.statusHandlers) { + if (handler.test(statusCode)) { + HttpRequest request = this.requestSupplier.get(); + Mono exMono; + try { + exMono = handler.apply(response, request); + exMono = exMono.flatMap(ex -> drainBody(response, ex)); + exMono = exMono.onErrorResume(ex -> drainBody(response, ex)); + } + catch (Throwable ex2) { + exMono = drainBody(response, ex2); + } + return errorFunction.apply(exMono); + } + } + return bodyPublisher; + } + + @SuppressWarnings("unchecked") + private Mono drainBody(ClientResponse response, Throwable ex) { + // Ensure the body is drained, even if the StatusHandler didn't consume it, + // but ignore exception, in case the handler did consume. + return (Mono) response.bodyToMono(Void.class) + .onErrorResume(ex2 -> Mono.empty()).thenReturn(ex); + } + + private static Mono createResponseException( + ClientResponse response, HttpRequest request) { + + return DataBufferUtils.join(response.body(BodyExtractors.toDataBuffers())) + .map(dataBuffer -> { + byte[] bytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(bytes); + DataBufferUtils.release(dataBuffer); + return bytes; + }) + .defaultIfEmpty(new byte[0]) + .map(bodyBytes -> { + Charset charset = response.headers().contentType() + .map(MimeType::getCharset) + .orElse(StandardCharsets.ISO_8859_1); + if (HttpStatus.resolve(response.rawStatusCode()) != null) { + return WebClientResponseException.create( + response.statusCode().value(), + response.statusCode().getReasonPhrase(), + response.headers().asHttpHeaders(), + bodyBytes, + charset, + request); + } + else { + return new UnknownHttpStatusCodeException( + response.rawStatusCode(), + response.headers().asHttpHeaders(), + bodyBytes, + charset, + request); + } + }); + } + + + private static class StatusHandler { + + private final IntPredicate predicate; + + private final BiFunction> exceptionFunction; + + public StatusHandler(IntPredicate predicate, + BiFunction> exceptionFunction) { + + this.predicate = predicate; + this.exceptionFunction = exceptionFunction; + } + + public boolean test(int status) { + return this.predicate.test(status); + } + + public Mono apply(ClientResponse response, HttpRequest request) { + return this.exceptionFunction.apply(response, request); + } + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..3b52090258b8829bb9f817252ae08b1cd1119947 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java @@ -0,0 +1,305 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.client.reactive.ClientHttpConnector; +import org.springframework.http.client.reactive.JettyClientHttpConnector; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.UriBuilderFactory; + +/** + * Default implementation of {@link WebClient.Builder}. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +final class DefaultWebClientBuilder implements WebClient.Builder { + + private static final boolean reactorClientPresent; + + private static final boolean jettyClientPresent; + + static { + ClassLoader loader = DefaultWebClientBuilder.class.getClassLoader(); + reactorClientPresent = ClassUtils.isPresent("reactor.netty.http.client.HttpClient", loader); + jettyClientPresent = ClassUtils.isPresent("org.eclipse.jetty.client.HttpClient", loader); + } + + + @Nullable + private String baseUrl; + + @Nullable + private Map defaultUriVariables; + + @Nullable + private UriBuilderFactory uriBuilderFactory; + + @Nullable + private HttpHeaders defaultHeaders; + + @Nullable + private MultiValueMap defaultCookies; + + @Nullable + private Consumer> defaultRequest; + + @Nullable + private List filters; + + @Nullable + private ClientHttpConnector connector; + + @Nullable + private ExchangeStrategies strategies; + + @Nullable + private List> strategiesConfigurers; + + @Nullable + private ExchangeFunction exchangeFunction; + + + public DefaultWebClientBuilder() { + } + + public DefaultWebClientBuilder(DefaultWebClientBuilder other) { + Assert.notNull(other, "DefaultWebClientBuilder must not be null"); + + this.baseUrl = other.baseUrl; + this.defaultUriVariables = (other.defaultUriVariables != null ? + new LinkedHashMap<>(other.defaultUriVariables) : null); + this.uriBuilderFactory = other.uriBuilderFactory; + + if (other.defaultHeaders != null) { + this.defaultHeaders = new HttpHeaders(); + this.defaultHeaders.putAll(other.defaultHeaders); + } + else { + this.defaultHeaders = null; + } + + this.defaultCookies = (other.defaultCookies != null ? + new LinkedMultiValueMap<>(other.defaultCookies) : null); + this.defaultRequest = other.defaultRequest; + this.filters = (other.filters != null ? new ArrayList<>(other.filters) : null); + + this.connector = other.connector; + this.strategies = other.strategies; + this.strategiesConfigurers = (other.strategiesConfigurers != null ? + new ArrayList<>(other.strategiesConfigurers) : null); + this.exchangeFunction = other.exchangeFunction; + } + + + @Override + public WebClient.Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + @Override + public WebClient.Builder defaultUriVariables(Map defaultUriVariables) { + this.defaultUriVariables = defaultUriVariables; + return this; + } + + @Override + public WebClient.Builder uriBuilderFactory(UriBuilderFactory uriBuilderFactory) { + this.uriBuilderFactory = uriBuilderFactory; + return this; + } + + @Override + public WebClient.Builder defaultHeader(String header, String... values) { + initHeaders().put(header, Arrays.asList(values)); + return this; + } + + @Override + public WebClient.Builder defaultHeaders(Consumer headersConsumer) { + headersConsumer.accept(initHeaders()); + return this; + } + + private HttpHeaders initHeaders() { + if (this.defaultHeaders == null) { + this.defaultHeaders = new HttpHeaders(); + } + return this.defaultHeaders; + } + + @Override + public WebClient.Builder defaultCookie(String cookie, String... values) { + initCookies().addAll(cookie, Arrays.asList(values)); + return this; + } + + @Override + public WebClient.Builder defaultCookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(initCookies()); + return this; + } + + private MultiValueMap initCookies() { + if (this.defaultCookies == null) { + this.defaultCookies = new LinkedMultiValueMap<>(4); + } + return this.defaultCookies; + } + + @Override + public WebClient.Builder defaultRequest(Consumer> defaultRequest) { + this.defaultRequest = this.defaultRequest != null ? + this.defaultRequest.andThen(defaultRequest) : defaultRequest; + return this; + } + + @Override + public WebClient.Builder filter(ExchangeFilterFunction filter) { + Assert.notNull(filter, "ExchangeFilterFunction must not be null"); + initFilters().add(filter); + return this; + } + + @Override + public WebClient.Builder filters(Consumer> filtersConsumer) { + filtersConsumer.accept(initFilters()); + return this; + } + + private List initFilters() { + if (this.filters == null) { + this.filters = new ArrayList<>(); + } + return this.filters; + } + + @Override + public WebClient.Builder clientConnector(ClientHttpConnector connector) { + this.connector = connector; + return this; + } + + @Override + public WebClient.Builder codecs(Consumer configurer) { + if (this.strategiesConfigurers == null) { + this.strategiesConfigurers = new ArrayList<>(4); + } + this.strategiesConfigurers.add(builder -> builder.codecs(configurer)); + return this; + } + + @Override + public WebClient.Builder exchangeStrategies(ExchangeStrategies strategies) { + this.strategies = strategies; + return this; + } + + @Override + @Deprecated + public WebClient.Builder exchangeStrategies(Consumer configurer) { + if (this.strategiesConfigurers == null) { + this.strategiesConfigurers = new ArrayList<>(4); + } + this.strategiesConfigurers.add(configurer); + return this; + } + + @Override + public WebClient.Builder exchangeFunction(ExchangeFunction exchangeFunction) { + this.exchangeFunction = exchangeFunction; + return this; + } + + @Override + public WebClient.Builder apply(Consumer builderConsumer) { + builderConsumer.accept(this); + return this; + } + + @Override + public WebClient.Builder clone() { + return new DefaultWebClientBuilder(this); + } + + @Override + public WebClient build() { + ExchangeFunction exchange = (this.exchangeFunction == null ? + ExchangeFunctions.create(getOrInitConnector(), initExchangeStrategies()) : + this.exchangeFunction); + ExchangeFunction filteredExchange = (this.filters != null ? this.filters.stream() + .reduce(ExchangeFilterFunction::andThen) + .map(filter -> filter.apply(exchange)) + .orElse(exchange) : exchange); + return new DefaultWebClient(filteredExchange, initUriBuilderFactory(), + this.defaultHeaders != null ? HttpHeaders.readOnlyHttpHeaders(this.defaultHeaders) : null, + this.defaultCookies != null ? CollectionUtils.unmodifiableMultiValueMap(this.defaultCookies) : null, + this.defaultRequest, new DefaultWebClientBuilder(this)); + } + + private ClientHttpConnector getOrInitConnector() { + if (this.connector != null) { + return this.connector; + } + else if (reactorClientPresent) { + return new ReactorClientHttpConnector(); + } + else if (jettyClientPresent) { + return new JettyClientHttpConnector(); + } + throw new IllegalStateException("No suitable default ClientHttpConnector found"); + } + + private ExchangeStrategies initExchangeStrategies() { + if (CollectionUtils.isEmpty(this.strategiesConfigurers)) { + return (this.strategies != null ? this.strategies : ExchangeStrategies.withDefaults()); + } + ExchangeStrategies.Builder builder = + (this.strategies != null ? this.strategies.mutate() : ExchangeStrategies.builder()); + this.strategiesConfigurers.forEach(configurer -> configurer.accept(builder)); + return builder.build(); + } + + private UriBuilderFactory initUriBuilderFactory() { + if (this.uriBuilderFactory != null) { + return this.uriBuilderFactory; + } + DefaultUriBuilderFactory factory = (this.baseUrl != null ? + new DefaultUriBuilderFactory(this.baseUrl) : new DefaultUriBuilderFactory()); + factory.setDefaultUriVariables(this.defaultUriVariables); + return factory; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..12fb186a539f477bec00cf0a4a1ebceecf21443f --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunction.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; + +/** + * Represents a function that filters an {@linkplain ExchangeFunction exchange function}. + *

The filter is executed when a {@code Subscriber} subscribes to the + * {@code Publisher} returned by the {@code WebClient}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +@FunctionalInterface +public interface ExchangeFilterFunction { + + /** + * Apply this filter to the given request and exchange function. + *

The given {@linkplain ExchangeFunction} represents the next entity + * in the chain, to be invoked via + * {@linkplain ExchangeFunction#exchange(ClientRequest) invoked} in order to + * proceed with the exchange, or not invoked to shortcut the chain. + * @param request the current request + * @param next the next exchange function in the chain + * @return the filtered response + */ + Mono filter(ClientRequest request, ExchangeFunction next); + + /** + * Return a composed filter function that first applies this filter, and + * then applies the given {@code "after"} filter. + * @param afterFilter the filter to apply after this filter + * @return the composed filter + */ + default ExchangeFilterFunction andThen(ExchangeFilterFunction afterFilter) { + Assert.notNull(afterFilter, "ExchangeFilterFunction must not be null"); + return (request, next) -> + filter(request, afterRequest -> afterFilter.filter(afterRequest, next)); + } + + /** + * Apply this filter to the given {@linkplain ExchangeFunction}, resulting + * in a filtered exchange function. + * @param exchange the exchange function to filter + * @return the filtered exchange function + */ + default ExchangeFunction apply(ExchangeFunction exchange) { + Assert.notNull(exchange, "ExchangeFunction must not be null"); + return request -> this.filter(request, exchange); + } + + /** + * Adapt the given request processor function to a filter function that only + * operates on the {@code ClientRequest}. + * @param processor the request processor + * @return the resulting filter adapter + */ + static ExchangeFilterFunction ofRequestProcessor(Function> processor) { + Assert.notNull(processor, "ClientRequest Function must not be null"); + return (request, next) -> processor.apply(request).flatMap(next::exchange); + } + + /** + * Adapt the given response processor function to a filter function that + * only operates on the {@code ClientResponse}. + * @param processor the response processor + * @return the resulting filter adapter + */ + static ExchangeFilterFunction ofResponseProcessor(Function> processor) { + Assert.notNull(processor, "ClientResponse Function must not be null"); + return (request, next) -> next.exchange(request).flatMap(processor); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java new file mode 100644 index 0000000000000000000000000000000000000000..6418883cf90d4d6a20c26c73e6421f52ee778863 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java @@ -0,0 +1,191 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.nio.charset.Charset; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.BodyExtractors; + +/** + * Static factory methods providing access to built-in implementations of + * {@link ExchangeFilterFunction} for basic authentication, error handling, etc. + * + * @author Rob Winch + * @author Arjen Poutsma + * @since 5.0 + */ +public abstract class ExchangeFilterFunctions { + + /** + * Name of the request attribute with {@link Credentials} for {@link #basicAuthentication()}. + * @deprecated as of Spring 5.1 in favor of using + * {@link HttpHeaders#setBasicAuth(String, String)} while building the request. + */ + @Deprecated + public static final String BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE = + ExchangeFilterFunctions.class.getName() + ".basicAuthenticationCredentials"; + + + /** + * Consume up to the specified number of bytes from the response body and + * cancel if any more data arrives. + *

Internally delegates to {@link DataBufferUtils#takeUntilByteCount}. + * @param maxByteCount the limit as number of bytes + * @return the filter to limit the response size with + * @since 5.1 + */ + public static ExchangeFilterFunction limitResponseSize(long maxByteCount) { + return (request, next) -> + next.exchange(request).map(response -> { + Flux body = response.body(BodyExtractors.toDataBuffers()); + body = DataBufferUtils.takeUntilByteCount(body, maxByteCount); + return ClientResponse.from(response).body(body).build(); + }); + } + + /** + * Return a filter that generates an error signal when the given + * {@link HttpStatus} predicate matches. + * @param statusPredicate the predicate to check the HTTP status with + * @param exceptionFunction the function that to create the exception + * @return the filter to generate an error signal + */ + public static ExchangeFilterFunction statusError(Predicate statusPredicate, + Function exceptionFunction) { + + Assert.notNull(statusPredicate, "Predicate must not be null"); + Assert.notNull(exceptionFunction, "Function must not be null"); + + return ExchangeFilterFunction.ofResponseProcessor( + response -> (statusPredicate.test(response.statusCode()) ? + Mono.error(exceptionFunction.apply(response)) : Mono.just(response))); + } + + /** + * Return a filter that applies HTTP Basic Authentication to the request + * headers via {@link HttpHeaders#setBasicAuth(String, String)}. + * @param user the user + * @param password the password + * @return the filter to add authentication headers with + * @see HttpHeaders#setBasicAuth(String, String) + * @see HttpHeaders#setBasicAuth(String, String, Charset) + */ + public static ExchangeFilterFunction basicAuthentication(String user, String password) { + return (request, next) -> + next.exchange(ClientRequest.from(request) + .headers(headers -> headers.setBasicAuth(user, password)) + .build()); + } + + + /** + * Variant of {@link #basicAuthentication(String, String)} that looks up + * the {@link Credentials Credentials} in a + * {@link #BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE request attribute}. + * @return the filter to use + * @see Credentials + * @deprecated as of Spring 5.1 in favor of using + * {@link HttpHeaders#setBasicAuth(String, String)} while building the request. + */ + @Deprecated + public static ExchangeFilterFunction basicAuthentication() { + return (request, next) -> { + Object attr = request.attributes().get(BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE); + if (attr instanceof Credentials) { + Credentials cred = (Credentials) attr; + return next.exchange(ClientRequest.from(request) + .headers(headers -> headers.setBasicAuth(cred.username, cred.password)) + .build()); + } + else { + return next.exchange(request); + } + }; + } + + + /** + * Stores user and password for HTTP basic authentication. + * @deprecated as of Spring 5.1 in favor of using + * {@link HttpHeaders#setBasicAuth(String, String)} while building the request. + */ + @Deprecated + public static final class Credentials { + + private final String username; + + private final String password; + + /** + * Create a new {@code Credentials} instance with the given username and password. + * @param username the username + * @param password the password + */ + public Credentials(String username, String password) { + Assert.notNull(username, "'username' must not be null"); + Assert.notNull(password, "'password' must not be null"); + this.username = username; + this.password = password; + } + + /** + * Return a {@literal Consumer} that stores the given user and password + * as a request attribute of type {@code Credentials} that is in turn + * used by {@link ExchangeFilterFunctions#basicAuthentication()}. + * @param user the user + * @param password the password + * @return a consumer that can be passed into + * {@linkplain ClientRequest.Builder#attributes(java.util.function.Consumer)} + * @see ClientRequest.Builder#attributes(java.util.function.Consumer) + * @see #BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE + */ + public static Consumer> basicAuthenticationCredentials(String user, String password) { + Credentials credentials = new Credentials(user, password); + return (map -> map.put(BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE, credentials)); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof Credentials)) { + return false; + } + Credentials otherCred = (Credentials) other; + return (this.username.equals(otherCred.username) && this.password.equals(otherCred.password)); + } + + @Override + public int hashCode() { + return 31 * this.username.hashCode() + this.password.hashCode(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..9078462f867dc28ac97f56887bef7a19bece6af3 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFunction.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import reactor.core.publisher.Mono; + +/** + * Represents a function that exchanges a {@linkplain ClientRequest request} for a (delayed) + * {@linkplain ClientResponse}. Can be used as an alternative to {@link WebClient}. + * + *

For example: + *

+ * ExchangeFunction exchangeFunction = ExchangeFunctions.create(new ReactorClientHttpConnector());
+ * ClientRequest<Void> request = ClientRequest.method(HttpMethod.GET, "https://example.com/resource").build();
+ *
+ * Mono<String> result = exchangeFunction
+ *     .exchange(request)
+ *     .then(response -> response.bodyToMono(String.class));
+ * 
+ * + * @author Arjen Poutsma + * @since 5.0 + */ +@FunctionalInterface +public interface ExchangeFunction { + + /** + * Exchange the given request for a response mono. + * @param request the request to exchange + * @return the delayed response + */ + Mono exchange(ClientRequest request); + + /** + * Filters this exchange function with the given {@code ExchangeFilterFunction}, resulting in a + * filtered {@code ExchangeFunction}. + * @param filter the filter to apply to this exchange + * @return the filtered exchange + * @see ExchangeFilterFunction#apply(ExchangeFunction) + */ + default ExchangeFunction filter(ExchangeFilterFunction filter) { + return filter.apply(this); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFunctions.java new file mode 100644 index 0000000000000000000000000000000000000000..35e09233dcf66abff31db7574c2ff97962606024 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFunctions.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.net.URI; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.reactive.ClientHttpConnector; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.http.codec.LoggingCodecSupport; +import org.springframework.util.Assert; + +/** + * Static factory methods to create an {@link ExchangeFunction}. + * + * @author Arjen Poutsma + * @author Rossen Stoyanchev + * @since 5.0 + */ +public abstract class ExchangeFunctions { + + private static final Log logger = LogFactory.getLog(ExchangeFunctions.class); + + + /** + * Create an {@code ExchangeFunction} with the given {@code ClientHttpConnector}. + * This is the same as calling + * {@link #create(ClientHttpConnector, ExchangeStrategies)} and passing + * {@link ExchangeStrategies#withDefaults()}. + * @param connector the connector to use for connecting to servers + * @return the created {@code ExchangeFunction} + */ + public static ExchangeFunction create(ClientHttpConnector connector) { + return create(connector, ExchangeStrategies.withDefaults()); + } + + /** + * Create an {@code ExchangeFunction} with the given + * {@code ClientHttpConnector} and {@code ExchangeStrategies}. + * @param connector the connector to use for connecting to servers + * @param strategies the {@code ExchangeStrategies} to use + * @return the created {@code ExchangeFunction} + */ + public static ExchangeFunction create(ClientHttpConnector connector, ExchangeStrategies strategies) { + return new DefaultExchangeFunction(connector, strategies); + } + + + private static class DefaultExchangeFunction implements ExchangeFunction { + + private final ClientHttpConnector connector; + + private final ExchangeStrategies strategies; + + private boolean enableLoggingRequestDetails; + + + public DefaultExchangeFunction(ClientHttpConnector connector, ExchangeStrategies strategies) { + Assert.notNull(connector, "ClientHttpConnector must not be null"); + Assert.notNull(strategies, "ExchangeStrategies must not be null"); + this.connector = connector; + this.strategies = strategies; + + strategies.messageWriters().stream() + .filter(LoggingCodecSupport.class::isInstance) + .forEach(reader -> { + if (((LoggingCodecSupport) reader).isEnableLoggingRequestDetails()) { + this.enableLoggingRequestDetails = true; + } + }); + } + + + @Override + public Mono exchange(ClientRequest clientRequest) { + Assert.notNull(clientRequest, "ClientRequest must not be null"); + HttpMethod httpMethod = clientRequest.method(); + URI url = clientRequest.url(); + String logPrefix = clientRequest.logPrefix(); + + return this.connector + .connect(httpMethod, url, httpRequest -> clientRequest.writeTo(httpRequest, this.strategies)) + .doOnRequest(n -> logRequest(clientRequest)) + .doOnCancel(() -> logger.debug(logPrefix + "Cancel signal (to close connection)")) + .map(httpResponse -> { + logResponse(httpResponse, logPrefix); + return new DefaultClientResponse(httpResponse, this.strategies, logPrefix); + }); + } + + private void logRequest(ClientRequest request) { + LogFormatUtils.traceDebug(logger, traceOn -> + request.logPrefix() + "HTTP " + request.method() + " " + request.url() + + (traceOn ? ", headers=" + formatHeaders(request.headers()) : "") + ); + } + + private void logResponse(ClientHttpResponse response, String logPrefix) { + LogFormatUtils.traceDebug(logger, traceOn -> { + int code = response.getRawStatusCode(); + HttpStatus status = HttpStatus.resolve(code); + return logPrefix + "Response " + (status != null ? status : code) + + (traceOn ? ", headers=" + formatHeaders(response.getHeaders()) : ""); + }); + } + + private String formatHeaders(HttpHeaders headers) { + return this.enableLoggingRequestDetails ? headers.toString() : headers.isEmpty() ? "{}" : "{masked}"; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeStrategies.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeStrategies.java new file mode 100644 index 0000000000000000000000000000000000000000..acf32d0959ae2fac2c99c0a438abdd395e559aa9 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeStrategies.java @@ -0,0 +1,108 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.util.List; +import java.util.function.Consumer; + +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; + +/** + * Provides strategies for use in an {@link ExchangeFunction}. + * + *

To create an instance, see the static methods {@link #withDefaults()}, + * {@link #builder()}, and {@link #empty()}. + * + * @author Brian Clozel + * @author Arjen Poutsma + * @since 5.0 + */ +public interface ExchangeStrategies { + + /** + * Return {@link HttpMessageReader HttpMessageReaders} to read and decode the response body with. + * @return the message readers + */ + List> messageReaders(); + + /** + * Return {@link HttpMessageWriter HttpMessageWriters} to write and encode the request body with. + * @return the message writers + */ + List> messageWriters(); + + /** + * Return a builder to create a new {@link ExchangeStrategies} instance + * replicated from the current instance. + * @since 5.1.12 + */ + default Builder mutate() { + throw new UnsupportedOperationException(); + } + + + // Static builder methods + + /** + * Return an {@code ExchangeStrategies} instance with default configuration + * provided by {@link ClientCodecConfigurer}. + */ + static ExchangeStrategies withDefaults() { + return DefaultExchangeStrategiesBuilder.DEFAULT_EXCHANGE_STRATEGIES; + } + + /** + * Return a builder pre-configured with default configuration to start. + * This is the same as {@link #withDefaults()} but returns a mutable builder + * for further customizations. + */ + static Builder builder() { + DefaultExchangeStrategiesBuilder builder = new DefaultExchangeStrategiesBuilder(); + builder.defaultConfiguration(); + return builder; + } + + /** + * Return a builder with empty configuration to start. + */ + static Builder empty() { + return new DefaultExchangeStrategiesBuilder(); + } + + + /** + * A mutable builder for an {@link ExchangeStrategies}. + */ + interface Builder { + + /** + * Customize the list of client-side HTTP message readers and writers. + * @param consumer the consumer to customize the codecs + * @return this builder + */ + Builder codecs(Consumer consumer); + + /** + * Builds the {@link ExchangeStrategies}. + * @return the built strategies + */ + ExchangeStrategies build(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/UnknownHttpStatusCodeException.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/UnknownHttpStatusCodeException.java new file mode 100644 index 0000000000000000000000000000000000000000..50c53a52f683cee0a7e55de3a6622869dd3ea404 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/UnknownHttpStatusCodeException.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; +import org.springframework.lang.Nullable; + +/** + * Exception thrown when an unknown (or custom) HTTP status code is received. + * + * @author Brian Clozel + * @since 5.1 + */ +public class UnknownHttpStatusCodeException extends WebClientResponseException { + + private static final long serialVersionUID = 2407169540168185007L; + + + /** + * Create a new instance of the {@code UnknownHttpStatusCodeException} with the given + * parameters. + */ + public UnknownHttpStatusCodeException( + int statusCode, HttpHeaders headers, byte[] responseBody, Charset responseCharset) { + + super("Unknown status code [" + statusCode + "]", statusCode, "", + headers, responseBody, responseCharset); + } + + /** + * Create a new instance of the {@code UnknownHttpStatusCodeException} with the given + * parameters. + * @since 5.1.4 + */ + public UnknownHttpStatusCodeException( + int statusCode, HttpHeaders headers, byte[] responseBody, Charset responseCharset, + @Nullable HttpRequest request) { + + super("Unknown status code [" + statusCode + "]", statusCode, "", + headers, responseBody, responseCharset, request); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java new file mode 100644 index 0000000000000000000000000000000000000000..eda41805c757e75bd2bab65c7bf0e15a7cc70bb6 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java @@ -0,0 +1,711 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.net.URI; +import java.nio.charset.Charset; +import java.time.ZonedDateTime; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.IntPredicate; +import java.util.function.Predicate; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.reactive.ClientHttpConnector; +import org.springframework.http.client.reactive.ClientHttpRequest; +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.util.UriBuilder; +import org.springframework.web.util.UriBuilderFactory; + +/** + * Non-blocking, reactive client to perform HTTP requests, exposing a fluent, + * reactive API over underlying HTTP client libraries such as Reactor Netty. + * + *

Use static factory methods {@link #create()} or {@link #create(String)}, + * or {@link WebClient#builder()} to prepare an instance. + * + *

For examples with a response body see: + *

    + *
  • {@link RequestHeadersSpec#retrieve() retrieve()} + *
  • {@link RequestHeadersSpec#exchange() exchange()} + *
+ *

For examples with a request body see: + *

    + *
  • {@link RequestBodySpec#body(Publisher, Class) body(Publisher,Class)} + *
  • {@link RequestBodySpec#syncBody(Object) syncBody(Object)} + *
  • {@link RequestBodySpec#body(BodyInserter) body(BodyInserter)} + *
+ * + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @author Brian Clozel + * @since 5.0 + */ +public interface WebClient { + + /** + * Start building an HTTP GET request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec get(); + + /** + * Start building an HTTP HEAD request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec head(); + + /** + * Start building an HTTP POST request. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec post(); + + /** + * Start building an HTTP PUT request. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec put(); + + /** + * Start building an HTTP PATCH request. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec patch(); + + /** + * Start building an HTTP DELETE request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec delete(); + + /** + * Start building an HTTP OPTIONS request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec options(); + + /** + * Start building a request for the given {@code HttpMethod}. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec method(HttpMethod method); + + + /** + * Return a builder to create a new {@code WebClient} whose settings are + * replicated from the current {@code WebClient}. + */ + Builder mutate(); + + + // Static, factory methods + + /** + * Create a new {@code WebClient} with Reactor Netty by default. + * @see #create(String) + * @see #builder() + */ + static WebClient create() { + return new DefaultWebClientBuilder().build(); + } + + /** + * Variant of {@link #create()} that accepts a default base URL. For more + * details see {@link Builder#baseUrl(String) Builder.baseUrl(String)}. + * @param baseUrl the base URI for all requests + * @see #builder() + */ + static WebClient create(String baseUrl) { + return new DefaultWebClientBuilder().baseUrl(baseUrl).build(); + } + + /** + * Obtain a {@code WebClient} builder. + */ + static WebClient.Builder builder() { + return new DefaultWebClientBuilder(); + } + + + /** + * A mutable builder for creating a {@link WebClient}. + */ + interface Builder { + + /** + * Configure a base URL for requests performed through the client. + * + *

For example given base URL "https://abc.go.com/v1": + *

+		 * Mono<Account> result = client.get().uri("/accounts/{id}", 43)
+		 *         .retrieve()
+		 *         .bodyToMono(Account.class);
+		 *
+		 * // Result: https://abc.go.com/v1/accounts/43
+		 *
+		 * Flux<Account> result = client.get()
+		 *         .uri(builder -> builder.path("/accounts").queryParam("q", "12").build())
+		 *         .retrieve()
+		 *         .bodyToFlux(Account.class);
+		 *
+		 * // Result: https://abc.go.com/v1/accounts?q=12
+		 * 
+ * + *

The base URL can be overridden with an absolute URI: + *

+		 * Mono<Account> result = client.get().uri("https://xyz.com/path")
+		 *         .retrieve()
+		 *         .bodyToMono(Account.class);
+		 *
+		 * // Result: https://xyz.com/path
+		 * 
+ * + *

Or partially overridden with a {@code UriBuilder}: + *

+		 * Flux<Account> result = client.get()
+		 *         .uri(builder -> builder.replacePath("/v2/accounts").queryParam("q", "12").build())
+		 *         .retrieve()
+		 *         .bodyToFlux(Account.class);
+		 *
+		 * // Result: https://abc.go.com/v2/accounts?q=12
+		 * 
+ * + * @see #defaultUriVariables(Map) + * @see #uriBuilderFactory(UriBuilderFactory) + */ + Builder baseUrl(String baseUrl); + + /** + * Configure default URI variable values that will be used when expanding + * URI templates using a {@link Map}. + * @param defaultUriVariables the default values to use + * @see #baseUrl(String) + * @see #uriBuilderFactory(UriBuilderFactory) + */ + Builder defaultUriVariables(Map defaultUriVariables); + + /** + * Provide a pre-configured {@link UriBuilderFactory} instance. This is + * an alternative to and effectively overrides the following: + *
    + *
  • {@link #baseUrl(String)} + *
  • {@link #defaultUriVariables(Map)}. + *
+ * @param uriBuilderFactory the URI builder factory to use + * @see #baseUrl(String) + * @see #defaultUriVariables(Map) + */ + Builder uriBuilderFactory(UriBuilderFactory uriBuilderFactory); + + /** + * Global option to specify a header to be added to every request, + * if the request does not already contain such a header. + * @param header the header name + * @param values the header values + */ + Builder defaultHeader(String header, String... values); + + /** + * Provides access to every {@link #defaultHeader(String, String...)} + * declared so far with the possibility to add, replace, or remove. + * @param headersConsumer the consumer + */ + Builder defaultHeaders(Consumer headersConsumer); + + /** + * Global option to specify a cookie to be added to every request, + * if the request does not already contain such a cookie. + * @param cookie the cookie name + * @param values the cookie values + */ + Builder defaultCookie(String cookie, String... values); + + /** + * Provides access to every {@link #defaultCookie(String, String...)} + * declared so far with the possibility to add, replace, or remove. + * @param cookiesConsumer a function that consumes the cookies map + */ + Builder defaultCookies(Consumer> cookiesConsumer); + + /** + * Provide a consumer to modify every request being built just before the + * call to {@link RequestHeadersSpec#exchange() exchange()}. + * @param defaultRequest the consumer to use for modifying requests + * @since 5.1 + */ + Builder defaultRequest(Consumer> defaultRequest); + + /** + * Add the given filter to the filter chain. + * @param filter the filter to be added to the chain + */ + Builder filter(ExchangeFilterFunction filter); + + /** + * Manipulate the filters with the given consumer. The list provided to + * the consumer is "live", so that the consumer can be used to remove + * filters, change ordering, etc. + * @param filtersConsumer a function that consumes the filter list + * @return this builder + */ + Builder filters(Consumer> filtersConsumer); + + /** + * Configure the {@link ClientHttpConnector} to use. This is useful for + * plugging in and/or customizing options of the underlying HTTP client + * library (e.g. SSL). + *

By default this is set to + * {@link org.springframework.http.client.reactive.ReactorClientHttpConnector + * ReactorClientHttpConnector}. + * @param connector the connector to use + */ + Builder clientConnector(ClientHttpConnector connector); + + /** + * Configure the codecs for the {@code WebClient} in the + * {@link #exchangeStrategies(ExchangeStrategies) underlying} + * {@code ExchangeStrategies}. + * @param configurer the configurer to apply + * @since 5.1.13 + */ + Builder codecs(Consumer configurer); + + /** + * Configure the {@link ExchangeStrategies} to use. + *

For most cases, prefer using {@link #codecs(Consumer)} which allows + * customizing the codecs in the {@code ExchangeStrategies} rather than + * replace them. That ensures multiple parties can contribute to codecs + * configuration. + *

By default this is set to {@link ExchangeStrategies#withDefaults()}. + * @param strategies the strategies to use + */ + Builder exchangeStrategies(ExchangeStrategies strategies); + + /** + * Customize the strategies configured via + * {@link #exchangeStrategies(ExchangeStrategies)}. This method is + * designed for use in scenarios where multiple parties wish to update + * the {@code ExchangeStrategies}. + * @deprecated as of 5.1.13 in favor of {@link #codecs(Consumer)} + */ + @Deprecated + Builder exchangeStrategies(Consumer configurer); + + /** + * Provide an {@link ExchangeFunction} pre-configured with + * {@link ClientHttpConnector} and {@link ExchangeStrategies}. + *

This is an alternative to, and effectively overrides + * {@link #clientConnector}, and + * {@link #exchangeStrategies(ExchangeStrategies)}. + * @param exchangeFunction the exchange function to use + */ + Builder exchangeFunction(ExchangeFunction exchangeFunction); + + /** + * Apply the given {@code Consumer} to this builder instance. + *

This can be useful for applying pre-packaged customizations. + * @param builderConsumer the consumer to apply + */ + Builder apply(Consumer builderConsumer); + + /** + * Clone this {@code WebClient.Builder}. + */ + Builder clone(); + + /** + * Builder the {@link WebClient} instance. + */ + WebClient build(); + } + + + /** + * Contract for specifying the URI for a request. + * @param a self reference to the spec type + */ + interface UriSpec> { + + /** + * Specify the URI using an absolute, fully constructed {@link URI}. + */ + S uri(URI uri); + + /** + * Specify the URI for the request using a URI template and URI variables. + * If a {@link UriBuilderFactory} was configured for the client (e.g. + * with a base URI) it will be used to expand the URI template. + */ + S uri(String uri, Object... uriVariables); + + /** + * Specify the URI for the request using a URI template and URI variables. + * If a {@link UriBuilderFactory} was configured for the client (e.g. + * with a base URI) it will be used to expand the URI template. + */ + S uri(String uri, Map uriVariables); + + /** + * Build the URI for the request using the {@link UriBuilderFactory} + * configured for this client. + */ + S uri(Function uriFunction); + } + + + /** + * Contract for specifying request headers leading up to the exchange. + * @param a self reference to the spec type + */ + interface RequestHeadersSpec> { + + /** + * Set the list of acceptable {@linkplain MediaType media types}, as + * specified by the {@code Accept} header. + * @param acceptableMediaTypes the acceptable media types + * @return this builder + */ + S accept(MediaType... acceptableMediaTypes); + + /** + * Set the list of acceptable {@linkplain Charset charsets}, as specified + * by the {@code Accept-Charset} header. + * @param acceptableCharsets the acceptable charsets + * @return this builder + */ + S acceptCharset(Charset... acceptableCharsets); + + /** + * Add a cookie with the given name and value. + * @param name the cookie name + * @param value the cookie value + * @return this builder + */ + S cookie(String name, String value); + + /** + * Provides access to every cookie declared so far with the possibility + * to add, replace, or remove values. + * @param cookiesConsumer the consumer to provide access to + * @return this builder + */ + S cookies(Consumer> cookiesConsumer); + + /** + * Set the value of the {@code If-Modified-Since} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param ifModifiedSince the new value of the header + * @return this builder + */ + S ifModifiedSince(ZonedDateTime ifModifiedSince); + + /** + * Set the values of the {@code If-None-Match} header. + * @param ifNoneMatches the new value of the header + * @return this builder + */ + S ifNoneMatch(String... ifNoneMatches); + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + */ + S header(String headerName, String... headerValues); + + /** + * Provides access to every header declared so far with the possibility + * to add, replace, or remove values. + * @param headersConsumer the consumer to provide access to + * @return this builder + */ + S headers(Consumer headersConsumer); + + /** + * Set the attribute with the given name to the given value. + * @param name the name of the attribute to add + * @param value the value of the attribute to add + * @return this builder + */ + S attribute(String name, Object value); + + /** + * Provides access to every attribute declared so far with the + * possibility to add, replace, or remove values. + * @param attributesConsumer the consumer to provide access to + * @return this builder + */ + S attributes(Consumer> attributesConsumer); + + /** + * Perform the HTTP request and retrieve the response body: + *

+		 * Mono<Person> bodyMono = client.get()
+		 *     .uri("/persons/1")
+		 *     .accept(MediaType.APPLICATION_JSON)
+		 *     .retrieve()
+		 *     .bodyToMono(Person.class);
+		 * 
+ *

This method is a shortcut to using {@link #exchange()} and + * decoding the response body through {@link ClientResponse}. + * @return {@code ResponseSpec} to specify how to decode the body + * @see #exchange() + */ + ResponseSpec retrieve(); + + /** + * Perform the HTTP request and return a {@link ClientResponse} with the + * response status and headers. You can then use methods of the response + * to consume the body: + *

+		 * Mono<Person> mono = client.get()
+		 *     .uri("/persons/1")
+		 *     .accept(MediaType.APPLICATION_JSON)
+		 *     .exchange()
+		 *     .flatMap(response -> response.bodyToMono(Person.class));
+		 *
+		 * Flux<Person> flux = client.get()
+		 *     .uri("/persons")
+		 *     .accept(MediaType.APPLICATION_STREAM_JSON)
+		 *     .exchange()
+		 *     .flatMapMany(response -> response.bodyToFlux(Person.class));
+		 * 
+ *

NOTE: You must always use one of the body or + * entity methods of the response to ensure resources are released. + * See {@link ClientResponse} for more details. + * @return a {@code Mono} for the response + * @see #retrieve() + */ + Mono exchange(); + } + + + /** + * Contract for specifying request headers and body leading up to the exchange. + */ + interface RequestBodySpec extends RequestHeadersSpec { + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + * @param contentLength the content length + * @return this builder + * @see HttpHeaders#setContentLength(long) + */ + RequestBodySpec contentLength(long contentLength); + + /** + * Set the {@linkplain MediaType media type} of the body, as specified + * by the {@code Content-Type} header. + * @param contentType the content type + * @return this builder + * @see HttpHeaders#setContentType(MediaType) + */ + RequestBodySpec contentType(MediaType contentType); + + /** + * Set the body of the request using the given body inserter. + * {@link BodyInserters} provides access to built-in implementations of + * {@link BodyInserter}. + * @param inserter the body inserter to use for the request body + * @return this builder + * @see org.springframework.web.reactive.function.BodyInserters + */ + RequestHeadersSpec body(BodyInserter inserter); + + /** + * A shortcut for {@link #body(BodyInserter)} with a + * {@linkplain BodyInserters#fromPublisher Publisher inserter}. + * For example: + *

+		 * Mono<Person> personMono = ... ;
+		 *
+		 * Mono<Void> result = client.post()
+		 *     .uri("/persons/{id}", id)
+		 *     .contentType(MediaType.APPLICATION_JSON)
+		 *     .body(personMono, Person.class)
+		 *     .retrieve()
+		 *     .bodyToMono(Void.class);
+		 * 
+ * @param publisher the {@code Publisher} to write to the request + * @param elementClass the class of elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return this builder + */ + > RequestHeadersSpec body(P publisher, Class elementClass); + + /** + * A variant of {@link #body(Publisher, Class)} that allows providing + * element type information that includes generics via a + * {@link ParameterizedTypeReference}. + * @param publisher the {@code Publisher} to write to the request + * @param typeReference the type reference of elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return this builder + */ + > RequestHeadersSpec body(P publisher, + ParameterizedTypeReference typeReference); + + /** + * A shortcut for {@link #body(BodyInserter)} with an + * {@linkplain BodyInserters#fromObject Object inserter}. + * For example: + *

+		 * Person person = ... ;
+		 *
+		 * Mono<Void> result = client.post()
+		 *     .uri("/persons/{id}", id)
+		 *     .contentType(MediaType.APPLICATION_JSON)
+		 *     .syncBody(person)
+		 *     .retrieve()
+		 *     .bodyToMono(Void.class);
+		 * 
+ *

For multipart requests, provide a + * {@link org.springframework.util.MultiValueMap MultiValueMap}. The + * values in the {@code MultiValueMap} can be any Object representing + * the body of the part, or an + * {@link org.springframework.http.HttpEntity HttpEntity} representing + * a part with body and headers. The {@code MultiValueMap} can be built + * with {@link org.springframework.http.client.MultipartBodyBuilder + * MultipartBodyBuilder}. + * @param body the {@code Object} to write to the request + * @return this builder + */ + RequestHeadersSpec syncBody(Object body); + } + + + /** + * Contract for specifying response operations following the exchange. + */ + interface ResponseSpec { + + /** + * Register a custom error function that gets invoked when the given {@link HttpStatus} + * predicate applies. The exception returned from the function will be returned from + * {@link #bodyToMono(Class)} and {@link #bodyToFlux(Class)}. + *

By default, an error handler is registered that throws a + * {@link WebClientResponseException} when the response status code is 4xx or 5xx. + * @param statusPredicate a predicate that indicates whether {@code exceptionFunction} + * applies + *

NOTE: if the response is expected to have content, + * the exceptionFunction should consume it. If not, the content will be + * automatically drained to ensure resources are released. + * @param exceptionFunction the function that returns the exception + * @return this builder + */ + ResponseSpec onStatus(Predicate statusPredicate, + Function> exceptionFunction); + + /** + * Register a custom error function that gets invoked when the given raw status code + * predicate applies. The exception returned from the function will be returned from + * {@link #bodyToMono(Class)} and {@link #bodyToFlux(Class)}. + *

By default, an error handler is registered that throws a + * {@link WebClientResponseException} when the response status code is 4xx or 5xx. + * @param statusCodePredicate a predicate of the raw status code that indicates + * whether {@code exceptionFunction} applies. + *

NOTE: if the response is expected to have content, + * the exceptionFunction should consume it. If not, the content will be + * automatically drained to ensure resources are released. + * @param exceptionFunction the function that returns the exception + * @return this builder + * @since 5.1.9 + */ + ResponseSpec onRawStatus(IntPredicate statusCodePredicate, + Function> exceptionFunction); + + /** + * Extract the body to a {@code Mono}. By default, if the response has status code 4xx or + * 5xx, the {@code Mono} will contain a {@link WebClientException}. This can be overridden + * with {@link #onStatus(Predicate, Function)}. + * @param bodyType the expected response body type + * @param response body type + * @return a mono containing the body, or a {@link WebClientResponseException} if the + * status code is 4xx or 5xx + */ + Mono bodyToMono(Class bodyType); + + /** + * Extract the body to a {@code Mono}. By default, if the response has status code 4xx or + * 5xx, the {@code Mono} will contain a {@link WebClientException}. This can be overridden + * with {@link #onStatus(Predicate, Function)}. + * @param typeReference a type reference describing the expected response body type + * @param response body type + * @return a mono containing the body, or a {@link WebClientResponseException} if the + * status code is 4xx or 5xx + */ + Mono bodyToMono(ParameterizedTypeReference typeReference); + + /** + * Extract the body to a {@code Flux}. By default, if the response has status code 4xx or + * 5xx, the {@code Flux} will contain a {@link WebClientException}. This can be overridden + * with {@link #onStatus(Predicate, Function)}. + * @param elementType the type of element in the response + * @param the type of elements in the response + * @return a flux containing the body, or a {@link WebClientResponseException} if the + * status code is 4xx or 5xx + */ + Flux bodyToFlux(Class elementType); + + /** + * Extract the body to a {@code Flux}. By default, if the response has status code 4xx or + * 5xx, the {@code Flux} will contain a {@link WebClientException}. This can be overridden + * with {@link #onStatus(Predicate, Function)}. + * @param typeReference a type reference describing the expected response body type + * @param the type of elements in the response + * @return a flux containing the body, or a {@link WebClientResponseException} if the + * status code is 4xx or 5xx + */ + Flux bodyToFlux(ParameterizedTypeReference typeReference); + } + + + /** + * Contract for specifying request headers and URI for a request. + * @param a self reference to the spec type + */ + interface RequestHeadersUriSpec> + extends UriSpec, RequestHeadersSpec { + } + + + /** + * Contract for specifying request headers, body and URI for a request. + */ + interface RequestBodyUriSpec extends RequestBodySpec, RequestHeadersUriSpec { + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClientException.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClientException.java new file mode 100644 index 0000000000000000000000000000000000000000..1e8f5fde34b0d78d6c3f9a287f588ec6c3e60589 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClientException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import org.springframework.core.NestedRuntimeException; + +/** + * Abstract base class for exception published by {@link WebClient} in case of errors. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public abstract class WebClientException extends NestedRuntimeException { + + private static final long serialVersionUID = 472776714118912855L; + + /** + * Construct a new instance of {@code WebClientException} with the given message. + * @param msg the message + */ + public WebClientException(String msg) { + super(msg); + } + + /** + * Construct a new instance of {@code WebClientException} with the given message + * and exception. + * @param msg the message + * @param ex the exception + */ + public WebClientException(String msg, Throwable ex) { + super(msg, ex); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClientResponseException.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClientResponseException.java new file mode 100644 index 0000000000000000000000000000000000000000..da3135e70799cf2f23430de4225e8fe4b624380e --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClientResponseException.java @@ -0,0 +1,435 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; + +/** + * Exceptions that contain actual HTTP response data. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public class WebClientResponseException extends WebClientException { + + private static final long serialVersionUID = 4127543205414951611L; + + + private final int statusCode; + + private final String statusText; + + private final byte[] responseBody; + + private final HttpHeaders headers; + + private final Charset responseCharset; + + @Nullable + private final HttpRequest request; + + + /** + * Constructor with response data only, and a default message. + * @since 5.1 + */ + public WebClientResponseException(int statusCode, String statusText, + @Nullable HttpHeaders headers, @Nullable byte[] body, @Nullable Charset charset) { + + this(statusCode, statusText, headers, body, charset, null); + } + + /** + * Constructor with response data only, and a default message. + * @since 5.1.4 + */ + public WebClientResponseException(int statusCode, String statusText, + @Nullable HttpHeaders headers, @Nullable byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + + this(statusCode + " " + statusText, statusCode, statusText, headers, body, charset, request); + } + + /** + * Constructor with a prepared message. + */ + public WebClientResponseException(String message, int statusCode, String statusText, + @Nullable HttpHeaders headers, @Nullable byte[] responseBody, @Nullable Charset charset) { + this(message, statusCode, statusText, headers, responseBody, charset, null); + } + + /** + * Constructor with a prepared message. + * @since 5.1.4 + */ + public WebClientResponseException(String message, int statusCode, String statusText, + @Nullable HttpHeaders headers, @Nullable byte[] responseBody, @Nullable Charset charset, + @Nullable HttpRequest request) { + + super(message); + + this.statusCode = statusCode; + this.statusText = statusText; + this.headers = (headers != null ? headers : HttpHeaders.EMPTY); + this.responseBody = (responseBody != null ? responseBody : new byte[0]); + this.responseCharset = (charset != null ? charset : StandardCharsets.ISO_8859_1); + this.request = request; + } + + + /** + * Return the HTTP status code value. + * @throws IllegalArgumentException in case of an unknown HTTP status code + */ + public HttpStatus getStatusCode() { + return HttpStatus.valueOf(this.statusCode); + } + + /** + * Return the raw HTTP status code value. + */ + public int getRawStatusCode() { + return this.statusCode; + } + + /** + * Return the HTTP status text. + */ + public String getStatusText() { + return this.statusText; + } + + /** + * Return the HTTP response headers. + */ + public HttpHeaders getHeaders() { + return this.headers; + } + + /** + * Return the response body as a byte array. + */ + public byte[] getResponseBodyAsByteArray() { + return this.responseBody; + } + + /** + * Return the response body as a string. + */ + public String getResponseBodyAsString() { + return new String(this.responseBody, this.responseCharset); + } + + /** + * Return the corresponding request. + * @since 5.1.4 + */ + @Nullable + public HttpRequest getRequest() { + return this.request; + } + + /** + * Create {@code WebClientResponseException} or an HTTP status specific subclass. + * @since 5.1 + */ + public static WebClientResponseException create( + int statusCode, String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset) { + + return create(statusCode, statusText, headers, body, charset, null); + } + + /** + * Create {@code WebClientResponseException} or an HTTP status specific subclass. + * @since 5.1.4 + */ + public static WebClientResponseException create( + int statusCode, String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + + HttpStatus httpStatus = HttpStatus.resolve(statusCode); + if (httpStatus != null) { + switch (httpStatus) { + case BAD_REQUEST: + return new WebClientResponseException.BadRequest(statusText, headers, body, charset, request); + case UNAUTHORIZED: + return new WebClientResponseException.Unauthorized(statusText, headers, body, charset, request); + case FORBIDDEN: + return new WebClientResponseException.Forbidden(statusText, headers, body, charset, request); + case NOT_FOUND: + return new WebClientResponseException.NotFound(statusText, headers, body, charset, request); + case METHOD_NOT_ALLOWED: + return new WebClientResponseException.MethodNotAllowed(statusText, headers, body, charset, request); + case NOT_ACCEPTABLE: + return new WebClientResponseException.NotAcceptable(statusText, headers, body, charset, request); + case CONFLICT: + return new WebClientResponseException.Conflict(statusText, headers, body, charset, request); + case GONE: + return new WebClientResponseException.Gone(statusText, headers, body, charset, request); + case UNSUPPORTED_MEDIA_TYPE: + return new WebClientResponseException.UnsupportedMediaType(statusText, headers, body, charset, request); + case TOO_MANY_REQUESTS: + return new WebClientResponseException.TooManyRequests(statusText, headers, body, charset, request); + case UNPROCESSABLE_ENTITY: + return new WebClientResponseException.UnprocessableEntity(statusText, headers, body, charset, request); + case INTERNAL_SERVER_ERROR: + return new WebClientResponseException.InternalServerError(statusText, headers, body, charset, request); + case NOT_IMPLEMENTED: + return new WebClientResponseException.NotImplemented(statusText, headers, body, charset, request); + case BAD_GATEWAY: + return new WebClientResponseException.BadGateway(statusText, headers, body, charset, request); + case SERVICE_UNAVAILABLE: + return new WebClientResponseException.ServiceUnavailable(statusText, headers, body, charset, request); + case GATEWAY_TIMEOUT: + return new WebClientResponseException.GatewayTimeout(statusText, headers, body, charset, request); + } + } + return new WebClientResponseException(statusCode, statusText, headers, body, charset, request); + } + + + + // Subclasses for specific, client-side, HTTP status codes + + /** + * {@link WebClientResponseException} for status HTTP 400 Bad Request. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class BadRequest extends WebClientResponseException { + + BadRequest(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + super(HttpStatus.BAD_REQUEST.value(), statusText, headers, body, charset, request); + } + + } + + /** + * {@link WebClientResponseException} for status HTTP 401 Unauthorized. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Unauthorized extends WebClientResponseException { + + Unauthorized(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + super(HttpStatus.UNAUTHORIZED.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 403 Forbidden. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Forbidden extends WebClientResponseException { + + Forbidden(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + super(HttpStatus.FORBIDDEN.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 404 Not Found. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class NotFound extends WebClientResponseException { + + NotFound(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + super(HttpStatus.NOT_FOUND.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 405 Method Not Allowed. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class MethodNotAllowed extends WebClientResponseException { + + MethodNotAllowed(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.METHOD_NOT_ALLOWED.value(), statusText, headers, body, charset, + request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 406 Not Acceptable. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class NotAcceptable extends WebClientResponseException { + + NotAcceptable(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.NOT_ACCEPTABLE.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 409 Conflict. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Conflict extends WebClientResponseException { + + Conflict(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + super(HttpStatus.CONFLICT.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 410 Gone. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class Gone extends WebClientResponseException { + + Gone(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + super(HttpStatus.GONE.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 415 Unsupported Media Type. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class UnsupportedMediaType extends WebClientResponseException { + + UnsupportedMediaType(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + + super(HttpStatus.UNSUPPORTED_MEDIA_TYPE.value(), statusText, headers, body, charset, + request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 422 Unprocessable Entity. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class UnprocessableEntity extends WebClientResponseException { + + UnprocessableEntity(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.UNPROCESSABLE_ENTITY.value(), statusText, headers, body, charset, + request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 429 Too Many Requests. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class TooManyRequests extends WebClientResponseException { + + TooManyRequests(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.TOO_MANY_REQUESTS.value(), statusText, headers, body, charset, + request); + } + } + + + + // Subclasses for specific, server-side, HTTP status codes + + /** + * {@link WebClientResponseException} for status HTTP 500 Internal Server Error. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class InternalServerError extends WebClientResponseException { + + InternalServerError(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.INTERNAL_SERVER_ERROR.value(), statusText, headers, body, charset, + request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 501 Not Implemented. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class NotImplemented extends WebClientResponseException { + + NotImplemented(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.NOT_IMPLEMENTED.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP HTTP 502 Bad Gateway. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class BadGateway extends WebClientResponseException { + + BadGateway(String statusText, HttpHeaders headers, byte[] body, @Nullable Charset charset, + @Nullable HttpRequest request) { + super(HttpStatus.BAD_GATEWAY.value(), statusText, headers, body, charset, request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 503 Service Unavailable. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class ServiceUnavailable extends WebClientResponseException { + + ServiceUnavailable(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.SERVICE_UNAVAILABLE.value(), statusText, headers, body, charset, + request); + } + } + + /** + * {@link WebClientResponseException} for status HTTP 504 Gateway Timeout. + * @since 5.1 + */ + @SuppressWarnings("serial") + public static class GatewayTimeout extends WebClientResponseException { + + GatewayTimeout(String statusText, HttpHeaders headers, byte[] body, + @Nullable Charset charset, @Nullable HttpRequest request) { + super(HttpStatus.GATEWAY_TIMEOUT.value(), statusText, headers, body, charset, + request); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..7c67791d7c173f1dd4f9e94c5bb083512ef52df5 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/package-info.java @@ -0,0 +1,11 @@ +/** + * Provides a reactive {@link org.springframework.web.reactive.function.client.WebClient} + * that builds on top of the + * {@code org.springframework.http.client.reactive} reactive HTTP adapter layer. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.function.client; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/support/ClientResponseWrapper.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/support/ClientResponseWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..f255977d82f05bb1be7e008c5082976117a639f6 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/support/ClientResponseWrapper.java @@ -0,0 +1,181 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.client.support; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalLong; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeStrategies; + +/** + * Implementation of the {@link ClientResponse} interface that can be subclassed + * to adapt the request in a + * {@link org.springframework.web.reactive.function.client.ExchangeFilterFunction exchange filter function}. + * All methods default to calling through to the wrapped request. + * + * @author Arjen Poutsma + * @since 5.0.5 + */ +public class ClientResponseWrapper implements ClientResponse { + + private final ClientResponse delegate; + + + /** + * Create a new {@code ClientResponseWrapper} that wraps the given response. + * @param delegate the response to wrap + */ + public ClientResponseWrapper(ClientResponse delegate) { + Assert.notNull(delegate, "Delegate is required"); + this.delegate = delegate; + } + + + /** + * Return the wrapped request. + */ + public ClientResponse response() { + return this.delegate; + } + + @Override + public ExchangeStrategies strategies() { + return this.delegate.strategies(); + } + + @Override + public HttpStatus statusCode() { + return this.delegate.statusCode(); + } + + @Override + public int rawStatusCode() { + return this.delegate.rawStatusCode(); + } + + @Override + public Headers headers() { + return this.delegate.headers(); + } + + @Override + public MultiValueMap cookies() { + return this.delegate.cookies(); + } + + @Override + public T body(BodyExtractor extractor) { + return this.delegate.body(extractor); + } + + @Override + public Mono bodyToMono(Class elementClass) { + return this.delegate.bodyToMono(elementClass); + } + + @Override + public Mono bodyToMono(ParameterizedTypeReference typeReference) { + return this.delegate.bodyToMono(typeReference); + } + + @Override + public Flux bodyToFlux(Class elementClass) { + return this.delegate.bodyToFlux(elementClass); + } + + @Override + public Flux bodyToFlux(ParameterizedTypeReference typeReference) { + return this.delegate.bodyToFlux(typeReference); + } + + @Override + public Mono> toEntity(Class bodyType) { + return this.delegate.toEntity(bodyType); + } + + @Override + public Mono> toEntity(ParameterizedTypeReference typeReference) { + return this.delegate.toEntity(typeReference); + } + + @Override + public Mono>> toEntityList(Class elementType) { + return this.delegate.toEntityList(elementType); + } + + @Override + public Mono>> toEntityList(ParameterizedTypeReference typeReference) { + return this.delegate.toEntityList(typeReference); + } + + /** + * Implementation of the {@code Headers} interface that can be subclassed + * to adapt the headers in a + * {@link org.springframework.web.reactive.function.client.ExchangeFilterFunction exchange filter function}. + * All methods default to calling through to the wrapped request. + */ + public static class HeadersWrapper implements ClientResponse.Headers { + + private final Headers headers; + + + /** + * Create a new {@code HeadersWrapper} that wraps the given request. + * @param headers the headers to wrap + */ + public HeadersWrapper(Headers headers) { + this.headers = headers; + } + + + @Override + public OptionalLong contentLength() { + return this.headers.contentLength(); + } + + @Override + public Optional contentType() { + return this.headers.contentType(); + } + + @Override + public List header(String headerName) { + return this.headers.header(headerName); + } + + @Override + public HttpHeaders asHttpHeaders() { + return this.headers.asHttpHeaders(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/support/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..f1df06d3885d9eb872511a44b38e2d837795674c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/support/package-info.java @@ -0,0 +1,10 @@ +/** + * Classes supporting the {@code org.springframework.web.reactive.function.client} package. + * Contains a {@code ClientResponse} wrapper to adapt a request. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.function.client.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..b7ef8b31baa645fab09251a44623dd374a0b7b98 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides a foundation for both the reactive client and server subpackages. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.function; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultEntityResponseBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultEntityResponseBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..d43a4e71f1a52a0b3fa05ae5d57bd240aef25256 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultEntityResponseBuilder.java @@ -0,0 +1,251 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.URI; +import java.time.Instant; +import java.time.ZonedDateTime; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; + +import reactor.core.publisher.Mono; + +import org.springframework.core.codec.Hints; +import org.springframework.http.CacheControl; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.server.ServerWebExchange; + +/** + * Default {@link EntityResponse.Builder} implementation. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 5.0 + * @param a self reference to the builder type + */ +class DefaultEntityResponseBuilder implements EntityResponse.Builder { + + private final T entity; + + private final BodyInserter inserter; + + private int status = HttpStatus.OK.value(); + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private final Map hints = new HashMap<>(); + + + public DefaultEntityResponseBuilder(T entity, BodyInserter inserter) { + this.entity = entity; + this.inserter = inserter; + } + + + @Override + public EntityResponse.Builder status(HttpStatus status) { + Assert.notNull(status, "HttpStatus must not be null"); + this.status = status.value(); + return this; + } + + @Override + public EntityResponse.Builder status(int status) { + this.status = status; + return this; + } + + @Override + public EntityResponse.Builder cookie(ResponseCookie cookie) { + Assert.notNull(cookie, "ResponseCookie must not be null"); + this.cookies.add(cookie.getName(), cookie); + return this; + } + + @Override + public EntityResponse.Builder cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(this.cookies); + return this; + } + + @Override + public EntityResponse.Builder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public EntityResponse.Builder headers(HttpHeaders headers) { + this.headers.putAll(headers); + return this; + } + + @Override + public EntityResponse.Builder allow(HttpMethod... allowedMethods) { + this.headers.setAllow(new LinkedHashSet<>(Arrays.asList(allowedMethods))); + return this; + } + + @Override + public EntityResponse.Builder allow(Set allowedMethods) { + this.headers.setAllow(allowedMethods); + return this; + } + + @Override + public EntityResponse.Builder contentLength(long contentLength) { + this.headers.setContentLength(contentLength); + return this; + } + + @Override + public EntityResponse.Builder contentType(MediaType contentType) { + this.headers.setContentType(contentType); + return this; + } + + @Override + public EntityResponse.Builder eTag(String etag) { + if (!etag.startsWith("\"") && !etag.startsWith("W/\"")) { + etag = "\"" + etag; + } + if (!etag.endsWith("\"")) { + etag = etag + "\""; + } + this.headers.setETag(etag); + return this; + } + + @Override + public EntityResponse.Builder hint(String key, Object value) { + this.hints.put(key, value); + return this; + } + + @Override + public EntityResponse.Builder hints(Consumer> hintsConsumer) { + hintsConsumer.accept(this.hints); + return this; + } + + @Override + public EntityResponse.Builder lastModified(ZonedDateTime lastModified) { + this.headers.setLastModified(lastModified); + return this; + } + + @Override + public EntityResponse.Builder lastModified(Instant lastModified) { + this.headers.setLastModified(lastModified); + return this; + } + + @Override + public EntityResponse.Builder location(URI location) { + this.headers.setLocation(location); + return this; + } + + @Override + public EntityResponse.Builder cacheControl(CacheControl cacheControl) { + this.headers.setCacheControl(cacheControl); + return this; + } + + @Override + public EntityResponse.Builder varyBy(String... requestHeaders) { + this.headers.setVary(Arrays.asList(requestHeaders)); + return this; + } + + @Override + public Mono> build() { + return Mono.just(new DefaultEntityResponse( + this.status, this.headers, this.cookies, this.entity, this.inserter, this.hints)); + } + + + private static final class DefaultEntityResponse + extends DefaultServerResponseBuilder.AbstractServerResponse + implements EntityResponse { + + private final T entity; + + private final BodyInserter inserter; + + + public DefaultEntityResponse(int statusCode, HttpHeaders headers, + MultiValueMap cookies, T entity, + BodyInserter inserter, Map hints) { + + super(statusCode, headers, cookies, hints); + this.entity = entity; + this.inserter = inserter; + } + + @Override + public T entity() { + return this.entity; + } + + @Override + public BodyInserter inserter() { + return this.inserter; + } + + @Override + protected Mono writeToInternal(ServerWebExchange exchange, Context context) { + return inserter().insert(exchange.getResponse(), new BodyInserter.Context() { + @Override + public List> messageWriters() { + return context.messageWriters(); + } + @Override + public Optional serverRequest() { + return Optional.of(exchange.getRequest()); + } + @Override + public Map hints() { + hints.put(Hints.LOG_PREFIX_HINT, exchange.getLogPrefix()); + return hints; + } + }); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultHandlerStrategiesBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultHandlerStrategiesBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..3565d7f7679ed3ec0d981867473ca02cf0c8da3b --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultHandlerStrategiesBuilder.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.util.Assert; +import org.springframework.web.reactive.handler.WebFluxResponseStatusExceptionHandler; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.i18n.LocaleContextResolver; + +/** + * Default implementation of {@link HandlerStrategies.Builder}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder { + + private final ServerCodecConfigurer codecConfigurer = ServerCodecConfigurer.create(); + + private final List viewResolvers = new ArrayList<>(); + + private final List webFilters = new ArrayList<>(); + + private final List exceptionHandlers = new ArrayList<>(); + + private LocaleContextResolver localeContextResolver = new AcceptHeaderLocaleContextResolver(); + + + public DefaultHandlerStrategiesBuilder() { + this.codecConfigurer.registerDefaults(false); + } + + + public void defaultConfiguration() { + this.codecConfigurer.registerDefaults(true); + this.exceptionHandlers.add(new WebFluxResponseStatusExceptionHandler()); + this.localeContextResolver = new AcceptHeaderLocaleContextResolver(); + } + + @Override + public HandlerStrategies.Builder codecs(Consumer consumer) { + consumer.accept(this.codecConfigurer); + return this; + } + + @Override + public HandlerStrategies.Builder viewResolver(ViewResolver viewResolver) { + Assert.notNull(viewResolver, "ViewResolver must not be null"); + this.viewResolvers.add(viewResolver); + return this; + } + + @Override + public HandlerStrategies.Builder webFilter(WebFilter filter) { + Assert.notNull(filter, "WebFilter must not be null"); + this.webFilters.add(filter); + return this; + } + + @Override + public HandlerStrategies.Builder exceptionHandler(WebExceptionHandler exceptionHandler) { + Assert.notNull(exceptionHandler, "WebExceptionHandler must not be null"); + this.exceptionHandlers.add(exceptionHandler); + return this; + } + + @Override + public HandlerStrategies.Builder localeContextResolver(LocaleContextResolver localeContextResolver) { + Assert.notNull(localeContextResolver, "LocaleContextResolver must not be null"); + this.localeContextResolver = localeContextResolver; + return this; + } + + @Override + public HandlerStrategies build() { + return new DefaultHandlerStrategies(this.codecConfigurer.getReaders(), + this.codecConfigurer.getWriters(), this.viewResolvers, this.webFilters, + this.exceptionHandlers, this.localeContextResolver); + } + + + private static class DefaultHandlerStrategies implements HandlerStrategies { + + private final List> messageReaders; + + private final List> messageWriters; + + private final List viewResolvers; + + private final List webFilters; + + private final List exceptionHandlers; + + private final LocaleContextResolver localeContextResolver; + + public DefaultHandlerStrategies( + List> messageReaders, + List> messageWriters, + List viewResolvers, + List webFilters, + List exceptionHandlers, + LocaleContextResolver localeContextResolver) { + + this.messageReaders = unmodifiableCopy(messageReaders); + this.messageWriters = unmodifiableCopy(messageWriters); + this.viewResolvers = unmodifiableCopy(viewResolvers); + this.webFilters = unmodifiableCopy(webFilters); + this.exceptionHandlers = unmodifiableCopy(exceptionHandlers); + this.localeContextResolver = localeContextResolver; + } + + private static List unmodifiableCopy(List list) { + return Collections.unmodifiableList(new ArrayList<>(list)); + } + + @Override + public List> messageReaders() { + return this.messageReaders; + } + + @Override + public List> messageWriters() { + return this.messageWriters; + } + + @Override + public List viewResolvers() { + return this.viewResolvers; + } + + @Override + public List webFilters() { + return this.webFilters; + } + + @Override + public List exceptionHandlers() { + return this.exceptionHandlers; + } + + @Override + public LocaleContextResolver localeContextResolver() { + return this.localeContextResolver; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultRenderingResponseBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultRenderingResponseBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..a67ae486532ff412d9ad41a3d8d53b9edfc9261a --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultRenderingResponseBuilder.java @@ -0,0 +1,207 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.context.i18n.LocaleContextHolder; +import org.springframework.core.Conventions; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.ServerWebExchange; + +/** + * Default {@link RenderingResponse.Builder} implementation. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 5.0 + */ +final class DefaultRenderingResponseBuilder implements RenderingResponse.Builder { + + private final String name; + + private int status = HttpStatus.OK.value(); + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private final Map model = new LinkedHashMap<>(); + + + public DefaultRenderingResponseBuilder(RenderingResponse other) { + Assert.notNull(other, "RenderingResponse must not be null"); + this.name = other.name(); + this.status = (other instanceof DefaultRenderingResponse ? + ((DefaultRenderingResponse) other).statusCode : other.statusCode().value()); + this.headers.putAll(other.headers()); + this.model.putAll(other.model()); + } + + public DefaultRenderingResponseBuilder(String name) { + Assert.notNull(name, "Name must not be null"); + this.name = name; + } + + + @Override + public RenderingResponse.Builder status(HttpStatus status) { + Assert.notNull(status, "HttpStatus must not be null"); + this.status = status.value(); + return this; + } + + @Override + public RenderingResponse.Builder status(int status) { + this.status = status; + return this; + } + + @Override + public RenderingResponse.Builder cookie(ResponseCookie cookie) { + Assert.notNull(cookie, "ResponseCookie must not be null"); + this.cookies.add(cookie.getName(), cookie); + return this; + } + + @Override + public RenderingResponse.Builder cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(this.cookies); + return this; + } + + @Override + public RenderingResponse.Builder modelAttribute(Object attribute) { + Assert.notNull(attribute, "Attribute must not be null"); + if (attribute instanceof Collection && ((Collection) attribute).isEmpty()) { + return this; + } + return modelAttribute(Conventions.getVariableName(attribute), attribute); + } + + @Override + public RenderingResponse.Builder modelAttribute(String name, @Nullable Object value) { + Assert.notNull(name, "Name must not be null"); + this.model.put(name, value); + return this; + } + + @Override + public RenderingResponse.Builder modelAttributes(Object... attributes) { + modelAttributes(Arrays.asList(attributes)); + return this; + } + + @Override + public RenderingResponse.Builder modelAttributes(Collection attributes) { + attributes.forEach(this::modelAttribute); + return this; + } + + @Override + public RenderingResponse.Builder modelAttributes(Map attributes) { + this.model.putAll(attributes); + return this; + } + + @Override + public RenderingResponse.Builder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public RenderingResponse.Builder headers(HttpHeaders headers) { + this.headers.putAll(headers); + return this; + } + + @Override + public Mono build() { + return Mono.just( + new DefaultRenderingResponse(this.status, this.headers, this.cookies, this.name, this.model)); + } + + + private static final class DefaultRenderingResponse extends DefaultServerResponseBuilder.AbstractServerResponse + implements RenderingResponse { + + private final String name; + + private final Map model; + + public DefaultRenderingResponse(int statusCode, HttpHeaders headers, + MultiValueMap cookies, String name, Map model) { + + super(statusCode, headers, cookies, Collections.emptyMap()); + this.name = name; + this.model = Collections.unmodifiableMap(new LinkedHashMap<>(model)); + } + + @Override + public String name() { + return this.name; + } + + @Override + public Map model() { + return this.model; + } + + @Override + protected Mono writeToInternal(ServerWebExchange exchange, Context context) { + MediaType contentType = exchange.getResponse().getHeaders().getContentType(); + Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext()); + Stream viewResolverStream = context.viewResolvers().stream(); + + return Flux.fromStream(viewResolverStream) + .concatMap(viewResolver -> viewResolver.resolveViewName(name(), locale)) + .next() + .switchIfEmpty(Mono.error(() -> + new IllegalArgumentException("Could not resolve view with name '" + name() + "'"))) + .flatMap(view -> { + List mediaTypes = view.getSupportedMediaTypes(); + return view.render(model(), + contentType == null && !mediaTypes.isEmpty() ? mediaTypes.get(0) : contentType, + exchange); + }); + } + + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..cdfa46cb324f04891cfc042be4dc438103ba69fe --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.Charset; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.function.Function; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.codec.DecodingException; +import org.springframework.core.codec.Hints; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRange; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.BodyExtractors; +import org.springframework.web.reactive.function.UnsupportedMediaTypeException; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.ServerWebInputException; +import org.springframework.web.server.UnsupportedMediaTypeStatusException; +import org.springframework.web.server.WebSession; +import org.springframework.web.util.UriBuilder; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * {@code ServerRequest} implementation based on a {@link ServerWebExchange}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +class DefaultServerRequest implements ServerRequest { + + private static final Function ERROR_MAPPER = + ex -> (ex.getContentType() != null ? + new UnsupportedMediaTypeStatusException( + ex.getContentType(), ex.getSupportedMediaTypes(), ex.getBodyType()) : + new UnsupportedMediaTypeStatusException(ex.getMessage())); + + private static final Function DECODING_MAPPER = + ex -> new ServerWebInputException("Failed to read HTTP message", null, ex); + + + private final ServerWebExchange exchange; + + private final Headers headers; + + private final List> messageReaders; + + + DefaultServerRequest(ServerWebExchange exchange, List> messageReaders) { + this.exchange = exchange; + this.messageReaders = Collections.unmodifiableList(new ArrayList<>(messageReaders)); + this.headers = new DefaultHeaders(); + } + + + @Override + public String methodName() { + return request().getMethodValue(); + } + + @Override + public URI uri() { + return request().getURI(); + } + + @Override + public UriBuilder uriBuilder() { + return UriComponentsBuilder.fromUri(uri()); + } + + @Override + public PathContainer pathContainer() { + return request().getPath(); + } + + @Override + public Headers headers() { + return this.headers; + } + + @Override + public MultiValueMap cookies() { + return request().getCookies(); + } + + @Override + public Optional remoteAddress() { + return Optional.ofNullable(request().getRemoteAddress()); + } + + @Override + public List> messageReaders() { + return this.messageReaders; + } + + @Override + public T body(BodyExtractor extractor) { + return bodyInternal(extractor, Hints.from(Hints.LOG_PREFIX_HINT, exchange().getLogPrefix())); + } + + @Override + public T body(BodyExtractor extractor, Map hints) { + hints = Hints.merge(hints, Hints.LOG_PREFIX_HINT, exchange().getLogPrefix()); + return bodyInternal(extractor, hints); + } + + private T bodyInternal(BodyExtractor extractor, Map hints) { + return extractor.extract(request(), + new BodyExtractor.Context() { + @Override + public List> messageReaders() { + return messageReaders; + } + @Override + public Optional serverResponse() { + return Optional.of(exchange().getResponse()); + } + @Override + public Map hints() { + return hints; + } + }); + } + + @Override + public Mono bodyToMono(Class elementClass) { + Mono mono = body(BodyExtractors.toMono(elementClass)); + return mono.onErrorMap(UnsupportedMediaTypeException.class, ERROR_MAPPER) + .onErrorMap(DecodingException.class, DECODING_MAPPER); + } + + @Override + public Mono bodyToMono(ParameterizedTypeReference typeReference) { + Mono mono = body(BodyExtractors.toMono(typeReference)); + return mono.onErrorMap(UnsupportedMediaTypeException.class, ERROR_MAPPER) + .onErrorMap(DecodingException.class, DECODING_MAPPER); + } + + @Override + public Flux bodyToFlux(Class elementClass) { + Flux flux = body(BodyExtractors.toFlux(elementClass)); + return flux.onErrorMap(UnsupportedMediaTypeException.class, ERROR_MAPPER) + .onErrorMap(DecodingException.class, DECODING_MAPPER); + } + + @Override + public Flux bodyToFlux(ParameterizedTypeReference typeReference) { + Flux flux = body(BodyExtractors.toFlux(typeReference)); + return flux.onErrorMap(UnsupportedMediaTypeException.class, ERROR_MAPPER) + .onErrorMap(DecodingException.class, DECODING_MAPPER); + } + + @Override + public Map attributes() { + return this.exchange.getAttributes(); + } + + @Override + public MultiValueMap queryParams() { + return request().getQueryParams(); + } + + @Override + public Map pathVariables() { + return this.exchange.getAttributeOrDefault( + RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Collections.emptyMap()); + } + + @Override + public Mono session() { + return this.exchange.getSession(); + } + + @Override + public Mono principal() { + return this.exchange.getPrincipal(); + } + + @Override + public Mono> formData() { + return this.exchange.getFormData(); + } + + @Override + public Mono> multipartData() { + return this.exchange.getMultipartData(); + } + + private ServerHttpRequest request() { + return this.exchange.getRequest(); + } + + @Override + public ServerWebExchange exchange() { + return this.exchange; + } + + @Override + public String toString() { + return String.format("HTTP %s %s", method(), path()); + } + + + private class DefaultHeaders implements Headers { + + private HttpHeaders delegate() { + return request().getHeaders(); + } + + @Override + public List accept() { + return delegate().getAccept(); + } + + @Override + public List acceptCharset() { + return delegate().getAcceptCharset(); + } + + @Override + public List acceptLanguage() { + return delegate().getAcceptLanguage(); + } + + @Override + public OptionalLong contentLength() { + long value = delegate().getContentLength(); + return (value != -1 ? OptionalLong.of(value) : OptionalLong.empty()); + } + + @Override + public Optional contentType() { + return Optional.ofNullable(delegate().getContentType()); + } + + @Override + public InetSocketAddress host() { + return delegate().getHost(); + } + + @Override + public List range() { + return delegate().getRange(); + } + + @Override + public List header(String headerName) { + List headerValues = delegate().get(headerName); + return (headerValues != null ? headerValues : Collections.emptyList()); + } + + @Override + public HttpHeaders asHttpHeaders() { + return HttpHeaders.readOnlyHttpHeaders(delegate()); + } + + @Override + public String toString() { + return delegate().toString(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..794dab654af4881e842c27adfaa1afdf11f45074 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilder.java @@ -0,0 +1,451 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.time.Instant; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.i18n.LocaleContext; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.InvalidMediaTypeException; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.RequestPath; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.util.UriUtils; + +/** + * Default {@link ServerRequest.Builder} implementation. + * + * @author Arjen Poutsma + * @author Sam Brannen + * @since 5.1 + */ +class DefaultServerRequestBuilder implements ServerRequest.Builder { + + private final List> messageReaders; + + private ServerWebExchange exchange; + + private String methodName; + + private URI uri; + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private final Map attributes = new LinkedHashMap<>(); + + private Flux body = Flux.empty(); + + + DefaultServerRequestBuilder(ServerRequest other) { + Assert.notNull(other, "ServerRequest must not be null"); + this.messageReaders = other.messageReaders(); + this.exchange = other.exchange(); + this.methodName = other.methodName(); + this.uri = other.uri(); + this.headers.addAll(other.headers().asHttpHeaders()); + this.cookies.addAll(other.cookies()); + this.attributes.putAll(other.attributes()); + } + + + @Override + public ServerRequest.Builder method(HttpMethod method) { + Assert.notNull(method, "HttpMethod must not be null"); + this.methodName = method.name(); + return this; + } + + @Override + public ServerRequest.Builder uri(URI uri) { + Assert.notNull(uri, "URI must not be null"); + this.uri = uri; + return this; + } + + @Override + public ServerRequest.Builder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public ServerRequest.Builder headers(Consumer headersConsumer) { + headersConsumer.accept(this.headers); + return this; + } + + @Override + public ServerRequest.Builder cookie(String name, String... values) { + for (String value : values) { + this.cookies.add(name, new HttpCookie(name, value)); + } + return this; + } + + @Override + public ServerRequest.Builder cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(this.cookies); + return this; + } + + @Override + public ServerRequest.Builder body(Flux body) { + Assert.notNull(body, "Body must not be null"); + releaseBody(); + this.body = body; + return this; + } + + @Override + public ServerRequest.Builder body(String body) { + Assert.notNull(body, "Body must not be null"); + releaseBody(); + DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(); + this.body = Flux.just(body). + map(s -> { + byte[] bytes = body.getBytes(StandardCharsets.UTF_8); + return dataBufferFactory.wrap(bytes); + }); + return this; + } + + private void releaseBody() { + this.body.subscribe(DataBufferUtils.releaseConsumer()); + } + + @Override + public ServerRequest.Builder attribute(String name, Object value) { + this.attributes.put(name, value); + return this; + } + + @Override + public ServerRequest.Builder attributes(Consumer> attributesConsumer) { + attributesConsumer.accept(this.attributes); + return this; + } + + @Override + public ServerRequest build() { + ServerHttpRequest serverHttpRequest = new BuiltServerHttpRequest(this.exchange.getRequest().getId(), + this.methodName, this.uri, this.headers, this.cookies, this.body); + ServerWebExchange exchange = new DelegatingServerWebExchange( + serverHttpRequest, this.attributes, this.exchange, this.messageReaders); + return new DefaultServerRequest(exchange, this.messageReaders); + } + + + private static class BuiltServerHttpRequest implements ServerHttpRequest { + + private static final Pattern QUERY_PATTERN = Pattern.compile("([^&=]+)(=?)([^&]+)?"); + + private final String id; + + private final String method; + + private final URI uri; + + private final RequestPath path; + + private final MultiValueMap queryParams; + + private final HttpHeaders headers; + + private final MultiValueMap cookies; + + private final Flux body; + + public BuiltServerHttpRequest(String id, String method, URI uri, HttpHeaders headers, + MultiValueMap cookies, Flux body) { + + this.id = id; + this.method = method; + this.uri = uri; + this.path = RequestPath.parse(uri, null); + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + this.cookies = unmodifiableCopy(cookies); + this.queryParams = parseQueryParams(uri); + this.body = body; + } + + private static MultiValueMap unmodifiableCopy(MultiValueMap original) { + return CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap<>(original)); + } + + private static MultiValueMap parseQueryParams(URI uri) { + MultiValueMap queryParams = new LinkedMultiValueMap<>(); + String query = uri.getRawQuery(); + if (query != null) { + Matcher matcher = QUERY_PATTERN.matcher(query); + while (matcher.find()) { + String name = UriUtils.decode(matcher.group(1), StandardCharsets.UTF_8); + String eq = matcher.group(2); + String value = matcher.group(3); + if (value != null) { + value = UriUtils.decode(value, StandardCharsets.UTF_8); + } + else { + value = (StringUtils.hasLength(eq) ? "" : null); + } + queryParams.add(name, value); + } + } + return queryParams; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public String getMethodValue() { + return this.method; + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public RequestPath getPath() { + return this.path; + } + + @Override + public HttpHeaders getHeaders() { + return this.headers; + } + + @Override + public MultiValueMap getCookies() { + return this.cookies; + } + + @Override + public MultiValueMap getQueryParams() { + return this.queryParams; + } + + @Override + public Flux getBody() { + return this.body; + } + } + + + private static class DelegatingServerWebExchange implements ServerWebExchange { + + private static final ResolvableType FORM_DATA_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + + private static final ResolvableType MULTIPART_DATA_TYPE = ResolvableType.forClassWithGenerics( + MultiValueMap.class, String.class, Part.class); + + private static final Mono> EMPTY_FORM_DATA = + Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))).cache(); + + private static final Mono> EMPTY_MULTIPART_DATA = + Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))).cache(); + + private final ServerHttpRequest request; + + private final Map attributes; + + private final ServerWebExchange delegate; + + private final Mono> formDataMono; + + private final Mono> multipartDataMono; + + DelegatingServerWebExchange(ServerHttpRequest request, Map attributes, + ServerWebExchange delegate, List> messageReaders) { + + this.request = request; + this.attributes = attributes; + this.delegate = delegate; + this.formDataMono = initFormData(request, messageReaders); + this.multipartDataMono = initMultipartData(request, messageReaders); + } + + @SuppressWarnings("unchecked") + private static Mono> initFormData(ServerHttpRequest request, + List> readers) { + + try { + MediaType contentType = request.getHeaders().getContentType(); + if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(contentType)) { + return ((HttpMessageReader>) readers.stream() + .filter(reader -> reader.canRead(FORM_DATA_TYPE, MediaType.APPLICATION_FORM_URLENCODED)) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No form data HttpMessageReader."))) + .readMono(FORM_DATA_TYPE, request, Hints.none()) + .switchIfEmpty(EMPTY_FORM_DATA) + .cache(); + } + } + catch (InvalidMediaTypeException ex) { + // Ignore + } + return EMPTY_FORM_DATA; + } + + @SuppressWarnings("unchecked") + private static Mono> initMultipartData(ServerHttpRequest request, + List> readers) { + + try { + MediaType contentType = request.getHeaders().getContentType(); + if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) { + return ((HttpMessageReader>) readers.stream() + .filter(reader -> reader.canRead(MULTIPART_DATA_TYPE, MediaType.MULTIPART_FORM_DATA)) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No multipart HttpMessageReader."))) + .readMono(MULTIPART_DATA_TYPE, request, Hints.none()) + .switchIfEmpty(EMPTY_MULTIPART_DATA) + .cache(); + } + } + catch (InvalidMediaTypeException ex) { + // Ignore + } + return EMPTY_MULTIPART_DATA; + } + + @Override + public ServerHttpRequest getRequest() { + return this.request; + } + + @Override + public Map getAttributes() { + return this.attributes; + } + + @Override + public Mono> getFormData() { + return this.formDataMono; + } + + @Override + public Mono> getMultipartData() { + return this.multipartDataMono; + } + + // Delegating methods + + @Override + public ServerHttpResponse getResponse() { + return this.delegate.getResponse(); + } + + @Override + public Mono getSession() { + return this.delegate.getSession(); + } + + @Override + public Mono getPrincipal() { + return this.delegate.getPrincipal(); + } + + @Override + public LocaleContext getLocaleContext() { + return this.delegate.getLocaleContext(); + } + + @Nullable + @Override + public ApplicationContext getApplicationContext() { + return this.delegate.getApplicationContext(); + } + + @Override + public boolean isNotModified() { + return this.delegate.isNotModified(); + } + + @Override + public boolean checkNotModified(Instant lastModified) { + return this.delegate.checkNotModified(lastModified); + } + + @Override + public boolean checkNotModified(String etag) { + return this.delegate.checkNotModified(etag); + } + + @Override + public boolean checkNotModified(@Nullable String etag, Instant lastModified) { + return this.delegate.checkNotModified(etag, lastModified); + } + + @Override + public String transformUrl(String url) { + return this.delegate.transformUrl(url); + } + + @Override + public void addUrlTransformer(Function transformer) { + this.delegate.addUrlTransformer(transformer); + } + + @Override + public String getLogPrefix() { + return this.delegate.getLogPrefix(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerResponseBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerResponseBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..cabbe2bf045cd98d565a1e9c3ce3decbac34745c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerResponseBuilder.java @@ -0,0 +1,437 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.URI; +import java.time.Instant; +import java.time.ZonedDateTime; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.codec.Hints; +import org.springframework.http.CacheControl; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.server.reactive.AbstractServerHttpResponse; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.server.ServerWebExchange; + +/** + * Default {@link ServerResponse.BodyBuilder} implementation. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 5.0 + */ +class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { + + private final int statusCode; + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private final Map hints = new HashMap<>(); + + + public DefaultServerResponseBuilder(ServerResponse other) { + Assert.notNull(other, "ServerResponse must not be null"); + this.headers.addAll(other.headers()); + this.cookies.addAll(other.cookies()); + if (other instanceof AbstractServerResponse) { + AbstractServerResponse abstractOther = (AbstractServerResponse) other; + this.statusCode = abstractOther.statusCode; + this.hints.putAll(abstractOther.hints); + } + else { + this.statusCode = other.statusCode().value(); + } + } + + public DefaultServerResponseBuilder(HttpStatus status) { + Assert.notNull(status, "HttpStatus must not be null"); + this.statusCode = status.value(); + } + + public DefaultServerResponseBuilder(int statusCode) { + this.statusCode = statusCode; + } + + + @Override + public ServerResponse.BodyBuilder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public ServerResponse.BodyBuilder headers(Consumer headersConsumer) { + headersConsumer.accept(this.headers); + return this; + } + + @Override + public ServerResponse.BodyBuilder cookie(ResponseCookie cookie) { + Assert.notNull(cookie, "ResponseCookie must not be null"); + this.cookies.add(cookie.getName(), cookie); + return this; + } + + @Override + public ServerResponse.BodyBuilder cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(this.cookies); + return this; + } + + @Override + public ServerResponse.BodyBuilder allow(HttpMethod... allowedMethods) { + this.headers.setAllow(new LinkedHashSet<>(Arrays.asList(allowedMethods))); + return this; + } + + @Override + public ServerResponse.BodyBuilder allow(Set allowedMethods) { + this.headers.setAllow(allowedMethods); + return this; + } + + @Override + public ServerResponse.BodyBuilder contentLength(long contentLength) { + this.headers.setContentLength(contentLength); + return this; + } + + @Override + public ServerResponse.BodyBuilder contentType(MediaType contentType) { + this.headers.setContentType(contentType); + return this; + } + + @Override + public ServerResponse.BodyBuilder eTag(String etag) { + if (!etag.startsWith("\"") && !etag.startsWith("W/\"")) { + etag = "\"" + etag; + } + if (!etag.endsWith("\"")) { + etag = etag + "\""; + } + this.headers.setETag(etag); + return this; + } + + @Override + public ServerResponse.BodyBuilder hint(String key, Object value) { + this.hints.put(key, value); + return this; + } + + @Override + public ServerResponse.BodyBuilder hints(Consumer> hintsConsumer) { + hintsConsumer.accept(this.hints); + return this; + } + + @Override + public ServerResponse.BodyBuilder lastModified(ZonedDateTime lastModified) { + this.headers.setLastModified(lastModified); + return this; + } + + @Override + public ServerResponse.BodyBuilder lastModified(Instant lastModified) { + this.headers.setLastModified(lastModified); + return this; + } + + @Override + public ServerResponse.BodyBuilder location(URI location) { + this.headers.setLocation(location); + return this; + } + + @Override + public ServerResponse.BodyBuilder cacheControl(CacheControl cacheControl) { + this.headers.setCacheControl(cacheControl); + return this; + } + + @Override + public ServerResponse.BodyBuilder varyBy(String... requestHeaders) { + this.headers.setVary(Arrays.asList(requestHeaders)); + return this; + } + + @Override + public Mono build() { + return build((exchange, handlerStrategies) -> exchange.getResponse().setComplete()); + } + + @Override + public Mono build(Publisher voidPublisher) { + Assert.notNull(voidPublisher, "Publisher must not be null"); + return build((exchange, handlerStrategies) -> + Mono.from(voidPublisher).then(exchange.getResponse().setComplete())); + } + + @Override + public Mono build( + BiFunction> writeFunction) { + + return Mono.just( + new WriterFunctionResponse(this.statusCode, this.headers, this.cookies, writeFunction)); + } + + @Override + public > Mono body(P publisher, Class elementClass) { + Assert.notNull(publisher, "Publisher must not be null"); + Assert.notNull(elementClass, "Element Class must not be null"); + + return new DefaultEntityResponseBuilder<>(publisher, + BodyInserters.fromPublisher(publisher, elementClass)) + .status(this.statusCode) + .headers(this.headers) + .cookies(cookies -> cookies.addAll(this.cookies)) + .hints(hints -> hints.putAll(this.hints)) + .build() + .map(entityResponse -> entityResponse); + } + + @Override + public > Mono body(P publisher, + ParameterizedTypeReference typeReference) { + + Assert.notNull(publisher, "Publisher must not be null"); + Assert.notNull(typeReference, "ParameterizedTypeReference must not be null"); + + return new DefaultEntityResponseBuilder<>(publisher, + BodyInserters.fromPublisher(publisher, typeReference)) + .status(this.statusCode) + .headers(this.headers) + .cookies(cookies -> cookies.addAll(this.cookies)) + .hints(hints -> hints.putAll(this.hints)) + .build() + .map(entityResponse -> entityResponse); + } + + @Override + public Mono syncBody(Object body) { + Assert.notNull(body, "Body must not be null"); + Assert.isTrue(!(body instanceof Publisher), + "Please specify the element class by using body(Publisher, Class)"); + + return new DefaultEntityResponseBuilder<>(body, + BodyInserters.fromObject(body)) + .status(this.statusCode) + .headers(this.headers) + .cookies(cookies -> cookies.addAll(this.cookies)) + .hints(hints -> hints.putAll(this.hints)) + .build() + .map(entityResponse -> entityResponse); + } + + @Override + public Mono body(BodyInserter inserter) { + return Mono.just( + new BodyInserterResponse<>(this.statusCode, this.headers, this.cookies, inserter, this.hints)); + } + + @Override + public Mono render(String name, Object... modelAttributes) { + return new DefaultRenderingResponseBuilder(name) + .status(this.statusCode) + .headers(this.headers) + .cookies(cookies -> cookies.addAll(this.cookies)) + .modelAttributes(modelAttributes) + .build() + .map(renderingResponse -> renderingResponse); + } + + @Override + public Mono render(String name, Map model) { + return new DefaultRenderingResponseBuilder(name) + .status(this.statusCode) + .headers(this.headers) + .cookies(cookies -> cookies.addAll(this.cookies)) + .modelAttributes(model) + .build() + .map(renderingResponse -> renderingResponse); + } + + + /** + * Abstract base class for {@link ServerResponse} implementations. + */ + abstract static class AbstractServerResponse implements ServerResponse { + + private static final Set SAFE_METHODS = EnumSet.of(HttpMethod.GET, HttpMethod.HEAD); + + final int statusCode; + + private final HttpHeaders headers; + + private final MultiValueMap cookies; + + final Map hints; + + + protected AbstractServerResponse( + int statusCode, HttpHeaders headers, MultiValueMap cookies, + Map hints) { + + this.statusCode = statusCode; + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + this.cookies = CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap<>(cookies)); + this.hints = hints; + } + + @Override + public final HttpStatus statusCode() { + return HttpStatus.valueOf(this.statusCode); + } + + @Override + public final HttpHeaders headers() { + return this.headers; + } + + @Override + public MultiValueMap cookies() { + return this.cookies; + } + + @Override + public final Mono writeTo(ServerWebExchange exchange, Context context) { + writeStatusAndHeaders(exchange.getResponse()); + Instant lastModified = Instant.ofEpochMilli(headers().getLastModified()); + HttpMethod httpMethod = exchange.getRequest().getMethod(); + if (SAFE_METHODS.contains(httpMethod) && exchange.checkNotModified(headers().getETag(), lastModified)) { + return exchange.getResponse().setComplete(); + } + else { + return writeToInternal(exchange, context); + } + } + + private void writeStatusAndHeaders(ServerHttpResponse response) { + if (response instanceof AbstractServerHttpResponse) { + ((AbstractServerHttpResponse) response).setStatusCodeValue(this.statusCode); + } + else { + HttpStatus status = HttpStatus.resolve(this.statusCode); + if (status == null) { + throw new IllegalStateException( + "Unresolvable HttpStatus for general ServerHttpResponse: " + this.statusCode); + } + response.setStatusCode(status); + } + copy(this.headers, response.getHeaders()); + copy(this.cookies, response.getCookies()); + } + + protected abstract Mono writeToInternal(ServerWebExchange exchange, Context context); + + private static void copy(MultiValueMap src, MultiValueMap dst) { + if (!src.isEmpty()) { + src.entrySet().stream() + .filter(entry -> !dst.containsKey(entry.getKey())) + .forEach(entry -> dst.put(entry.getKey(), entry.getValue())); + } + } + } + + + private static final class WriterFunctionResponse extends AbstractServerResponse { + + private final BiFunction> writeFunction; + + public WriterFunctionResponse(int statusCode, HttpHeaders headers, + MultiValueMap cookies, + BiFunction> writeFunction) { + + super(statusCode, headers, cookies, Collections.emptyMap()); + Assert.notNull(writeFunction, "BiFunction must not be null"); + this.writeFunction = writeFunction; + } + + @Override + protected Mono writeToInternal(ServerWebExchange exchange, Context context) { + return this.writeFunction.apply(exchange, context); + } + } + + + private static final class BodyInserterResponse extends AbstractServerResponse { + + private final BodyInserter inserter; + + + public BodyInserterResponse(int statusCode, HttpHeaders headers, + MultiValueMap cookies, + BodyInserter body, Map hints) { + + super(statusCode, headers, cookies, hints); + Assert.notNull(body, "BodyInserter must not be null"); + this.inserter = body; + } + + @Override + protected Mono writeToInternal(ServerWebExchange exchange, Context context) { + return this.inserter.insert(exchange.getResponse(), new BodyInserter.Context() { + @Override + public List> messageWriters() { + return context.messageWriters(); + } + @Override + public Optional serverRequest() { + return Optional.of(exchange.getRequest()); + } + @Override + public Map hints() { + hints.put(Hints.LOG_PREFIX_HINT, exchange.getLogPrefix()); + return hints; + } + }); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/EntityResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/EntityResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..ab72a30f5ac35e5f60db5112cb601c75649a67c5 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/EntityResponse.java @@ -0,0 +1,281 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.URI; +import java.time.Instant; +import java.time.ZonedDateTime; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.CacheControl; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.codec.json.Jackson2CodecSupport; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.BodyInserters; + +/** + * Entity-specific subtype of {@link ServerResponse} that exposes entity data. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 5.0 + * @param the entity type + */ +public interface EntityResponse extends ServerResponse { + + /** + * Return the entity that makes up this response. + */ + T entity(); + + /** + * Return the {@code BodyInserter} that writes the entity to the output stream. + */ + BodyInserter inserter(); + + + // Static builder methods + + /** + * Create a builder with the given object. + * @param t the object that represents the body of the response + * @param the type of the elements contained in the publisher + * @return the created builder + */ + static Builder fromObject(T t) { + return new DefaultEntityResponseBuilder<>(t, BodyInserters.fromObject(t)); + } + + /** + * Create a builder with the given publisher. + * @param publisher the publisher that represents the body of the response + * @param elementClass the class of elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the created builder + */ + static > Builder

fromPublisher(P publisher, Class elementClass) { + return new DefaultEntityResponseBuilder<>(publisher, + BodyInserters.fromPublisher(publisher, elementClass)); + } + + /** + * Create a builder with the given publisher. + * @param publisher the publisher that represents the body of the response + * @param typeReference the type of elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the created builder + */ + static > Builder

fromPublisher(P publisher, + ParameterizedTypeReference typeReference) { + + return new DefaultEntityResponseBuilder<>(publisher, + BodyInserters.fromPublisher(publisher, typeReference)); + } + + + /** + * Defines a builder for {@code EntityResponse}. + * + * @param a self reference to the builder type + */ + interface Builder { + + /** + * Add the given header value(s) under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder header(String headerName, String... headerValues); + + /** + * Copy the given headers into the entity's headers map. + * @param headers the existing HttpHeaders to copy from + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder headers(HttpHeaders headers); + + /** + * Set the HTTP status. + * @param status the response status + * @return this builder + */ + Builder status(HttpStatus status); + + /** + * Set the HTTP status. + * @param status the response status + * @return this builder + * @since 5.0.3 + */ + Builder status(int status); + + /** + * Add the given cookie to the response. + * @param cookie the cookie to add + * @return this builder + */ + Builder cookie(ResponseCookie cookie); + + /** + * Manipulate this response's cookies with the given consumer. The + * cookies provided to the consumer are "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing cookies, + * {@linkplain MultiValueMap#remove(Object) remove} cookies, or use any of the other + * {@link MultiValueMap} methods. + * @param cookiesConsumer a function that consumes the cookies + * @return this builder + */ + Builder cookies(Consumer> cookiesConsumer); + + /** + * Set the set of allowed {@link HttpMethod HTTP methods}, as specified + * by the {@code Allow} header. + * @param allowedMethods the allowed methods + * @return this builder + * @see HttpHeaders#setAllow(Set) + */ + Builder allow(HttpMethod... allowedMethods); + + /** + * Set the set of allowed {@link HttpMethod HTTP methods}, as specified + * by the {@code Allow} header. + * @param allowedMethods the allowed methods + * @return this builder + * @see HttpHeaders#setAllow(Set) + */ + Builder allow(Set allowedMethods); + + /** + * Set the entity tag of the body, as specified by the {@code ETag} header. + * @param etag the new entity tag + * @return this builder + * @see HttpHeaders#setETag(String) + */ + Builder eTag(String etag); + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param lastModified the last modified date + * @return this builder + * @see HttpHeaders#setLastModified(long) + */ + Builder lastModified(ZonedDateTime lastModified); + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param lastModified the last modified date + * @return this builder + * @since 5.1.4 + * @see HttpHeaders#setLastModified(long) + */ + Builder lastModified(Instant lastModified); + + /** + * Set the location of a resource, as specified by the {@code Location} header. + * @param location the location + * @return this builder + * @see HttpHeaders#setLocation(URI) + */ + Builder location(URI location); + + /** + * Set the caching directives for the resource, as specified by the HTTP 1.1 + * {@code Cache-Control} header. + *

A {@code CacheControl} instance can be built like + * {@code CacheControl.maxAge(3600).cachePublic().noTransform()}. + * @param cacheControl a builder for cache-related HTTP response headers + * @return this builder + * @see RFC-7234 Section 5.2 + */ + Builder cacheControl(CacheControl cacheControl); + + /** + * Configure one or more request header names (e.g. "Accept-Language") to + * add to the "Vary" response header to inform clients that the response is + * subject to content negotiation and variances based on the value of the + * given request headers. The configured request header names are added only + * if not already present in the response "Vary" header. + * @param requestHeaders request header names + * @return this builder + */ + Builder varyBy(String... requestHeaders); + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + * @param contentLength the content length + * @return this builder + * @see HttpHeaders#setContentLength(long) + */ + Builder contentLength(long contentLength); + + /** + * Set the {@linkplain MediaType media type} of the body, as specified by the + * {@code Content-Type} header. + * @param contentType the content type + * @return this builder + * @see HttpHeaders#setContentType(MediaType) + */ + Builder contentType(MediaType contentType); + + /** + * Add a serialization hint like {@link Jackson2CodecSupport#JSON_VIEW_HINT} to + * customize how the body will be serialized. + * @param key the hint key + * @param value the hint value + */ + Builder hint(String key, Object value); + + /** + * Customize the serialization hints with the given consumer. + * @param hintsConsumer a function that consumes the hints + * @return this builder + * @since 5.1.6 + */ + Builder hints(Consumer> hintsConsumer); + + /** + * Build the response. + * @return the built response + */ + Mono> build(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerFilterFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerFilterFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..f8ee5dafc5cffb58e5215ec4c2b1f909ebed8465 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerFilterFunction.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.server.support.ServerRequestWrapper; + +/** + * Represents a function that filters a {@linkplain HandlerFunction handler function}. + * + * @author Arjen Poutsma + * @since 5.0 + * @param the type of the {@linkplain HandlerFunction handler function} to filter + * @param the type of the response of the function + * @see RouterFunction#filter(HandlerFilterFunction) + */ +@FunctionalInterface +public interface HandlerFilterFunction { + + /** + * Apply this filter to the given handler function. The given + * {@linkplain HandlerFunction handler function} represents the next entity in the chain, + * and can be {@linkplain HandlerFunction#handle(ServerRequest) invoked} in order to + * proceed to this entity, or not invoked to block the chain. + * @param request the request + * @param next the next handler or filter function in the chain + * @return the filtered response + * @see ServerRequestWrapper + */ + Mono filter(ServerRequest request, HandlerFunction next); + + /** + * Return a composed filter function that first applies this filter, and then applies the + * {@code after} filter. + * @param after the filter to apply after this filter is applied + * @return a composed filter that first applies this function and then applies the + * {@code after} function + */ + default HandlerFilterFunction andThen(HandlerFilterFunction after) { + Assert.notNull(after, "HandlerFilterFunction must not be null"); + return (request, next) -> { + HandlerFunction nextHandler = handlerRequest -> after.filter(handlerRequest, next); + return filter(request, nextHandler); + }; + } + + /** + * Apply this filter to the given handler function, resulting in a filtered handler function. + * @param handler the handler function to filter + * @return the filtered handler function + */ + default HandlerFunction apply(HandlerFunction handler) { + Assert.notNull(handler, "HandlerFunction must not be null"); + return request -> this.filter(request, handler); + } + + /** + * Adapt the given request processor function to a filter function that only operates + * on the {@code ServerRequest}. + * @param requestProcessor the request processor + * @return the filter adaptation of the request processor + */ + static HandlerFilterFunction ofRequestProcessor( + Function> requestProcessor) { + + Assert.notNull(requestProcessor, "Function must not be null"); + return (request, next) -> requestProcessor.apply(request).flatMap(next::handle); + } + + /** + * Adapt the given response processor function to a filter function that only operates + * on the {@code ServerResponse}. + * @param responseProcessor the response processor + * @return the filter adaptation of the request processor + */ + static HandlerFilterFunction ofResponseProcessor( + Function> responseProcessor) { + + Assert.notNull(responseProcessor, "Function must not be null"); + return (request, next) -> next.handle(request).flatMap(responseProcessor); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..b73216edc0d00d26113d03de40287abf6080d9dd --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerFunction.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import reactor.core.publisher.Mono; + +/** + * Represents a function that handles a {@linkplain ServerRequest request}. + * + * @author Arjen Poutsma + * @since 5.0 + * @param the type of the response of the function + * @see RouterFunction + */ +@FunctionalInterface +public interface HandlerFunction { + + /** + * Handle the given request. + * @param request the request to handle + * @return the response + */ + Mono handle(ServerRequest request); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerStrategies.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerStrategies.java new file mode 100644 index 0000000000000000000000000000000000000000..1d41f07d252d290c1a23990fb87d2f045527835b --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/HandlerStrategies.java @@ -0,0 +1,158 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.List; +import java.util.function.Consumer; + +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.i18n.LocaleContextResolver; + +/** + * Defines the strategies to be used for processing {@link HandlerFunction HandlerFunctions}. + * + *

An instance of this class is immutable. Instances are typically created through the + * mutable {@link Builder}: either through {@link #builder()} to set up default strategies, + * or {@link #empty()} to start from scratch. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 5.0 + * @see RouterFunctions#toHttpHandler(RouterFunction, HandlerStrategies) + */ +public interface HandlerStrategies { + + /** + * Return the {@link HttpMessageReader HttpMessageReaders} to be used for request body conversion. + * @return the message readers + */ + List> messageReaders(); + + /** + * Return the {@link HttpMessageWriter HttpMessageWriters} to be used for response body conversion. + * @return the message writers + */ + List> messageWriters(); + + /** + * Return the {@link ViewResolver ViewResolvers} to be used for view name resolution. + * @return the view resolvers + */ + List viewResolvers(); + + /** + * Return the {@link WebFilter WebFilters} to be used for filtering the request and response. + * @return the web filters + */ + List webFilters(); + + /** + * Return the {@link WebExceptionHandler WebExceptionHandlers} to be used for handling exceptions. + * @return the exception handlers + */ + List exceptionHandlers(); + + /** + * Return the {@link LocaleContextResolver} to be used for resolving locale context. + * @return the locale context resolver + */ + LocaleContextResolver localeContextResolver(); + + + // Static builder methods + + /** + * Return a new {@code HandlerStrategies} with default initialization. + * @return the new {@code HandlerStrategies} + */ + static HandlerStrategies withDefaults() { + return builder().build(); + } + + /** + * Return a mutable builder for a {@code HandlerStrategies} with default initialization. + * @return the builder + */ + static Builder builder() { + DefaultHandlerStrategiesBuilder builder = new DefaultHandlerStrategiesBuilder(); + builder.defaultConfiguration(); + return builder; + } + + /** + * Return a mutable, empty builder for a {@code HandlerStrategies}. + * @return the builder + */ + static Builder empty() { + return new DefaultHandlerStrategiesBuilder(); + } + + + /** + * A mutable builder for a {@link HandlerStrategies}. + */ + interface Builder { + + /** + * Customize the list of server-side HTTP message readers and writers. + * @param consumer the consumer to customize the codecs + * @return this builder + */ + Builder codecs(Consumer consumer); + + /** + * Add the given view resolver to this builder. + * @param viewResolver the view resolver to add + * @return this builder + */ + Builder viewResolver(ViewResolver viewResolver); + + /** + * Add the given web filter to this builder. + * @param filter the filter to add + * @return this builder + */ + Builder webFilter(WebFilter filter); + + /** + * Add the given exception handler to this builder. + * @param exceptionHandler the exception handler to add + * @return this builder + */ + Builder exceptionHandler(WebExceptionHandler exceptionHandler); + + /** + * Add the given locale context resolver to this builder. + * @param localeContextResolver the locale context resolver to add + * @return this builder + */ + Builder localeContextResolver(LocaleContextResolver localeContextResolver); + + /** + * Builds the {@link HandlerStrategies}. + * @return the built strategies + */ + HandlerStrategies build(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/PathResourceLookupFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/PathResourceLookupFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..0e2b719cc13e986e16c54a84da2f7ca1eab28762 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/PathResourceLookupFunction.java @@ -0,0 +1,162 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.UrlResource; +import org.springframework.http.server.PathContainer; +import org.springframework.util.Assert; +import org.springframework.util.ResourceUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.pattern.PathPattern; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * Lookup function used by {@link RouterFunctions#resources(String, Resource)}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +class PathResourceLookupFunction implements Function> { + + private static final PathPatternParser PATTERN_PARSER = new PathPatternParser(); + + private final PathPattern pattern; + + private final Resource location; + + + public PathResourceLookupFunction(String pattern, Resource location) { + Assert.hasLength(pattern, "'pattern' must not be empty"); + Assert.notNull(location, "'location' must not be null"); + this.pattern = PATTERN_PARSER.parse(pattern); + this.location = location; + } + + + @Override + public Mono apply(ServerRequest request) { + PathContainer pathContainer = request.pathContainer(); + if (!this.pattern.matches(pathContainer)) { + return Mono.empty(); + } + + pathContainer = this.pattern.extractPathWithinPattern(pathContainer); + String path = processPath(pathContainer.value()); + if (path.contains("%")) { + path = StringUtils.uriDecode(path, StandardCharsets.UTF_8); + } + if (!StringUtils.hasLength(path) || isInvalidPath(path)) { + return Mono.empty(); + } + + try { + Resource resource = this.location.createRelative(path); + if (resource.exists() && resource.isReadable() && isResourceUnderLocation(resource)) { + return Mono.just(resource); + } + else { + return Mono.empty(); + } + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + } + + private String processPath(String path) { + boolean slash = false; + for (int i = 0; i < path.length(); i++) { + if (path.charAt(i) == '/') { + slash = true; + } + else if (path.charAt(i) > ' ' && path.charAt(i) != 127) { + if (i == 0 || (i == 1 && slash)) { + return path; + } + path = slash ? "/" + path.substring(i) : path.substring(i); + return path; + } + } + return (slash ? "/" : ""); + } + + private boolean isInvalidPath(String path) { + if (path.contains("WEB-INF") || path.contains("META-INF")) { + return true; + } + if (path.contains(":/")) { + String relativePath = (path.charAt(0) == '/' ? path.substring(1) : path); + if (ResourceUtils.isUrl(relativePath) || relativePath.startsWith("url:")) { + return true; + } + } + if (path.contains("..") && StringUtils.cleanPath(path).contains("../")) { + return true; + } + return false; + } + + private boolean isResourceUnderLocation(Resource resource) throws IOException { + if (resource.getClass() != this.location.getClass()) { + return false; + } + + String resourcePath; + String locationPath; + + if (resource instanceof UrlResource) { + resourcePath = resource.getURL().toExternalForm(); + locationPath = StringUtils.cleanPath(this.location.getURL().toString()); + } + else if (resource instanceof ClassPathResource) { + resourcePath = ((ClassPathResource) resource).getPath(); + locationPath = StringUtils.cleanPath(((ClassPathResource) this.location).getPath()); + } + else { + resourcePath = resource.getURL().getPath(); + locationPath = StringUtils.cleanPath(this.location.getURL().getPath()); + } + + if (locationPath.equals(resourcePath)) { + return true; + } + locationPath = (locationPath.endsWith("/") || locationPath.isEmpty() ? locationPath : locationPath + "/"); + if (!resourcePath.startsWith(locationPath)) { + return false; + } + if (resourcePath.contains("%") && StringUtils.uriDecode(resourcePath, StandardCharsets.UTF_8).contains("../")) { + return false; + } + return true; + } + + + @Override + public String toString() { + return this.pattern + " -> " + this.location; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RenderingResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RenderingResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..2b2ff2979c9145ea837d2616243f403e1f255925 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RenderingResponse.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.Collection; +import java.util.Map; +import java.util.function.Consumer; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseCookie; +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * Rendering-specific subtype of {@link ServerResponse} that exposes model and template data. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @since 5.0 + */ +public interface RenderingResponse extends ServerResponse { + + /** + * Return the name of the template to be rendered. + */ + String name(); + + /** + * Return the unmodifiable model map. + */ + Map model(); + + + // Builder + + /** + * Create a builder with the template name, status code, headers and model of the given response. + * @param other the response to copy the values from + * @return the created builder + */ + static Builder from(RenderingResponse other) { + return new DefaultRenderingResponseBuilder(other); + } + + /** + * Create a builder with the given template name. + * @param name the name of the template to render + * @return the created builder + */ + static Builder create(String name) { + return new DefaultRenderingResponseBuilder(name); + } + + + /** + * Defines a builder for {@code RenderingResponse}. + */ + interface Builder { + + /** + * Add the supplied attribute to the model using a + * {@linkplain org.springframework.core.Conventions#getVariableName generated name}. + *

Note: Empty {@link Collection Collections} are not added to + * the model when using this method because we cannot correctly determine + * the true convention name. View code should check for {@code null} rather + * than for empty collections. + * @param attribute the model attribute value (never {@code null}) + */ + Builder modelAttribute(Object attribute); + + /** + * Add the supplied attribute value under the supplied name. + * @param name the name of the model attribute (never {@code null}) + * @param value the model attribute value (can be {@code null}) + */ + Builder modelAttribute(String name, @Nullable Object value); + + /** + * Copy all attributes in the supplied array into the model, + * using attribute name generation for each element. + * @see #modelAttribute(Object) + */ + Builder modelAttributes(Object... attributes); + + /** + * Copy all attributes in the supplied {@code Collection} into the model, + * using attribute name generation for each element. + * @see #modelAttribute(Object) + */ + Builder modelAttributes(Collection attributes); + + /** + * Copy all attributes in the supplied {@code Map} into the model. + * @see #modelAttribute(String, Object) + */ + Builder modelAttributes(Map attributes); + + /** + * Add the given header value(s) under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder header(String headerName, String... headerValues); + + /** + * Copy the given headers into the entity's headers map. + * @param headers the existing HttpHeaders to copy from + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder headers(HttpHeaders headers); + + /** + * Set the HTTP status. + * @param status the response status + * @return this builder + */ + Builder status(HttpStatus status); + + /** + * Set the HTTP status. + * @param status the response status + * @return this builder + * @since 5.0.3 + */ + Builder status(int status); + + /** + * Add the given cookie to the response. + * @param cookie the cookie to add + * @return this builder + */ + Builder cookie(ResponseCookie cookie); + + /** + * Manipulate this response's cookies with the given consumer. The + * cookies provided to the consumer are "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing cookies, + * {@linkplain MultiValueMap#remove(Object) remove} cookies, or use any of the other + * {@link MultiValueMap} methods. + * @param cookiesConsumer a function that consumes the cookies + * @return this builder + */ + Builder cookies(Consumer> cookiesConsumer); + + /** + * Build the response. + * @return the built response + */ + Mono build(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicate.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicate.java new file mode 100644 index 0000000000000000000000000000000000000000..2db4631fbe9ceb136b08676b93e3f4c5c4591327 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicate.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.Optional; + +/** + * Represents a function that evaluates on a given {@link ServerRequest}. + * Instances of this function that evaluate on common request properties + * can be found in {@link RequestPredicates}. + * + * @author Arjen Poutsma + * @since 5.0 + * @see RequestPredicates + * @see RouterFunctions#route(RequestPredicate, HandlerFunction) + * @see RouterFunctions#nest(RequestPredicate, RouterFunction) + */ +@FunctionalInterface +public interface RequestPredicate { + + /** + * Evaluate this predicate on the given request. + * @param request the request to match against + * @return {@code true} if the request matches the predicate; {@code false} otherwise + */ + boolean test(ServerRequest request); + + /** + * Return a composed request predicate that tests against both this predicate AND + * the {@code other} predicate. When evaluating the composed predicate, if this + * predicate is {@code false}, then the {@code other} predicate is not evaluated. + * @param other a predicate that will be logically-ANDed with this predicate + * @return a predicate composed of this predicate AND the {@code other} predicate + */ + default RequestPredicate and(RequestPredicate other) { + return new RequestPredicates.AndRequestPredicate(this, other); + } + + /** + * Return a predicate that represents the logical negation of this predicate. + * @return a predicate that represents the logical negation of this predicate + */ + default RequestPredicate negate() { + return new RequestPredicates.NegateRequestPredicate(this); + } + + /** + * Return a composed request predicate that tests against both this predicate OR + * the {@code other} predicate. When evaluating the composed predicate, if this + * predicate is {@code true}, then the {@code other} predicate is not evaluated. + * @param other a predicate that will be logically-ORed with this predicate + * @return a predicate composed of this predicate OR the {@code other} predicate + */ + default RequestPredicate or(RequestPredicate other) { + return new RequestPredicates.OrRequestPredicate(this, other); + } + + /** + * Transform the given request into a request used for a nested route. For instance, + * a path-based predicate can return a {@code ServerRequest} with a the path remaining + * after a match. + *

The default implementation returns an {@code Optional} wrapping the given path if + * {@link #test(ServerRequest)} evaluates to {@code true}; or {@link Optional#empty()} + * if it evaluates to {@code false}. + * @param request the request to be nested + * @return the nested request + * @see RouterFunctions#nest(RequestPredicate, RouterFunction) + */ + default Optional nest(ServerRequest request) { + return (test(request) ? Optional.of(request) : Optional.empty()); + } + + /** + * Accept the given visitor. Default implementation calls + * {@link RequestPredicates.Visitor#unknown(RequestPredicate)}; composed {@code RequestPredicate} + * implementations are expected to call {@code accept} for all components that make up this + * request predicate. + * @param visitor the visitor to accept + */ + default void accept(RequestPredicates.Visitor visitor) { + visitor.unknown(this); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java new file mode 100644 index 0000000000000000000000000000000000000000..a0e24944ff3f06036afe5ab34ca859c08d0a0c46 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -0,0 +1,1078 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import java.util.function.Predicate; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.util.UriBuilder; +import org.springframework.web.util.UriUtils; +import org.springframework.web.util.pattern.PathPattern; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * Implementations of {@link RequestPredicate} that implement various useful + * request matching operations, such as matching based on path, HTTP method, etc. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public abstract class RequestPredicates { + + private static final Log logger = LogFactory.getLog(RequestPredicates.class); + + private static final PathPatternParser DEFAULT_PATTERN_PARSER = new PathPatternParser(); + + + /** + * Return a {@code RequestPredicate} that always matches. + * @return a predicate that always matches + */ + public static RequestPredicate all() { + return request -> true; + } + + + /** + * Return a {@code RequestPredicate} that matches if the request's + * HTTP method is equal to the given method. + * @param httpMethod the HTTP method to match against + * @return a predicate that tests against the given HTTP method + */ + public static RequestPredicate method(HttpMethod httpMethod) { + return new HttpMethodPredicate(httpMethod); + } + + /** + * Return a {@code RequestPredicate} that matches if the request's + * HTTP method is equal to one the of the given methods. + * @param httpMethods the HTTP methods to match against + * @return a predicate that tests against the given HTTP methods + * @since 5.1 + */ + public static RequestPredicate methods(HttpMethod... httpMethods) { + return new HttpMethodPredicate(httpMethods); + } + + /** + * Return a {@code RequestPredicate} that tests the request path + * against the given path pattern. + * @param pattern the pattern to match to + * @return a predicate that tests against the given path pattern + */ + public static RequestPredicate path(String pattern) { + Assert.notNull(pattern, "'pattern' must not be null"); + return pathPredicates(DEFAULT_PATTERN_PARSER).apply(pattern); + } + + /** + * Return a function that creates new path-matching {@code RequestPredicates} + * from pattern Strings using the given {@link PathPatternParser}. + *

This method can be used to specify a non-default, customized + * {@code PathPatternParser} when resolving path patterns. + * @param patternParser the parser used to parse patterns given to the returned function + * @return a function that resolves a pattern String into a path-matching + * {@code RequestPredicates} instance + */ + public static Function pathPredicates(PathPatternParser patternParser) { + Assert.notNull(patternParser, "PathPatternParser must not be null"); + return pattern -> new PathPatternPredicate(patternParser.parse(pattern)); + } + + /** + * Return a {@code RequestPredicate} that tests the request's headers + * against the given headers predicate. + * @param headersPredicate a predicate that tests against the request headers + * @return a predicate that tests against the given header predicate + */ + public static RequestPredicate headers(Predicate headersPredicate) { + return new HeadersPredicate(headersPredicate); + } + + /** + * Return a {@code RequestPredicate} that tests if the request's + * {@linkplain ServerRequest.Headers#contentType() content type} is + * {@linkplain MediaType#includes(MediaType) included} by any of the given media types. + * @param mediaTypes the media types to match the request's content type against + * @return a predicate that tests the request's content type against the given media types + */ + public static RequestPredicate contentType(MediaType... mediaTypes) { + Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); + return new ContentTypePredicate(mediaTypes); + } + + /** + * Return a {@code RequestPredicate} that tests if the request's + * {@linkplain ServerRequest.Headers#accept() accept} header is + * {@linkplain MediaType#isCompatibleWith(MediaType) compatible} with any of the given media types. + * @param mediaTypes the media types to match the request's accept header against + * @return a predicate that tests the request's accept header against the given media types + */ + public static RequestPredicate accept(MediaType... mediaTypes) { + Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); + return new AcceptPredicate(mediaTypes); + } + + /** + * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code GET} + * and the given {@code pattern} matches against the request path. + * @param pattern the path pattern to match against + * @return a predicate that matches if the request method is GET and if the given pattern + * matches against the request path + */ + public static RequestPredicate GET(String pattern) { + return method(HttpMethod.GET).and(path(pattern)); + } + + /** + * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code HEAD} + * and the given {@code pattern} matches against the request path. + * @param pattern the path pattern to match against + * @return a predicate that matches if the request method is HEAD and if the given pattern + * matches against the request path + */ + public static RequestPredicate HEAD(String pattern) { + return method(HttpMethod.HEAD).and(path(pattern)); + } + + /** + * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code POST} + * and the given {@code pattern} matches against the request path. + * @param pattern the path pattern to match against + * @return a predicate that matches if the request method is POST and if the given pattern + * matches against the request path + */ + public static RequestPredicate POST(String pattern) { + return method(HttpMethod.POST).and(path(pattern)); + } + + /** + * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code PUT} + * and the given {@code pattern} matches against the request path. + * @param pattern the path pattern to match against + * @return a predicate that matches if the request method is PUT and if the given pattern + * matches against the request path + */ + public static RequestPredicate PUT(String pattern) { + return method(HttpMethod.PUT).and(path(pattern)); + } + + /** + * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code PATCH} + * and the given {@code pattern} matches against the request path. + * @param pattern the path pattern to match against + * @return a predicate that matches if the request method is PATCH and if the given pattern + * matches against the request path + */ + public static RequestPredicate PATCH(String pattern) { + return method(HttpMethod.PATCH).and(path(pattern)); + } + + /** + * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code DELETE} + * and the given {@code pattern} matches against the request path. + * @param pattern the path pattern to match against + * @return a predicate that matches if the request method is DELETE and if the given pattern + * matches against the request path + */ + public static RequestPredicate DELETE(String pattern) { + return method(HttpMethod.DELETE).and(path(pattern)); + } + + /** + * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code OPTIONS} + * and the given {@code pattern} matches against the request path. + * @param pattern the path pattern to match against + * @return a predicate that matches if the request method is OPTIONS and if the given pattern + * matches against the request path + */ + public static RequestPredicate OPTIONS(String pattern) { + return method(HttpMethod.OPTIONS).and(path(pattern)); + } + + /** + * Return a {@code RequestPredicate} that matches if the request's path has the given extension. + * @param extension the path extension to match against, ignoring case + * @return a predicate that matches if the request's path has the given file extension + */ + public static RequestPredicate pathExtension(String extension) { + Assert.notNull(extension, "'extension' must not be null"); + return new PathExtensionPredicate(extension); + } + + /** + * Return a {@code RequestPredicate} that matches if the request's path matches the given + * predicate. + * @param extensionPredicate the predicate to test against the request path extension + * @return a predicate that matches if the given predicate matches against the request's path + * file extension + */ + public static RequestPredicate pathExtension(Predicate extensionPredicate) { + return new PathExtensionPredicate(extensionPredicate); + } + + /** + * Return a {@code RequestPredicate} that matches if the request's query parameter of the given name + * has the given value. + * @param name the name of the query parameter to test against + * @param value the value of the query parameter to test against + * @return a predicate that matches if the query parameter has the given value + * @since 5.0.7 + * @see ServerRequest#queryParam(String) + */ + public static RequestPredicate queryParam(String name, String value) { + return new QueryParamPredicate(name, value); + } + + /** + * Return a {@code RequestPredicate} that tests the request's query parameter of the given name + * against the given predicate. + * @param name the name of the query parameter to test against + * @param predicate predicate to test against the query parameter value + * @return a predicate that matches the given predicate against the query parameter of the given name + * @see ServerRequest#queryParam(String) + */ + public static RequestPredicate queryParam(String name, Predicate predicate) { + return new QueryParamPredicate(name, predicate); + } + + + private static void traceMatch(String prefix, Object desired, @Nullable Object actual, boolean match) { + if (logger.isTraceEnabled()) { + logger.trace(String.format("%s \"%s\" %s against value \"%s\"", + prefix, desired, match ? "matches" : "does not match", actual)); + } + } + + private static void restoreAttributes(ServerRequest request, Map attributes) { + request.attributes().clear(); + request.attributes().putAll(attributes); + } + + private static Map mergePathVariables(Map oldVariables, + Map newVariables) { + + if (!newVariables.isEmpty()) { + Map mergedVariables = new LinkedHashMap<>(oldVariables); + mergedVariables.putAll(newVariables); + return mergedVariables; + } + else { + return oldVariables; + } + } + + private static PathPattern mergePatterns(@Nullable PathPattern oldPattern, PathPattern newPattern) { + if (oldPattern != null) { + return oldPattern.combine(newPattern); + } + else { + return newPattern; + } + + } + + + /** + * Receives notifications from the logical structure of request predicates. + */ + public interface Visitor { + + /** + * Receive notification of an HTTP method predicate. + * @param methods the HTTP methods that make up the predicate + * @see RequestPredicates#method(HttpMethod) + */ + void method(Set methods); + + /** + * Receive notification of an path predicate. + * @param pattern the path pattern that makes up the predicate + * @see RequestPredicates#path(String) + */ + void path(String pattern); + + /** + * Receive notification of an path extension predicate. + * @param extension the path extension that makes up the predicate + * @see RequestPredicates#pathExtension(String) + */ + void pathExtension(String extension); + + /** + * Receive notification of a HTTP header predicate. + * @param name the name of the HTTP header to check + * @param value the desired value of the HTTP header + * @see RequestPredicates#headers(Predicate) + * @see RequestPredicates#contentType(MediaType...) + * @see RequestPredicates#accept(MediaType...) + */ + void header(String name, String value); + + /** + * Receive notification of a query parameter predicate. + * @param name the name of the query parameter + * @param value the desired value of the parameter + * @see RequestPredicates#queryParam(String, String) + */ + void queryParam(String name, String value); + + /** + * Receive first notification of a logical AND predicate. + * The first subsequent notification will contain the left-hand side of the AND-predicate; + * followed by {@link #and()}, followed by the right-hand side, followed by {@link #endAnd()}. + * @see RequestPredicate#and(RequestPredicate) + */ + void startAnd(); + + /** + * Receive "middle" notification of a logical AND predicate. + * The following notification contains the right-hand side, followed by {@link #endAnd()}. + * @see RequestPredicate#and(RequestPredicate) + */ + void and(); + + /** + * Receive last notification of a logical AND predicate. + * @see RequestPredicate#and(RequestPredicate) + */ + void endAnd(); + + /** + * Receive first notification of a logical OR predicate. + * The first subsequent notification will contain the left-hand side of the OR-predicate; + * the second notification contains the right-hand side, followed by {@link #endOr()}. + * @see RequestPredicate#or(RequestPredicate) + */ + void startOr(); + + /** + * Receive "middle" notification of a logical OR predicate. + * The following notification contains the right-hand side, followed by {@link #endOr()}. + * @see RequestPredicate#or(RequestPredicate) + */ + void or(); + + /** + * Receive last notification of a logical OR predicate. + * @see RequestPredicate#or(RequestPredicate) + */ + void endOr(); + + /** + * Receive first notification of a negated predicate. + * The first subsequent notification will contain the negated predicated, followed + * by {@link #endNegate()}. + * @see RequestPredicate#negate() + */ + void startNegate(); + + /** + * Receive last notification of a negated predicate. + * @see RequestPredicate#negate() + */ + void endNegate(); + + /** + * Receive first notification of an unknown predicate. + */ + void unknown(RequestPredicate predicate); + } + + + private static class HttpMethodPredicate implements RequestPredicate { + + private final Set httpMethods; + + public HttpMethodPredicate(HttpMethod httpMethod) { + Assert.notNull(httpMethod, "HttpMethod must not be null"); + this.httpMethods = EnumSet.of(httpMethod); + } + + public HttpMethodPredicate(HttpMethod... httpMethods) { + Assert.notEmpty(httpMethods, "HttpMethods must not be empty"); + + this.httpMethods = EnumSet.copyOf(Arrays.asList(httpMethods)); + } + + @Override + public boolean test(ServerRequest request) { + boolean match = this.httpMethods.contains(request.method()); + traceMatch("Method", this.httpMethods, request.method(), match); + return match; + } + + @Override + public void accept(Visitor visitor) { + visitor.method(Collections.unmodifiableSet(this.httpMethods)); + } + + @Override + public String toString() { + if (this.httpMethods.size() == 1) { + return this.httpMethods.iterator().next().toString(); + } + else { + return this.httpMethods.toString(); + } + } + } + + + private static class PathPatternPredicate implements RequestPredicate { + + private final PathPattern pattern; + + public PathPatternPredicate(PathPattern pattern) { + Assert.notNull(pattern, "'pattern' must not be null"); + this.pattern = pattern; + } + + @Override + public boolean test(ServerRequest request) { + PathContainer pathContainer = request.pathContainer(); + PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer); + traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null); + if (info != null) { + mergeAttributes(request, info.getUriVariables(), this.pattern); + return true; + } + else { + return false; + } + } + + private static void mergeAttributes(ServerRequest request, Map variables, + PathPattern pattern) { + Map pathVariables = mergePathVariables(request.pathVariables(), variables); + request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Collections.unmodifiableMap(pathVariables)); + + pattern = mergePatterns( + (PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), + pattern); + request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); + } + + @Override + public Optional nest(ServerRequest request) { + return Optional.ofNullable(this.pattern.matchStartOfPath(request.pathContainer())) + .map(info -> new SubPathServerRequestWrapper(request, info, this.pattern)); + } + + @Override + public void accept(Visitor visitor) { + visitor.path(this.pattern.getPatternString()); + } + + @Override + public String toString() { + return this.pattern.getPatternString(); + } + } + + + private static class HeadersPredicate implements RequestPredicate { + + private final Predicate headersPredicate; + + public HeadersPredicate(Predicate headersPredicate) { + Assert.notNull(headersPredicate, "Predicate must not be null"); + this.headersPredicate = headersPredicate; + } + + @Override + public boolean test(ServerRequest request) { + return this.headersPredicate.test(request.headers()); + } + + @Override + public String toString() { + return this.headersPredicate.toString(); + } + } + + private static class ContentTypePredicate extends HeadersPredicate { + + private final Set mediaTypes; + + public ContentTypePredicate(MediaType... mediaTypes) { + this(new HashSet<>(Arrays.asList(mediaTypes))); + } + + private ContentTypePredicate(Set mediaTypes) { + super(headers -> { + MediaType contentType = + headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean match = mediaTypes.stream() + .anyMatch(mediaType -> mediaType.includes(contentType)); + traceMatch("Content-Type", mediaTypes, contentType, match); + return match; + }); + this.mediaTypes = mediaTypes; + } + + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.CONTENT_TYPE, + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + + @Override + public String toString() { + return String.format("Content-Type: %s", + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + } + + private static class AcceptPredicate extends HeadersPredicate { + + private final Set mediaTypes; + + public AcceptPredicate(MediaType... mediaTypes) { + this(new HashSet<>(Arrays.asList(mediaTypes))); + } + + private AcceptPredicate(Set mediaTypes) { + super(headers -> { + List acceptedMediaTypes = acceptedMediaTypes(headers); + boolean match = acceptedMediaTypes.stream() + .anyMatch(acceptedMediaType -> mediaTypes.stream() + .anyMatch(acceptedMediaType::isCompatibleWith)); + traceMatch("Accept", mediaTypes, acceptedMediaTypes, match); + return match; + }); + this.mediaTypes = mediaTypes; + } + + @NonNull + private static List acceptedMediaTypes(ServerRequest.Headers headers) { + List acceptedMediaTypes = headers.accept(); + if (acceptedMediaTypes.isEmpty()) { + acceptedMediaTypes = Collections.singletonList(MediaType.ALL); + } + else { + MediaType.sortBySpecificityAndQuality(acceptedMediaTypes); + } + return acceptedMediaTypes; + } + + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.ACCEPT, + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + + @Override + public String toString() { + return String.format("Accept: %s", + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + } + + + private static class PathExtensionPredicate implements RequestPredicate { + + private final Predicate extensionPredicate; + + @Nullable + private final String extension; + + public PathExtensionPredicate(Predicate extensionPredicate) { + Assert.notNull(extensionPredicate, "Predicate must not be null"); + this.extensionPredicate = extensionPredicate; + this.extension = null; + } + + public PathExtensionPredicate(String extension) { + Assert.notNull(extension, "Extension must not be null"); + + this.extensionPredicate = s -> { + boolean match = extension.equalsIgnoreCase(s); + traceMatch("Extension", extension, s, match); + return match; + }; + this.extension = extension; + } + + @Override + public boolean test(ServerRequest request) { + String pathExtension = UriUtils.extractFileExtension(request.path()); + return this.extensionPredicate.test(pathExtension); + } + + @Override + public void accept(Visitor visitor) { + visitor.pathExtension( + (this.extension != null) ? + this.extension : + this.extensionPredicate.toString()); + } + + @Override + public String toString() { + return String.format("*.%s", + (this.extension != null) ? + this.extension : + this.extensionPredicate); + } + + } + + + private static class QueryParamPredicate implements RequestPredicate { + + private final String name; + + private final Predicate valuePredicate; + + @Nullable + private final String value; + + public QueryParamPredicate(String name, Predicate valuePredicate) { + Assert.notNull(name, "Name must not be null"); + Assert.notNull(valuePredicate, "Predicate must not be null"); + this.name = name; + this.valuePredicate = valuePredicate; + this.value = null; + } + + public QueryParamPredicate(String name, String value) { + Assert.notNull(name, "Name must not be null"); + Assert.notNull(value, "Value must not be null"); + this.name = name; + this.valuePredicate = value::equals; + this.value = value; + } + + @Override + public boolean test(ServerRequest request) { + Optional s = request.queryParam(this.name); + return s.filter(this.valuePredicate).isPresent(); + } + + @Override + public void accept(Visitor visitor) { + visitor.queryParam(this.name, + (this.value != null) ? + this.value : + this.valuePredicate.toString()); + } + + @Override + public String toString() { + return String.format("?%s %s", this.name, + (this.value != null) ? + this.value : + this.valuePredicate); + } + } + + + /** + * {@link RequestPredicate} for where both {@code left} and {@code right} predicates + * must match. + */ + static class AndRequestPredicate implements RequestPredicate { + + private final RequestPredicate left; + + private final RequestPredicate right; + + public AndRequestPredicate(RequestPredicate left, RequestPredicate right) { + Assert.notNull(left, "Left RequestPredicate must not be null"); + Assert.notNull(right, "Right RequestPredicate must not be null"); + this.left = left; + this.right = right; + } + + @Override + public boolean test(ServerRequest request) { + Map oldAttributes = new HashMap<>(request.attributes()); + + if (this.left.test(request) && this.right.test(request)) { + return true; + } + restoreAttributes(request, oldAttributes); + return false; + } + + @Override + public Optional nest(ServerRequest request) { + return this.left.nest(request).flatMap(this.right::nest); + } + + @Override + public void accept(Visitor visitor) { + visitor.startAnd(); + this.left.accept(visitor); + visitor.and(); + this.right.accept(visitor); + visitor.endAnd(); + } + + @Override + public String toString() { + return String.format("(%s && %s)", this.left, this.right); + } + } + + /** + * {@link RequestPredicate} that negates a delegate predicate. + */ + static class NegateRequestPredicate implements RequestPredicate { + private final RequestPredicate delegate; + + public NegateRequestPredicate(RequestPredicate delegate) { + Assert.notNull(delegate, "Delegate must not be null"); + this.delegate = delegate; + } + + @Override + public boolean test(ServerRequest request) { + Map oldAttributes = new HashMap<>(request.attributes()); + boolean result = !this.delegate.test(request); + if (!result) { + restoreAttributes(request, oldAttributes); + } + return result; + } + + @Override + public void accept(Visitor visitor) { + visitor.startNegate(); + this.delegate.accept(visitor); + visitor.endNegate(); + } + + @Override + public String toString() { + return "!" + this.delegate.toString(); + } + } + + /** + * {@link RequestPredicate} where either {@code left} or {@code right} predicates + * may match. + */ + static class OrRequestPredicate implements RequestPredicate { + + private final RequestPredicate left; + + private final RequestPredicate right; + + public OrRequestPredicate(RequestPredicate left, RequestPredicate right) { + Assert.notNull(left, "Left RequestPredicate must not be null"); + Assert.notNull(right, "Right RequestPredicate must not be null"); + this.left = left; + this.right = right; + } + + @Override + public boolean test(ServerRequest request) { + Map oldAttributes = new HashMap<>(request.attributes()); + + if (this.left.test(request)) { + return true; + } + else { + restoreAttributes(request, oldAttributes); + if (this.right.test(request)) { + return true; + } + } + restoreAttributes(request, oldAttributes); + return false; + } + + @Override + public Optional nest(ServerRequest request) { + Optional leftResult = this.left.nest(request); + if (leftResult.isPresent()) { + return leftResult; + } + else { + return this.right.nest(request); + } + } + + @Override + public void accept(Visitor visitor) { + visitor.startOr(); + this.left.accept(visitor); + visitor.or(); + this.right.accept(visitor); + visitor.endOr(); + } + + + @Override + public String toString() { + return String.format("(%s || %s)", this.left, this.right); + } + } + + + private static class SubPathServerRequestWrapper implements ServerRequest { + + private final ServerRequest request; + + private final PathContainer pathContainer; + + private final Map attributes; + + public SubPathServerRequestWrapper(ServerRequest request, + PathPattern.PathRemainingMatchInfo info, PathPattern pattern) { + this.request = request; + this.pathContainer = new SubPathContainer(info.getPathRemaining()); + this.attributes = mergeAttributes(request, info.getUriVariables(), pattern); + } + + private static Map mergeAttributes(ServerRequest request, + Map pathVariables, PathPattern pattern) { + Map result = new ConcurrentHashMap<>(request.attributes()); + + result.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + mergePathVariables(request.pathVariables(), pathVariables)); + + pattern = mergePatterns( + (PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), + pattern); + result.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); + return result; + } + + @Override + public HttpMethod method() { + return this.request.method(); + } + + @Override + public String methodName() { + return this.request.methodName(); + } + + @Override + public URI uri() { + return this.request.uri(); + } + + @Override + public UriBuilder uriBuilder() { + return this.request.uriBuilder(); + } + + @Override + public String path() { + return this.pathContainer.value(); + } + + @Override + public PathContainer pathContainer() { + return this.pathContainer; + } + + @Override + public Headers headers() { + return this.request.headers(); + } + + @Override + public MultiValueMap cookies() { + return this.request.cookies(); + } + + @Override + public Optional remoteAddress() { + return this.request.remoteAddress(); + } + + @Override + public List> messageReaders() { + return this.request.messageReaders(); + } + + @Override + public T body(BodyExtractor extractor) { + return this.request.body(extractor); + } + + @Override + public T body(BodyExtractor extractor, Map hints) { + return this.request.body(extractor, hints); + } + + @Override + public Mono bodyToMono(Class elementClass) { + return this.request.bodyToMono(elementClass); + } + + @Override + public Mono bodyToMono(ParameterizedTypeReference typeReference) { + return this.request.bodyToMono(typeReference); + } + + @Override + public Flux bodyToFlux(Class elementClass) { + return this.request.bodyToFlux(elementClass); + } + + @Override + public Flux bodyToFlux(ParameterizedTypeReference typeReference) { + return this.request.bodyToFlux(typeReference); + } + + @Override + public Map attributes() { + return this.attributes; + } + + @Override + public Optional queryParam(String name) { + return this.request.queryParam(name); + } + + @Override + public MultiValueMap queryParams() { + return this.request.queryParams(); + } + + @Override + @SuppressWarnings("unchecked") + public Map pathVariables() { + return (Map) this.attributes.getOrDefault( + RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Collections.emptyMap()); + + } + + @Override + public Mono session() { + return this.request.session(); + } + + @Override + public Mono principal() { + return this.request.principal(); + } + + @Override + public Mono> formData() { + return this.request.formData(); + } + + @Override + public Mono> multipartData() { + return this.request.multipartData(); + } + + @Override + public ServerWebExchange exchange() { + return this.request.exchange(); + } + + @Override + public String toString() { + return method() + " " + path(); + } + + private static class SubPathContainer implements PathContainer { + + private static final PathContainer.Separator SEPARATOR = () -> "/"; + + + private final String value; + + private final List elements; + + public SubPathContainer(PathContainer original) { + this.value = prefixWithSlash(original.value()); + this.elements = prependWithSeparator(original.elements()); + } + + private static String prefixWithSlash(String path) { + if (!path.startsWith("/")) { + path = "/" + path; + } + return path; + } + + private static List prependWithSeparator(List elements) { + List result = new ArrayList<>(elements); + if (result.isEmpty() || !(result.get(0) instanceof Separator)) { + result.add(0, SEPARATOR); + } + return Collections.unmodifiableList(result); + } + + + @Override + public String value() { + return this.value; + } + + @Override + public List elements() { + return this.elements; + } + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ResourceHandlerFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ResourceHandlerFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..95b91fb0f44876520a488d466bc8830c2c6ddf87 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ResourceHandlerFunction.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URL; +import java.util.EnumSet; +import java.util.Set; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.web.reactive.function.BodyInserters; + +/** + * Resource-based implementation of {@link HandlerFunction}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +class ResourceHandlerFunction implements HandlerFunction { + + private static final Set SUPPORTED_METHODS = + EnumSet.of(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.OPTIONS); + + + private final Resource resource; + + + public ResourceHandlerFunction(Resource resource) { + this.resource = resource; + } + + + @Override + public Mono handle(ServerRequest request) { + HttpMethod method = request.method(); + if (method != null) { + switch (method) { + case GET: + return EntityResponse.fromObject(this.resource).build() + .map(response -> response); + case HEAD: + Resource headResource = new HeadMethodResource(this.resource); + return EntityResponse.fromObject(headResource).build() + .map(response -> response); + case OPTIONS: + return ServerResponse.ok() + .allow(SUPPORTED_METHODS) + .body(BodyInserters.empty()); + } + } + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED) + .allow(SUPPORTED_METHODS) + .body(BodyInserters.empty()); + } + + + private static class HeadMethodResource implements Resource { + + private static final byte[] EMPTY = new byte[0]; + + private final Resource delegate; + + public HeadMethodResource(Resource delegate) { + this.delegate = delegate; + } + + @Override + public InputStream getInputStream() throws IOException { + return new ByteArrayInputStream(EMPTY); + } + + // delegation + + @Override + public boolean exists() { + return this.delegate.exists(); + } + + @Override + public URL getURL() throws IOException { + return this.delegate.getURL(); + } + + @Override + public URI getURI() throws IOException { + return this.delegate.getURI(); + } + + @Override + public File getFile() throws IOException { + return this.delegate.getFile(); + } + + @Override + public long contentLength() throws IOException { + return this.delegate.contentLength(); + } + + @Override + public long lastModified() throws IOException { + return this.delegate.lastModified(); + } + + @Override + public Resource createRelative(String relativePath) throws IOException { + return this.delegate.createRelative(relativePath); + } + + @Override + @Nullable + public String getFilename() { + return this.delegate.getFilename(); + } + + @Override + public String getDescription() { + return this.delegate.getDescription(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..ba1482639ce1d987c0fbe47764cf00a57f033ab7 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunction.java @@ -0,0 +1,118 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import reactor.core.publisher.Mono; + +/** + * Represents a function that routes to a {@linkplain HandlerFunction handler function}. + * + * @author Arjen Poutsma + * @since 5.0 + * @param the type of the {@linkplain HandlerFunction handler function} to route to + * @see RouterFunctions + */ +@FunctionalInterface +public interface RouterFunction { + + /** + * Return the {@linkplain HandlerFunction handler function} that matches the given request. + * @param request the request to route + * @return an {@code Mono} describing the {@code HandlerFunction} that matches this request, + * or an empty {@code Mono} if there is no match + */ + Mono> route(ServerRequest request); + + /** + * Return a composed routing function that first invokes this function, + * and then invokes the {@code other} function (of the same response type {@code T}) + * if this route had {@linkplain Mono#empty() no result}. + * @param other the function of type {@code T} to apply when this function has no result + * @return a composed function that first routes with this function and then the + * {@code other} function if this function has no result + * @see #andOther(RouterFunction) + */ + default RouterFunction and(RouterFunction other) { + return new RouterFunctions.SameComposedRouterFunction<>(this, other); + } + + /** + * Return a composed routing function that first invokes this function, + * and then invokes the {@code other} function (of a different response type) if this route had + * {@linkplain Mono#empty() no result}. + * @param other the function to apply when this function has no result + * @return a composed function that first routes with this function and then the + * {@code other} function if this function has no result + * @see #and(RouterFunction) + */ + default RouterFunction andOther(RouterFunction other) { + return new RouterFunctions.DifferentComposedRouterFunction(this, other); + } + + /** + * Return a composed routing function that routes to the given handler function if this + * route does not match and the given request predicate applies. This method is a convenient + * combination of {@link #and(RouterFunction)} and + * {@link RouterFunctions#route(RequestPredicate, HandlerFunction)}. + * @param predicate the predicate to test if this route does not match + * @param handlerFunction the handler function to route to if this route does not match and + * the predicate applies + * @return a composed function that route to {@code handlerFunction} if this route does not + * match and if {@code predicate} applies + */ + default RouterFunction andRoute(RequestPredicate predicate, HandlerFunction handlerFunction) { + return and(RouterFunctions.route(predicate, handlerFunction)); + } + + /** + * Return a composed routing function that routes to the given router function if this + * route does not match and the given request predicate applies. This method is a convenient + * combination of {@link #and(RouterFunction)} and + * {@link RouterFunctions#nest(RequestPredicate, RouterFunction)}. + * @param predicate the predicate to test if this route does not match + * @param routerFunction the router function to route to if this route does not match and + * the predicate applies + * @return a composed function that route to {@code routerFunction} if this route does not + * match and if {@code predicate} applies + */ + default RouterFunction andNest(RequestPredicate predicate, RouterFunction routerFunction) { + return and(RouterFunctions.nest(predicate, routerFunction)); + } + + /** + * Filter all {@linkplain HandlerFunction handler functions} routed by this function with the given + * {@linkplain HandlerFilterFunction filter function}. + * @param the filter return type + * @param filterFunction the filter to apply + * @return the filtered routing function + */ + default RouterFunction filter(HandlerFilterFunction filterFunction) { + return new RouterFunctions.FilteredRouterFunction<>(this, filterFunction); + } + + /** + * Accept the given visitor. Default implementation calls + * {@link RouterFunctions.Visitor#unknown(RouterFunction)}; composed {@code RouterFunction} + * implementations are expected to call {@code accept} for all components that make up this + * router function. + * @param visitor the visitor to accept + */ + default void accept(RouterFunctions.Visitor visitor) { + visitor.unknown(this); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..d9cf733b28b2eb23f0f5732105298cafd30edcf5 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java @@ -0,0 +1,255 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.util.Assert; + +/** + * Default implementation of {@link RouterFunctions.Builder}. + * + * @author Arjen Poutsma + * @since 5.1 + */ +class RouterFunctionBuilder implements RouterFunctions.Builder { + + private List> routerFunctions = new ArrayList<>(); + + private List> filterFunctions = new ArrayList<>(); + + + @Override + public RouterFunctions.Builder add(RouterFunction routerFunction) { + Assert.notNull(routerFunction, "RouterFunction must not be null"); + this.routerFunctions.add(routerFunction); + return this; + } + + private RouterFunctions.Builder add(RequestPredicate predicate, + HandlerFunction handlerFunction) { + + this.routerFunctions.add(RouterFunctions.route(predicate, handlerFunction)); + return this; + } + + @Override + public RouterFunctions.Builder GET(String pattern, HandlerFunction handlerFunction) { + return add(RequestPredicates.GET(pattern), handlerFunction); + } + + @Override + public RouterFunctions.Builder GET(String pattern, RequestPredicate predicate, + HandlerFunction handlerFunction) { + + return add(RequestPredicates.GET(pattern).and(predicate), handlerFunction); + } + + @Override + public RouterFunctions.Builder HEAD(String pattern, HandlerFunction handlerFunction) { + return add(RequestPredicates.HEAD(pattern), handlerFunction); + } + + @Override + public RouterFunctions.Builder HEAD(String pattern, RequestPredicate predicate, + HandlerFunction handlerFunction) { + + return add(RequestPredicates.HEAD(pattern).and(predicate), handlerFunction); + } + + @Override + public RouterFunctions.Builder POST(String pattern, HandlerFunction handlerFunction) { + return add(RequestPredicates.POST(pattern), handlerFunction); + } + + @Override + public RouterFunctions.Builder POST(String pattern, RequestPredicate predicate, + HandlerFunction handlerFunction) { + + return add(RequestPredicates.POST(pattern).and(predicate), handlerFunction); + } + + @Override + public RouterFunctions.Builder PUT(String pattern, HandlerFunction handlerFunction) { + return add(RequestPredicates.PUT(pattern), handlerFunction); + } + + @Override + public RouterFunctions.Builder PUT(String pattern, RequestPredicate predicate, + HandlerFunction handlerFunction) { + + return add(RequestPredicates.PUT(pattern).and(predicate), handlerFunction); + } + + @Override + public RouterFunctions.Builder PATCH(String pattern, HandlerFunction handlerFunction) { + return add(RequestPredicates.PATCH(pattern), handlerFunction); + } + + @Override + public RouterFunctions.Builder PATCH(String pattern, RequestPredicate predicate, + HandlerFunction handlerFunction) { + + return add(RequestPredicates.PATCH(pattern).and(predicate), handlerFunction); + } + + @Override + public RouterFunctions.Builder DELETE(String pattern, HandlerFunction handlerFunction) { + return add(RequestPredicates.DELETE(pattern), handlerFunction); + } + + @Override + public RouterFunctions.Builder DELETE(String pattern, RequestPredicate predicate, + HandlerFunction handlerFunction) { + + return add(RequestPredicates.DELETE(pattern).and(predicate), handlerFunction); + } + + @Override + public RouterFunctions.Builder OPTIONS(String pattern, HandlerFunction handlerFunction) { + return add(RequestPredicates.OPTIONS(pattern), handlerFunction); + } + + @Override + public RouterFunctions.Builder OPTIONS(String pattern, RequestPredicate predicate, + HandlerFunction handlerFunction) { + + return add(RequestPredicates.OPTIONS(pattern).and(predicate), handlerFunction); + } + + @Override + public RouterFunctions.Builder resources(String pattern, Resource location) { + return add(RouterFunctions.resources(pattern, location)); + } + + @Override + public RouterFunctions.Builder resources(Function> lookupFunction) { + return add(RouterFunctions.resources(lookupFunction)); + } + + @Override + public RouterFunctions.Builder nest(RequestPredicate predicate, + Consumer builderConsumer) { + + Assert.notNull(builderConsumer, "Consumer must not be null"); + + RouterFunctionBuilder nestedBuilder = new RouterFunctionBuilder(); + builderConsumer.accept(nestedBuilder); + RouterFunction nestedRoute = nestedBuilder.build(); + this.routerFunctions.add(RouterFunctions.nest(predicate, nestedRoute)); + return this; + } + + @Override + public RouterFunctions.Builder nest(RequestPredicate predicate, + Supplier> routerFunctionSupplier) { + + Assert.notNull(routerFunctionSupplier, "RouterFunction Supplier must not be null"); + + RouterFunction nestedRoute = routerFunctionSupplier.get(); + this.routerFunctions.add(RouterFunctions.nest(predicate, nestedRoute)); + return this; + } + + @Override + public RouterFunctions.Builder path(String pattern, + Consumer builderConsumer) { + + return nest(RequestPredicates.path(pattern), builderConsumer); + } + + @Override + public RouterFunctions.Builder path(String pattern, + Supplier> routerFunctionSupplier) { + + return nest(RequestPredicates.path(pattern), routerFunctionSupplier); + } + + @Override + public RouterFunctions.Builder filter(HandlerFilterFunction filterFunction) { + Assert.notNull(filterFunction, "HandlerFilterFunction must not be null"); + + this.filterFunctions.add(filterFunction); + return this; + } + + @Override + public RouterFunctions.Builder before(Function requestProcessor) { + Assert.notNull(requestProcessor, "RequestProcessor must not be null"); + return filter((request, next) -> next.handle(requestProcessor.apply(request))); + } + + @Override + public RouterFunctions.Builder after( + BiFunction responseProcessor) { + + Assert.notNull(responseProcessor, "ResponseProcessor must not be null"); + return filter((request, next) -> next.handle(request) + .map(serverResponse -> responseProcessor.apply(request, serverResponse))); + } + + @Override + public RouterFunctions.Builder onError(Predicate predicate, + BiFunction> responseProvider) { + + Assert.notNull(predicate, "Predicate must not be null"); + Assert.notNull(responseProvider, "ResponseProvider must not be null"); + + return filter((request, next) -> next.handle(request) + .onErrorResume(predicate, t -> responseProvider.apply(t, request))); + } + + @Override + public RouterFunctions.Builder onError(Class exceptionType, + BiFunction> responseProvider) { + + Assert.notNull(exceptionType, "ExceptionType must not be null"); + Assert.notNull(responseProvider, "ResponseProvider must not be null"); + + return filter((request, next) -> next.handle(request) + .onErrorResume(exceptionType, t -> responseProvider.apply(t, request))); + } + + @Override + public RouterFunction build() { + RouterFunction result = this.routerFunctions.stream() + .reduce(RouterFunction::and) + .orElseThrow(IllegalStateException::new); + + if (this.filterFunctions.isEmpty()) { + return result; + } + else { + HandlerFilterFunction filter = + this.filterFunctions.stream() + .reduce(HandlerFilterFunction::andThen) + .orElseThrow(IllegalStateException::new); + + return result.filter(filter); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java new file mode 100644 index 0000000000000000000000000000000000000000..9b869c5ba8dc5c7012770fe86675b214ed7d8e9c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java @@ -0,0 +1,1004 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.util.Assert; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +/** + * Central entry point to Spring's functional web framework. + * Exposes routing functionality, such as to {@linkplain #route() create} a + * {@code RouterFunction} using a discoverable builder-style API, to + * {@linkplain #route(RequestPredicate, HandlerFunction) create} a {@code RouterFunction} + * given a {@code RequestPredicate} and {@code HandlerFunction}, and to do further + * {@linkplain #nest(RequestPredicate, RouterFunction) subrouting} on an existing routing + * function. + * + *

Additionally, this class can {@linkplain #toHttpHandler(RouterFunction) transform} a + * {@code RouterFunction} into an {@code HttpHandler}, which can be run in Servlet 3.1+, + * Reactor, or Undertow. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public abstract class RouterFunctions { + + private static final Log logger = LogFactory.getLog(RouterFunctions.class); + + /** + * Name of the {@link ServerWebExchange} attribute that contains the {@link ServerRequest}. + */ + public static final String REQUEST_ATTRIBUTE = RouterFunctions.class.getName() + ".request"; + + /** + * Name of the {@link ServerWebExchange} attribute that contains the URI + * templates map, mapping variable names to values. + */ + public static final String URI_TEMPLATE_VARIABLES_ATTRIBUTE = + RouterFunctions.class.getName() + ".uriTemplateVariables"; + + /** + * Name of the {@link ServerWebExchange#getAttributes() attribute} that + * contains the matching pattern, as a {@link org.springframework.web.util.pattern.PathPattern}. + */ + public static final String MATCHING_PATTERN_ATTRIBUTE = + RouterFunctions.class.getName() + ".matchingPattern"; + + + private static final HandlerFunction NOT_FOUND_HANDLER = + request -> ServerResponse.notFound().build(); + + + /** + * Offers a discoverable way to create router functions through a builder-style interface. + * @return a router function builder + * @since 5.1 + */ + public static Builder route() { + return new RouterFunctionBuilder(); + } + + /** + * Route to the given handler function if the given request predicate applies. + *

For instance, the following example routes GET requests for "/user" to the + * {@code listUsers} method in {@code userController}: + *

+	 * RouterFunction<ServerResponse> route =
+	 *     RouterFunctions.route(RequestPredicates.GET("/user"), userController::listUsers);
+	 * 
+ * @param predicate the predicate to test + * @param handlerFunction the handler function to route to if the predicate applies + * @param the type of response returned by the handler function + * @return a router function that routes to {@code handlerFunction} if + * {@code predicate} evaluates to {@code true} + * @see RequestPredicates + */ + public static RouterFunction route( + RequestPredicate predicate, HandlerFunction handlerFunction) { + + return new DefaultRouterFunction<>(predicate, handlerFunction); + } + + /** + * Route to the given router function if the given request predicate applies. This method can be + * used to create nested routes, where a group of routes share a common path + * (prefix), header, or other request predicate. + *

For instance, the following example first creates a composed route that resolves to + * {@code listUsers} for a GET, and {@code createUser} for a POST. This composed route then gets + * nested with a "/user" path predicate, so that GET requests for "/user" will list users, + * and POST request for "/user" will create a new user. + *

+	 * RouterFunction<ServerResponse> userRoutes =
+	 *   RouterFunctions.route(RequestPredicates.method(HttpMethod.GET), this::listUsers)
+	 *     .andRoute(RequestPredicates.method(HttpMethod.POST), this::createUser);
+	 * RouterFunction<ServerResponse> nestedRoute =
+	 *   RouterFunctions.nest(RequestPredicates.path("/user"), userRoutes);
+	 * 
+ * @param predicate the predicate to test + * @param routerFunction the nested router function to delegate to if the predicate applies + * @param the type of response returned by the handler function + * @return a router function that routes to {@code routerFunction} if + * {@code predicate} evaluates to {@code true} + * @see RequestPredicates + */ + public static RouterFunction nest( + RequestPredicate predicate, RouterFunction routerFunction) { + + return new DefaultNestedRouterFunction<>(predicate, routerFunction); + } + + /** + * Route requests that match the given pattern to resources relative to the given root location. + * For instance + *
+	 * Resource location = new FileSystemResource("public-resources/");
+	 * RouterFunction<ServerResponse> resources = RouterFunctions.resources("/resources/**", location);
+     * 
+ * @param pattern the pattern to match + * @param location the location directory relative to which resources should be resolved + * @return a router function that routes to resources + * @see #resourceLookupFunction(String, Resource) + */ + public static RouterFunction resources(String pattern, Resource location) { + return resources(resourceLookupFunction(pattern, location)); + } + + /** + * Returns the resource lookup function used by {@link #resources(String, Resource)}. + * The returned function can be {@linkplain Function#andThen(Function) composed} on, for + * instance to return a default resource when the lookup function does not match: + *
+	 * Mono<Resource> defaultResource = Mono.just(new ClassPathResource("index.html"));
+	 * Function<ServerRequest, Mono<Resource>> lookupFunction =
+	 *   RouterFunctions.resourceLookupFunction("/resources/**", new FileSystemResource("public-resources/"))
+	 *     .andThen(resourceMono -> resourceMono.switchIfEmpty(defaultResource));
+	 * RouterFunction<ServerResponse> resources = RouterFunctions.resources(lookupFunction);
+     * 
+ * @param pattern the pattern to match + * @param location the location directory relative to which resources should be resolved + * @return the default resource lookup function for the given parameters. + */ + public static Function> resourceLookupFunction(String pattern, Resource location) { + return new PathResourceLookupFunction(pattern, location); + } + + /** + * Route to resources using the provided lookup function. If the lookup function provides a + * {@link Resource} for the given request, it will be it will be exposed using a + * {@link HandlerFunction} that handles GET, HEAD, and OPTIONS requests. + * @param lookupFunction the function to provide a {@link Resource} given the {@link ServerRequest} + * @return a router function that routes to resources + */ + public static RouterFunction resources(Function> lookupFunction) { + return new ResourcesRouterFunction(lookupFunction); + } + + /** + * Convert the given {@linkplain RouterFunction router function} into a {@link HttpHandler}. + * This conversion uses {@linkplain HandlerStrategies#builder() default strategies}. + *

The returned handler can be adapted to run in + *

    + *
  • Servlet 3.1+ using the + * {@link org.springframework.http.server.reactive.ServletHttpHandlerAdapter},
  • + *
  • Reactor using the + * {@link org.springframework.http.server.reactive.ReactorHttpHandlerAdapter},
  • + *
  • Undertow using the + * {@link org.springframework.http.server.reactive.UndertowHttpHandlerAdapter}.
  • + *
+ *

Note that {@code HttpWebHandlerAdapter} also implements {@link WebHandler}, allowing + * for additional filter and exception handler registration through + * {@link WebHttpHandlerBuilder}. + * @param routerFunction the router function to convert + * @return an http handler that handles HTTP request using the given router function + */ + public static HttpHandler toHttpHandler(RouterFunction routerFunction) { + return toHttpHandler(routerFunction, HandlerStrategies.withDefaults()); + } + + /** + * Convert the given {@linkplain RouterFunction router function} into a {@link HttpHandler}, + * using the given strategies. + *

The returned {@code HttpHandler} can be adapted to run in + *

    + *
  • Servlet 3.1+ using the + * {@link org.springframework.http.server.reactive.ServletHttpHandlerAdapter},
  • + *
  • Reactor using the + * {@link org.springframework.http.server.reactive.ReactorHttpHandlerAdapter},
  • + *
  • Undertow using the + * {@link org.springframework.http.server.reactive.UndertowHttpHandlerAdapter}.
  • + *
+ * @param routerFunction the router function to convert + * @param strategies the strategies to use + * @return an http handler that handles HTTP request using the given router function + */ + public static HttpHandler toHttpHandler(RouterFunction routerFunction, HandlerStrategies strategies) { + WebHandler webHandler = toWebHandler(routerFunction, strategies); + return WebHttpHandlerBuilder.webHandler(webHandler) + .filters(filters -> filters.addAll(strategies.webFilters())) + .exceptionHandlers(handlers -> handlers.addAll(strategies.exceptionHandlers())) + .localeContextResolver(strategies.localeContextResolver()) + .build(); + } + + /** + * Convert the given {@linkplain RouterFunction router function} into a {@link WebHandler}. + * This conversion uses {@linkplain HandlerStrategies#builder() default strategies}. + * @param routerFunction the router function to convert + * @return a web handler that handles web request using the given router function + */ + public static WebHandler toWebHandler(RouterFunction routerFunction) { + return toWebHandler(routerFunction, HandlerStrategies.withDefaults()); + } + + /** + * Convert the given {@linkplain RouterFunction router function} into a {@link WebHandler}, + * using the given strategies. + * @param routerFunction the router function to convert + * @param strategies the strategies to use + * @return a web handler that handles web request using the given router function + */ + public static WebHandler toWebHandler(RouterFunction routerFunction, HandlerStrategies strategies) { + Assert.notNull(routerFunction, "RouterFunction must not be null"); + Assert.notNull(strategies, "HandlerStrategies must not be null"); + + return exchange -> { + ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders()); + addAttributes(exchange, request); + return routerFunction.route(request) + .defaultIfEmpty(notFound()) + .flatMap(handlerFunction -> wrapException(() -> handlerFunction.handle(request))) + .flatMap(response -> wrapException(() -> response.writeTo(exchange, + new HandlerStrategiesResponseContext(strategies)))); + }; + } + + + private static Mono wrapException(Supplier> supplier) { + try { + return supplier.get(); + } + catch (Throwable ex) { + return Mono.error(ex); + } + } + + private static void addAttributes(ServerWebExchange exchange, ServerRequest request) { + Map attributes = exchange.getAttributes(); + attributes.put(REQUEST_ATTRIBUTE, request); + } + + @SuppressWarnings("unchecked") + private static HandlerFunction notFound() { + return (HandlerFunction) NOT_FOUND_HANDLER; + } + + @SuppressWarnings("unchecked") + static HandlerFunction cast(HandlerFunction handlerFunction) { + return (HandlerFunction) handlerFunction; + } + + + /** + * Represents a discoverable builder for router functions. + * Obtained via {@link RouterFunctions#route()}. + * @since 5.1 + */ + public interface Builder { + + /** + * Adds a route to the given handler function that handles all HTTP {@code GET} requests + * that match the given pattern. + * @param pattern the pattern to match to + * @param handlerFunction the handler function to handle all {@code GET} requests that + * match {@code pattern} + * @return this builder + */ + Builder GET(String pattern, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code GET} requests + * that match the given pattern and predicate. + *

For instance, the following example routes GET requests for "/user" that accept JSON + * to the {@code listUsers} method in {@code userController}: + *

+		 * RouterFunction<ServerResponse> route =
+		 *   RouterFunctions.route()
+		 *     .GET("/user", RequestPredicates.accept(MediaType.APPLICATION_JSON), userController::listUsers)
+		 *     .build();
+		 * 
+ * @param pattern the pattern to match to + * @param predicate additional predicate to match + * @param handlerFunction the handler function to handle all {@code GET} requests that + * match {@code pattern} + * @return this builder + * @see RequestPredicates + */ + Builder GET(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code HEAD} requests + * that match the given pattern. + * @param pattern the pattern to match to + * @param handlerFunction the handler function to handle all {@code HEAD} requests that + * match {@code pattern} + * @return this builder + */ + Builder HEAD(String pattern, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code HEAD} requests + * that match the given pattern and predicate. + * @param pattern the pattern to match to + * @param predicate additional predicate to match + * @param handlerFunction the handler function to handle all {@code HEAD} requests that + * match {@code pattern} + * @return this builder + */ + Builder HEAD(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code POST} requests + * that match the given pattern. + * @param pattern the pattern to match to + * @param handlerFunction the handler function to handle all {@code POST} requests that + * match {@code pattern} + * @return this builder + */ + Builder POST(String pattern, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code POST} requests + * that match the given pattern and predicate. + *

For instance, the following example routes POST requests for "/user" that contain JSON + * to the {@code addUser} method in {@code userController}: + *

+		 * RouterFunction<ServerResponse> route =
+		 *   RouterFunctions.route()
+		 *     .POST("/user", RequestPredicates.contentType(MediaType.APPLICATION_JSON), userController::addUser)
+		 *     .build();
+		 * 
+ * @param pattern the pattern to match to + * @param predicate additional predicate to match + * @param handlerFunction the handler function to handle all {@code POST} requests that + * match {@code pattern} + * @return this builder + */ + Builder POST(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code PUT} requests + * that match the given pattern. + * @param pattern the pattern to match to + * @param handlerFunction the handler function to handle all {@code PUT} requests that + * match {@code pattern} + * @return this builder + */ + Builder PUT(String pattern, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code PUT} requests + * that match the given pattern and predicate. + *

For instance, the following example routes PUT requests for "/user" that contain JSON + * to the {@code editUser} method in {@code userController}: + *

+		 * RouterFunction<ServerResponse> route =
+		 *   RouterFunctions.route()
+		 *     .PUT("/user", RequestPredicates.contentType(MediaType.APPLICATION_JSON), userController::editUser)
+		 *     .build();
+		 * 
+ * @param pattern the pattern to match to + * @param predicate additional predicate to match + * @param handlerFunction the handler function to handle all {@code PUT} requests that + * match {@code pattern} + * @return this builder + */ + Builder PUT(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code PATCH} requests + * that match the given pattern. + * @param pattern the pattern to match to + * @param handlerFunction the handler function to handle all {@code PATCH} requests that + * match {@code pattern} + * @return this builder + */ + Builder PATCH(String pattern, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code PATCH} requests + * that match the given pattern and predicate. + *

For instance, the following example routes PATCH requests for "/user" that contain JSON + * to the {@code editUser} method in {@code userController}: + *

+		 * RouterFunction<ServerResponse> route =
+		 *   RouterFunctions.route()
+		 *     .PATCH("/user", RequestPredicates.contentType(MediaType.APPLICATION_JSON), userController::editUser)
+		 *     .build();
+		 * 
+ * @param pattern the pattern to match to + * @param predicate additional predicate to match + * @param handlerFunction the handler function to handle all {@code PATCH} requests that + * match {@code pattern} + * @return this builder + */ + Builder PATCH(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code DELETE} requests + * that match the given pattern. + * @param pattern the pattern to match to + * @param handlerFunction the handler function to handle all {@code DELETE} requests that + * match {@code pattern} + * @return this builder + */ + Builder DELETE(String pattern, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code DELETE} requests + * that match the given pattern and predicate. + * @param pattern the pattern to match to + * @param predicate additional predicate to match + * @param handlerFunction the handler function to handle all {@code DELETE} requests that + * match {@code pattern} + * @return this builder + */ + Builder DELETE(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code OPTIONS} requests + * that match the given pattern. + * @param pattern the pattern to match to + * @param handlerFunction the handler function to handle all {@code OPTIONS} requests that + * match {@code pattern} + * @return this builder + */ + Builder OPTIONS(String pattern, HandlerFunction handlerFunction); + + /** + * Adds a route to the given handler function that handles all HTTP {@code OPTIONS} requests + * that match the given pattern and predicate. + * @param pattern the pattern to match to + * @param predicate additional predicate to match + * @param handlerFunction the handler function to handle all {@code OPTIONS} requests that + * match {@code pattern} + * @return this builder + */ + Builder OPTIONS(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Adds the given route to this builder. Can be used to merge externally defined router + * functions into this builder, or can be combined with + * {@link RouterFunctions#route(RequestPredicate, HandlerFunction)} + * to allow for more flexible predicate matching. + *

For instance, the following example adds the router function returned from + * {@code OrderController.routerFunction()}. + * to the {@code changeUser} method in {@code userController}: + *

+		 * RouterFunction<ServerResponse> route =
+		 *   RouterFunctions.route()
+		 *     .GET("/users", userController::listUsers)
+		 *     .add(orderController.routerFunction());
+		 *     .build();
+		 * 
+ * @param routerFunction the router function to be added + * @return this builder + * @see RequestPredicates + */ + Builder add(RouterFunction routerFunction); + + /** + * Route requests that match the given pattern to resources relative to the given root location. + * For instance + *
+		 * Resource location = new FileSystemResource("public-resources/");
+		 * RouterFunction<ServerResponse> resources = RouterFunctions.resources("/resources/**", location);
+	     * 
+ * @param pattern the pattern to match + * @param location the location directory relative to which resources should be resolved + * @return this builder + */ + Builder resources(String pattern, Resource location); + + /** + * Route to resources using the provided lookup function. If the lookup function provides a + * {@link Resource} for the given request, it will be it will be exposed using a + * {@link HandlerFunction} that handles GET, HEAD, and OPTIONS requests. + * @param lookupFunction the function to provide a {@link Resource} given the {@link ServerRequest} + * @return this builder + */ + Builder resources(Function> lookupFunction); + + /** + * Route to the supplied router function if the given request predicate applies. This method + * can be used to create nested routes, where a group of routes share a + * common path (prefix), header, or other request predicate. + *

For instance, the following example creates a nested route with a "/user" path + * predicate, so that GET requests for "/user" will list users, + * and POST request for "/user" will create a new user. + *

+		 * RouterFunction<ServerResponse> nestedRoute =
+		 *   RouterFunctions.route()
+		 *     .nest(RequestPredicates.path("/user"), () ->
+		 *       RouterFunctions.route()
+		 *         .GET(this::listUsers)
+		 *         .POST(this::createUser)
+		 *         .build())
+		 *     .build();
+		 * 
+ * @param predicate the predicate to test + * @param routerFunctionSupplier supplier for the nested router function to delegate to if + * the predicate applies + * @return this builder + * @see RequestPredicates + */ + Builder nest(RequestPredicate predicate, Supplier> routerFunctionSupplier); + + /** + * Route to a built router function if the given request predicate applies. + * This method can be used to create nested routes, where a group of routes + * share a common path (prefix), header, or other request predicate. + *

For instance, the following example creates a nested route with a "/user" path + * predicate, so that GET requests for "/user" will list users, + * and POST request for "/user" will create a new user. + *

+		 * RouterFunction<ServerResponse> nestedRoute =
+		 *   RouterFunctions.route()
+		 *     .nest(RequestPredicates.path("/user"), builder ->
+		 *       builder.GET(this::listUsers)
+		 *              .POST(this::createUser))
+		 *     .build();
+		 * 
+ * @param predicate the predicate to test + * @param builderConsumer consumer for a {@code Builder} that provides the nested router + * function + * @return this builder + * @see RequestPredicates + */ + Builder nest(RequestPredicate predicate, Consumer builderConsumer); + + /** + * Route to the supplied router function if the given path prefix pattern applies. This method + * can be used to create nested routes, where a group of routes share a + * common path prefix. Specifically, this method can be used to merge externally defined + * router functions under a path prefix. + *

For instance, the following example creates a nested route with a "/user" path + * predicate that delegates to the router function defined in {@code userController}, + * and with a "/order" path that delegates to {@code orderController}. + *

+		 * RouterFunction<ServerResponse> nestedRoute =
+		 *   RouterFunctions.route()
+		 *     .path("/user", userController::routerFunction)
+		 *     .path("/order", orderController::routerFunction)
+		 *     .build();
+		 * 
+ * @param pattern the pattern to match to + * @param routerFunctionSupplier supplier for the nested router function to delegate to if + * the pattern matches + * @return this builder + */ + Builder path(String pattern, Supplier> routerFunctionSupplier); + + /** + * Route to a built router function if the given path prefix pattern applies. + * This method can be used to create nested routes, where a group of routes + * share a common path prefix. + *

For instance, the following example creates a nested route with a "/user" path + * predicate, so that GET requests for "/user" will list users, + * and POST request for "/user" will create a new user. + *

+		 * RouterFunction<ServerResponse> nestedRoute =
+		 *   RouterFunctions.route()
+		 *     .path("/user", builder ->
+		 *       builder.GET(this::listUsers)
+		 *              .POST(this::createUser))
+		 *     .build();
+		 * 
+ * @param pattern the pattern to match to + * @param builderConsumer consumer for a {@code Builder} that provides the nested router + * function + * @return this builder + */ + Builder path(String pattern, Consumer builderConsumer); + + /** + * Filters all routes created by this builder with the given filter function. Filter + * functions are typically used to address cross-cutting concerns, such as logging, + * security, etc. + *

For instance, the following example creates a filter that returns a 401 Unauthorized + * response if the request does not contain the necessary authentication headers. + *

+		 * RouterFunction<ServerResponse> filteredRoute =
+		 *   RouterFunctions.route()
+		 *     .GET("/user", this::listUsers)
+		 *     .filter((request, next) -> {
+		 *       // check for authentication headers
+		 *       if (isAuthenticated(request)) {
+		 *         return next.handle(request);
+		 *       }
+		 *       else {
+		 *         return ServerResponse.status(HttpStatus.UNAUTHORIZED).build();
+		 *       }
+		 *     })
+		 *     .build();
+		 * 
+ * @param filterFunction the function to filter all routes built by this builder + * @return this builder + */ + Builder filter(HandlerFilterFunction filterFunction); + + /** + * Filter the request object for all routes created by this builder with the given request + * processing function. Filters are typically used to address cross-cutting concerns, such + * as logging, security, etc. + *

For instance, the following example creates a filter that logs the request before + * the handler function executes. + *

+		 * RouterFunction<ServerResponse> filteredRoute =
+		 *   RouterFunctions.route()
+		 *     .GET("/user", this::listUsers)
+		 *     .before(request -> {
+		 *       log(request);
+		 *       return request;
+		 *     })
+		 *     .build();
+		 * 
+ * @param requestProcessor a function that transforms the request + * @return this builder + */ + Builder before(Function requestProcessor); + + /** + * Filter the response object for all routes created by this builder with the given response + * processing function. Filters are typically used to address cross-cutting concerns, such + * as logging, security, etc. + *

For instance, the following example creates a filter that logs the response after + * the handler function executes. + *

+		 * RouterFunction<ServerResponse> filteredRoute =
+		 *   RouterFunctions.route()
+		 *     .GET("/user", this::listUsers)
+		 *     .after((request, response) -> {
+		 *       log(response);
+		 *       return response;
+		 *     })
+		 *     .build();
+		 * 
+ * @param responseProcessor a function that transforms the response + * @return this builder + */ + Builder after(BiFunction responseProcessor); + + /** + * Filters all exceptions that match the predicate by applying the given response provider + * function. + *

For instance, the following example creates a filter that returns a 500 response + * status when an {@code IllegalStateException} occurs. + *

+		 * RouterFunction<ServerResponse> filteredRoute =
+		 *   RouterFunctions.route()
+		 *     .GET("/user", this::listUsers)
+		 *     .onError(e -> e instanceof IllegalStateException,
+		 *       (e, request) -> ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build())
+		 *     .build();
+		 * 
+ * @param predicate the type of exception to filter + * @param responseProvider a function that creates a response + * @return this builder + */ + Builder onError(Predicate predicate, + BiFunction> responseProvider); + + /** + * Filters all exceptions of the given type by applying the given response provider + * function. + *

For instance, the following example creates a filter that returns a 500 response + * status when an {@code IllegalStateException} occurs. + *

+		 * RouterFunction<ServerResponse> filteredRoute =
+		 *   RouterFunctions.route()
+		 *     .GET("/user", this::listUsers)
+		 *     .onError(IllegalStateException.class,
+		 *       (e, request) -> ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build())
+		 *     .build();
+		 * 
+ * @param exceptionType the type of exception to filter + * @param responseProvider a function that creates a response + * @return this builder + */ + Builder onError(Class exceptionType, + BiFunction> responseProvider); + + /** + * Builds the {@code RouterFunction}. All created routes are + * {@linkplain RouterFunction#and(RouterFunction) composed} with one another, and filters + * (if any) are applied to the result. + * @return the built router function + */ + RouterFunction build(); + } + + + /** + * Receives notifications from the logical structure of router functions. + */ + public interface Visitor { + + /** + * Receive notification of the beginning of a nested router function. + * @param predicate the predicate that applies to the nested router functions + * @see RouterFunctions#nest(RequestPredicate, RouterFunction) + */ + void startNested(RequestPredicate predicate); + + /** + * Receive notification of the end of a nested router function. + * @param predicate the predicate that applies to the nested router functions + * @see RouterFunctions#nest(RequestPredicate, RouterFunction) + */ + void endNested(RequestPredicate predicate); + + /** + * Receive notification of a standard predicated route to a handler function. + * @param predicate the predicate that applies to the handler function + * @param handlerFunction the handler function. + * @see RouterFunctions#route(RequestPredicate, HandlerFunction) + */ + void route(RequestPredicate predicate, HandlerFunction handlerFunction); + + /** + * Receive notification of a resource router function. + * @param lookupFunction the lookup function for the resources + * @see RouterFunctions#resources(Function) + */ + void resources(Function> lookupFunction); + + /** + * Receive notification of an unknown router function. This method is called for router + * functions that were not created via the various {@link RouterFunctions} methods. + * @param routerFunction the router function + */ + void unknown(RouterFunction routerFunction); + } + + + private abstract static class AbstractRouterFunction implements RouterFunction { + + @Override + public String toString() { + ToStringVisitor visitor = new ToStringVisitor(); + accept(visitor); + return visitor.toString(); + } + } + + + /** + * A composed routing function that first invokes one function, and then invokes the + * another function (of the same response type {@code T}) if this route had + * {@linkplain Mono#empty() no result}. + * @param the server response type + */ + static final class SameComposedRouterFunction extends AbstractRouterFunction { + + private final RouterFunction first; + + private final RouterFunction second; + + public SameComposedRouterFunction(RouterFunction first, RouterFunction second) { + this.first = first; + this.second = second; + } + + @Override + public Mono> route(ServerRequest request) { + return Flux.concat(this.first.route(request), Mono.defer(() -> this.second.route(request))) + .next(); + } + + @Override + public void accept(Visitor visitor) { + this.first.accept(visitor); + this.second.accept(visitor); + } + } + + + /** + * A composed routing function that first invokes one function, and then invokes + * another function (of a different response type) if this route had + * {@linkplain Mono#empty() no result}. + */ + static final class DifferentComposedRouterFunction extends AbstractRouterFunction { + + private final RouterFunction first; + + private final RouterFunction second; + + public DifferentComposedRouterFunction(RouterFunction first, RouterFunction second) { + this.first = first; + this.second = second; + } + + @Override + public Mono> route(ServerRequest request) { + return Flux.concat(this.first.route(request), Mono.defer(() -> this.second.route(request))) + .next() + .map(RouterFunctions::cast); + } + + @Override + public void accept(Visitor visitor) { + this.first.accept(visitor); + this.second.accept(visitor); + } + } + + + /** + * Filter the specified {@linkplain HandlerFunction handler functions} with the given + * {@linkplain HandlerFilterFunction filter function}. + * @param the type of the {@linkplain HandlerFunction handler function} to filter + * @param the type of the response of the function + */ + static final class FilteredRouterFunction + implements RouterFunction { + + private final RouterFunction routerFunction; + + private final HandlerFilterFunction filterFunction; + + public FilteredRouterFunction( + RouterFunction routerFunction, + HandlerFilterFunction filterFunction) { + this.routerFunction = routerFunction; + this.filterFunction = filterFunction; + } + + @Override + public Mono> route(ServerRequest request) { + return this.routerFunction.route(request).map(this.filterFunction::apply); + } + + @Override + public void accept(Visitor visitor) { + this.routerFunction.accept(visitor); + } + + @Override + public String toString() { + return this.routerFunction.toString(); + } + } + + + private static final class DefaultRouterFunction extends AbstractRouterFunction { + + private final RequestPredicate predicate; + + private final HandlerFunction handlerFunction; + + public DefaultRouterFunction(RequestPredicate predicate, HandlerFunction handlerFunction) { + Assert.notNull(predicate, "Predicate must not be null"); + Assert.notNull(handlerFunction, "HandlerFunction must not be null"); + this.predicate = predicate; + this.handlerFunction = handlerFunction; + } + + @Override + public Mono> route(ServerRequest request) { + if (this.predicate.test(request)) { + if (logger.isTraceEnabled()) { + String logPrefix = request.exchange().getLogPrefix(); + logger.trace(logPrefix + String.format("Matched %s", this.predicate)); + } + return Mono.just(this.handlerFunction); + } + else { + return Mono.empty(); + } + } + + @Override + public void accept(Visitor visitor) { + visitor.route(this.predicate, this.handlerFunction); + } + } + + + private static final class DefaultNestedRouterFunction extends AbstractRouterFunction { + + private final RequestPredicate predicate; + + private final RouterFunction routerFunction; + + public DefaultNestedRouterFunction(RequestPredicate predicate, RouterFunction routerFunction) { + Assert.notNull(predicate, "Predicate must not be null"); + Assert.notNull(routerFunction, "RouterFunction must not be null"); + this.predicate = predicate; + this.routerFunction = routerFunction; + } + + @Override + public Mono> route(ServerRequest serverRequest) { + return this.predicate.nest(serverRequest) + .map(nestedRequest -> { + if (logger.isTraceEnabled()) { + String logPrefix = serverRequest.exchange().getLogPrefix(); + logger.trace(logPrefix + String.format("Matched nested %s", this.predicate)); + } + return this.routerFunction.route(nestedRequest) + .doOnNext(match -> { + if (nestedRequest != serverRequest) { + serverRequest.attributes().clear(); + serverRequest.attributes() + .putAll(nestedRequest.attributes()); + } + }); + } + ).orElseGet(Mono::empty); + } + + + @Override + public void accept(Visitor visitor) { + visitor.startNested(this.predicate); + this.routerFunction.accept(visitor); + visitor.endNested(this.predicate); + } + } + + + private static class ResourcesRouterFunction extends AbstractRouterFunction { + + private final Function> lookupFunction; + + public ResourcesRouterFunction(Function> lookupFunction) { + Assert.notNull(lookupFunction, "Function must not be null"); + this.lookupFunction = lookupFunction; + } + + @Override + public Mono> route(ServerRequest request) { + return this.lookupFunction.apply(request).map(ResourceHandlerFunction::new); + } + + @Override + public void accept(Visitor visitor) { + visitor.resources(this.lookupFunction); + } + } + + + private static class HandlerStrategiesResponseContext implements ServerResponse.Context { + + private final HandlerStrategies strategies; + + public HandlerStrategiesResponseContext(HandlerStrategies strategies) { + this.strategies = strategies; + } + + @Override + public List> messageWriters() { + return this.strategies.messageWriters(); + } + + @Override + public List viewResolvers() { + return this.strategies.viewResolvers(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..d85586bf6fec6ec6c79379fa035af6690e0fba37 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java @@ -0,0 +1,482 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.Charset; +import java.security.Principal; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.function.Consumer; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRange; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.json.Jackson2CodecSupport; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.util.UriBuilder; + +/** + * Represents a server-side HTTP request, as handled by a {@code HandlerFunction}. + * + *

Access to headers and body is offered by {@link Headers} and + * {@link #body(BodyExtractor)}, respectively. + * + * @author Arjen Poutsma + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface ServerRequest { + + /** + * Get the HTTP method. + * @return the HTTP method as an HttpMethod enum value, or {@code null} + * if not resolvable (e.g. in case of a non-standard HTTP method) + */ + @Nullable + default HttpMethod method() { + return HttpMethod.resolve(methodName()); + } + + /** + * Get the name of the HTTP method. + * @return the HTTP method as a String + */ + String methodName(); + + /** + * Get the request URI. + */ + URI uri(); + + /** + * Get a {@code UriBuilderComponents} from the URI associated with this + * {@code ServerRequest}. + *

Note: as of 5.1 this method ignores {@code "Forwarded"} + * and {@code "X-Forwarded-*"} headers that specify the + * client-originated address. Consider using the {@code ForwardedHeaderFilter} + * to extract and use, or to discard such headers. + * @return a URI builder + */ + UriBuilder uriBuilder(); + + /** + * Get the request path. + */ + default String path() { + return uri().getRawPath(); + } + + /** + * Get the request path as a {@code PathContainer}. + */ + default PathContainer pathContainer() { + return PathContainer.parsePath(path()); + } + + /** + * Get the headers of this request. + */ + Headers headers(); + + /** + * Get the cookies of this request. + */ + MultiValueMap cookies(); + + /** + * Get the remote address to which this request is connected, if available. + * @since 5.1 + */ + Optional remoteAddress(); + + /** + * Get the readers used to convert the body of this request. + * @since 5.1 + */ + List> messageReaders(); + + /** + * Extract the body with the given {@code BodyExtractor}. + * @param extractor the {@code BodyExtractor} that reads from the request + * @param the type of the body returned + * @return the extracted body + * @see #body(BodyExtractor, Map) + */ + T body(BodyExtractor extractor); + + /** + * Extract the body with the given {@code BodyExtractor} and hints. + * @param extractor the {@code BodyExtractor} that reads from the request + * @param hints the map of hints like {@link Jackson2CodecSupport#JSON_VIEW_HINT} + * to use to customize body extraction + * @param the type of the body returned + * @return the extracted body + */ + T body(BodyExtractor extractor, Map hints); + + /** + * Extract the body to a {@code Mono}. + * @param elementClass the class of element in the {@code Mono} + * @param the element type + * @return the body as a mono + */ + Mono bodyToMono(Class elementClass); + + /** + * Extract the body to a {@code Mono}. + * @param typeReference a type reference describing the expected response request type + * @param the element type + * @return a mono containing the body of the given type {@code T} + */ + Mono bodyToMono(ParameterizedTypeReference typeReference); + + /** + * Extract the body to a {@code Flux}. + * @param elementClass the class of element in the {@code Flux} + * @param the element type + * @return the body as a flux + */ + Flux bodyToFlux(Class elementClass); + + /** + * Extract the body to a {@code Flux}. + * @param typeReference a type reference describing the expected request body type + * @param the element type + * @return a flux containing the body of the given type {@code T} + */ + Flux bodyToFlux(ParameterizedTypeReference typeReference); + + /** + * Get the request attribute value if present. + * @param name the attribute name + * @return the attribute value + */ + default Optional attribute(String name) { + return Optional.ofNullable(attributes().get(name)); + } + + /** + * Get a mutable map of request attributes. + * @return the request attributes + */ + Map attributes(); + + /** + * Get the first query parameter with the given name, if present. + * @param name the parameter name + * @return the parameter value + */ + default Optional queryParam(String name) { + List queryParamValues = queryParams().get(name); + if (CollectionUtils.isEmpty(queryParamValues)) { + return Optional.empty(); + } + else { + String value = queryParamValues.get(0); + if (value == null) { + value = ""; + } + return Optional.of(value); + } + } + + /** + * Get all query parameters for this request. + */ + MultiValueMap queryParams(); + + /** + * Get the path variable with the given name, if present. + * @param name the variable name + * @return the variable value + * @throws IllegalArgumentException if there is no path variable with the given name + */ + default String pathVariable(String name) { + Map pathVariables = pathVariables(); + if (pathVariables.containsKey(name)) { + return pathVariables().get(name); + } + else { + throw new IllegalArgumentException("No path variable with name \"" + name + "\" available"); + } + } + + /** + * Get all path variables for this request. + */ + Map pathVariables(); + + /** + * Get the web session for this request. + *

Always guaranteed to return an instance either matching the session id + * requested by the client, or with a new session id either because the client + * did not specify one or because the underlying session had expired. + *

Use of this method does not automatically create a session. + */ + Mono session(); + + /** + * Get the authenticated user for the request, if any. + */ + Mono principal(); + + /** + * Get the form data from the body of the request if the Content-Type is + * {@code "application/x-www-form-urlencoded"} or an empty map otherwise. + *

Note: calling this method causes the request body to + * be read and parsed in full, and the resulting {@code MultiValueMap} is + * cached so that this method is safe to call more than once. + */ + Mono> formData(); + + /** + * Get the parts of a multipart request if the Content-Type is + * {@code "multipart/form-data"} or an empty map otherwise. + *

Note: calling this method causes the request body to + * be read and parsed in full, and the resulting {@code MultiValueMap} is + * cached so that this method is safe to call more than once. + */ + Mono> multipartData(); + + /** + * Get the web exchange that this request is based on. + *

Note: Manipulating the exchange directly (instead of using the methods provided on + * {@code ServerRequest} and {@code ServerResponse}) can lead to irregular results. + * @since 5.1 + */ + ServerWebExchange exchange(); + + + // Static builder methods + + /** + * Create a new {@code ServerRequest} based on the given {@code ServerWebExchange} and + * message readers. + * @param exchange the exchange + * @param messageReaders the message readers + * @return the created {@code ServerRequest} + */ + static ServerRequest create(ServerWebExchange exchange, List> messageReaders) { + return new DefaultServerRequest(exchange, messageReaders); + } + + /** + * Create a builder with the {@linkplain HttpMessageReader message readers}, + * method name, URI, headers, cookies, and attributes of the given request. + * @param other the request to copy the message readers, method name, URI, + * headers, and attributes from + * @return the created builder + * @since 5.1 + */ + static Builder from(ServerRequest other) { + return new DefaultServerRequestBuilder(other); + } + + + /** + * Represents the headers of the HTTP request. + * @see ServerRequest#headers() + */ + interface Headers { + + /** + * Get the list of acceptable media types, as specified by the {@code Accept} + * header. + *

Returns an empty list if the acceptable media types are unspecified. + */ + List accept(); + + /** + * Get the list of acceptable charsets, as specified by the + * {@code Accept-Charset} header. + */ + List acceptCharset(); + + /** + * Get the list of acceptable languages, as specified by the + * {@code Accept-Language} header. + */ + List acceptLanguage(); + + /** + * Get the length of the body in bytes, as specified by the + * {@code Content-Length} header. + */ + OptionalLong contentLength(); + + /** + * Get the media type of the body, as specified by the + * {@code Content-Type} header. + */ + Optional contentType(); + + /** + * Get the value of the {@code Host} header, if available. + *

If the header value does not contain a port, the + * {@linkplain InetSocketAddress#getPort() port} in the returned address will + * be {@code 0}. + */ + @Nullable + InetSocketAddress host(); + + /** + * Get the value of the {@code Range} header. + *

Returns an empty list when the range is unknown. + */ + List range(); + + /** + * Get the header value(s), if any, for the header with the given name. + *

Returns an empty list if no header values are found. + * @param headerName the header name + */ + List header(String headerName); + + /** + * Get the headers as an instance of {@link HttpHeaders}. + */ + HttpHeaders asHttpHeaders(); + } + + + /** + * Defines a builder for a request. + * @since 5.1 + */ + interface Builder { + + /** + * Set the method of the request. + * @param method the new method + * @return this builder + */ + Builder method(HttpMethod method); + + /** + * Set the URI of the request. + * @param uri the new URI + * @return this builder + */ + Builder uri(URI uri); + + /** + * Add the given header value(s) under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder header(String headerName, String... headerValues); + + /** + * Manipulate this request's headers with the given consumer. + *

The headers provided to the consumer are "live", so that the consumer can be used to + * {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, + * {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other + * {@link HttpHeaders} methods. + * @param headersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + Builder headers(Consumer headersConsumer); + + /** + * Add a cookie with the given name and value(s). + * @param name the cookie name + * @param values the cookie value(s) + * @return this builder + */ + Builder cookie(String name, String... values); + + /** + * Manipulate this request's cookies with the given consumer. + *

The map provided to the consumer is "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing cookies, + * {@linkplain MultiValueMap#remove(Object) remove} cookies, or use any of the other + * {@link MultiValueMap} methods. + * @param cookiesConsumer a function that consumes the cookies map + * @return this builder + */ + Builder cookies(Consumer> cookiesConsumer); + + /** + * Set the body of the request. + *

Calling this methods will + * {@linkplain org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) release} + * the existing body of the builder. + * @param body the new body + * @return this builder + */ + Builder body(Flux body); + + /** + * Set the body of the request to the UTF-8 encoded bytes of the given string. + *

Calling this methods will + * {@linkplain org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) release} + * the existing body of the builder. + * @param body the new body + * @return this builder + */ + Builder body(String body); + + /** + * Add an attribute with the given name and value. + * @param name the attribute name + * @param value the attribute value + * @return this builder + */ + Builder attribute(String name, Object value); + + /** + * Manipulate this request's attributes with the given consumer. + *

The map provided to the consumer is "live", so that the consumer can be used + * to {@linkplain Map#put(Object, Object) overwrite} existing attributes, + * {@linkplain Map#remove(Object) remove} attributes, or use any of the other + * {@link Map} methods. + * @param attributesConsumer a function that consumes the attributes map + * @return this builder + */ + Builder attributes(Consumer> attributesConsumer); + + /** + * Build the request. + * @return the built request + */ + ServerRequest build(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..286845d9dac4c7397fcc7fc958f915eacc76fb46 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerResponse.java @@ -0,0 +1,471 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.URI; +import java.time.Instant; +import java.time.ZonedDateTime; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.CacheControl; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseCookie; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.json.Jackson2CodecSupport; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.ServerWebExchange; + +/** + * Represents a typed server-side HTTP response, as returned + * by a {@linkplain HandlerFunction handler function} or + * {@linkplain HandlerFilterFunction filter function}. + * + * @author Arjen Poutsma + * @author Juergen Hoeller + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface ServerResponse { + + /** + * Return the status code of this response. + */ + HttpStatus statusCode(); + + /** + * Return the headers of this response. + */ + HttpHeaders headers(); + + /** + * Return the cookies of this response. + */ + MultiValueMap cookies(); + + /** + * Write this response to the given web exchange. + * @param exchange the web exchange to write to + * @param context the context to use when writing + * @return {@code Mono} to indicate when writing is complete + */ + Mono writeTo(ServerWebExchange exchange, Context context); + + + // Static methods + + /** + * Create a builder with the status code and headers of the given response. + * @param other the response to copy the status and headers from + * @return the created builder + */ + static BodyBuilder from(ServerResponse other) { + return new DefaultServerResponseBuilder(other); + } + + /** + * Create a builder with the given HTTP status. + * @param status the response status + * @return the created builder + */ + static BodyBuilder status(HttpStatus status) { + return new DefaultServerResponseBuilder(status); + } + + /** + * Create a builder with the given HTTP status. + * @param status the response status + * @return the created builder + * @since 5.0.3 + */ + static BodyBuilder status(int status) { + return new DefaultServerResponseBuilder(status); + } + + /** + * Create a builder with the status set to {@linkplain HttpStatus#OK 200 OK}. + * @return the created builder + */ + static BodyBuilder ok() { + return status(HttpStatus.OK); + } + + /** + * Create a new builder with a {@linkplain HttpStatus#CREATED 201 Created} status + * and a location header set to the given URI. + * @param location the location URI + * @return the created builder + */ + static BodyBuilder created(URI location) { + BodyBuilder builder = status(HttpStatus.CREATED); + return builder.location(location); + } + + /** + * Create a builder with an {@linkplain HttpStatus#ACCEPTED 202 Accepted} status. + * @return the created builder + */ + static BodyBuilder accepted() { + return status(HttpStatus.ACCEPTED); + } + + /** + * Create a builder with a {@linkplain HttpStatus#NO_CONTENT 204 No Content} status. + * @return the created builder + */ + static HeadersBuilder noContent() { + return status(HttpStatus.NO_CONTENT); + } + + /** + * Create a builder with a {@linkplain HttpStatus#SEE_OTHER 303 See Other} + * status and a location header set to the given URI. + * @param location the location URI + * @return the created builder + */ + static BodyBuilder seeOther(URI location) { + BodyBuilder builder = status(HttpStatus.SEE_OTHER); + return builder.location(location); + } + + /** + * Create a builder with a {@linkplain HttpStatus#TEMPORARY_REDIRECT 307 Temporary Redirect} + * status and a location header set to the given URI. + * @param location the location URI + * @return the created builder + */ + static BodyBuilder temporaryRedirect(URI location) { + BodyBuilder builder = status(HttpStatus.TEMPORARY_REDIRECT); + return builder.location(location); + } + + /** + * Create a builder with a {@linkplain HttpStatus#PERMANENT_REDIRECT 308 Permanent Redirect} + * status and a location header set to the given URI. + * @param location the location URI + * @return the created builder + */ + static BodyBuilder permanentRedirect(URI location) { + BodyBuilder builder = status(HttpStatus.PERMANENT_REDIRECT); + return builder.location(location); + } + + /** + * Create a builder with a {@linkplain HttpStatus#BAD_REQUEST 400 Bad Request} status. + * @return the created builder + */ + static BodyBuilder badRequest() { + return status(HttpStatus.BAD_REQUEST); + } + + /** + * Create a builder with a {@linkplain HttpStatus#NOT_FOUND 404 Not Found} status. + * @return the created builder + */ + static HeadersBuilder notFound() { + return status(HttpStatus.NOT_FOUND); + } + + /** + * Create a builder with an + * {@linkplain HttpStatus#UNPROCESSABLE_ENTITY 422 Unprocessable Entity} status. + * @return the created builder + */ + static BodyBuilder unprocessableEntity() { + return status(HttpStatus.UNPROCESSABLE_ENTITY); + } + + + /** + * Defines a builder that adds headers to the response. + * @param the builder subclass + */ + interface HeadersBuilder> { + + /** + * Add the given header value(s) under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + B header(String headerName, String... headerValues); + + /** + * Manipulate this response's headers with the given consumer. The + * headers provided to the consumer are "live", so that the consumer can be used to + * {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, + * {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other + * {@link HttpHeaders} methods. + * @param headersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + B headers(Consumer headersConsumer); + + /** + * Add the given cookie to the response. + * @param cookie the cookie to add + * @return this builder + */ + B cookie(ResponseCookie cookie); + + /** + * Manipulate this response's cookies with the given consumer. The + * cookies provided to the consumer are "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing cookies, + * {@linkplain MultiValueMap#remove(Object) remove} cookies, or use any of the other + * {@link MultiValueMap} methods. + * @param cookiesConsumer a function that consumes the cookies + * @return this builder + */ + B cookies(Consumer> cookiesConsumer); + + /** + * Set the set of allowed {@link HttpMethod HTTP methods}, as specified + * by the {@code Allow} header. + * + * @param allowedMethods the allowed methods + * @return this builder + * @see HttpHeaders#setAllow(Set) + */ + B allow(HttpMethod... allowedMethods); + + /** + * Set the set of allowed {@link HttpMethod HTTP methods}, as specified + * by the {@code Allow} header. + * @param allowedMethods the allowed methods + * @return this builder + * @see HttpHeaders#setAllow(Set) + */ + B allow(Set allowedMethods); + + /** + * Set the entity tag of the body, as specified by the {@code ETag} header. + * @param eTag the new entity tag + * @return this builder + * @see HttpHeaders#setETag(String) + */ + B eTag(String eTag); + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @param lastModified the last modified date + * @return this builder + * @see HttpHeaders#setLastModified(long) + */ + B lastModified(ZonedDateTime lastModified); + + /** + * Set the time the resource was last changed, as specified by the + * {@code Last-Modified} header. + * @param lastModified the last modified date + * @return this builder + * @since 5.1.4 + * @see HttpHeaders#setLastModified(long) + */ + B lastModified(Instant lastModified); + + /** + * Set the location of a resource, as specified by the {@code Location} header. + * @param location the location + * @return this builder + * @see HttpHeaders#setLocation(URI) + */ + B location(URI location); + + /** + * Set the caching directives for the resource, as specified by the HTTP 1.1 + * {@code Cache-Control} header. + *

A {@code CacheControl} instance can be built like + * {@code CacheControl.maxAge(3600).cachePublic().noTransform()}. + * @param cacheControl a builder for cache-related HTTP response headers + * @return this builder + * @see RFC-7234 Section 5.2 + */ + B cacheControl(CacheControl cacheControl); + + /** + * Configure one or more request header names (e.g. "Accept-Language") to + * add to the "Vary" response header to inform clients that the response is + * subject to content negotiation and variances based on the value of the + * given request headers. The configured request header names are added only + * if not already present in the response "Vary" header. + * @param requestHeaders request header names + * @return this builder + */ + B varyBy(String... requestHeaders); + + /** + * Build the response entity with no body. + */ + Mono build(); + + /** + * Build the response entity with no body. + * The response will be committed when the given {@code voidPublisher} completes. + * @param voidPublisher publisher publisher to indicate when the response should be committed + */ + Mono build(Publisher voidPublisher); + + /** + * Build the response entity with a custom writer function. + * @param writeFunction the function used to write to the {@link ServerWebExchange} + */ + Mono build(BiFunction> writeFunction); + } + + + /** + * Defines a builder that adds a body to the response. + */ + interface BodyBuilder extends HeadersBuilder { + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + * @param contentLength the content length + * @return this builder + * @see HttpHeaders#setContentLength(long) + */ + BodyBuilder contentLength(long contentLength); + + /** + * Set the {@linkplain MediaType media type} of the body, as specified by the + * {@code Content-Type} header. + * @param contentType the content type + * @return this builder + * @see HttpHeaders#setContentType(MediaType) + */ + BodyBuilder contentType(MediaType contentType); + + /** + * Add a serialization hint like {@link Jackson2CodecSupport#JSON_VIEW_HINT} + * to customize how the body will be serialized. + * @param key the hint key + * @param value the hint value + */ + BodyBuilder hint(String key, Object value); + + /** + * Customize the serialization hints with the given consumer. + * @param hintsConsumer a function that consumes the hints + * @return this builder + * @since 5.1.6 + */ + BodyBuilder hints(Consumer> hintsConsumer); + + /** + * Set the body of the response to the given asynchronous {@code Publisher} and return it. + * This convenience method combines {@link #body(BodyInserter)} and + * {@link BodyInserters#fromPublisher(Publisher, Class)}. + * @param publisher the {@code Publisher} to write to the response + * @param elementClass the class of elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the built response + */ + > Mono body(P publisher, Class elementClass); + + /** + * Set the body of the response to the given asynchronous {@code Publisher} and return it. + * This convenience method combines {@link #body(BodyInserter)} and + * {@link BodyInserters#fromPublisher(Publisher, Class)}. + * @param publisher the {@code Publisher} to write to the response + * @param typeReference a type reference describing the elements contained in the publisher + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the built response + */ + > Mono body(P publisher, + ParameterizedTypeReference typeReference); + + /** + * Set the body of the response to the given synchronous {@code Object} and return it. + * This convenience method combines {@link #body(BodyInserter)} and + * {@link BodyInserters#fromObject(Object)}. + * @param body the body of the response + * @return the built response + * @throws IllegalArgumentException if {@code body} is a {@link Publisher}, for which + * {@link #body(Publisher, Class)} should be used. + */ + Mono syncBody(Object body); + + /** + * Set the body of the response to the given {@code BodyInserter} and return it. + * @param inserter the {@code BodyInserter} that writes to the response + * @return the built response + */ + Mono body(BodyInserter inserter); + + /** + * Render the template with the given {@code name} using the given {@code modelAttributes}. + * The model attributes are mapped under a + * {@linkplain org.springframework.core.Conventions#getVariableName generated name}. + *

Note: Empty {@link Collection Collections} are not added to + * the model when using this method because we cannot correctly determine + * the true convention name. + * @param name the name of the template to be rendered + * @param modelAttributes the modelAttributes used to render the template + * @return the built response + */ + Mono render(String name, Object... modelAttributes); + + /** + * Render the template with the given {@code name} using the given {@code model}. + * @param name the name of the template to be rendered + * @param model the model used to render the template + * @return the built response + */ + Mono render(String name, Map model); + } + + + /** + * Defines the context used during the {@link #writeTo(ServerWebExchange, Context)}. + */ + interface Context { + + /** + * Return the {@link HttpMessageWriter HttpMessageWriters} to be used for response body conversion. + * @return the list of message writers + */ + List> messageWriters(); + + /** + * Return the {@link ViewResolver ViewResolvers} to be used for view name resolution. + * @return the list of view resolvers + */ + List viewResolvers(); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ToStringVisitor.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ToStringVisitor.java new file mode 100644 index 0000000000000000000000000000000000000000..dcfb74e5691c310b15dc09dc1d4694710a7368ec --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ToStringVisitor.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.util.Set; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.http.HttpMethod; + +/** + * Implementation of {@link RouterFunctions.Visitor} that creates a formatted + * string representation of router functions. + * + * @author Arjen Poutsma + * @since 5.0 + */ +class ToStringVisitor implements RouterFunctions.Visitor, RequestPredicates.Visitor { + + private final StringBuilder builder = new StringBuilder(); + + private int indent = 0; + + + // RouterFunctions.Visitor + + @Override + public void startNested(RequestPredicate predicate) { + indent(); + predicate.accept(this); + this.builder.append(" => {\n"); + this.indent++; + } + + @Override + public void endNested(RequestPredicate predicate) { + this.indent--; + indent(); + this.builder.append("}\n"); + } + + @Override + public void route(RequestPredicate predicate, HandlerFunction handlerFunction) { + indent(); + predicate.accept(this); + this.builder.append(" -> "); + this.builder.append(handlerFunction).append('\n'); + } + + @Override + public void resources(Function> lookupFunction) { + indent(); + this.builder.append(lookupFunction).append('\n'); + } + + @Override + public void unknown(RouterFunction routerFunction) { + indent(); + this.builder.append(routerFunction); + } + + private void indent() { + for (int i = 0; i < this.indent; i++) { + this.builder.append(' '); + } + } + + + // RequestPredicates.Visitor + + @Override + public void method(Set methods) { + if (methods.size() == 1) { + this.builder.append(methods.iterator().next()); + } + else { + this.builder.append(methods); + } + } + + @Override + public void path(String pattern) { + this.builder.append(pattern); + } + + @Override + public void pathExtension(String extension) { + this.builder.append(String.format("*.%s", extension)); + } + + @Override + public void header(String name, String value) { + this.builder.append(String.format("%s: %s", name, value)); + } + + @Override + public void queryParam(String name, String value) { + this.builder.append(String.format("?%s == %s", name, value)); + } + + @Override + public void startAnd() { + this.builder.append('('); + } + + @Override + public void and() { + this.builder.append(" && "); + } + + @Override + public void endAnd() { + this.builder.append(')'); + } + + @Override + public void startOr() { + this.builder.append('('); + } + + @Override + public void or() { + this.builder.append(" || "); + + } + + @Override + public void endOr() { + this.builder.append(')'); + } + + @Override + public void startNegate() { + this.builder.append("!("); + } + + @Override + public void endNegate() { + this.builder.append(')'); + } + + @Override + public void unknown(RequestPredicate predicate) { + this.builder.append(predicate); + } + + @Override + public String toString() { + String result = this.builder.toString(); + if (result.endsWith("\n")) { + result = result.substring(0, result.length() - 1); + } + return result; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..f790316d726980168f36851108595e6e515544d2 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides the types that make up Spring's functional web framework. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.function.server; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/HandlerFunctionAdapter.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/HandlerFunctionAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..fb4fa3c93e9a2bde4f9a22f1c6d56f42ae15581c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/HandlerFunctionAdapter.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server.support; + +import java.lang.reflect.Method; + +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.web.reactive.HandlerAdapter; +import org.springframework.web.reactive.HandlerResult; +import org.springframework.web.reactive.function.server.HandlerFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.server.ServerWebExchange; + +/** + * {@code HandlerAdapter} implementation that supports {@link HandlerFunction HandlerFunctions}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public class HandlerFunctionAdapter implements HandlerAdapter { + + private static final MethodParameter HANDLER_FUNCTION_RETURN_TYPE; + + static { + try { + Method method = HandlerFunction.class.getMethod("handle", ServerRequest.class); + HANDLER_FUNCTION_RETURN_TYPE = new MethodParameter(method, -1); + } + catch (NoSuchMethodException ex) { + throw new IllegalStateException(ex); + } + } + + + @Override + public boolean supports(Object handler) { + return handler instanceof HandlerFunction; + } + + @Override + public Mono handle(ServerWebExchange exchange, Object handler) { + HandlerFunction handlerFunction = (HandlerFunction) handler; + ServerRequest request = exchange.getRequiredAttribute(RouterFunctions.REQUEST_ATTRIBUTE); + return handlerFunction.handle(request) + .map(response -> new HandlerResult(handlerFunction, response, HANDLER_FUNCTION_RETURN_TYPE)); + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..33baed774934a71f1ccbfcd7bde8538b61e949cb --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java @@ -0,0 +1,175 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server.support; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.web.reactive.function.server.HandlerFunction; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.handler.AbstractHandlerMapping; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.pattern.PathPattern; + +/** + * {@code HandlerMapping} implementation that supports {@link RouterFunction RouterFunctions}. + * + *

If no {@link RouterFunction} is provided at + * {@linkplain #RouterFunctionMapping(RouterFunction) construction time}, this mapping + * will detect all router functions in the application context, and consult them in + * {@linkplain org.springframework.core.annotation.Order order}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public class RouterFunctionMapping extends AbstractHandlerMapping implements InitializingBean { + + @Nullable + private RouterFunction routerFunction; + + private List> messageReaders = Collections.emptyList(); + + + /** + * Create an empty {@code RouterFunctionMapping}. + *

If this constructor is used, this mapping will detect all + * {@link RouterFunction} instances available in the application context. + */ + public RouterFunctionMapping() { + } + + /** + * Create a {@code RouterFunctionMapping} with the given {@link RouterFunction}. + *

If this constructor is used, no application context detection will occur. + * @param routerFunction the router function to use for mapping + */ + public RouterFunctionMapping(RouterFunction routerFunction) { + this.routerFunction = routerFunction; + } + + + /** + * Return the configured {@link RouterFunction}. + *

Note: When router functions are detected from the + * ApplicationContext, this method may return {@code null} if invoked + * prior to {@link #afterPropertiesSet()}. + * @return the router function or {@code null} + */ + @Nullable + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Configure HTTP message readers to de-serialize the request body with. + *

By default this is set to the {@link ServerCodecConfigurer}'s defaults. + */ + public void setMessageReaders(List> messageReaders) { + this.messageReaders = messageReaders; + } + + @Override + public void afterPropertiesSet() throws Exception { + if (CollectionUtils.isEmpty(this.messageReaders)) { + ServerCodecConfigurer codecConfigurer = ServerCodecConfigurer.create(); + this.messageReaders = codecConfigurer.getReaders(); + } + + if (this.routerFunction == null) { + initRouterFunctions(); + } + } + + /** + * Initialized the router functions by detecting them in the application context. + */ + protected void initRouterFunctions() { + List> routerFunctions = routerFunctions(); + this.routerFunction = routerFunctions.stream().reduce(RouterFunction::andOther).orElse(null); + logRouterFunctions(routerFunctions); + } + + private List> routerFunctions() { + List> functions = obtainApplicationContext() + .getBeanProvider(RouterFunction.class) + .orderedStream() + .map(router -> (RouterFunction)router) + .collect(Collectors.toList()); + return (!CollectionUtils.isEmpty(functions) ? functions : Collections.emptyList()); + } + + private void logRouterFunctions(List> routerFunctions) { + if (logger.isDebugEnabled()) { + int total = routerFunctions.size(); + String message = total + " RouterFunction(s) in " + formatMappingName(); + if (logger.isTraceEnabled()) { + if (total > 0) { + routerFunctions.forEach(routerFunction -> logger.trace("Mapped " + routerFunction)); + } + else { + logger.trace(message); + } + } + else if (total > 0) { + logger.debug(message); + } + } + } + + + @Override + protected Mono getHandlerInternal(ServerWebExchange exchange) { + if (this.routerFunction != null) { + ServerRequest request = ServerRequest.create(exchange, this.messageReaders); + return this.routerFunction.route(request) + .doOnNext(handler -> setAttributes(exchange.getAttributes(), request, handler)); + } + else { + return Mono.empty(); + } + } + + @SuppressWarnings("unchecked") + private void setAttributes( + Map attributes, ServerRequest serverRequest, HandlerFunction handlerFunction) { + + attributes.put(RouterFunctions.REQUEST_ATTRIBUTE, serverRequest); + attributes.put(BEST_MATCHING_HANDLER_ATTRIBUTE, handlerFunction); + + PathPattern matchingPattern = (PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE); + if (matchingPattern != null) { + attributes.put(BEST_MATCHING_PATTERN_ATTRIBUTE, matchingPattern); + } + Map uriVariables = + (Map) attributes.get(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE); + if (uriVariables != null) { + attributes.put(URI_TEMPLATE_VARIABLES_ATTRIBUTE, uriVariables); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerRequestWrapper.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerRequestWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..3c5284c72615c2c03115132cd7a2388b82685c27 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerRequestWrapper.java @@ -0,0 +1,281 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server.support; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.Charset; +import java.security.Principal; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRange; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.util.UriBuilder; + +/** + * Implementation of the {@link ServerRequest} interface that can be subclassed + * to adapt the request in a + * {@link org.springframework.web.reactive.function.server.HandlerFilterFunction handler filter function}. + * All methods default to calling through to the wrapped request. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public class ServerRequestWrapper implements ServerRequest { + + private final ServerRequest delegate; + + + /** + * Create a new {@code ServerRequestWrapper} that wraps the given request. + * @param delegate the request to wrap + */ + public ServerRequestWrapper(ServerRequest delegate) { + Assert.notNull(delegate, "Delegate must not be null"); + this.delegate = delegate; + } + + + /** + * Return the wrapped request. + */ + public ServerRequest request() { + return this.delegate; + } + + @Override + public HttpMethod method() { + return this.delegate.method(); + } + + @Override + public String methodName() { + return this.delegate.methodName(); + } + + @Override + public URI uri() { + return this.delegate.uri(); + } + + @Override + public UriBuilder uriBuilder() { + return this.delegate.uriBuilder(); + } + + @Override + public String path() { + return this.delegate.path(); + } + + @Override + public PathContainer pathContainer() { + return this.delegate.pathContainer(); + } + + @Override + public Headers headers() { + return this.delegate.headers(); + } + + @Override + public MultiValueMap cookies() { + return this.delegate.cookies(); + } + + @Override + public Optional remoteAddress() { + return this.delegate.remoteAddress(); + } + + @Override + public List> messageReaders() { + return this.delegate.messageReaders(); + } + + @Override + public T body(BodyExtractor extractor) { + return this.delegate.body(extractor); + } + + @Override + public T body(BodyExtractor extractor, Map hints) { + return this.delegate.body(extractor, hints); + } + + @Override + public Mono bodyToMono(Class elementClass) { + return this.delegate.bodyToMono(elementClass); + } + + @Override + public Mono bodyToMono(ParameterizedTypeReference typeReference) { + return this.delegate.bodyToMono(typeReference); + } + + @Override + public Flux bodyToFlux(Class elementClass) { + return this.delegate.bodyToFlux(elementClass); + } + + @Override + public Flux bodyToFlux(ParameterizedTypeReference typeReference) { + return this.delegate.bodyToFlux(typeReference); + } + + @Override + public Optional attribute(String name) { + return this.delegate.attribute(name); + } + + @Override + public Map attributes() { + return this.delegate.attributes(); + } + + @Override + public Optional queryParam(String name) { + return this.delegate.queryParam(name); + } + + @Override + public MultiValueMap queryParams() { + return this.delegate.queryParams(); + } + + @Override + public String pathVariable(String name) { + return this.delegate.pathVariable(name); + } + + @Override + public Map pathVariables() { + return this.delegate.pathVariables(); + } + + @Override + public Mono session() { + return this.delegate.session(); + } + + @Override + public Mono principal() { + return this.delegate.principal(); + } + + @Override + public Mono> formData() { + return this.delegate.formData(); + } + + @Override + public Mono> multipartData() { + return this.delegate.multipartData(); + } + + @Override + public ServerWebExchange exchange() { + return this.delegate.exchange(); + } + + /** + * Implementation of the {@code Headers} interface that can be subclassed + * to adapt the headers in a + * {@link org.springframework.web.reactive.function.server.HandlerFilterFunction handler filter function}. + * All methods default to calling through to the wrapped headers. + */ + public static class HeadersWrapper implements ServerRequest.Headers { + + private final Headers headers; + + /** + * Create a new {@code HeadersWrapper} that wraps the given request. + * @param headers the headers to wrap + */ + public HeadersWrapper(Headers headers) { + Assert.notNull(headers, "Headers must not be null"); + this.headers = headers; + } + + @Override + public List accept() { + return this.headers.accept(); + } + + @Override + public List acceptCharset() { + return this.headers.acceptCharset(); + } + + @Override + public List acceptLanguage() { + return this.headers.acceptLanguage(); + } + + @Override + public OptionalLong contentLength() { + return this.headers.contentLength(); + } + + @Override + public Optional contentType() { + return this.headers.contentType(); + } + + @Override + public InetSocketAddress host() { + return this.headers.host(); + } + + @Override + public List range() { + return this.headers.range(); + } + + @Override + public List header(String headerName) { + return this.headers.header(headerName); + } + + @Override + public HttpHeaders asHttpHeaders() { + return this.headers.asHttpHeaders(); + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerResponseResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerResponseResultHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..1eac334d3f3a7b4102a84d6a32e37f4203ec8d5d --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerResponseResultHandler.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server.support; + +import java.util.Collections; +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.Ordered; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.reactive.HandlerResult; +import org.springframework.web.reactive.HandlerResultHandler; +import org.springframework.web.reactive.function.server.ServerResponse; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.ServerWebExchange; + +/** + * {@code HandlerResultHandler} implementation that supports {@link ServerResponse ServerResponses}. + * + * @author Arjen Poutsma + * @since 5.0 + */ +public class ServerResponseResultHandler implements HandlerResultHandler, InitializingBean, Ordered { + + private List> messageWriters = Collections.emptyList(); + + private List viewResolvers = Collections.emptyList(); + + private int order = 0; + + + /** + * Configure HTTP message writers to serialize the request body with. + *

By default this is set to {@link ServerCodecConfigurer}'s default writers. + */ + public void setMessageWriters(List> configurer) { + this.messageWriters = configurer; + } + + public void setViewResolvers(List viewResolvers) { + this.viewResolvers = viewResolvers; + } + + /** + * Set the order for this result handler relative to others. + *

By default set to 0. It is generally safe to place it early in the + * order as it looks for a concrete return type. + */ + public void setOrder(int order) { + this.order = order; + } + + @Override + public int getOrder() { + return this.order; + } + + + @Override + public void afterPropertiesSet() throws Exception { + if (CollectionUtils.isEmpty(this.messageWriters)) { + throw new IllegalArgumentException("Property 'messageWriters' is required"); + } + } + + @Override + public boolean supports(HandlerResult result) { + return (result.getReturnValue() instanceof ServerResponse); + } + + @Override + public Mono handleResult(ServerWebExchange exchange, HandlerResult result) { + ServerResponse response = (ServerResponse) result.getReturnValue(); + Assert.state(response != null, "No ServerResponse"); + return response.writeTo(exchange, new ServerResponse.Context() { + @Override + public List> messageWriters() { + return messageWriters; + } + @Override + public List viewResolvers() { + return viewResolvers; + } + }); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..2f4f6d65044294579da29f3bebe58c4151fc7a6e --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/package-info.java @@ -0,0 +1,12 @@ +/** + * Classes supporting the {@code org.springframework.web.reactive.function.server} package. + * Contains a {@code HandlerAdapter} that supports {@code HandlerFunction}s, + * a {@code HandlerResultHandler} that supports {@code ServerResponse}s, and + * a {@code ServerRequest} wrapper to adapt a request. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.function.server.support; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..3a736d50ffc14f2983ba8b77e76b7f754fd65f9c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.handler; + +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.context.support.ApplicationObjectSupport; +import org.springframework.core.Ordered; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.reactive.CorsConfigurationSource; +import org.springframework.web.cors.reactive.CorsProcessor; +import org.springframework.web.cors.reactive.CorsUtils; +import org.springframework.web.cors.reactive.DefaultCorsProcessor; +import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource; +import org.springframework.web.reactive.HandlerMapping; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * Abstract base class for {@link org.springframework.web.reactive.HandlerMapping} + * implementations. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Brian Clozel + * @since 5.0 + */ +public abstract class AbstractHandlerMapping extends ApplicationObjectSupport + implements HandlerMapping, Ordered, BeanNameAware { + + private static final WebHandler REQUEST_HANDLED_HANDLER = exchange -> Mono.empty(); + + + private final PathPatternParser patternParser; + + private CorsConfigurationSource corsConfigurationSource; + + private CorsProcessor corsProcessor = new DefaultCorsProcessor(); + + private int order = Ordered.LOWEST_PRECEDENCE; // default: same as non-Ordered + + @Nullable + private String beanName; + + + public AbstractHandlerMapping() { + this.patternParser = new PathPatternParser(); + this.corsConfigurationSource = new UrlBasedCorsConfigurationSource(this.patternParser); + } + + + /** + * Shortcut method for setting the same property on the underlying pattern + * parser in use. For more details see: + *

    + *
  • {@link #getPathPatternParser()} -- the underlying pattern parser + *
  • {@link PathPatternParser#setCaseSensitive(boolean)} -- the case + * sensitive slash option, including its default value. + *
+ *

Note: aside from + */ + public void setUseCaseSensitiveMatch(boolean caseSensitiveMatch) { + this.patternParser.setCaseSensitive(caseSensitiveMatch); + } + + /** + * Shortcut method for setting the same property on the underlying pattern + * parser in use. For more details see: + *

    + *
  • {@link #getPathPatternParser()} -- the underlying pattern parser + *
  • {@link PathPatternParser#setMatchOptionalTrailingSeparator(boolean)} -- + * the trailing slash option, including its default value. + *
+ */ + public void setUseTrailingSlashMatch(boolean trailingSlashMatch) { + this.patternParser.setMatchOptionalTrailingSeparator(trailingSlashMatch); + } + + /** + * Return the {@link PathPatternParser} instance that is used for + * {@link #setCorsConfigurations(Map) CORS configuration checks}. + * Sub-classes can also use this pattern parser for their own request + * mapping purposes. + */ + public PathPatternParser getPathPatternParser() { + return this.patternParser; + } + + /** + * Set the "global" CORS configurations based on URL patterns. By default the + * first matching URL pattern is combined with handler-level CORS configuration if any. + * @see #setCorsConfigurationSource(CorsConfigurationSource) + */ + public void setCorsConfigurations(Map corsConfigurations) { + Assert.notNull(corsConfigurations, "corsConfigurations must not be null"); + this.corsConfigurationSource = new UrlBasedCorsConfigurationSource(this.patternParser); + ((UrlBasedCorsConfigurationSource) this.corsConfigurationSource).setCorsConfigurations(corsConfigurations); + } + + /** + * Set the "global" CORS configuration source. By default the first matching URL + * pattern is combined with the CORS configuration for the handler, if any. + * @since 5.1 + * @see #setCorsConfigurations(Map) + */ + public void setCorsConfigurationSource(CorsConfigurationSource corsConfigurationSource) { + Assert.notNull(corsConfigurationSource, "corsConfigurationSource must not be null"); + this.corsConfigurationSource = corsConfigurationSource; + } + + /** + * Configure a custom {@link CorsProcessor} to use to apply the matched + * {@link CorsConfiguration} for a request. + *

By default an instance of {@link DefaultCorsProcessor} is used. + */ + public void setCorsProcessor(CorsProcessor corsProcessor) { + Assert.notNull(corsProcessor, "CorsProcessor must not be null"); + this.corsProcessor = corsProcessor; + } + + /** + * Return the configured {@link CorsProcessor}. + */ + public CorsProcessor getCorsProcessor() { + return this.corsProcessor; + } + + /** + * Specify the order value for this HandlerMapping bean. + *

The default value is {@code Ordered.LOWEST_PRECEDENCE}, meaning non-ordered. + * @see org.springframework.core.Ordered#getOrder() + */ + public void setOrder(int order) { + this.order = order; + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public void setBeanName(String name) { + this.beanName = name; + } + + protected String formatMappingName() { + return this.beanName != null ? "'" + this.beanName + "'" : ""; + } + + + @Override + public Mono getHandler(ServerWebExchange exchange) { + return getHandlerInternal(exchange).map(handler -> { + if (logger.isDebugEnabled()) { + logger.debug(exchange.getLogPrefix() + "Mapped to " + handler); + } + if (CorsUtils.isCorsRequest(exchange.getRequest())) { + CorsConfiguration configA = this.corsConfigurationSource.getCorsConfiguration(exchange); + CorsConfiguration configB = getCorsConfiguration(handler, exchange); + CorsConfiguration config = (configA != null ? configA.combine(configB) : configB); + if (!getCorsProcessor().process(config, exchange) || + CorsUtils.isPreFlightRequest(exchange.getRequest())) { + return REQUEST_HANDLED_HANDLER; + } + } + return handler; + }); + } + + /** + * Look up a handler for the given request, returning an empty {@code Mono} + * if no specific one is found. This method is called by {@link #getHandler}. + *

On CORS pre-flight requests this method should return a match not for + * the pre-flight request but for the expected actual request based on the URL + * path, the HTTP methods from the "Access-Control-Request-Method" header, and + * the headers from the "Access-Control-Request-Headers" header. + * @param exchange current exchange + * @return {@code Mono} for the matching handler, if any + */ + protected abstract Mono getHandlerInternal(ServerWebExchange exchange); + + /** + * Retrieve the CORS configuration for the given handler. + * @param handler the handler to check (never {@code null}) + * @param exchange the current exchange + * @return the CORS configuration for the handler, or {@code null} if none + */ + @Nullable + protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) { + if (handler instanceof CorsConfigurationSource) { + return ((CorsConfigurationSource) handler).getCorsConfiguration(exchange); + } + return null; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..10296b8d4b83bad8d269356409d6c818ee919c91 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.handler; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import reactor.core.publisher.Mono; + +import org.springframework.beans.BeansException; +import org.springframework.http.server.PathContainer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.pattern.PathPattern; + +/** + * Abstract base class for URL-mapped + * {@link org.springframework.web.reactive.HandlerMapping} implementations. + * + *

Supports direct matches, e.g. a registered "/test" matches "/test", and + * various path pattern matches, e.g. a registered "/t*" pattern matches + * both "/test" and "/team", "/test/*" matches all paths under "/test", + * "/test/**" matches all paths below "/test". For details, see the + * {@link org.springframework.web.util.pattern.PathPattern} javadoc. + * + *

Will search all path patterns to find the most specific match for the + * current request path. The most specific pattern is defined as the longest + * path pattern with the fewest captured variables and wildcards. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Brian Clozel + * @since 5.0 + */ +public abstract class AbstractUrlHandlerMapping extends AbstractHandlerMapping { + + private boolean lazyInitHandlers = false; + + private final Map handlerMap = new LinkedHashMap<>(); + + + /** + * Set whether to lazily initialize handlers. Only applicable to + * singleton handlers, as prototypes are always lazily initialized. + * Default is "false", as eager initialization allows for more efficiency + * through referencing the controller objects directly. + *

If you want to allow your controllers to be lazily initialized, + * make them "lazy-init" and set this flag to true. Just making them + * "lazy-init" will not work, as they are initialized through the + * references from the handler mapping in this case. + */ + public void setLazyInitHandlers(boolean lazyInitHandlers) { + this.lazyInitHandlers = lazyInitHandlers; + } + + /** + * Return a read-only view of registered path patterns and handlers which may + * may be an actual handler instance or the bean name of lazily initialized + * handler. + */ + public final Map getHandlerMap() { + return Collections.unmodifiableMap(this.handlerMap); + } + + + @Override + public Mono getHandlerInternal(ServerWebExchange exchange) { + PathContainer lookupPath = exchange.getRequest().getPath().pathWithinApplication(); + Object handler; + try { + handler = lookupHandler(lookupPath, exchange); + } + catch (Exception ex) { + return Mono.error(ex); + } + return Mono.justOrEmpty(handler); + } + + /** + * Look up a handler instance for the given URL lookup path. + *

Supports direct matches, e.g. a registered "/test" matches "/test", + * and various path pattern matches, e.g. a registered "/t*" matches + * both "/test" and "/team". For details, see the PathPattern class. + * @param lookupPath the URL the handler is mapped to + * @param exchange the current exchange + * @return the associated handler instance, or {@code null} if not found + * @see org.springframework.web.util.pattern.PathPattern + */ + @Nullable + protected Object lookupHandler(PathContainer lookupPath, ServerWebExchange exchange) throws Exception { + + List matches = this.handlerMap.keySet().stream() + .filter(key -> key.matches(lookupPath)) + .collect(Collectors.toList()); + + if (matches.isEmpty()) { + return null; + } + + if (matches.size() > 1) { + matches.sort(PathPattern.SPECIFICITY_COMPARATOR); + if (logger.isTraceEnabled()) { + logger.debug(exchange.getLogPrefix() + "Matching patterns " + matches); + } + } + + PathPattern pattern = matches.get(0); + PathContainer pathWithinMapping = pattern.extractPathWithinPattern(lookupPath); + return handleMatch(this.handlerMap.get(pattern), pattern, pathWithinMapping, exchange); + } + + private Object handleMatch(Object handler, PathPattern bestMatch, PathContainer pathWithinMapping, + ServerWebExchange exchange) { + + // Bean name or resolved handler? + if (handler instanceof String) { + String handlerName = (String) handler; + handler = obtainApplicationContext().getBean(handlerName); + } + + validateHandler(handler, exchange); + + exchange.getAttributes().put(BEST_MATCHING_HANDLER_ATTRIBUTE, handler); + exchange.getAttributes().put(BEST_MATCHING_PATTERN_ATTRIBUTE, bestMatch); + exchange.getAttributes().put(PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, pathWithinMapping); + + return handler; + } + + /** + * Validate the given handler against the current request. + *

The default implementation is empty. Can be overridden in subclasses, + * for example to enforce specific preconditions expressed in URL mappings. + * @param handler the handler object to validate + * @param exchange current exchange + */ + @SuppressWarnings("UnusedParameters") + protected void validateHandler(Object handler, ServerWebExchange exchange) { + } + + /** + * Register the specified handler for the given URL paths. + * @param urlPaths the URLs that the bean should be mapped to + * @param beanName the name of the handler bean + * @throws BeansException if the handler couldn't be registered + * @throws IllegalStateException if there is a conflicting handler registered + */ + protected void registerHandler(String[] urlPaths, String beanName) throws BeansException, IllegalStateException { + Assert.notNull(urlPaths, "URL path array must not be null"); + for (String urlPath : urlPaths) { + registerHandler(urlPath, beanName); + } + } + + /** + * Register the specified handler for the given URL path. + * @param urlPath the URL the bean should be mapped to + * @param handler the handler instance or handler bean name String + * (a bean name will automatically be resolved into the corresponding handler bean) + * @throws BeansException if the handler couldn't be registered + * @throws IllegalStateException if there is a conflicting handler registered + */ + protected void registerHandler(String urlPath, Object handler) throws BeansException, IllegalStateException { + Assert.notNull(urlPath, "URL path must not be null"); + Assert.notNull(handler, "Handler object must not be null"); + Object resolvedHandler = handler; + + // Parse path pattern + urlPath = prependLeadingSlash(urlPath); + PathPattern pattern = getPathPatternParser().parse(urlPath); + if (this.handlerMap.containsKey(pattern)) { + Object existingHandler = this.handlerMap.get(pattern); + if (existingHandler != null && existingHandler != resolvedHandler) { + throw new IllegalStateException( + "Cannot map " + getHandlerDescription(handler) + " to [" + urlPath + "]: " + + "there is already " + getHandlerDescription(existingHandler) + " mapped."); + } + } + + // Eagerly resolve handler if referencing singleton via name. + if (!this.lazyInitHandlers && handler instanceof String) { + String handlerName = (String) handler; + if (obtainApplicationContext().isSingleton(handlerName)) { + resolvedHandler = obtainApplicationContext().getBean(handlerName); + } + } + + // Register resolved handler + this.handlerMap.put(pattern, resolvedHandler); + if (logger.isTraceEnabled()) { + logger.trace("Mapped [" + urlPath + "] onto " + getHandlerDescription(handler)); + } + } + + private String getHandlerDescription(Object handler) { + return (handler instanceof String ? "'" + handler + "'" : handler.toString()); + } + + + private static String prependLeadingSlash(String pattern) { + if (StringUtils.hasLength(pattern) && !pattern.startsWith("/")) { + return "/" + pattern; + } + else { + return pattern; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/SimpleUrlHandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/SimpleUrlHandlerMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..131c4f99e2fb2f3a7b4711506d432d0c68d439f6 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/SimpleUrlHandlerMapping.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.handler; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Properties; + +import org.springframework.beans.BeansException; +import org.springframework.util.CollectionUtils; + +/** + * Implementation of the {@link org.springframework.web.reactive.HandlerMapping} + * interface to map from URLs to request handler beans. Supports both mapping + * to bean instances and mapping to bean names; the latter is required for + * non-singleton handlers. + * + *

The "urlMap" property is suitable for populating the handler map with + * bean instances. Mappings to bean names can be set via the "mappings" + * property, in a form accepted by the {@code java.util.Properties} class, + * like as follows: + * + *

+ * /welcome.html=ticketController
+ * /show.html=ticketController
+ * 
+ * + *

The syntax is {@code PATH=HANDLER_BEAN_NAME}. If the path doesn't begin + * with a slash, one is prepended. + * + *

Supports direct matches, e.g. a registered "/test" matches "/test", and + * various Ant-style pattern matches, e.g. a registered "/t*" pattern matches + * both "/test" and "/team", "/test/*" matches all paths under "/test", + * "/test/**" matches all paths below "/test". For details, see the + * {@link org.springframework.web.util.pattern.PathPattern} javadoc. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class SimpleUrlHandlerMapping extends AbstractUrlHandlerMapping { + + private final Map urlMap = new LinkedHashMap<>(); + + + /** + * Map URL paths to handler bean names. + * This is the typical way of configuring this HandlerMapping. + *

Supports direct URL matches and Ant-style pattern matches. For syntax details, + * see the {@link org.springframework.web.util.pattern.PathPattern} javadoc. + * @param mappings properties with URLs as keys and bean names as values + * @see #setUrlMap + */ + public void setMappings(Properties mappings) { + CollectionUtils.mergePropertiesIntoMap(mappings, this.urlMap); + } + + /** + * Set a Map with URL paths as keys and handler beans (or handler bean names) + * as values. Convenient for population with bean references. + *

Supports direct URL matches and Ant-style pattern matches. For syntax details, + * see the {@link org.springframework.web.util.pattern.PathPattern} javadoc. + * @param urlMap map with URLs as keys and beans as values + * @see #setMappings + */ + public void setUrlMap(Map urlMap) { + this.urlMap.putAll(urlMap); + } + + /** + * Allow Map access to the URL path mappings, with the option to add or + * override specific entries. + *

Useful for specifying entries directly, for example via "urlMap[myKey]". + * This is particularly useful for adding or overriding entries in child + * bean definitions. + */ + public Map getUrlMap() { + return this.urlMap; + } + + + /** + * Calls the {@link #registerHandlers} method in addition to the + * superclass's initialization. + */ + @Override + public void initApplicationContext() throws BeansException { + super.initApplicationContext(); + registerHandlers(this.urlMap); + } + + /** + * Register all handlers specified in the URL map for the corresponding paths. + * @param urlMap a Map with URL paths as keys and handler beans or bean names as values + * @throws BeansException if a handler couldn't be registered + * @throws IllegalStateException if there is a conflicting handler registered + */ + protected void registerHandlers(Map urlMap) throws BeansException { + if (urlMap.isEmpty()) { + logger.trace("No patterns in " + formatMappingName()); + } + else { + for (Map.Entry entry : urlMap.entrySet()) { + String url = entry.getKey(); + Object handler = entry.getValue(); + // Prepend with slash if not already present. + if (!url.startsWith("/")) { + url = "/" + url; + } + // Remove whitespace from handler bean name. + if (handler instanceof String) { + handler = ((String) handler).trim(); + } + registerHandler(url, handler); + } + if (logger.isDebugEnabled()) { + logger.debug("Patterns " + getHandlerMap().keySet() + " in " + formatMappingName()); + } + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/WebFluxResponseStatusExceptionHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/WebFluxResponseStatusExceptionHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..4ecd685d1840a9aac8a7d22050ea7b311db96a36 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/WebFluxResponseStatusExceptionHandler.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.handler; + +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; +import org.springframework.web.bind.annotation.ResponseStatus; +import org.springframework.web.server.handler.ResponseStatusExceptionHandler; + +/** + * Common WebFlux exception handler that detects instances of + * {@link org.springframework.web.server.ResponseStatusException} + * (inherited from the base class) as well as exceptions annotated with + * {@link ResponseStatus @ResponseStatus} by determining the HTTP status + * for them and updating the status of the response accordingly. + * + *

If the response is already committed, the error remains unresolved + * and is propagated. + * + * @author Juergen Hoeller + * @author Rossen Stoyanchev + * @since 5.0.5 + */ +public class WebFluxResponseStatusExceptionHandler extends ResponseStatusExceptionHandler { + + @Override + @Nullable + protected HttpStatus determineStatus(Throwable ex) { + HttpStatus status = super.determineStatus(ex); + if (status == null) { + ResponseStatus ann = AnnotatedElementUtils.findMergedAnnotation(ex.getClass(), ResponseStatus.class); + if (ann != null) { + status = ann.code(); + } + } + return status; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..061a4179c41fb863ed1bb67d6b6557bf692885c9 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides HandlerMapping implementations including abstract base classes. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive.handler; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/package-info.java b/spring-webflux/src/main/java/org/springframework/web/reactive/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..549793e0e97b39c2f28f867af5c5587e03a4d0ef --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/package-info.java @@ -0,0 +1,17 @@ +/** + * Top-level package for the {@code spring-webflux} module that contains + * {@link org.springframework.web.reactive.DispatcherHandler}, the main entry + * point for WebFlux server endpoint processing including key contracts used to + * map requests to handlers, invoke them, and process the result. + * + *

The module provides two programming models for reactive server endpoints. + * One based on annotated {@code @Controller}'s and another based on functional + * routing and handling. The module also contains a functional, reactive + * {@code WebClient} as well as client and server, reactive WebSocket support. + */ +@NonNullApi +@NonNullFields +package org.springframework.web.reactive; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractFileNameVersionStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractFileNameVersionStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..7872195efd55d3a5db86700355fc4e8cee1a5b86 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractFileNameVersionStrategy.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.util.StringUtils; + +/** + * Abstract base class for filename suffix based {@link VersionStrategy} + * implementations, e.g. "static/myresource-version.js" + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public abstract class AbstractFileNameVersionStrategy implements VersionStrategy { + + protected final Log logger = LogFactory.getLog(getClass()); + + private static final Pattern pattern = Pattern.compile("-(\\S*)\\."); + + + @Override + public String extractVersion(String requestPath) { + Matcher matcher = pattern.matcher(requestPath); + if (matcher.find()) { + String match = matcher.group(1); + return (match.contains("-") ? match.substring(match.lastIndexOf('-') + 1) : match); + } + else { + return null; + } + } + + @Override + public String removeVersion(String requestPath, String version) { + return StringUtils.delete(requestPath, "-" + version); + } + + @Override + public String addVersion(String requestPath, String version) { + String baseFilename = StringUtils.stripFilenameExtension(requestPath); + String extension = StringUtils.getFilenameExtension(requestPath); + return (baseFilename + '-' + version + '.' + extension); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractPrefixVersionStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractPrefixVersionStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..a387e2766f68004fb74371c0e2d7a9f70db44578 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractPrefixVersionStrategy.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.util.Assert; + +/** + * Abstract base class for {@link VersionStrategy} implementations that insert + * a prefix into the URL path, e.g. "version/static/myresource.js". + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public abstract class AbstractPrefixVersionStrategy implements VersionStrategy { + + protected final Log logger = LogFactory.getLog(getClass()); + + + private final String prefix; + + + protected AbstractPrefixVersionStrategy(String version) { + Assert.hasText(version, "Version must not be empty"); + this.prefix = version; + } + + + @Override + public String extractVersion(String requestPath) { + return (requestPath.startsWith(this.prefix) ? this.prefix : null); + } + + @Override + public String removeVersion(String requestPath, String version) { + return requestPath.substring(this.prefix.length()); + } + + @Override + public String addVersion(String path, String version) { + if (path.startsWith(".")) { + return path; + } + else if (this.prefix.endsWith("/") || path.startsWith("/")) { + return this.prefix + path; + } + else { + return this.prefix + '/' + path; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..c862345a2c4e70677e69aa8e210e3457098b95d0 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AbstractResourceResolver.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ServerWebExchange; + +/** + * Base {@link ResourceResolver} providing consistent logging. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public abstract class AbstractResourceResolver implements ResourceResolver { + + protected final Log logger = LogFactory.getLog(getClass()); + + + @Override + public Mono resolveResource(@Nullable ServerWebExchange exchange, String requestPath, + List locations, ResourceResolverChain chain) { + + return resolveResourceInternal(exchange, requestPath, locations, chain); + } + + @Override + public Mono resolveUrlPath(String resourceUrlPath, List locations, + ResourceResolverChain chain) { + + return resolveUrlPathInternal(resourceUrlPath, locations, chain); + } + + + protected abstract Mono resolveResourceInternal(@Nullable ServerWebExchange exchange, + String requestPath, List locations, ResourceResolverChain chain); + + protected abstract Mono resolveUrlPathInternal(String resourceUrlPath, + List locations, ResourceResolverChain chain); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AppCacheManifestTransformer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AppCacheManifestTransformer.java new file mode 100644 index 0000000000000000000000000000000000000000..58ea4ab7fccb9015fb0e11eedeb8e6dd76f178c4 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/AppCacheManifestTransformer.java @@ -0,0 +1,249 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.Scanner; +import java.util.function.Consumer; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SynchronousSink; + +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.DigestUtils; +import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@link ResourceTransformer} HTML5 AppCache manifests. + * + *

This transformer: + *

    + *
  • modifies links to match the public URL paths that should be exposed to + * clients, using configured {@code ResourceResolver} strategies + *
  • appends a comment in the manifest, containing a Hash + * (e.g. "# Hash: 9de0f09ed7caf84e885f1f0f11c7e326"), thus changing the content + * of the manifest in order to trigger an appcache reload in the browser. + *
+ * + *

All files with an ".appcache" file extension (or the extension given + * to the constructor) will be transformed by this class. The hash is computed + * using the content of the appcache manifest so that changes in the manifest + * should invalidate the browser cache. This should also work with changes in + * referenced resources whose links are also versioned. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + * @see HTML5 offline applications spec + */ +public class AppCacheManifestTransformer extends ResourceTransformerSupport { + + private static final Collection MANIFEST_SECTION_HEADERS = + Arrays.asList("CACHE MANIFEST", "NETWORK:", "FALLBACK:", "CACHE:"); + + private static final String MANIFEST_HEADER = "CACHE MANIFEST"; + + private static final String CACHE_HEADER = "CACHE:"; + + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + private static final Log logger = LogFactory.getLog(AppCacheManifestTransformer.class); + + + private final String fileExtension; + + + /** + * Create an AppCacheResourceTransformer that transforms files with extension ".appcache". + */ + public AppCacheManifestTransformer() { + this("appcache"); + } + + /** + * Create an AppCacheResourceTransformer that transforms files with the extension + * given as a parameter. + */ + public AppCacheManifestTransformer(String fileExtension) { + this.fileExtension = fileExtension; + } + + + @Override + public Mono transform(ServerWebExchange exchange, Resource inputResource, + ResourceTransformerChain chain) { + + return chain.transform(exchange, inputResource) + .flatMap(outputResource -> { + String name = outputResource.getFilename(); + if (!this.fileExtension.equals(StringUtils.getFilenameExtension(name))) { + return Mono.just(outputResource); + } + DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); + Flux flux = DataBufferUtils + .read(outputResource, bufferFactory, StreamUtils.BUFFER_SIZE); + return DataBufferUtils.join(flux) + .flatMap(dataBuffer -> { + CharBuffer charBuffer = DEFAULT_CHARSET.decode(dataBuffer.asByteBuffer()); + DataBufferUtils.release(dataBuffer); + String content = charBuffer.toString(); + return transform(content, outputResource, chain, exchange); + }); + }); + } + + private Mono transform(String content, Resource resource, + ResourceTransformerChain chain, ServerWebExchange exchange) { + + if (!content.startsWith(MANIFEST_HEADER)) { + if (logger.isTraceEnabled()) { + logger.trace(exchange.getLogPrefix() + + "Skipping " + resource + ": Manifest does not start with 'CACHE MANIFEST'"); + } + return Mono.just(resource); + } + return Flux.generate(new LineInfoGenerator(content)) + .concatMap(info -> processLine(info, exchange, resource, chain)) + .reduce(new ByteArrayOutputStream(), (out, line) -> { + writeToByteArrayOutputStream(out, line + "\n"); + return out; + }) + .map(out -> { + String hash = DigestUtils.md5DigestAsHex(out.toByteArray()); + writeToByteArrayOutputStream(out, "\n" + "# Hash: " + hash); + return new TransformedResource(resource, out.toByteArray()); + }); + } + + private static void writeToByteArrayOutputStream(ByteArrayOutputStream out, String toWrite) { + try { + byte[] bytes = toWrite.getBytes(DEFAULT_CHARSET); + out.write(bytes); + } + catch (IOException ex) { + throw Exceptions.propagate(ex); + } + } + + private Mono processLine(LineInfo info, ServerWebExchange exchange, + Resource resource, ResourceTransformerChain chain) { + + if (!info.isLink()) { + return Mono.just(info.getLine()); + } + + String link = toAbsolutePath(info.getLine(), exchange); + return resolveUrlPath(link, exchange, resource, chain); + } + + + private static class LineInfoGenerator implements Consumer> { + + private final Scanner scanner; + + @Nullable + private LineInfo previous; + + + LineInfoGenerator(String content) { + this.scanner = new Scanner(content); + } + + + @Override + public void accept(SynchronousSink sink) { + if (this.scanner.hasNext()) { + String line = this.scanner.nextLine(); + LineInfo current = new LineInfo(line, this.previous); + sink.next(current); + this.previous = current; + } + else { + sink.complete(); + } + } + } + + + private static class LineInfo { + + private final String line; + + private final boolean cacheSection; + + private final boolean link; + + + LineInfo(String line, @Nullable LineInfo previousLine) { + this.line = line; + this.cacheSection = initCacheSectionFlag(line, previousLine); + this.link = iniLinkFlag(line, this.cacheSection); + } + + + private static boolean initCacheSectionFlag(String line, @Nullable LineInfo previousLine) { + if (MANIFEST_SECTION_HEADERS.contains(line.trim())) { + return line.trim().equals(CACHE_HEADER); + } + else if (previousLine != null) { + return previousLine.isCacheSection(); + } + throw new IllegalStateException( + "Manifest does not start with " + MANIFEST_HEADER + ": " + line); + } + + private static boolean iniLinkFlag(String line, boolean isCacheSection) { + return (isCacheSection && StringUtils.hasText(line) && !line.startsWith("#") + && !line.startsWith("//") && !hasScheme(line)); + } + + private static boolean hasScheme(String line) { + int index = line.indexOf(':'); + return (line.startsWith("//") || (index > 0 && !line.substring(0, index).contains("/"))); + } + + + public String getLine() { + return this.line; + } + + public boolean isCacheSection() { + return this.cacheSection; + } + + public boolean isLink() { + return this.link; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CachingResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CachingResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..6125120b69a461a5e86b6984547c9bc43247add8 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CachingResourceResolver.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import reactor.core.publisher.Mono; + +import org.springframework.cache.Cache; +import org.springframework.cache.CacheManager; +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@link ResourceResolver} that resolves resources from a {@link Cache} or + * otherwise delegates to the resolver chain and caches the result. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public class CachingResourceResolver extends AbstractResourceResolver { + + /** + * The prefix used for resolved resource cache keys. + */ + public static final String RESOLVED_RESOURCE_CACHE_KEY_PREFIX = "resolvedResource:"; + + /** + * The prefix used for resolved URL path cache keys. + */ + public static final String RESOLVED_URL_PATH_CACHE_KEY_PREFIX = "resolvedUrlPath:"; + + + private final Cache cache; + + private final List contentCodings = new ArrayList<>(EncodedResourceResolver.DEFAULT_CODINGS); + + + public CachingResourceResolver(Cache cache) { + Assert.notNull(cache, "Cache is required"); + this.cache = cache; + } + + public CachingResourceResolver(CacheManager cacheManager, String cacheName) { + Cache cache = cacheManager.getCache(cacheName); + if (cache == null) { + throw new IllegalArgumentException("Cache '" + cacheName + "' not found"); + } + this.cache = cache; + } + + + /** + * Return the configured {@code Cache}. + */ + public Cache getCache() { + return this.cache; + } + + /** + * Configure the supported content codings from the + * {@literal "Accept-Encoding"} header for which to cache resource variations. + *

The codings configured here are generally expected to match those + * configured on {@link EncodedResourceResolver#setContentCodings(List)}. + *

By default this property is set to {@literal ["br", "gzip"]} based on + * the value of {@link EncodedResourceResolver#DEFAULT_CODINGS}. + * @param codings one or more supported content codings + * @since 5.1 + */ + public void setContentCodings(List codings) { + Assert.notEmpty(codings, "At least one content coding expected"); + this.contentCodings.clear(); + this.contentCodings.addAll(codings); + } + + /** + * Return a read-only list with the supported content codings. + * @since 5.1 + */ + public List getContentCodings() { + return Collections.unmodifiableList(this.contentCodings); + } + + + @Override + protected Mono resolveResourceInternal(@Nullable ServerWebExchange exchange, + String requestPath, List locations, ResourceResolverChain chain) { + + String key = computeKey(exchange, requestPath); + Resource cachedResource = this.cache.get(key, Resource.class); + + if (cachedResource != null) { + String logPrefix = exchange != null ? exchange.getLogPrefix() : ""; + logger.trace(logPrefix + "Resource resolved from cache"); + return Mono.just(cachedResource); + } + + return chain.resolveResource(exchange, requestPath, locations) + .doOnNext(resource -> this.cache.put(key, resource)); + } + + protected String computeKey(@Nullable ServerWebExchange exchange, String requestPath) { + StringBuilder key = new StringBuilder(RESOLVED_RESOURCE_CACHE_KEY_PREFIX); + key.append(requestPath); + if (exchange != null) { + String codingKey = getContentCodingKey(exchange); + if (StringUtils.hasText(codingKey)) { + key.append("+encoding=").append(codingKey); + } + } + return key.toString(); + } + + @Nullable + private String getContentCodingKey(ServerWebExchange exchange) { + String header = exchange.getRequest().getHeaders().getFirst("Accept-Encoding"); + if (!StringUtils.hasText(header)) { + return null; + } + return Arrays.stream(StringUtils.tokenizeToStringArray(header, ",")) + .map(token -> { + int index = token.indexOf(';'); + return (index >= 0 ? token.substring(0, index) : token).trim().toLowerCase(); + }) + .filter(this.contentCodings::contains) + .sorted() + .collect(Collectors.joining(",")); + } + + @Override + protected Mono resolveUrlPathInternal(String resourceUrlPath, + List locations, ResourceResolverChain chain) { + + String key = RESOLVED_URL_PATH_CACHE_KEY_PREFIX + resourceUrlPath; + String cachedUrlPath = this.cache.get(key, String.class); + + if (cachedUrlPath != null) { + logger.trace("Path resolved from cache"); + return Mono.just(cachedUrlPath); + } + + return chain.resolveUrlPath(resourceUrlPath, locations) + .doOnNext(resolvedPath -> this.cache.put(key, resolvedPath)); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CachingResourceTransformer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CachingResourceTransformer.java new file mode 100644 index 0000000000000000000000000000000000000000..29492a0e23c2216c27a06e0852cfa019e68f13de --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CachingResourceTransformer.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.cache.Cache; +import org.springframework.cache.CacheManager; +import org.springframework.core.io.Resource; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@link ResourceTransformer} that checks a {@link Cache} to see if a + * previously transformed resource exists in the cache and returns it if found, + * or otherwise delegates to the resolver chain and caches the result. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class CachingResourceTransformer implements ResourceTransformer { + + private static final Log logger = LogFactory.getLog(CachingResourceTransformer.class); + + private final Cache cache; + + + public CachingResourceTransformer(Cache cache) { + Assert.notNull(cache, "Cache is required"); + this.cache = cache; + } + + public CachingResourceTransformer(CacheManager cacheManager, String cacheName) { + Cache cache = cacheManager.getCache(cacheName); + if (cache == null) { + throw new IllegalArgumentException("Cache '" + cacheName + "' not found"); + } + this.cache = cache; + } + + + /** + * Return the configured {@code Cache}. + */ + public Cache getCache() { + return this.cache; + } + + + @Override + public Mono transform(ServerWebExchange exchange, Resource resource, + ResourceTransformerChain transformerChain) { + + Resource cachedResource = this.cache.get(resource, Resource.class); + if (cachedResource != null) { + if (logger.isTraceEnabled()) { + logger.trace(exchange.getLogPrefix() + "Resource resolved from cache"); + } + return Mono.just(cachedResource); + } + + return transformerChain.transform(exchange, resource) + .doOnNext(transformed -> this.cache.put(resource, transformed)); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ContentVersionStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ContentVersionStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..823cc19fdafef05febefae9c4cff0952a579d424 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ContentVersionStrategy.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.util.DigestUtils; +import org.springframework.util.StreamUtils; + +/** + * A {@code VersionStrategy} that calculates an Hex MD5 hashes from the content + * of the resource and appends it to the file name, e.g. + * {@code "styles/main-e36d2e05253c6c7085a91522ce43a0b4.css"}. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + * @see VersionResourceResolver + */ +public class ContentVersionStrategy extends AbstractFileNameVersionStrategy { + + private static final DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(); + + + @Override + public Mono getResourceVersion(Resource resource) { + Flux flux = + DataBufferUtils.read(resource, dataBufferFactory, StreamUtils.BUFFER_SIZE); + return DataBufferUtils.join(flux) + .map(buffer -> { + byte[] result = new byte[buffer.readableByteCount()]; + buffer.read(result); + DataBufferUtils.release(buffer); + return DigestUtils.md5DigestAsHex(result); + }); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CssLinkResourceTransformer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CssLinkResourceTransformer.java new file mode 100644 index 0000000000000000000000000000000000000000..715943b1247d07b430a842213a0796af3b9ea32f --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/CssLinkResourceTransformer.java @@ -0,0 +1,304 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.StringWriter; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@link ResourceTransformer} implementation that modifies links in a CSS + * file to match the public URL paths that should be exposed to clients (e.g. + * with an MD5 content-based hash inserted in the URL). + * + *

The implementation looks for links in CSS {@code @import} statements and + * also inside CSS {@code url()} functions. All links are then passed through the + * {@link ResourceResolverChain} and resolved relative to the location of the + * containing CSS file. If successfully resolved, the link is modified, otherwise + * the original link is preserved. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class CssLinkResourceTransformer extends ResourceTransformerSupport { + + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + private static final Log logger = LogFactory.getLog(CssLinkResourceTransformer.class); + + private final List linkParsers = new ArrayList<>(2); + + + public CssLinkResourceTransformer() { + this.linkParsers.add(new ImportLinkParser()); + this.linkParsers.add(new UrlFunctionLinkParser()); + } + + + @SuppressWarnings("deprecation") + @Override + public Mono transform(ServerWebExchange exchange, Resource inputResource, + ResourceTransformerChain transformerChain) { + + return transformerChain.transform(exchange, inputResource) + .flatMap(outputResource -> { + String filename = outputResource.getFilename(); + if (!"css".equals(StringUtils.getFilenameExtension(filename)) || + inputResource instanceof EncodedResourceResolver.EncodedResource || + inputResource instanceof GzipResourceResolver.GzippedResource) { + return Mono.just(outputResource); + } + + DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); + Flux flux = DataBufferUtils + .read(outputResource, bufferFactory, StreamUtils.BUFFER_SIZE); + return DataBufferUtils.join(flux) + .flatMap(dataBuffer -> { + CharBuffer charBuffer = DEFAULT_CHARSET.decode(dataBuffer.asByteBuffer()); + DataBufferUtils.release(dataBuffer); + String cssContent = charBuffer.toString(); + return transformContent(cssContent, outputResource, transformerChain, exchange); + }); + }); + } + + private Mono transformContent(String cssContent, Resource resource, + ResourceTransformerChain chain, ServerWebExchange exchange) { + + List contentChunkInfos = parseContent(cssContent); + if (contentChunkInfos.isEmpty()) { + return Mono.just(resource); + } + + return Flux.fromIterable(contentChunkInfos) + .concatMap(contentChunkInfo -> { + String contentChunk = contentChunkInfo.getContent(cssContent); + if (contentChunkInfo.isLink() && !hasScheme(contentChunk)) { + String link = toAbsolutePath(contentChunk, exchange); + return resolveUrlPath(link, exchange, resource, chain).defaultIfEmpty(contentChunk); + } + else { + return Mono.just(contentChunk); + } + }) + .reduce(new StringWriter(), (writer, chunk) -> { + writer.write(chunk); + return writer; + }) + .map(writer -> { + byte[] newContent = writer.toString().getBytes(DEFAULT_CHARSET); + return new TransformedResource(resource, newContent); + }); + } + + private List parseContent(String cssContent) { + SortedSet links = new TreeSet<>(); + this.linkParsers.forEach(parser -> parser.parse(cssContent, links)); + if (links.isEmpty()) { + return Collections.emptyList(); + } + int index = 0; + List result = new ArrayList<>(); + for (ContentChunkInfo link : links) { + result.add(new ContentChunkInfo(index, link.getStart(), false)); + result.add(link); + index = link.getEnd(); + } + if (index < cssContent.length()) { + result.add(new ContentChunkInfo(index, cssContent.length(), false)); + } + return result; + } + + private boolean hasScheme(String link) { + int schemeIndex = link.indexOf(':'); + return (schemeIndex > 0 && !link.substring(0, schemeIndex).contains("/")) || link.indexOf("//") == 0; + } + + + /** + * Extract content chunks that represent links. + */ + @FunctionalInterface + protected interface LinkParser { + + void parse(String cssContent, SortedSet result); + + } + + + /** + * Abstract base class for {@link LinkParser} implementations. + */ + protected abstract static class AbstractLinkParser implements LinkParser { + + /** Return the keyword to use to search for links, e.g. "@import", "url(" */ + protected abstract String getKeyword(); + + @Override + public void parse(String content, SortedSet result) { + int position = 0; + while (true) { + position = content.indexOf(getKeyword(), position); + if (position == -1) { + return; + } + position += getKeyword().length(); + while (Character.isWhitespace(content.charAt(position))) { + position++; + } + if (content.charAt(position) == '\'') { + position = extractLink(position, '\'', content, result); + } + else if (content.charAt(position) == '"') { + position = extractLink(position, '"', content, result); + } + else { + position = extractUnquotedLink(position, content, result); + + } + } + } + + protected int extractLink(int index, char endChar, String content, Set result) { + int start = index + 1; + int end = content.indexOf(endChar, start); + result.add(new ContentChunkInfo(start, end, true)); + return end + 1; + } + + /** + * Invoked after a keyword match, after whitespaces removed, and when + * the next char is neither a single nor double quote. + */ + protected abstract int extractUnquotedLink(int position, String content, + Set linksToAdd); + + } + + + private static class ImportLinkParser extends AbstractLinkParser { + + @Override + protected String getKeyword() { + return "@import"; + } + + @Override + protected int extractUnquotedLink(int position, String content, Set result) { + if (content.substring(position, position + 4).equals("url(")) { + // Ignore, UrlFunctionContentParser will take care + } + else if (logger.isTraceEnabled()) { + logger.trace("Unexpected syntax for @import link at index " + position); + } + return position; + } + } + + + private static class UrlFunctionLinkParser extends AbstractLinkParser { + + @Override + protected String getKeyword() { + return "url("; + } + + @Override + protected int extractUnquotedLink(int position, String content, Set result) { + // A url() function without unquoted + return extractLink(position - 1, ')', content, result); + } + } + + + private static class ContentChunkInfo implements Comparable { + + private final int start; + + private final int end; + + private final boolean isLink; + + + ContentChunkInfo(int start, int end, boolean isLink) { + this.start = start; + this.end = end; + this.isLink = isLink; + } + + + public int getStart() { + return this.start; + } + + public int getEnd() { + return this.end; + } + + public boolean isLink() { + return this.isLink; + } + + public String getContent(String fullContent) { + return fullContent.substring(this.start, this.end); + } + + @Override + public int compareTo(ContentChunkInfo other) { + return Integer.compare(this.start, other.start); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ContentChunkInfo)) { + return false; + } + ContentChunkInfo otherCci = (ContentChunkInfo) other; + return (this.start == otherCci.start && this.end == otherCci.end); + } + + @Override + public int hashCode() { + return this.start * 31 + this.end; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/DefaultResourceResolverChain.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/DefaultResourceResolverChain.java new file mode 100644 index 0000000000000000000000000000000000000000..be05220f9341988359bb30d74b4cdcc0e98e687b --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/DefaultResourceResolverChain.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * Default immutable implementation of {@link ResourceResolverChain}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +class DefaultResourceResolverChain implements ResourceResolverChain { + + @Nullable + private final ResourceResolver resolver; + + @Nullable + private final ResourceResolverChain nextChain; + + + public DefaultResourceResolverChain(@Nullable List resolvers) { + resolvers = (resolvers != null ? resolvers : Collections.emptyList()); + DefaultResourceResolverChain chain = initChain(new ArrayList<>(resolvers)); + this.resolver = chain.resolver; + this.nextChain = chain.nextChain; + } + + private static DefaultResourceResolverChain initChain(ArrayList resolvers) { + DefaultResourceResolverChain chain = new DefaultResourceResolverChain(null, null); + ListIterator it = resolvers.listIterator(resolvers.size()); + while (it.hasPrevious()) { + chain = new DefaultResourceResolverChain(it.previous(), chain); + } + return chain; + } + + private DefaultResourceResolverChain(@Nullable ResourceResolver resolver, @Nullable ResourceResolverChain chain) { + Assert.isTrue((resolver == null && chain == null) || (resolver != null && chain != null), + "Both resolver and resolver chain must be null, or neither is"); + this.resolver = resolver; + this.nextChain = chain; + } + + + @Override + public Mono resolveResource(@Nullable ServerWebExchange exchange, String requestPath, + List locations) { + + return (this.resolver != null && this.nextChain != null ? + this.resolver.resolveResource(exchange, requestPath, locations, this.nextChain) : + Mono.empty()); + } + + @Override + public Mono resolveUrlPath(String resourcePath, List locations) { + return (this.resolver != null && this.nextChain != null ? + this.resolver.resolveUrlPath(resourcePath, locations, this.nextChain) : + Mono.empty()); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/DefaultResourceTransformerChain.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/DefaultResourceTransformerChain.java new file mode 100644 index 0000000000000000000000000000000000000000..dc79e5d6aad8d48e1b44cf54b782ecaec137ce3a --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/DefaultResourceTransformerChain.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * Default immutable implementation of {@link ResourceTransformerChain}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +class DefaultResourceTransformerChain implements ResourceTransformerChain { + + private final ResourceResolverChain resolverChain; + + @Nullable + private final ResourceTransformer transformer; + + @Nullable + private final ResourceTransformerChain nextChain; + + + public DefaultResourceTransformerChain( + ResourceResolverChain resolverChain, @Nullable List transformers) { + + Assert.notNull(resolverChain, "ResourceResolverChain is required"); + this.resolverChain = resolverChain; + transformers = (transformers != null ? transformers : Collections.emptyList()); + DefaultResourceTransformerChain chain = initTransformerChain(resolverChain, new ArrayList<>(transformers)); + this.transformer = chain.transformer; + this.nextChain = chain.nextChain; + } + + private DefaultResourceTransformerChain initTransformerChain(ResourceResolverChain resolverChain, + ArrayList transformers) { + + DefaultResourceTransformerChain chain = new DefaultResourceTransformerChain(resolverChain, null, null); + ListIterator it = transformers.listIterator(transformers.size()); + while (it.hasPrevious()) { + chain = new DefaultResourceTransformerChain(resolverChain, it.previous(), chain); + } + return chain; + } + + public DefaultResourceTransformerChain(ResourceResolverChain resolverChain, + @Nullable ResourceTransformer transformer, @Nullable ResourceTransformerChain chain) { + + Assert.isTrue((transformer == null && chain == null) || (transformer != null && chain != null), + "Both transformer and transformer chain must be null, or neither is"); + this.resolverChain = resolverChain; + this.transformer = transformer; + this.nextChain = chain; + } + + + @Override + public ResourceResolverChain getResolverChain() { + return this.resolverChain; + } + + @Override + public Mono transform(ServerWebExchange exchange, Resource resource) { + return (this.transformer != null && this.nextChain != null ? + this.transformer.transform(exchange, resource, this.nextChain) : + Mono.just(resource)); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/EncodedResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/EncodedResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..e56af089603a0c80c460173c6a8bb4e9e843ceab --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/EncodedResourceResolver.java @@ -0,0 +1,287 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URL; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.AbstractResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * Resolver that delegates to the chain, and if a resource is found, it then + * attempts to find an encoded (e.g. gzip, brotli) variant that is acceptable + * based on the "Accept-Encoding" request header. + * + *

The list of supported {@link #setContentCodings(List) contentCodings} can + * be configured, in order of preference, and each coding must be associated + * with {@link #setExtensions(Map) extensions}. + * + *

Note that this resolver must be ordered ahead of a + * {@link VersionResourceResolver} with a content-based, version strategy to + * ensure the version calculation is not impacted by the encoding. + * + * @author Rossen Stoyanchev + * @since 5.1 + */ +public class EncodedResourceResolver extends AbstractResourceResolver { + + /** + * The default content codings. + */ + public static final List DEFAULT_CODINGS = Arrays.asList("br", "gzip"); + + + private final List contentCodings = new ArrayList<>(DEFAULT_CODINGS); + + private final Map extensions = new LinkedHashMap<>(); + + + public EncodedResourceResolver() { + this.extensions.put("gzip", ".gz"); + this.extensions.put("br", ".br"); + } + + + /** + * Configure the supported content codings in order of preference. The first + * coding that is present in the {@literal "Accept-Encoding"} header for a + * given request, and that has a file present with the associated extension, + * is used. + *

Note: Each coding must be associated with a file + * extension via {@link #registerExtension} or {@link #setExtensions}. Also + * customizations to the list of codings here should be matched by + * customizations to the same list in {@link CachingResourceResolver} to + * ensure encoded variants of a resource are cached under separate keys. + *

By default this property is set to {@literal ["br", "gzip"]}. + * @param codings one or more supported content codings + */ + public void setContentCodings(List codings) { + Assert.notEmpty(codings, "At least one content coding expected"); + this.contentCodings.clear(); + this.contentCodings.addAll(codings); + } + + /** + * Return a read-only list with the supported content codings. + */ + public List getContentCodings() { + return Collections.unmodifiableList(this.contentCodings); + } + + /** + * Configure mappings from content codings to file extensions. A dot "." + * will be prepended in front of the extension value if not present. + *

By default this is configured with {@literal ["br" -> ".br"]} and + * {@literal ["gzip" -> ".gz"]}. + * @param extensions the extensions to use. + * @see #registerExtension(String, String) + */ + public void setExtensions(Map extensions) { + extensions.forEach(this::registerExtension); + } + + /** + * Return a read-only map with coding-to-extension mappings. + */ + public Map getExtensions() { + return Collections.unmodifiableMap(this.extensions); + } + + /** + * Java config friendly alternative to {@link #setExtensions(Map)}. + * @param coding the content coding + * @param extension the associated file extension + */ + public void registerExtension(String coding, String extension) { + this.extensions.put(coding, (extension.startsWith(".") ? extension : "." + extension)); + } + + + @Override + protected Mono resolveResourceInternal(@Nullable ServerWebExchange exchange, + String requestPath, List locations, ResourceResolverChain chain) { + + return chain.resolveResource(exchange, requestPath, locations).map(resource -> { + + if (exchange == null) { + return resource; + } + + String acceptEncoding = getAcceptEncoding(exchange); + if (acceptEncoding == null) { + return resource; + } + + for (String coding : this.contentCodings) { + if (acceptEncoding.contains(coding)) { + try { + String extension = getExtension(coding); + Resource encoded = new EncodedResource(resource, coding, extension); + if (encoded.exists()) { + return encoded; + } + } + catch (IOException ex) { + logger.trace(exchange.getLogPrefix() + + "No " + coding + " resource for [" + resource.getFilename() + "]", ex); + } + } + } + + return resource; + }); + } + + @Nullable + private String getAcceptEncoding(ServerWebExchange exchange) { + ServerHttpRequest request = exchange.getRequest(); + String header = request.getHeaders().getFirst(HttpHeaders.ACCEPT_ENCODING); + return (header != null ? header.toLowerCase() : null); + } + + private String getExtension(String coding) { + String extension = this.extensions.get(coding); + if (extension == null) { + throw new IllegalStateException("No file extension associated with content coding " + coding); + } + return extension; + } + + @Override + protected Mono resolveUrlPathInternal(String resourceUrlPath, + List locations, ResourceResolverChain chain) { + + return chain.resolveUrlPath(resourceUrlPath, locations); + } + + + /** + * An encoded {@link HttpResource}. + */ + static final class EncodedResource extends AbstractResource implements HttpResource { + + private final Resource original; + + private final String coding; + + private final Resource encoded; + + EncodedResource(Resource original, String coding, String extension) throws IOException { + this.original = original; + this.coding = coding; + this.encoded = original.createRelative(original.getFilename() + extension); + } + + @Override + public InputStream getInputStream() throws IOException { + return this.encoded.getInputStream(); + } + + @Override + public boolean exists() { + return this.encoded.exists(); + } + + @Override + public boolean isReadable() { + return this.encoded.isReadable(); + } + + @Override + public boolean isOpen() { + return this.encoded.isOpen(); + } + + @Override + public boolean isFile() { + return this.encoded.isFile(); + } + + @Override + public URL getURL() throws IOException { + return this.encoded.getURL(); + } + + @Override + public URI getURI() throws IOException { + return this.encoded.getURI(); + } + + @Override + public File getFile() throws IOException { + return this.encoded.getFile(); + } + + @Override + public long contentLength() throws IOException { + return this.encoded.contentLength(); + } + + @Override + public long lastModified() throws IOException { + return this.encoded.lastModified(); + } + + @Override + public Resource createRelative(String relativePath) throws IOException { + return this.encoded.createRelative(relativePath); + } + + @Override + @Nullable + public String getFilename() { + return this.original.getFilename(); + } + + @Override + public String getDescription() { + return this.encoded.getDescription(); + } + + @Override + public HttpHeaders getResponseHeaders() { + HttpHeaders headers; + if (this.original instanceof HttpResource) { + headers = ((HttpResource) this.original).getResponseHeaders(); + } + else { + headers = new HttpHeaders(); + } + headers.add(HttpHeaders.CONTENT_ENCODING, this.coding); + headers.add(HttpHeaders.VARY, HttpHeaders.ACCEPT_ENCODING); + return headers; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/FixedVersionStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/FixedVersionStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..ad3daecbd1f9d7154051449da81266f51f16045c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/FixedVersionStrategy.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; + +/** + * A {@code VersionStrategy} that relies on a fixed version applied as a request + * path prefix, e.g. reduced SHA, version name, release date, etc. + * + *

This is useful for example when {@link ContentVersionStrategy} cannot be + * used such as when using JavaScript module loaders which are in charge of + * loading the JavaScript resources and need to know their relative paths. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + * @see VersionResourceResolver + */ +public class FixedVersionStrategy extends AbstractPrefixVersionStrategy { + + private final Mono versionMono; + + + /** + * Create a new FixedVersionStrategy with the given version string. + * @param version the fixed version string to use + */ + public FixedVersionStrategy(String version) { + super(version); + this.versionMono = Mono.just(version); + } + + + @Override + public Mono getResourceVersion(Resource resource) { + return this.versionMono; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/GzipResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/GzipResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..74a7a628bb9b238d0550b6b394c45cc55f3552f6 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/GzipResourceResolver.java @@ -0,0 +1,173 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URL; +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.AbstractResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@code ResourceResolver} that delegates to the chain to locate a resource + * and then attempts to find a variation with the ".gz" extension. + * + *

The resolver gets involved only if the "Accept-Encoding" request header + * contains the value "gzip" indicating the client accepts gzipped responses. + * + * @author Rossen Stoyanchev + * @since 5.0 + * @deprecated as of 5.1, in favor of using {@link EncodedResourceResolver} + */ +@Deprecated +public class GzipResourceResolver extends AbstractResourceResolver { + + @Override + protected Mono resolveResourceInternal(@Nullable ServerWebExchange exchange, + String requestPath, List locations, ResourceResolverChain chain) { + + return chain.resolveResource(exchange, requestPath, locations) + .map(resource -> { + if (exchange == null || isGzipAccepted(exchange)) { + try { + Resource gzipped = new GzippedResource(resource); + if (gzipped.exists()) { + resource = gzipped; + } + } + catch (IOException ex) { + String logPrefix = exchange != null ? exchange.getLogPrefix() : ""; + logger.trace(logPrefix + "No gzip resource for [" + resource.getFilename() + "]", ex); + } + } + return resource; + }); + } + + private boolean isGzipAccepted(ServerWebExchange exchange) { + String value = exchange.getRequest().getHeaders().getFirst("Accept-Encoding"); + return (value != null && value.toLowerCase().contains("gzip")); + } + + @Override + protected Mono resolveUrlPathInternal(String resourceUrlPath, + List locations, ResourceResolverChain chain) { + + return chain.resolveUrlPath(resourceUrlPath, locations); + } + + + /** + * A gzipped {@link HttpResource}. + */ + static final class GzippedResource extends AbstractResource implements HttpResource { + + private final Resource original; + + private final Resource gzipped; + + public GzippedResource(Resource original) throws IOException { + this.original = original; + this.gzipped = original.createRelative(original.getFilename() + ".gz"); + } + + @Override + public InputStream getInputStream() throws IOException { + return this.gzipped.getInputStream(); + } + + @Override + public boolean exists() { + return this.gzipped.exists(); + } + + @Override + public boolean isReadable() { + return this.gzipped.isReadable(); + } + + @Override + public boolean isOpen() { + return this.gzipped.isOpen(); + } + + @Override + public boolean isFile() { + return this.gzipped.isFile(); + } + + @Override + public URL getURL() throws IOException { + return this.gzipped.getURL(); + } + + @Override + public URI getURI() throws IOException { + return this.gzipped.getURI(); + } + + @Override + public File getFile() throws IOException { + return this.gzipped.getFile(); + } + + @Override + public long contentLength() throws IOException { + return this.gzipped.contentLength(); + } + + @Override + public long lastModified() throws IOException { + return this.gzipped.lastModified(); + } + + @Override + public Resource createRelative(String relativePath) throws IOException { + return this.gzipped.createRelative(relativePath); + } + + @Override + @Nullable + public String getFilename() { + return this.original.getFilename(); + } + + @Override + public String getDescription() { + return this.gzipped.getDescription(); + } + + @Override + public HttpHeaders getResponseHeaders() { + HttpHeaders headers = (this.original instanceof HttpResource ? + ((HttpResource) this.original).getResponseHeaders() : new HttpHeaders()); + headers.add(HttpHeaders.CONTENT_ENCODING, "gzip"); + headers.add(HttpHeaders.VARY, HttpHeaders.ACCEPT_ENCODING); + return headers; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/HttpResource.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/HttpResource.java new file mode 100644 index 0000000000000000000000000000000000000000..bd2570e057a910047b9c4a745e093a82bf5e0857 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/HttpResource.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import org.springframework.core.io.Resource; +import org.springframework.http.HttpHeaders; + +/** + * Extended interface for a {@link Resource} to be written to an + * HTTP response. + * + * @author Brian Clozel + * @since 5.0 + */ +public interface HttpResource extends Resource { + + /** + * The HTTP headers to be contributed to the HTTP response + * that serves the current resource. + * @return the HTTP response headers + */ + HttpHeaders getResponseHeaders(); +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/PathResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/PathResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..beafe65d0d238b31c20c6e41d23e30ec69a9aba2 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/PathResourceResolver.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.UrlResource; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.UriUtils; + +/** + * A simple {@code ResourceResolver} that tries to find a resource under the given + * locations matching to the request path. + * + *

This resolver does not delegate to the {@code ResourceResolverChain} and is + * expected to be configured at the end in a chain of resolvers. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class PathResourceResolver extends AbstractResourceResolver { + + @Nullable + private Resource[] allowedLocations; + + + /** + * By default when a Resource is found, the path of the resolved resource is + * compared to ensure it's under the input location where it was found. + * However sometimes that may not be the case, e.g. when + * {@link CssLinkResourceTransformer} + * resolves public URLs of links it contains, the CSS file is the location + * and the resources being resolved are css files, images, fonts and others + * located in adjacent or parent directories. + *

This property allows configuring a complete list of locations under + * which resources must be so that if a resource is not under the location + * relative to which it was found, this list may be checked as well. + *

By default {@link ResourceWebHandler} initializes this property + * to match its list of locations. + * @param locations the list of allowed locations + */ + public void setAllowedLocations(@Nullable Resource... locations) { + this.allowedLocations = locations; + } + + @Nullable + public Resource[] getAllowedLocations() { + return this.allowedLocations; + } + + + @Override + protected Mono resolveResourceInternal(@Nullable ServerWebExchange exchange, + String requestPath, List locations, ResourceResolverChain chain) { + + return getResource(requestPath, locations); + } + + @Override + protected Mono resolveUrlPathInternal(String path, List locations, + ResourceResolverChain chain) { + + if (StringUtils.hasText(path)) { + return getResource(path, locations).map(resource -> path); + } + else { + return Mono.empty(); + } + } + + private Mono getResource(String resourcePath, List locations) { + return Flux.fromIterable(locations) + .concatMap(location -> getResource(resourcePath, location)) + .next(); + } + + /** + * Find the resource under the given location. + *

The default implementation checks if there is a readable + * {@code Resource} for the given path relative to the location. + * @param resourcePath the path to the resource + * @param location the location to check + * @return the resource, or empty {@link Mono} if none found + */ + protected Mono getResource(String resourcePath, Resource location) { + try { + if (location instanceof ClassPathResource) { + resourcePath = UriUtils.decode(resourcePath, StandardCharsets.UTF_8); + } + Resource resource = location.createRelative(resourcePath); + if (resource.isReadable()) { + if (checkResource(resource, location)) { + return Mono.just(resource); + } + else if (logger.isWarnEnabled()) { + Resource[] allowedLocations = getAllowedLocations(); + logger.warn("Resource path \"" + resourcePath + "\" was successfully resolved " + + "but resource \"" + resource.getURL() + "\" is neither under the " + + "current location \"" + location.getURL() + "\" nor under any of the " + + "allowed locations " + (allowedLocations != null ? Arrays.asList(allowedLocations) : "[]")); + } + } + return Mono.empty(); + } + catch (IOException ex) { + if (logger.isDebugEnabled()) { + String error = "Skip location [" + location + "] due to error"; + if (logger.isTraceEnabled()) { + logger.trace(error, ex); + } + else { + logger.debug(error + ": " + ex.getMessage()); + } + } + return Mono.error(ex); + } + } + + /** + * Perform additional checks on a resolved resource beyond checking whether the + * resources exists and is readable. The default implementation also verifies + * the resource is either under the location relative to which it was found or + * is under one of the {@link #setAllowedLocations allowed locations}. + * @param resource the resource to check + * @param location the location relative to which the resource was found + * @return "true" if resource is in a valid location, "false" otherwise. + */ + protected boolean checkResource(Resource resource, Resource location) throws IOException { + if (isResourceUnderLocation(resource, location)) { + return true; + } + if (getAllowedLocations() != null) { + for (Resource current : getAllowedLocations()) { + if (isResourceUnderLocation(resource, current)) { + return true; + } + } + } + return false; + } + + private boolean isResourceUnderLocation(Resource resource, Resource location) throws IOException { + if (resource.getClass() != location.getClass()) { + return false; + } + + String resourcePath; + String locationPath; + + if (resource instanceof UrlResource) { + resourcePath = resource.getURL().toExternalForm(); + locationPath = StringUtils.cleanPath(location.getURL().toString()); + } + else if (resource instanceof ClassPathResource) { + resourcePath = ((ClassPathResource) resource).getPath(); + locationPath = StringUtils.cleanPath(((ClassPathResource) location).getPath()); + } + else { + resourcePath = resource.getURL().getPath(); + locationPath = StringUtils.cleanPath(location.getURL().getPath()); + } + + if (locationPath.equals(resourcePath)) { + return true; + } + locationPath = (locationPath.endsWith("/") || locationPath.isEmpty() ? locationPath : locationPath + "/"); + return (resourcePath.startsWith(locationPath) && !isInvalidEncodedPath(resourcePath)); + } + + private boolean isInvalidEncodedPath(String resourcePath) { + if (resourcePath.contains("%")) { + // Use URLDecoder (vs UriUtils) to preserve potentially decoded UTF-8 chars... + try { + String decodedPath = URLDecoder.decode(resourcePath, "UTF-8"); + if (decodedPath.contains("../") || decodedPath.contains("..\\")) { + logger.warn("Resolved resource path contains encoded \"../\" or \"..\\\": " + resourcePath); + return true; + } + } + catch (IllegalArgumentException ex) { + // May not be possible to decode... + } + catch (UnsupportedEncodingException ex) { + // Should never happen... + } + } + return false; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..610e052a78244e121868515e96f4a30be3e87acb --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceResolver.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.http.server.RequestPath; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ServerWebExchange; + +/** + * A strategy for resolving a request to a server-side resource. + * + *

Provides mechanisms for resolving an incoming request to an actual + * {@link Resource} and for obtaining the + * public URL path that clients should use when requesting the resource. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface ResourceResolver { + + /** + * Resolve the supplied request and request path to a {@link Resource} that + * exists under one of the given resource locations. + * @param exchange the current exchange + * @param requestPath the portion of the request path to use. This is + * expected to be the encoded path, i.e. {@link RequestPath#value()}. + * @param locations the locations to search in when looking up resources + * @param chain the chain of remaining resolvers to delegate to + * @return the resolved resource or an empty {@code Mono} if unresolved + */ + Mono resolveResource(@Nullable ServerWebExchange exchange, String requestPath, + List locations, ResourceResolverChain chain); + + /** + * Resolve the externally facing public URL path for clients to use + * to access the resource that is located at the given internal + * resource path. + *

This is useful when rendering URL links to clients. + * @param resourcePath the "internal" resource path to resolve a path for + * public use. This is expected to be the encoded path. + * @param locations the locations to search in when looking up resources + * @param chain the chain of resolvers to delegate to + * @return the resolved public URL path or an empty {@code Mono} if unresolved + */ + Mono resolveUrlPath(String resourcePath, List locations, + ResourceResolverChain chain); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceResolverChain.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceResolverChain.java new file mode 100644 index 0000000000000000000000000000000000000000..53f3fe4991227fd5f9013312193adc51d49b2a74 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceResolverChain.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ServerWebExchange; + +/** + * A contract for invoking a chain of {@link ResourceResolver ResourceResolvers} where each resolver + * is given a reference to the chain allowing it to delegate when necessary. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface ResourceResolverChain { + + /** + * Resolve the supplied request and request path to a {@link Resource} that + * exists under one of the given resource locations. + * @param exchange the current exchange + * @param requestPath the portion of the request path to use + * @param locations the locations to search in when looking up resources + * @return the resolved resource; or an empty {@code Mono} if unresolved + */ + Mono resolveResource(@Nullable ServerWebExchange exchange, String requestPath, + List locations); + + /** + * Resolve the externally facing public URL path for clients to use + * to access the resource that is located at the given internal + * resource path. + *

This is useful when rendering URL links to clients. + * @param resourcePath the internal resource path + * @param locations the locations to search in when looking up resources + * @return the resolved public URL path; or an empty {@code Mono} if unresolved + */ + Mono resolveUrlPath(String resourcePath, List locations); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformer.java new file mode 100644 index 0000000000000000000000000000000000000000..ccb9ac6d0f423028d52d774374c8cf6288f3a24f --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformer.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.web.server.ServerWebExchange; + +/** + * An abstraction for transforming the content of a resource. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +@FunctionalInterface +public interface ResourceTransformer { + + /** + * Transform the given resource. + * @param exchange the current exchange + * @param resource the resource to transform + * @param transformerChain the chain of remaining transformers to delegate to + * @return the transformed resource (never empty) + */ + Mono transform(ServerWebExchange exchange, Resource resource, + ResourceTransformerChain transformerChain); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformerChain.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformerChain.java new file mode 100644 index 0000000000000000000000000000000000000000..c40c5e6bbc520387a09159f1115298a2fa02b59b --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformerChain.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.web.server.ServerWebExchange; + +/** + * A contract for invoking a chain of {@link ResourceTransformer ResourceTransformers} where each resolver + * is given a reference to the chain allowing it to delegate when necessary. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface ResourceTransformerChain { + + /** + * Return the {@code ResourceResolverChain} that was used to resolve the + * {@code Resource} being transformed. This may be needed for resolving + * related resources, e.g. links to other resources. + */ + ResourceResolverChain getResolverChain(); + + /** + * Transform the given resource. + * @param exchange the current exchange + * @param resource the candidate resource to transform + * @return the transformed or the same resource, never empty + */ + Mono transform(ServerWebExchange exchange, Resource resource); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformerSupport.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformerSupport.java new file mode 100644 index 0000000000000000000000000000000000000000..628e12fe90f7d9f6d1f14b184e1b8e2df6da3bde --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceTransformerSupport.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.Collections; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * A base class for a {@code ResourceTransformer} with an optional helper method + * for resolving public links within a transformed resource. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public abstract class ResourceTransformerSupport implements ResourceTransformer { + + @Nullable + private ResourceUrlProvider resourceUrlProvider; + + + /** + * Configure a {@link ResourceUrlProvider} to use when resolving the public + * URL of links in a transformed resource (e.g. import links in a CSS file). + * This is required only for links expressed as full paths and not for + * relative links. + * @param resourceUrlProvider the URL provider to use + */ + public void setResourceUrlProvider(@Nullable ResourceUrlProvider resourceUrlProvider) { + this.resourceUrlProvider = resourceUrlProvider; + } + + /** + * Return the configured {@code ResourceUrlProvider}. + */ + @Nullable + public ResourceUrlProvider getResourceUrlProvider() { + return this.resourceUrlProvider; + } + + + /** + * A transformer can use this method when a resource being transformed + * contains links to other resources. Such links need to be replaced with the + * public facing link as determined by the resource resolver chain (e.g. the + * public URL may have a version inserted). + * @param resourcePath the path to a resource that needs to be re-written + * @param exchange the current exchange + * @param resource the resource being transformed + * @param transformerChain the transformer chain + * @return the resolved URL or an empty {@link Mono} + */ + protected Mono resolveUrlPath(String resourcePath, ServerWebExchange exchange, + Resource resource, ResourceTransformerChain transformerChain) { + + if (resourcePath.startsWith("/")) { + // full resource path + ResourceUrlProvider urlProvider = getResourceUrlProvider(); + return (urlProvider != null ? urlProvider.getForUriString(resourcePath, exchange) : Mono.empty()); + } + else { + // try resolving as relative path + return transformerChain.getResolverChain() + .resolveUrlPath(resourcePath, Collections.singletonList(resource)); + } + } + + /** + * Transform the given relative request path to an absolute path, + * taking the path of the given request as a point of reference. + * The resulting path is also cleaned from sequences like "path/..". + * + * @param path the relative path to transform + * @param exchange the current exchange + * @return the absolute request path for the given resource path + */ + protected String toAbsolutePath(String path, ServerWebExchange exchange) { + String requestPath = exchange.getRequest().getURI().getPath(); + String absolutePath = (path.startsWith("/") ? path : StringUtils.applyRelativePath(requestPath, path)); + return StringUtils.cleanPath(absolutePath); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceUrlProvider.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceUrlProvider.java new file mode 100644 index 0000000000000000000000000000000000000000..1494f3870ebc721e68aa51f9edc0c89b6c31a49d --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceUrlProvider.java @@ -0,0 +1,177 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationListener; +import org.springframework.context.event.ContextRefreshedEvent; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.pattern.PathPattern; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * A central component to use to obtain the public URL path that clients should + * use to access a static resource. + * + *

This class is aware of Spring WebFlux handler mappings used to serve static + * resources and uses the {@code ResourceResolver} chains of the configured + * {@code ResourceHttpRequestHandler}s to make its decisions. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class ResourceUrlProvider implements ApplicationListener { + + private static final Log logger = LogFactory.getLog(ResourceUrlProvider.class); + + + private final PathPatternParser patternParser = new PathPatternParser(); + + private final Map handlerMap = new LinkedHashMap<>(); + + + /** + * Return a read-only view of the resource handler mappings either manually + * configured or auto-detected from Spring configuration. + */ + public Map getHandlerMap() { + return Collections.unmodifiableMap(this.handlerMap); + } + + + /** + * Manually configure resource handler mappings. + *

Note: by default resource mappings are auto-detected + * from the Spring {@code ApplicationContext}. If this property is used, + * auto-detection is turned off. + */ + public void registerHandlers(Map handlerMap) { + this.handlerMap.clear(); + handlerMap.forEach((rawPattern, resourceWebHandler) -> { + rawPattern = prependLeadingSlash(rawPattern); + PathPattern pattern = this.patternParser.parse(rawPattern); + this.handlerMap.put(pattern, resourceWebHandler); + }); + } + + @Override + public void onApplicationEvent(ContextRefreshedEvent event) { + if (this.handlerMap.isEmpty()) { + detectResourceHandlers(event.getApplicationContext()); + } + } + + private void detectResourceHandlers(ApplicationContext context) { + Map beans = context.getBeansOfType(SimpleUrlHandlerMapping.class); + List mappings = new ArrayList<>(beans.values()); + AnnotationAwareOrderComparator.sort(mappings); + + mappings.forEach(mapping -> + mapping.getHandlerMap().forEach((pattern, handler) -> { + if (handler instanceof ResourceWebHandler) { + ResourceWebHandler resourceHandler = (ResourceWebHandler) handler; + this.handlerMap.put(pattern, resourceHandler); + } + })); + + if (this.handlerMap.isEmpty()) { + logger.trace("No resource handling mappings found"); + } + } + + + /** + * Get the public resource URL for the given URI string. + *

The URI string is expected to be a path and if it contains a query or + * fragment those will be preserved in the resulting public resource URL. + * @param uriString the URI string to transform + * @param exchange the current exchange + * @return the resolved public resource URL path, or empty if unresolved + */ + public final Mono getForUriString(String uriString, ServerWebExchange exchange) { + ServerHttpRequest request = exchange.getRequest(); + int queryIndex = getQueryIndex(uriString); + String lookupPath = uriString.substring(0, queryIndex); + String query = uriString.substring(queryIndex); + PathContainer parsedLookupPath = PathContainer.parsePath(lookupPath); + + return resolveResourceUrl(exchange, parsedLookupPath).map(resolvedPath -> + request.getPath().contextPath().value() + resolvedPath + query); + } + + private int getQueryIndex(String path) { + int suffixIndex = path.length(); + int queryIndex = path.indexOf('?'); + if (queryIndex > 0) { + suffixIndex = queryIndex; + } + int hashIndex = path.indexOf('#'); + if (hashIndex > 0) { + suffixIndex = Math.min(suffixIndex, hashIndex); + } + return suffixIndex; + } + + private Mono resolveResourceUrl(ServerWebExchange exchange, PathContainer lookupPath) { + return this.handlerMap.entrySet().stream() + .filter(entry -> entry.getKey().matches(lookupPath)) + .min((entry1, entry2) -> + PathPattern.SPECIFICITY_COMPARATOR.compare(entry1.getKey(), entry2.getKey())) + .map(entry -> { + PathContainer path = entry.getKey().extractPathWithinPattern(lookupPath); + int endIndex = lookupPath.elements().size() - path.elements().size(); + PathContainer mapping = lookupPath.subPath(0, endIndex); + ResourceWebHandler handler = entry.getValue(); + List resolvers = handler.getResourceResolvers(); + ResourceResolverChain chain = new DefaultResourceResolverChain(resolvers); + return chain.resolveUrlPath(path.value(), handler.getLocations()) + .map(resolvedPath -> mapping.value() + resolvedPath); + }) + .orElseGet(() ->{ + if (logger.isTraceEnabled()) { + logger.trace(exchange.getLogPrefix() + "No match for \"" + lookupPath + "\""); + } + return Mono.empty(); + }); + } + + + private static String prependLeadingSlash(String pattern) { + if (StringUtils.hasLength(pattern) && !pattern.startsWith("/")) { + return "/" + pattern; + } + else { + return pattern; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceWebHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceWebHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..9da249f7c0eda1f23c200246e674bb01d25adb2a --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/ResourceWebHandler.java @@ -0,0 +1,554 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.Resource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.http.CacheControl; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.MediaTypeFactory; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.server.PathContainer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.ResourceUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.HandlerMapping; +import org.springframework.web.server.MethodNotAllowedException; +import org.springframework.web.server.ResponseStatusException; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; + +/** + * {@code HttpRequestHandler} that serves static resources in an optimized way + * according to the guidelines of Page Speed, YSlow, etc. + * + *

The {@linkplain #setLocations "locations"} property takes a list of Spring + * {@link Resource} locations from which static resources are allowed to + * be served by this handler. Resources could be served from a classpath location, + * e.g. "classpath:/META-INF/public-web-resources/", allowing convenient packaging + * and serving of resources such as .js, .css, and others in jar files. + * + *

This request handler may also be configured with a + * {@link #setResourceResolvers(List) resourcesResolver} and + * {@link #setResourceTransformers(List) resourceTransformer} chains to support + * arbitrary resolution and transformation of resources being served. By default a + * {@link PathResourceResolver} simply finds resources based on the configured + * "locations". An application can configure additional resolvers and + * transformers such as the {@link VersionResourceResolver} which can resolve + * and prepare URLs for resources with a version in the URL. + * + *

This handler also properly evaluates the {@code Last-Modified} header (if + * present) so that a {@code 304} status code will be returned as appropriate, + * avoiding unnecessary overhead for resources that are already cached by the + * client. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + */ +public class ResourceWebHandler implements WebHandler, InitializingBean { + + private static final Set SUPPORTED_METHODS = EnumSet.of(HttpMethod.GET, HttpMethod.HEAD); + + private static final Log logger = LogFactory.getLog(ResourceWebHandler.class); + + + private final List locationValues = new ArrayList<>(4); + + private final List locations = new ArrayList<>(4); + + private final List resourceResolvers = new ArrayList<>(4); + + private final List resourceTransformers = new ArrayList<>(4); + + @Nullable + private ResourceResolverChain resolverChain; + + @Nullable + private ResourceTransformerChain transformerChain; + + @Nullable + private CacheControl cacheControl; + + @Nullable + private ResourceHttpMessageWriter resourceHttpMessageWriter; + + @Nullable + private ResourceLoader resourceLoader; + + + /** + * Accepts a list of String-based location values to be resolved into + * {@link Resource} locations. + * @since 5.1 + */ + public void setLocationValues(List locationValues) { + Assert.notNull(locationValues, "Location values list must not be null"); + this.locationValues.clear(); + this.locationValues.addAll(locationValues); + } + + /** + * Return the configured location values. + * @since 5.1 + */ + public List getLocationValues() { + return this.locationValues; + } + + /** + * Set the {@code List} of {@code Resource} paths to use as sources + * for serving static resources. + */ + public void setLocations(@Nullable List locations) { + this.locations.clear(); + if (locations != null) { + this.locations.addAll(locations); + } + } + + /** + * Return the {@code List} of {@code Resource} paths to use as sources + * for serving static resources. + *

Note that if {@link #setLocationValues(List) locationValues} are provided, + * instead of loaded Resource-based locations, this method will return + * empty until after initialization via {@link #afterPropertiesSet()}. + * @see #setLocationValues + * @see #setLocations + */ + public List getLocations() { + return this.locations; + } + + /** + * Configure the list of {@link ResourceResolver ResourceResolvers} to use. + *

By default {@link PathResourceResolver} is configured. If using this property, + * it is recommended to add {@link PathResourceResolver} as the last resolver. + */ + public void setResourceResolvers(@Nullable List resourceResolvers) { + this.resourceResolvers.clear(); + if (resourceResolvers != null) { + this.resourceResolvers.addAll(resourceResolvers); + } + } + + /** + * Return the list of configured resource resolvers. + */ + public List getResourceResolvers() { + return this.resourceResolvers; + } + + /** + * Configure the list of {@link ResourceTransformer ResourceTransformers} to use. + *

By default no transformers are configured for use. + */ + public void setResourceTransformers(@Nullable List resourceTransformers) { + this.resourceTransformers.clear(); + if (resourceTransformers != null) { + this.resourceTransformers.addAll(resourceTransformers); + } + } + + /** + * Return the list of configured resource transformers. + */ + public List getResourceTransformers() { + return this.resourceTransformers; + } + + /** + * Set the {@link org.springframework.http.CacheControl} instance to build + * the Cache-Control HTTP response header. + */ + public void setCacheControl(@Nullable CacheControl cacheControl) { + this.cacheControl = cacheControl; + } + + /** + * Return the {@link org.springframework.http.CacheControl} instance to build + * the Cache-Control HTTP response header. + */ + @Nullable + public CacheControl getCacheControl() { + return this.cacheControl; + } + + /** + * Configure the {@link ResourceHttpMessageWriter} to use. + *

By default a {@link ResourceHttpMessageWriter} will be configured. + */ + public void setResourceHttpMessageWriter(@Nullable ResourceHttpMessageWriter httpMessageWriter) { + this.resourceHttpMessageWriter = httpMessageWriter; + } + + /** + * Return the configured resource message writer. + */ + @Nullable + public ResourceHttpMessageWriter getResourceHttpMessageWriter() { + return this.resourceHttpMessageWriter; + } + + /** + * Provide the ResourceLoader to load {@link #setLocationValues(List) + * location values} with. + * @since 5.1 + */ + public void setResourceLoader(ResourceLoader resourceLoader) { + this.resourceLoader = resourceLoader; + } + + + @Override + public void afterPropertiesSet() throws Exception { + resolveResourceLocations(); + + if (logger.isWarnEnabled() && CollectionUtils.isEmpty(this.locations)) { + logger.warn("Locations list is empty. No resources will be served unless a " + + "custom ResourceResolver is configured as an alternative to PathResourceResolver."); + } + + if (this.resourceResolvers.isEmpty()) { + this.resourceResolvers.add(new PathResourceResolver()); + } + + initAllowedLocations(); + + if (getResourceHttpMessageWriter() == null) { + this.resourceHttpMessageWriter = new ResourceHttpMessageWriter(); + } + + // Initialize immutable resolver and transformer chains + this.resolverChain = new DefaultResourceResolverChain(this.resourceResolvers); + this.transformerChain = new DefaultResourceTransformerChain(this.resolverChain, this.resourceTransformers); + } + + private void resolveResourceLocations() { + if (CollectionUtils.isEmpty(this.locationValues)) { + return; + } + else if (!CollectionUtils.isEmpty(this.locations)) { + throw new IllegalArgumentException("Please set either Resource-based \"locations\" or " + + "String-based \"locationValues\", but not both."); + } + + Assert.notNull(this.resourceLoader, + "ResourceLoader is required when \"locationValues\" are configured."); + + for (String location : this.locationValues) { + Resource resource = this.resourceLoader.getResource(location); + this.locations.add(resource); + } + } + + /** + * Look for a {@code PathResourceResolver} among the configured resource + * resolvers and set its {@code allowedLocations} property (if empty) to + * match the {@link #setLocations locations} configured on this class. + */ + protected void initAllowedLocations() { + if (CollectionUtils.isEmpty(this.locations)) { + if (logger.isInfoEnabled()) { + logger.info("Locations list is empty. No resources will be served unless a " + + "custom ResourceResolver is configured as an alternative to PathResourceResolver."); + } + return; + } + for (int i = getResourceResolvers().size() - 1; i >= 0; i--) { + if (getResourceResolvers().get(i) instanceof PathResourceResolver) { + PathResourceResolver resolver = (PathResourceResolver) getResourceResolvers().get(i); + if (ObjectUtils.isEmpty(resolver.getAllowedLocations())) { + resolver.setAllowedLocations(getLocations().toArray(new Resource[0])); + } + break; + } + } + } + + + /** + * Processes a resource request. + *

Checks for the existence of the requested resource in the configured list of locations. + * If the resource does not exist, a {@code 404} response will be returned to the client. + * If the resource exists, the request will be checked for the presence of the + * {@code Last-Modified} header, and its value will be compared against the last-modified + * timestamp of the given resource, returning a {@code 304} status code if the + * {@code Last-Modified} value is greater. If the resource is newer than the + * {@code Last-Modified} value, or the header is not present, the content resource + * of the resource will be written to the response with caching headers + * set to expire one year in the future. + */ + @Override + public Mono handle(ServerWebExchange exchange) { + return getResource(exchange) + .switchIfEmpty(Mono.defer(() -> { + logger.debug(exchange.getLogPrefix() + "Resource not found"); + return Mono.error(new ResponseStatusException(HttpStatus.NOT_FOUND)); + })) + .flatMap(resource -> { + try { + if (HttpMethod.OPTIONS.matches(exchange.getRequest().getMethodValue())) { + exchange.getResponse().getHeaders().add("Allow", "GET,HEAD,OPTIONS"); + return Mono.empty(); + } + + // Supported methods and required session + HttpMethod httpMethod = exchange.getRequest().getMethod(); + if (!SUPPORTED_METHODS.contains(httpMethod)) { + return Mono.error(new MethodNotAllowedException( + exchange.getRequest().getMethodValue(), SUPPORTED_METHODS)); + } + + // Header phase + if (exchange.checkNotModified(Instant.ofEpochMilli(resource.lastModified()))) { + logger.trace(exchange.getLogPrefix() + "Resource not modified"); + return Mono.empty(); + } + + // Apply cache settings, if any + CacheControl cacheControl = getCacheControl(); + if (cacheControl != null) { + exchange.getResponse().getHeaders().setCacheControl(cacheControl); + } + + // Check the media type for the resource + MediaType mediaType = MediaTypeFactory.getMediaType(resource).orElse(null); + setHeaders(exchange, resource, mediaType); + + // Content phase + ResourceHttpMessageWriter writer = getResourceHttpMessageWriter(); + Assert.state(writer != null, "No ResourceHttpMessageWriter"); + return writer.write(Mono.just(resource), + null, ResolvableType.forClass(Resource.class), mediaType, + exchange.getRequest(), exchange.getResponse(), + Hints.from(Hints.LOG_PREFIX_HINT, exchange.getLogPrefix())); + } + catch (IOException ex) { + return Mono.error(ex); + } + }); + } + + protected Mono getResource(ServerWebExchange exchange) { + String name = HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE; + PathContainer pathWithinHandler = exchange.getRequiredAttribute(name); + + String path = processPath(pathWithinHandler.value()); + if (!StringUtils.hasText(path) || isInvalidPath(path)) { + return Mono.empty(); + } + if (isInvalidEncodedPath(path)) { + return Mono.empty(); + } + + Assert.state(this.resolverChain != null, "ResourceResolverChain not initialized"); + Assert.state(this.transformerChain != null, "ResourceTransformerChain not initialized"); + + return this.resolverChain.resolveResource(exchange, path, getLocations()) + .flatMap(resource -> this.transformerChain.transform(exchange, resource)); + } + + /** + * Process the given resource path. + *

The default implementation replaces: + *

    + *
  • Backslash with forward slash. + *
  • Duplicate occurrences of slash with a single slash. + *
  • Any combination of leading slash and control characters (00-1F and 7F) + * with a single "/" or "". For example {@code " / // foo/bar"} + * becomes {@code "/foo/bar"}. + *
+ * @since 3.2.12 + */ + protected String processPath(String path) { + path = StringUtils.replace(path, "\\", "/"); + path = cleanDuplicateSlashes(path); + return cleanLeadingSlash(path); + } + + private String cleanDuplicateSlashes(String path) { + StringBuilder sb = null; + char prev = 0; + for (int i = 0; i < path.length(); i++) { + char curr = path.charAt(i); + try { + if (curr == '/' && prev == '/') { + if (sb == null) { + sb = new StringBuilder(path.substring(0, i)); + } + continue; + } + if (sb != null) { + sb.append(path.charAt(i)); + } + } + finally { + prev = curr; + } + } + return (sb != null ? sb.toString() : path); + } + + private String cleanLeadingSlash(String path) { + boolean slash = false; + for (int i = 0; i < path.length(); i++) { + if (path.charAt(i) == '/') { + slash = true; + } + else if (path.charAt(i) > ' ' && path.charAt(i) != 127) { + if (i == 0 || (i == 1 && slash)) { + return path; + } + return (slash ? "/" + path.substring(i) : path.substring(i)); + } + } + return (slash ? "/" : ""); + } + + /** + * Check whether the given path contains invalid escape sequences. + * @param path the path to validate + * @return {@code true} if the path is invalid, {@code false} otherwise + */ + private boolean isInvalidEncodedPath(String path) { + if (path.contains("%")) { + try { + // Use URLDecoder (vs UriUtils) to preserve potentially decoded UTF-8 chars + String decodedPath = URLDecoder.decode(path, "UTF-8"); + if (isInvalidPath(decodedPath)) { + return true; + } + decodedPath = processPath(decodedPath); + if (isInvalidPath(decodedPath)) { + return true; + } + } + catch (IllegalArgumentException ex) { + // May not be possible to decode... + } + catch (UnsupportedEncodingException ex) { + // Should never happen... + } + } + return false; + } + + /** + * Identifies invalid resource paths. By default rejects: + *
    + *
  • Paths that contain "WEB-INF" or "META-INF" + *
  • Paths that contain "../" after a call to + * {@link StringUtils#cleanPath}. + *
  • Paths that represent a {@link ResourceUtils#isUrl + * valid URL} or would represent one after the leading slash is removed. + *
+ *

Note: this method assumes that leading, duplicate '/' + * or control characters (e.g. white space) have been trimmed so that the + * path starts predictably with a single '/' or does not have one. + * @param path the path to validate + * @return {@code true} if the path is invalid, {@code false} otherwise + */ + protected boolean isInvalidPath(String path) { + if (path.contains("WEB-INF") || path.contains("META-INF")) { + if (logger.isWarnEnabled()) { + logger.warn("Path with \"WEB-INF\" or \"META-INF\": [" + path + "]"); + } + return true; + } + if (path.contains(":/")) { + String relativePath = (path.charAt(0) == '/' ? path.substring(1) : path); + if (ResourceUtils.isUrl(relativePath) || relativePath.startsWith("url:")) { + if (logger.isWarnEnabled()) { + logger.warn("Path represents URL or has \"url:\" prefix: [" + path + "]"); + } + return true; + } + } + if (path.contains("..") && StringUtils.cleanPath(path).contains("../")) { + if (logger.isWarnEnabled()) { + logger.warn("Path contains \"../\" after call to StringUtils#cleanPath: [" + path + "]"); + } + return true; + } + return false; + } + + /** + * Set headers on the response. Called for both GET and HEAD requests. + * @param exchange current exchange + * @param resource the identified resource (never {@code null}) + * @param mediaType the resource's media type (never {@code null}) + */ + protected void setHeaders(ServerWebExchange exchange, Resource resource, @Nullable MediaType mediaType) + throws IOException { + + HttpHeaders headers = exchange.getResponse().getHeaders(); + + long length = resource.contentLength(); + headers.setContentLength(length); + + if (mediaType != null) { + headers.setContentType(mediaType); + } + + if (resource instanceof HttpResource) { + HttpHeaders resourceHeaders = ((HttpResource) resource).getResponseHeaders(); + exchange.getResponse().getHeaders().putAll(resourceHeaders); + } + } + + + @Override + public String toString() { + return "ResourceWebHandler " + formatLocations(); + } + + private Object formatLocations() { + if (!this.locationValues.isEmpty()) { + return this.locationValues.stream().collect(Collectors.joining("\", \"", "[\"", "\"]")); + } + else if (!this.locations.isEmpty()) { + return "[" + this.locations + "]"; + } + return Collections.emptyList(); + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/TransformedResource.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/TransformedResource.java new file mode 100644 index 0000000000000000000000000000000000000000..220dd7b56dcf751694ccec739573cbd0b672e887 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/TransformedResource.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.IOException; + +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; + +/** + * An extension of {@link ByteArrayResource} that a {@link ResourceTransformer} + * can use to represent an original resource preserving all other information + * except the content. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class TransformedResource extends ByteArrayResource { + + @Nullable + private final String filename; + + private final long lastModified; + + + public TransformedResource(Resource original, byte[] transformedContent) { + super(transformedContent); + this.filename = original.getFilename(); + try { + this.lastModified = original.lastModified(); + } + catch (IOException ex) { + // should never happen + throw new IllegalArgumentException(ex); + } + } + + + @Override + @Nullable + public String getFilename() { + return this.filename; + } + + @Override + public long lastModified() { + return this.lastModified; + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/VersionResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/VersionResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..acb0ce4f657aea19d158d64cffba1e71faf34b44 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/VersionResourceResolver.java @@ -0,0 +1,325 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URL; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.AbstractResource; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.AntPathMatcher; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * Resolves request paths containing a version string that can be used as part + * of an HTTP caching strategy in which a resource is cached with a date in the + * distant future (e.g. 1 year) and cached until the version, and therefore the + * URL, is changed. + * + *

Different versioning strategies exist, and this resolver must be configured + * with one or more such strategies along with path mappings to indicate which + * strategy applies to which resources. + * + *

{@code ContentVersionStrategy} is a good default choice except in cases + * where it cannot be used. Most notably the {@code ContentVersionStrategy} + * cannot be combined with JavaScript module loaders. For such cases the + * {@code FixedVersionStrategy} is a better choice. + * + *

Note that using this resolver to serve CSS files means that the + * {@link CssLinkResourceTransformer} should also be used in order to modify + * links within CSS files to also contain the appropriate versions generated + * by this resolver. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + * @see VersionStrategy + */ +public class VersionResourceResolver extends AbstractResourceResolver { + + private AntPathMatcher pathMatcher = new AntPathMatcher(); + + /** Map from path pattern -> VersionStrategy. */ + private final Map versionStrategyMap = new LinkedHashMap<>(); + + + /** + * Set a Map with URL paths as keys and {@code VersionStrategy} as values. + *

Supports direct URL matches and Ant-style pattern matches. For syntax + * details, see the {@link AntPathMatcher} javadoc. + * @param map map with URLs as keys and version strategies as values + */ + public void setStrategyMap(Map map) { + this.versionStrategyMap.clear(); + this.versionStrategyMap.putAll(map); + } + + /** + * Return the map with version strategies keyed by path pattern. + */ + public Map getStrategyMap() { + return this.versionStrategyMap; + } + + /** + * Insert a content-based version in resource URLs that match the given path + * patterns. The version is computed from the content of the file, e.g. + * {@code "css/main-e36d2e05253c6c7085a91522ce43a0b4.css"}. This is a good + * default strategy to use except when it cannot be, for example when using + * JavaScript module loaders, use {@link #addFixedVersionStrategy} instead + * for serving JavaScript files. + * @param pathPatterns one or more resource URL path patterns, + * relative to the pattern configured with the resource handler + * @return the current instance for chained method invocation + * @see ContentVersionStrategy + */ + public VersionResourceResolver addContentVersionStrategy(String... pathPatterns) { + addVersionStrategy(new ContentVersionStrategy(), pathPatterns); + return this; + } + + /** + * Insert a fixed, prefix-based version in resource URLs that match the given + * path patterns, for example: "{version}/js/main.js". This is useful (vs. + * content-based versions) when using JavaScript module loaders. + *

The version may be a random number, the current date, or a value + * fetched from a git commit sha, a property file, or environment variable + * and set with SpEL expressions in the configuration (e.g. see {@code @Value} + * in Java config). + *

If not done already, variants of the given {@code pathPatterns}, prefixed with + * the {@code version} will be also configured. For example, adding a {@code "/js/**"} path pattern + * will also cofigure automatically a {@code "/v1.0.0/js/**"} with {@code "v1.0.0"} the + * {@code version} String given as an argument. + * @param version a version string + * @param pathPatterns one or more resource URL path patterns, + * relative to the pattern configured with the resource handler + * @return the current instance for chained method invocation + * @see FixedVersionStrategy + */ + public VersionResourceResolver addFixedVersionStrategy(String version, String... pathPatterns) { + List patternsList = Arrays.asList(pathPatterns); + List prefixedPatterns = new ArrayList<>(pathPatterns.length); + String versionPrefix = "/" + version; + for (String pattern : patternsList) { + prefixedPatterns.add(pattern); + if (!pattern.startsWith(versionPrefix) && !patternsList.contains(versionPrefix + pattern)) { + prefixedPatterns.add(versionPrefix + pattern); + } + } + return addVersionStrategy(new FixedVersionStrategy(version), StringUtils.toStringArray(prefixedPatterns)); + } + + /** + * Register a custom VersionStrategy to apply to resource URLs that match the + * given path patterns. + * @param strategy the custom strategy + * @param pathPatterns one or more resource URL path patterns, + * relative to the pattern configured with the resource handler + * @return the current instance for chained method invocation + * @see VersionStrategy + */ + public VersionResourceResolver addVersionStrategy(VersionStrategy strategy, String... pathPatterns) { + for (String pattern : pathPatterns) { + getStrategyMap().put(pattern, strategy); + } + return this; + } + + + @Override + protected Mono resolveResourceInternal(@Nullable ServerWebExchange exchange, + String requestPath, List locations, ResourceResolverChain chain) { + + return chain.resolveResource(exchange, requestPath, locations) + .switchIfEmpty(Mono.defer(() -> + resolveVersionedResource(exchange, requestPath, locations, chain))); + } + + private Mono resolveVersionedResource(@Nullable ServerWebExchange exchange, + String requestPath, List locations, ResourceResolverChain chain) { + + VersionStrategy versionStrategy = getStrategyForPath(requestPath); + if (versionStrategy == null) { + return Mono.empty(); + } + + String candidate = versionStrategy.extractVersion(requestPath); + if (!StringUtils.hasLength(candidate)) { + return Mono.empty(); + } + + String simplePath = versionStrategy.removeVersion(requestPath, candidate); + return chain.resolveResource(exchange, simplePath, locations) + .filterWhen(resource -> versionStrategy.getResourceVersion(resource) + .map(actual -> { + if (candidate.equals(actual)) { + return true; + } + else { + if (logger.isTraceEnabled()) { + String logPrefix = exchange != null ? exchange.getLogPrefix() : ""; + logger.trace(logPrefix + "Found resource for \"" + requestPath + + "\", but version [" + candidate + "] does not match"); + } + return false; + } + })) + .map(resource -> new FileNameVersionedResource(resource, candidate)); + } + + @Override + protected Mono resolveUrlPathInternal(String resourceUrlPath, + List locations, ResourceResolverChain chain) { + + return chain.resolveUrlPath(resourceUrlPath, locations) + .flatMap(baseUrl -> { + if (StringUtils.hasText(baseUrl)) { + VersionStrategy strategy = getStrategyForPath(resourceUrlPath); + if (strategy == null) { + return Mono.just(baseUrl); + } + return chain.resolveResource(null, baseUrl, locations) + .flatMap(resource -> strategy.getResourceVersion(resource) + .map(version -> strategy.addVersion(baseUrl, version))); + } + return Mono.empty(); + }); + } + + /** + * Find a {@code VersionStrategy} for the request path of the requested resource. + * @return an instance of a {@code VersionStrategy} or null if none matches that request path + */ + @Nullable + protected VersionStrategy getStrategyForPath(String requestPath) { + String path = "/".concat(requestPath); + List matchingPatterns = new ArrayList<>(); + for (String pattern : this.versionStrategyMap.keySet()) { + if (this.pathMatcher.match(pattern, path)) { + matchingPatterns.add(pattern); + } + } + if (!matchingPatterns.isEmpty()) { + Comparator comparator = this.pathMatcher.getPatternComparator(path); + matchingPatterns.sort(comparator); + return this.versionStrategyMap.get(matchingPatterns.get(0)); + } + return null; + } + + + private class FileNameVersionedResource extends AbstractResource implements HttpResource { + + private final Resource original; + + private final String version; + + public FileNameVersionedResource(Resource original, String version) { + this.original = original; + this.version = version; + } + + @Override + public boolean exists() { + return this.original.exists(); + } + + @Override + public boolean isReadable() { + return this.original.isReadable(); + } + + @Override + public boolean isOpen() { + return this.original.isOpen(); + } + + @Override + public boolean isFile() { + return this.original.isFile(); + } + + @Override + public URL getURL() throws IOException { + return this.original.getURL(); + } + + @Override + public URI getURI() throws IOException { + return this.original.getURI(); + } + + @Override + public File getFile() throws IOException { + return this.original.getFile(); + } + + @Override + @Nullable + public String getFilename() { + return this.original.getFilename(); + } + + @Override + public long contentLength() throws IOException { + return this.original.contentLength(); + } + + @Override + public long lastModified() throws IOException { + return this.original.lastModified(); + } + + @Override + public Resource createRelative(String relativePath) throws IOException { + return this.original.createRelative(relativePath); + } + + @Override + public String getDescription() { + return this.original.getDescription(); + } + + @Override + public InputStream getInputStream() throws IOException { + return this.original.getInputStream(); + } + + @Override + public HttpHeaders getResponseHeaders() { + HttpHeaders headers = (this.original instanceof HttpResource ? + ((HttpResource) this.original).getResponseHeaders() : new HttpHeaders()); + headers.setETag("\"" + this.version + "\""); + return headers; + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/VersionStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/VersionStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..f7269bc1766be0b6fea5e79d06839d6d3ca0d7f7 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/VersionStrategy.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; + +/** + * A strategy to determine the version of a static resource and to apply and/or + * extract it from the URL path. + * + * @author Rossen Stoyanchev + * @author Brian Clozel + * @since 5.0 + * @see VersionResourceResolver +*/ +public interface VersionStrategy { + + /** + * Extract the resource version from the request path. + * @param requestPath the request path to check + * @return the version string or {@code null} if none was found + */ + @Nullable + String extractVersion(String requestPath); + + /** + * Remove the version from the request path. It is assumed that the given + * version was extracted via {@link #extractVersion(String)}. + * @param requestPath the request path of the resource being resolved + * @param version the version obtained from {@link #extractVersion(String)} + * @return the request path with the version removed + */ + String removeVersion(String requestPath, String version); + + /** + * Add a version to the given request path. + * @param requestPath the requestPath + * @param version the version + * @return the requestPath updated with a version string + */ + String addVersion(String requestPath, String version); + + /** + * Determine the version for the given resource. + * @param resource the resource to check + * @return the resource version + */ + Mono getResourceVersion(Resource resource); + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/resource/WebJarsResourceResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/WebJarsResourceResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..ca5bd89b78e128f9202f527d0b95d8a338277a3c --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/resource/WebJarsResourceResolver.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.resource; + +import java.util.List; + +import org.webjars.WebJarAssetLocator; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@code ResourceResolver} that delegates to the chain to locate a resource and then + * attempts to find a matching versioned resource contained in a WebJar JAR file. + * + *

This allows WebJars.org users to write version agnostic paths in their templates, + * like {@code "; + request.setQueryString(xssQueryString); + tag.doStartTag(); + assertEquals("

", getOutput()); + } + + @Test + public void get() throws Exception { + this.tag.setMethod("get"); + + this.tag.doStartTag(); + this.tag.doEndTag(); + this.tag.doFinally(); + + String output = getOutput(); + String formOutput = getFormTag(output); + String inputOutput = getInputTag(output); + + assertContainsAttribute(formOutput, "method", "get"); + assertEquals("", inputOutput); + } + + @Test + public void post() throws Exception { + this.tag.setMethod("post"); + + this.tag.doStartTag(); + this.tag.doEndTag(); + this.tag.doFinally(); + + String output = getOutput(); + String formOutput = getFormTag(output); + String inputOutput = getInputTag(output); + + assertContainsAttribute(formOutput, "method", "post"); + assertEquals("", inputOutput); + } + + @Test + public void put() throws Exception { + this.tag.setMethod("put"); + + this.tag.doStartTag(); + this.tag.doEndTag(); + this.tag.doFinally(); + + String output = getOutput(); + String formOutput = getFormTag(output); + String inputOutput = getInputTag(output); + + assertContainsAttribute(formOutput, "method", "post"); + assertContainsAttribute(inputOutput, "name", "_method"); + assertContainsAttribute(inputOutput, "value", "put"); + assertContainsAttribute(inputOutput, "type", "hidden"); + } + + @Test + public void delete() throws Exception { + this.tag.setMethod("delete"); + + this.tag.doStartTag(); + this.tag.doEndTag(); + this.tag.doFinally(); + + String output = getOutput(); + String formOutput = getFormTag(output); + String inputOutput = getInputTag(output); + + assertContainsAttribute(formOutput, "method", "post"); + assertContainsAttribute(inputOutput, "name", "_method"); + assertContainsAttribute(inputOutput, "value", "delete"); + assertContainsAttribute(inputOutput, "type", "hidden"); + } + + @Test + public void customMethodParameter() throws Exception { + this.tag.setMethod("put"); + this.tag.setMethodParam("methodParameter"); + + this.tag.doStartTag(); + this.tag.doEndTag(); + this.tag.doFinally(); + + String output = getOutput(); + String formOutput = getFormTag(output); + String inputOutput = getInputTag(output); + + assertContainsAttribute(formOutput, "method", "post"); + assertContainsAttribute(inputOutput, "name", "methodParameter"); + assertContainsAttribute(inputOutput, "value", "put"); + assertContainsAttribute(inputOutput, "type", "hidden"); + } + + @Test + public void clearAttributesOnFinally() throws Exception { + this.tag.setModelAttribute("model"); + getPageContext().setAttribute("model", "foo bar"); + assertNull(getPageContext().getAttribute(FormTag.MODEL_ATTRIBUTE_VARIABLE_NAME, PageContext.REQUEST_SCOPE)); + this.tag.doStartTag(); + assertNotNull(getPageContext().getAttribute(FormTag.MODEL_ATTRIBUTE_VARIABLE_NAME, PageContext.REQUEST_SCOPE)); + this.tag.doFinally(); + assertNull(getPageContext().getAttribute(FormTag.MODEL_ATTRIBUTE_VARIABLE_NAME, PageContext.REQUEST_SCOPE)); + } + + @Test + public void requestDataValueProcessorHooks() throws Exception { + String action = "/my/form?foo=bar"; + RequestDataValueProcessor processor = getMockRequestDataValueProcessor(); + given(processor.processAction(this.request, action, "post")).willReturn(action); + given(processor.getExtraHiddenFields(this.request)).willReturn(Collections.singletonMap("key", "value")); + + this.tag.doStartTag(); + this.tag.doEndTag(); + this.tag.doFinally(); + + String output = getOutput(); + + assertEquals("
\n\n
", getInputTag(output)); + assertFormTagOpened(output); + assertFormTagClosed(output); + } + + @Test + public void defaultActionEncoded() throws Exception { + + this.request.setRequestURI("/a b c"); + request.setQueryString(""); + + this.tag.doStartTag(); + this.tag.doEndTag(); + this.tag.doFinally(); + + String output = getOutput(); + String formOutput = getFormTag(output); + + assertContainsAttribute(formOutput, "action", "/a%20b%20c"); + } + + private String getFormTag(String output) { + int inputStart = output.indexOf("<", 1); + int inputEnd = output.lastIndexOf(">", output.length() - 2); + return output.substring(0, inputStart) + output.substring(inputEnd + 1); + } + + private String getInputTag(String output) { + int inputStart = output.indexOf("<", 1); + int inputEnd = output.lastIndexOf(">", output.length() - 2); + return output.substring(inputStart, inputEnd + 1); + } + + + private static void assertFormTagOpened(String output) { + assertTrue(output.startsWith("")); + } + +} diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/tags/form/HiddenInputTagTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/tags/form/HiddenInputTagTests.java new file mode 100644 index 0000000000000000000000000000000000000000..60699b102dba4c2a8cf2d656da875958f38c18b9 --- /dev/null +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/tags/form/HiddenInputTagTests.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.tags.form; + +import javax.servlet.jsp.JspException; +import javax.servlet.jsp.tagext.Tag; + +import org.junit.Test; + +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.validation.BeanPropertyBindingResult; + +import static org.junit.Assert.*; + +/** + * @author Rob Harrop + */ +public class HiddenInputTagTests extends AbstractFormTagTests { + + private HiddenInputTag tag; + + private TestBean bean; + + @Override + @SuppressWarnings("serial") + protected void onSetUp() { + this.tag = new HiddenInputTag() { + @Override + protected TagWriter createTagWriter() { + return new TagWriter(getWriter()); + } + }; + this.tag.setPageContext(getPageContext()); + } + + @Test + public void render() throws Exception { + this.tag.setPath("name"); + int result = this.tag.doStartTag(); + assertEquals(Tag.SKIP_BODY, result); + + String output = getOutput(); + + assertTagOpened(output); + assertTagClosed(output); + + assertContainsAttribute(output, "type", "hidden"); + assertContainsAttribute(output, "value", "Sally Greenwood"); + assertAttributeNotPresent(output, "disabled"); + } + + @Test + public void withCustomBinder() throws Exception { + this.tag.setPath("myFloat"); + + BeanPropertyBindingResult errors = new BeanPropertyBindingResult(this.bean, COMMAND_NAME); + errors.getPropertyAccessor().registerCustomEditor(Float.class, new SimpleFloatEditor()); + exposeBindingResult(errors); + + assertEquals(Tag.SKIP_BODY, this.tag.doStartTag()); + + String output = getOutput(); + + assertTagOpened(output); + assertTagClosed(output); + + assertContainsAttribute(output, "type", "hidden"); + assertContainsAttribute(output, "value", "12.34f"); + } + + @Test + public void dynamicTypeAttribute() throws JspException { + try { + this.tag.setDynamicAttribute(null, "type", "email"); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Attribute type=\"email\" is not allowed", e.getMessage()); + } + } + + @Test + public void disabledTrue() throws Exception { + this.tag.setDisabled(true); + + this.tag.doStartTag(); + this.tag.doEndTag(); + + String output = getOutput(); + assertTagOpened(output); + assertTagClosed(output); + + assertContainsAttribute(output, "disabled", "disabled"); + } + + // SPR-8661 + + @Test + public void disabledFalse() throws Exception { + this.tag.setDisabled(false); + + this.tag.doStartTag(); + this.tag.doEndTag(); + + String output = getOutput(); + assertTagOpened(output); + assertTagClosed(output); + + assertAttributeNotPresent(output, "disabled"); + } + + private void assertTagClosed(String output) { + assertTrue(output.endsWith("/>")); + } + + private void assertTagOpened(String output) { + assertTrue(output.startsWith("")); + } + + protected final void assertTagOpened(String output) { + assertTrue("Tag not opened properly", output.startsWith(" + assertAttributeNotPresent(output, "name"); + // id attribute is supported, but we don't want it + assertAttributeNotPresent(output, "id"); + assertTrue(output.startsWith("